├── apex ├── amp │ ├── lists │ │ ├── __init__.py │ │ ├── tensor_overrides.py │ │ ├── functional_overrides.py │ │ └── torch_overrides.py │ ├── __version__.py │ ├── __init__.py │ ├── compat.py │ ├── rnn_compat.py │ ├── _amp_state.py │ ├── README.md │ ├── opt.py │ ├── amp.py │ ├── utils.py │ ├── scaler.py │ ├── wrap.py │ ├── _initialize.py │ └── handle.py ├── requirements.txt └── parallel │ ├── multiproc.py │ ├── README.md │ ├── __init__.py │ ├── sync_batchnorm_kernel.py │ ├── LARC.py │ ├── optimized_sync_batchnorm.py │ ├── optimized_sync_batchnorm_kernel.py │ └── sync_batchnorm.py ├── model.png ├── table1.png ├── table2.png ├── table3.png ├── table4.png ├── utils ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── dataset.cpython-39.pyc │ ├── datasets.cpython-38.pyc │ ├── datasets.cpython-39.pyc │ ├── autoaugment.cpython-39.pyc │ ├── data_utils.cpython-39.pyc │ ├── dist_util.cpython-39.pyc │ └── scheduler.cpython-39.pyc ├── dist_util.py ├── utils.py ├── stanford_dogs.py ├── datasets_split.py ├── scheduler.py ├── datasets.py ├── autoaugment.py └── data_utils.py ├── logs └── task_TransFG │ ├── events.out.tfevents.1662525191.gpu06.55524.0 │ ├── events.out.tfevents.1662525191.gpu06.55525.0 │ ├── events.out.tfevents.1662525191.gpu06.55526.0 │ ├── events.out.tfevents.1662525191.gpu06.55527.0 │ ├── events.out.tfevents.1681479653.gpu45.28047.0 │ └── events.out.tfevents.1681479653.gpu45.28048.0 ├── models ├── config.py ├── model_ViT.py └── model_TransFG.py ├── README.md └── train.py /apex/amp/lists/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/model.png -------------------------------------------------------------------------------- /table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/table1.png -------------------------------------------------------------------------------- /table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/table2.png -------------------------------------------------------------------------------- /table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/table3.png -------------------------------------------------------------------------------- /table4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/table4.png -------------------------------------------------------------------------------- /apex/amp/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 0) 2 | __version__ = '.'.join(map(str, VERSION)) 3 | -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/datasets.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoaugment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/autoaugment.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/data_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/dist_util.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/utils/__pycache__/scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /apex/requirements.txt: -------------------------------------------------------------------------------- 1 | cxxfilt>=0.2.0 2 | tqdm>=4.28.1 3 | numpy>=1.15.3 4 | PyYAML>=5.1 5 | pytest>=3.5.1 6 | packaging>=14.0 7 | flake8>=3.7.9 8 | Sphinx>=3.0.3 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55524.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55524.0 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55525.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55525.0 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55526.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55526.0 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55527.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1662525191.gpu06.55527.0 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1681479653.gpu45.28047.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1681479653.gpu45.28047.0 -------------------------------------------------------------------------------- /logs/task_TransFG/events.out.tfevents.1681479653.gpu45.28048.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/AA-Trans/HEAD/logs/task_TransFG/events.out.tfevents.1681479653.gpu45.28048.0 -------------------------------------------------------------------------------- /apex/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from .amp import init, half_function, float_function, promote_function,\ 2 | register_half_function, register_float_function, register_promote_function 3 | from .handle import scale_loss, disable_casts 4 | from .frontend import initialize, state_dict, load_state_dict 5 | from ._amp_state import master_params, _amp_state 6 | -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def get_rank(): 5 | if not dist.is_available(): 6 | return 0 7 | if not dist.is_initialized(): 8 | return 0 9 | return dist.get_rank() 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def is_main_process(): 21 | return get_rank() == 0 22 | 23 | 24 | def format_step(step): 25 | if isinstance(step, str): 26 | return step 27 | s = "" 28 | if len(step) > 0: 29 | s += "Training Epoch: {} ".format(step[0]) 30 | if len(step) > 1: 31 | s += "Training Iteration: {} ".format(step[1]) 32 | if len(step) > 2: 33 | s += "Validation Iteration: {} ".format(step[2]) 34 | return s 35 | -------------------------------------------------------------------------------- /apex/parallel/multiproc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import subprocess 4 | 5 | def docstring_hack(): 6 | """ 7 | Multiproc file which will launch a set of processes locally for multi-gpu 8 | usage: python -m apex.parallel.multiproc main.py ... 9 | """ 10 | pass 11 | 12 | argslist = list(sys.argv)[1:] 13 | world_size = torch.cuda.device_count() 14 | 15 | if '--world-size' in argslist: 16 | world_size = int(argslist[argslist.index('--world-size')+1]) 17 | else: 18 | argslist.append('--world-size') 19 | argslist.append(str(world_size)) 20 | 21 | workers = [] 22 | 23 | for i in range(world_size): 24 | if '--rank' in argslist: 25 | argslist[argslist.index('--rank')+1] = str(i) 26 | else: 27 | argslist.append('--rank') 28 | argslist.append(str(i)) 29 | stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w") 30 | print(argslist) 31 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 32 | workers.append(p) 33 | 34 | for p in workers: 35 | p.wait() 36 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | 4 | 5 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 6 | torch.save(state, filename) 7 | if is_best: 8 | shutil.copyfile(filename, 'model_best.pth.tar') 9 | 10 | 11 | class AverageMeter(object): 12 | """ 13 | Keeps track of most recent, average, sum, and count of a metric. 14 | """ 15 | 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def accuracy(scores, targets, k): 33 | """ 34 | Computes top-k accuracy, from predicted and true labels. 35 | 36 | :param scores: scores from the model 37 | :param targets: true labels 38 | :param k: k in top-k accuracy 39 | :return: top-k accuracy 40 | """ 41 | 42 | batch_size = targets.size(0) 43 | _, ind = scores.topk(k, 1, True, True) 44 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 45 | correct_total = correct.view(-1).float().sum() # 0D tensor 46 | return correct_total.item() * (100.0 / batch_size) 47 | -------------------------------------------------------------------------------- /apex/amp/compat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # True for post-0.4, when Variables/Tensors merged. 4 | def variable_is_tensor(): 5 | v = torch.autograd.Variable() 6 | return isinstance(v, torch.Tensor) 7 | 8 | def tensor_is_variable(): 9 | x = torch.Tensor() 10 | return type(x) == torch.autograd.Variable 11 | 12 | # False for post-0.4 13 | def tensor_is_float_tensor(): 14 | x = torch.Tensor() 15 | return type(x) == torch.FloatTensor 16 | 17 | # Akin to `torch.is_tensor`, but returns True for Variable 18 | # objects in pre-0.4. 19 | def is_tensor_like(x): 20 | return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable) 21 | 22 | # Wraps `torch.is_floating_point` if present, otherwise checks 23 | # the suffix of `x.type()`. 24 | def is_floating_point(x): 25 | if hasattr(torch, 'is_floating_point'): 26 | return torch.is_floating_point(x) 27 | try: 28 | torch_type = x.type() 29 | return torch_type.endswith('FloatTensor') or \ 30 | torch_type.endswith('HalfTensor') or \ 31 | torch_type.endswith('DoubleTensor') 32 | except AttributeError: 33 | return False 34 | 35 | def scalar_python_val(x): 36 | if hasattr(x, 'item'): 37 | return x.item() 38 | else: 39 | if isinstance(x, torch.autograd.Variable): 40 | return x.data[0] 41 | else: 42 | return x[0] 43 | 44 | # Accounts for the possibility that some ops may be removed from a namespace. 45 | def filter_attrs(module, attrs): 46 | return list(attrname for attrname in attrs if hasattr(module, attrname)) 47 | -------------------------------------------------------------------------------- /apex/amp/lists/tensor_overrides.py: -------------------------------------------------------------------------------- 1 | from .. import compat 2 | from . import torch_overrides 3 | 4 | import importlib 5 | 6 | import torch 7 | 8 | # if compat.variable_is_tensor() and not compat.tensor_is_variable(): 9 | MODULE = torch.Tensor 10 | # else: 11 | # MODULE = torch.autograd.Variable 12 | 13 | 14 | FP16_FUNCS = compat.filter_attrs(MODULE, [ 15 | '__matmul__', 16 | ]) 17 | 18 | FP32_FUNCS = compat.filter_attrs(MODULE, [ 19 | '__ipow__', 20 | '__pow__', 21 | '__rpow__', 22 | 23 | # Cast to fp32 before transfer to CPU 24 | 'cpu', 25 | ]) 26 | 27 | CASTS = compat.filter_attrs(MODULE, [ 28 | '__add__', 29 | '__div__', 30 | '__eq__', 31 | '__ge__', 32 | '__gt__', 33 | '__iadd__', 34 | '__idiv__', 35 | '__imul__', 36 | '__isub__', 37 | '__itruediv__', 38 | '__le__', 39 | '__lt__', 40 | '__mul__', 41 | '__ne__', 42 | '__radd__', 43 | '__rdiv__', 44 | '__rmul__', 45 | '__rsub__', 46 | '__rtruediv__', 47 | '__sub__', 48 | '__truediv__', 49 | ]) 50 | 51 | # None of these, but here to make code cleaner. 52 | SEQUENCE_CASTS = [] 53 | 54 | # We need to grab all the methods from torch_overrides and add them to 55 | # the Tensor lists as well, as almost all methods are duplicated 56 | # between `torch` and `torch.Tensor` (and check with `hasattr`, 57 | # because a few random ones aren't defined on Tensor) 58 | _self_mod = importlib.import_module(__name__) 59 | for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: 60 | lst = getattr(_self_mod, attrname) 61 | for fn in getattr(torch_overrides, attrname): 62 | if hasattr(MODULE, fn): 63 | lst.append(fn) 64 | -------------------------------------------------------------------------------- /utils/stanford_dogs.py: -------------------------------------------------------------------------------- 1 | import scipy.io as scio 2 | import collections 3 | import math 4 | import os 5 | import shutil 6 | 7 | 8 | def copyfile(filename, target_dir): 9 | """将文件复制到目标目录。""" 10 | os.makedirs(target_dir, exist_ok=True) 11 | shutil.copy(filename, target_dir) 12 | 13 | 14 | data_dir = '/home/samuel/datasets/Stanford_Dogs/' 15 | train_dir = '/home/samuel/datasets/Stanford_Dogs/train_list.mat' 16 | test_dir = '/home/samuel/datasets/Stanford_Dogs/test_list.mat' 17 | 18 | train_list = scio.loadmat(train_dir) 19 | test_list = scio.loadmat(test_dir) 20 | 21 | # print(train_list['file_list']) 22 | # print(train_list['labels']) 23 | 24 | train_data = train_list['file_list'] 25 | test_data = test_list['file_list'] 26 | train_label = train_list['labels'] 27 | test_label = test_list['labels'] 28 | 29 | # for i in range(len(train_data)): 30 | # pic = train_data[i][0].tolist()[0] 31 | # label = train_label[i][0].tolist() - 1 32 | # label = str(label) 33 | # # print(pic) 34 | # # print(label) 35 | # line = pic + ' ' + label + '\n' 36 | # train_file = open('/home/samuel/datasets/Stanford_Dogs/train.txt', 'a') 37 | # train_file.write(line) 38 | 39 | for j in range(len(test_data)): 40 | pic = test_data[j][0].tolist()[0] 41 | label = test_label[j][0].tolist() - 1 42 | label = str(label) 43 | # print(pic) 44 | # print(label) 45 | line = pic + ' ' + label + '\n' 46 | test_file = open('/home/samuel/datasets/Stanford_Dogs/val.txt', 'a') 47 | test_file.write(line) 48 | 49 | # print(train_data[0][0].tolist()[0]) 50 | # 51 | # for i in range(len(train_data)): 52 | # pic = train_data[i][0].tolist()[0] 53 | # root = os.path.join(data_dir, 'Images', pic) 54 | # label, img = pic.split('/') 55 | # # print(label) 56 | # copyfile(root, os.path.join(data_dir, 'train', label)) 57 | # 58 | # for j in range(len(test_data)): 59 | # pic = test_data[j][0].tolist()[0] 60 | # root = os.path.join(data_dir, 'Images', pic) 61 | # label, img = pic.split('/') 62 | # # print(root) 63 | # copyfile(root, os.path.join(data_dir, 'val', label)) 64 | -------------------------------------------------------------------------------- /apex/amp/rnn_compat.py: -------------------------------------------------------------------------------- 1 | from . import utils, wrap 2 | 3 | import torch 4 | _VF = torch._C._VariableFunctions 5 | RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm'] 6 | 7 | def _gen_VF_wrapper(name): 8 | def wrapper(*args, **kwargs): 9 | return getattr(_VF, name)(*args, **kwargs) 10 | return wrapper 11 | 12 | # Some python magic to generate an object that has the rnn cell functions 13 | # defined on it, all of which call into corresponding _VF version. 14 | # Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF" 15 | # imported at module scope within torch.nn.modules.rnn). This should 16 | # not affect third-party importers of _VF.py. 17 | class VariableFunctionsShim(object): 18 | def __init__(self): 19 | for name in RNN_NAMES: 20 | for suffix in ['', '_cell']: 21 | fn_name = name + suffix 22 | setattr(self, fn_name, _gen_VF_wrapper(fn_name)) 23 | 24 | def has_old_rnns(): 25 | try: 26 | torch.nn.backends.thnn.backend.LSTMCell 27 | return True 28 | except: 29 | return False 30 | 31 | def whitelist_rnn_cells(handle, verbose): 32 | # Different module + function names in old/new RNN cases 33 | if has_old_rnns(): 34 | fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell'] 35 | mod = torch.nn.backends.thnn.backend 36 | else: 37 | fn_names = [x + '_cell' for x in RNN_NAMES] 38 | mod = torch.nn.modules.rnn._VF 39 | assert isinstance(mod, VariableFunctionsShim) 40 | 41 | # Insert casts on cell functions 42 | for fn in fn_names: 43 | wrap.cached_cast(mod, fn, utils.maybe_half, handle, 44 | try_caching=True, verbose=verbose) 45 | 46 | if has_old_rnns(): 47 | # Special handling of `backward` for fused gru / lstm: 48 | # The `backward` method calls Tensor.sum() (blacklist) internally, 49 | # and then the resulting grad_input has the wrong type. 50 | # TODO: where else is this a problem? 51 | for rnn_type in ['GRUFused', 'LSTMFused']: 52 | mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type) 53 | wrap.disable_casts(mod, 'backward', handle) 54 | -------------------------------------------------------------------------------- /apex/amp/_amp_state.py: -------------------------------------------------------------------------------- 1 | # This is a "header object" that allows different amp modules to communicate. 2 | # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like. 3 | # But apparently it's ok: 4 | # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm 5 | import os 6 | import torch 7 | 8 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 9 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 10 | 11 | 12 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 13 | from torch._six import container_abcs 14 | else: 15 | import collections.abc as container_abcs 16 | 17 | 18 | class AmpState(object): 19 | def __init__(self): 20 | self.hard_override=False 21 | self.allow_incoming_model_not_fp32 = False 22 | self.verbosity=1 23 | 24 | 25 | # Attribute stash. Could also just stash things as global module attributes. 26 | _amp_state = AmpState() 27 | 28 | 29 | def warn_or_err(msg): 30 | if _amp_state.hard_override: 31 | print("Warning: " + msg) 32 | else: 33 | raise RuntimeError(msg) 34 | # I'm not sure if allowing hard_override is a good idea. 35 | # + " If you're sure you know what you're doing, supply " + 36 | # "hard_override=True to amp.initialize.") 37 | 38 | 39 | def maybe_print(msg, rank0=False): 40 | distributed = torch.distributed.is_available() and \ 41 | torch.distributed.is_initialized() and \ 42 | torch.distributed.get_world_size() > 1 43 | if _amp_state.verbosity > 0: 44 | if rank0: 45 | if distributed: 46 | if torch.distributed.get_rank() == 0: 47 | print(msg) 48 | else: 49 | print(msg) 50 | else: 51 | print(msg) 52 | 53 | 54 | # def iter_params(param_groups): 55 | # for group in param_groups: 56 | # for p in group['params']: 57 | # yield p 58 | 59 | 60 | def master_params(optimizer): 61 | """ 62 | Generator expression that iterates over the params owned by ``optimizer``. 63 | 64 | Args: 65 | optimizer: An optimizer previously returned from ``amp.initialize``. 66 | """ 67 | for group in optimizer.param_groups: 68 | for p in group['params']: 69 | yield p 70 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_b16_config(): 5 | """Returns the ViT-B/16 configuration.""" 6 | config = ml_collections.ConfigDict() 7 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 8 | config.split = 'overlap' 9 | config.slide_step = 12 10 | config.hidden_size = 768 11 | config.transformer = ml_collections.ConfigDict() 12 | config.transformer.mlp_dim = 3072 13 | config.transformer.num_heads = 12 14 | config.transformer.num_layers = 12 15 | config.transformer.attention_dropout_rate = 0.0 16 | config.transformer.dropout_rate = 0.1 17 | config.classifier = 'token' 18 | config.representation_size = None 19 | return config 20 | 21 | 22 | def get_b32_config(): 23 | """Returns the ViT-B/32 configuration.""" 24 | config = get_b16_config() 25 | config.patches.size = (32, 32) 26 | return config 27 | 28 | 29 | def get_l16_config(): 30 | """Returns the ViT-L/16 configuration.""" 31 | config = ml_collections.ConfigDict() 32 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 33 | config.hidden_size = 1024 34 | config.transformer = ml_collections.ConfigDict() 35 | config.transformer.mlp_dim = 4096 36 | config.transformer.num_heads = 16 37 | config.transformer.num_layers = 24 38 | config.transformer.attention_dropout_rate = 0.0 39 | config.transformer.dropout_rate = 0.1 40 | config.classifier = 'token' 41 | config.representation_size = None 42 | return config 43 | 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.patches.size = (32, 32) 49 | return config 50 | 51 | 52 | def get_h14_config(): 53 | """Returns the ViT-L/16 configuration.""" 54 | config = ml_collections.ConfigDict() 55 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 56 | config.hidden_size = 1280 57 | config.transformer = ml_collections.ConfigDict() 58 | config.transformer.mlp_dim = 5120 59 | config.transformer.num_heads = 16 60 | config.transformer.num_layers = 32 61 | config.transformer.attention_dropout_rate = 0.0 62 | config.transformer.dropout_rate = 0.1 63 | config.classifier = 'token' 64 | config.representation_size = None 65 | return config 66 | -------------------------------------------------------------------------------- /apex/amp/README.md: -------------------------------------------------------------------------------- 1 | # amp: Automatic Mixed Precision 2 | 3 | ## Annotating User Functions 4 | 5 | Nearly all PyTorch user code needs nothing more than the two steps 6 | above to use amp. After all, custom layers are built out of simpler 7 | PyTorch components, and amp already can see those. 8 | 9 | However, any custom C++ or CUDA code is outside of amp's (default) 10 | view of things. For example, suppose I implemented a new recurrent 11 | cell called a "forgetful recurrent unit" that calls directly into a 12 | CUDA backend: 13 | 14 | ```python 15 | from backend import FRUBackend 16 | 17 | def fru(input, hidden, weight, bias): 18 | # call to CUDA code 19 | FRUBackend(input, hidden, weight, bias) 20 | ``` 21 | 22 | In this case, it is possible to get a runtime type mismatch. For 23 | example, you might have `input` in fp16, and `weight` in fp32, and amp 24 | doesn't have the visibility to insert an appropriate cast. 25 | 26 | amp exposes two ways to handle "invisible" backend code: function 27 | annotations and explicit registration. 28 | 29 | #### Function annotation 30 | 31 | The first way to handle backend code is a set of function annotations: 32 | 33 | - `@amp.half_function` 34 | - `@amp.float_function` 35 | - `@amp.promote_function` 36 | 37 | These correspond to: 38 | 39 | - Cast all arguments to fp16 40 | - Cast all argumnets fo fp32 41 | - If there are any type mismatches, cast everything to the widest type 42 | 43 | In our example, we believe that the FRU unit is fp16-safe and will get 44 | performance gains from casting its arguments to fp16, so we write: 45 | 46 | ```python 47 | @amp.half_function 48 | def fru(input, hidden, weight, bias): 49 | #... 50 | ``` 51 | 52 | #### Explicit registration 53 | 54 | The other way to handle backend code is with explicit function 55 | registration: 56 | 57 | - `amp.register_half_function(module, function_name)` 58 | - `amp.register_float_function(module, function_name)` 59 | - `amp.register_promote_function(module, function_name)` 60 | 61 | When using this API, `module` is the containing class or module for 62 | the function, and `function_name` is the _string_ name of the 63 | function. Note that the function must be registered before the call to 64 | `amp.initalize()`. 65 | 66 | For our FRU unit, we can register the backend function directly: 67 | 68 | ```python 69 | import backend 70 | 71 | amp.register_half_function(backend, 'FRUBackend') 72 | ``` 73 | -------------------------------------------------------------------------------- /apex/amp/lists/functional_overrides.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: think about the following two. They do weird things. 3 | # - torch.nn.utils.clip_grad (but it should always be fp32 anyway) 4 | # - torch.nn.utils.weight_norm 5 | 6 | # Notes: 7 | # F.instance_norm uses batch_norm internally. Which correctly handles 8 | # fp16 in/out with fp32 weights. So we shouldn't do anything for 9 | # either of these. 10 | # F.normalize calls `input.norm()` internally, so it's redundant, but 11 | # kept here in case impl. changes. 12 | # F.cosine_similarity is same: calls `x.norm()` internally. 13 | 14 | import torch.nn.functional 15 | 16 | MODULE = torch.nn.functional 17 | 18 | FP16_FUNCS = [ 19 | 'conv1d', 20 | 'conv2d', 21 | 'conv3d', 22 | 'conv_transpose1d', 23 | 'conv_transpose2d', 24 | 'conv_transpose3d', 25 | 'conv_tbc', # Undocumented / maybe new? 26 | 'linear', 27 | ] 28 | 29 | FP32_FUNCS = [ 30 | 31 | # Interpolation/Upsampling TODO: Remove for 1.2 32 | 'interpolate', 33 | 'grid_sample', 34 | 35 | # Pointwise 36 | 'softplus', 37 | 'softmin', 38 | 'log_softmax', 39 | 'softmax', 40 | 'gelu', 41 | 42 | # Normalization 43 | 'layer_norm', 44 | 'group_norm', 45 | 'local_response_norm', 46 | 'normalize', 47 | 'cosine_similarity', 48 | 49 | # Loss functions 50 | # TODO: which of these can be fp16? 51 | 'poisson_nll_loss', 52 | 'cosine_embedding_loss', 53 | 'cross_entropy', 54 | 'hinge_embedding_loss', 55 | 'kl_div', 56 | 'l1_loss', 57 | 'mse_loss', 58 | 'margin_ranking_loss', 59 | 'multilabel_margin_loss', 60 | 'multilabel_soft_margin_loss', 61 | 'multi_margin_loss', 62 | 'nll_loss', 63 | 'binary_cross_entropy_with_logits', 64 | 'smooth_l1_loss', 65 | 'soft_margin_loss', 66 | 'triplet_margin_loss', 67 | 'ctc_loss' 68 | ] 69 | 70 | BANNED_FUNCS = [ 71 | ('binary_cross_entropy', 72 | ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` " 73 | "It requires that the output of the previous function be already a FloatTensor. \n\n" 74 | "Most models have a Sigmoid right before BCELoss. In that case, you can use\n" 75 | " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer " 76 | "that is compatible with amp.\nAnother option is to add\n" 77 | " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n" 78 | "If you _really_ know what you are doing, you can disable this warning by passing " 79 | "allow_banned=True to `amp.init()`.")) 80 | ] 81 | -------------------------------------------------------------------------------- /utils/datasets_split.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import os 4 | import shutil 5 | 6 | import pandas as pd 7 | 8 | 9 | def copyfile(filename, target_dir): 10 | """将文件复制到目标目录。""" 11 | os.makedirs(target_dir, exist_ok=True) 12 | shutil.copy(filename, target_dir) 13 | 14 | 15 | def reorg_train_valid(data_dir, labels, labels_counts, valid_ratio): 16 | # # 训练数据集中示例最少的类别中的示例数 17 | # n = collections.Counter(labels.values()).most_common()[-1][1] 18 | # 19 | # print(n) 20 | # # 验证集中每个类别的示例数 21 | # n_valid_per_label = max(1, math.floor(n * valid_ratio)) 22 | label_count = {} 23 | for train_file in os.listdir(os.path.join(data_dir, 'train_images')): 24 | label = labels[train_file] 25 | count = labels_counts[label] 26 | n_valid_per_label = math.floor(count * valid_ratio) 27 | fname = os.path.join(data_dir, 'train_images', train_file) 28 | 29 | if label not in label_count or label_count[label] < n_valid_per_label: 30 | copyfile(fname, os.path.join(data_dir, 31 | 'test', label)) 32 | label_count[label] = label_count.get(label, 0) + 1 33 | else: 34 | copyfile(fname, os.path.join(data_dir, 35 | 'train', label)) 36 | return n_valid_per_label 37 | 38 | 39 | def read_txt_labels(fname): 40 | """读取 `fname` 来给标签字典返回一个文件名。""" 41 | with open(fname, 'r') as f: 42 | # 跳过文件头行 (列名) 43 | lines = f.readlines()[1:] 44 | tokens = [l.rstrip().split() for l in lines] 45 | return list((name, label) for name, label in tokens) 46 | 47 | 48 | data_dir = '/home/samuel/datasets/CUB_200_2011/' 49 | 50 | train_dir = '/home/samuel/datasets/CUB_200_2011/train.txt' 51 | test_dir = '/home/samuel/datasets/CUB_200_2011/val.txt' 52 | labels1 = read_txt_labels(train_dir) 53 | labels2 = read_txt_labels(test_dir) 54 | 55 | for label in labels1: 56 | fname, img = label[0].split('/') 57 | root = os.path.join(data_dir, 'images', fname, img) 58 | copyfile(root, os.path.join(data_dir, 'train', fname)) 59 | 60 | for label in labels2: 61 | fname, img = label[0].split('/') 62 | root = os.path.join(data_dir, 'images', fname, img) 63 | copyfile(root, os.path.join(data_dir, 'val', fname)) 64 | 65 | 66 | # data_csv = pd.read_csv('data/train.csv', header=0) 67 | # labels_counts = data_csv['labels'].value_counts() 68 | # print(labels_counts) 69 | # 70 | # valid_ratio = 0.3 71 | # reorg_train_valid('./data/', labels, labels_counts, valid_ratio) 72 | -------------------------------------------------------------------------------- /apex/amp/lists/torch_overrides.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import utils 4 | 5 | MODULE = torch 6 | 7 | FP16_FUNCS = [ 8 | # Low level functions wrapped by torch.nn layers. 9 | # The wrapper layers contain the weights which are then passed in as a parameter 10 | # to these functions. 11 | 'conv1d', 12 | 'conv2d', 13 | 'conv3d', 14 | 'conv_transpose1d', 15 | 'conv_transpose2d', 16 | 'conv_transpose3d', 17 | 'conv_tbc', 18 | 'prelu', 19 | 20 | # BLAS 21 | 'addmm', 22 | 'addmv', 23 | 'addr', 24 | 'matmul', 25 | 'mm', 26 | 'mv', 27 | ] 28 | 29 | FP32_FUNCS = [ 30 | # Pointwise 31 | 'acos', 32 | 'asin', 33 | 'cosh', 34 | 'erfinv', 35 | 'exp', 36 | 'expm1', 37 | 'log', 38 | 'log10', 39 | 'log2', 40 | 'reciprocal', 41 | 'rsqrt', 42 | 'sinh', 43 | 'tan', 44 | 45 | # Other math 46 | 'pow', 47 | 48 | # Reduction 49 | 'cumprod', 50 | 'cumsum', 51 | 'dist', 52 | # 'mean', 53 | 'norm', 54 | 'prod', 55 | 'std', 56 | 'sum', 57 | 'var', 58 | 59 | # Misc 60 | 'renorm' 61 | ] 62 | 63 | version_strings = torch.__version__.split('.') 64 | version_major = version_strings[0] 65 | version_minor = version_strings[1] 66 | version_num = float(version_major + "." + version_minor) 67 | # Before torch 1.1, mean must be blacklisted. 68 | if version_num < 1.1: 69 | FP32_FUNCS.append('mean') 70 | 71 | # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We 72 | # check the CUDA version -- if at least 9.1, then put the bmm 73 | # functions on the fp16 list. Otherwise, put them on the fp32 list. 74 | _bmms = ['addbmm', 75 | 'baddbmm', 76 | 'bmm'] 77 | 78 | if utils.is_cuda_enabled(): 79 | # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802 80 | if utils.get_cuda_version() >= (9, 1, 0): 81 | FP16_FUNCS.extend(_bmms) 82 | else: 83 | FP32_FUNCS.extend(_bmms) 84 | 85 | # Multi-tensor fns that may need type promotion 86 | CASTS = [ 87 | # Multi-tensor math 88 | 'addcdiv', 89 | 'addcmul', 90 | 'atan2', 91 | 'cross', 92 | 'bilinear', 93 | 'dot', 94 | 95 | # Element-wise _or_ tensor-wise math 96 | 'add', 97 | 'div', 98 | 'mul', 99 | 100 | # Comparison 101 | 'eq', 102 | 'equal', 103 | 'ge', 104 | 'gt', 105 | 'le', 106 | 'lt', 107 | 'ne' 108 | ] 109 | 110 | # Functions that take sequence arguments. We need to inspect the whole 111 | # sequence and cast to the widest type. 112 | SEQUENCE_CASTS = [ 113 | 'cat', 114 | 'stack' 115 | ] 116 | -------------------------------------------------------------------------------- /apex/parallel/README.md: -------------------------------------------------------------------------------- 1 | ## Distributed Data Parallel 2 | 3 | distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library. 4 | 5 | `apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with 6 | computation in the backward pass and bucketing smaller transfers to reduce the total number of 7 | transfers required. 8 | 9 | multiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs. 10 | 11 | #### [API Documentation](https://nvidia.github.io/apex/parallel.html) 12 | 13 | #### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed) 14 | 15 | #### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 16 | 17 | #### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex) 18 | 19 | ### Synchronized Batch Normalization 20 | 21 | `apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`. 22 | It reduces stats on the first (channel) dimension of the Tensor and accepts 23 | arbitrary spatial dimensions. 24 | 25 | #### Installation 26 | 27 | Apex provides two sync BN implementation: 28 | 29 | 1. There is the Python-only implementation, which is the default implementation 30 | when install with `python setup.py install`. 31 | It uses PyTorch primitive operations and distributed communication package from 32 | `torch.distributed`. 33 | 34 | - _Python-only implementation requires input tensor to be of same data type as 35 | layer_ 36 | 37 | 2. We also provide implementation with kernels through CUDA/C++ extension with 38 | improved performance. We are experimenting with Welford and Kahan for reduction 39 | hoping to get better accuracy. 40 | To use the kernel implementation, user need to install Apex with CUDA extension 41 | enabled `python setup.py install --cuda_ext`. 42 | 43 | - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn. 44 | This is required to run imagenet example in fp16._ 45 | 46 | - _Currently kernel implementation only supports GPU._ 47 | 48 | #### HowTo 49 | 50 | 1. User could use `apex.parallel.SyncBatchNorm` by building their module with 51 | the layer explicitly. 52 | 53 | ``` 54 | import apex 55 | input_t = torch.randn(3, 5, 20).cuda() 56 | sbn = apex.parallel.SyncBatchNorm(5).cuda() 57 | output_t = sbn(input) 58 | ``` 59 | 60 | 2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`. 61 | 62 | ``` 63 | # model is an instance of torch.nn.Module 64 | import apex 65 | sync_bn_model = apex.parallel.convert_syncbn_model(model) 66 | ``` 67 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class ConstantLRSchedule(LambdaLR): 9 | """ Constant learning rate schedule. 10 | """ 11 | def __init__(self, optimizer, last_epoch=-1): 12 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 13 | 14 | 15 | class WarmupConstantSchedule(LambdaLR): 16 | """ Linear warmup and then constant. 17 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 18 | Keeps learning rate schedule equal to 1. after warmup_steps. 19 | """ 20 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 21 | self.warmup_steps = warmup_steps 22 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 23 | 24 | def lr_lambda(self, step): 25 | if step < self.warmup_steps: 26 | return float(step) / float(max(1.0, self.warmup_steps)) 27 | return 1. 28 | 29 | 30 | class WarmupLinearSchedule(LambdaLR): 31 | """ Linear warmup and then linear decay. 32 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 33 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 34 | """ 35 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 36 | self.warmup_steps = warmup_steps 37 | self.t_total = t_total 38 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 39 | 40 | def lr_lambda(self, step): 41 | if step < self.warmup_steps: 42 | return float(step) / float(max(1, self.warmup_steps)) 43 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 44 | 45 | 46 | class WarmupCosineSchedule(LambdaLR): 47 | """ Linear warmup and then cosine decay. 48 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 49 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 50 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 51 | """ 52 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 53 | self.warmup_steps = warmup_steps 54 | self.t_total = t_total 55 | self.cycles = cycles 56 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1.0, self.warmup_steps)) 61 | # progress after warmup 62 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 63 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AA-trans: Core attention aggregating transformer with informationentropy selector for fine-grained visual classification 2 | >The task of fine-grained visual classification (FGVC) is to distinguish targets from subordinate classifica-tions. Since fine-grained images have the inherent characteristic of large inter-class variances and >smallintra-class variances, it is considered an extremely difficult task. To resolve this problem, we redesigned an attention aggregating transformer (AA-Trans) to better capture minor differences among images by >improving the ViT structure.Extensive experiments showthat our proposed model structure can achieve a new state-of-the-art performance on several mainstreamdatasets. 3 | ## Contributions 4 | >1. We propose a reasonable transformer model for FGVC that candynamically detect distinguished regions and effectively exploitthe global and local information of images. 5 | >2. We present a core attention aggregator, which well mitigatesthe common information loss problem existing in the trans-former layer. 6 | >3. We design an efficient key token auto-selector based on infor-mation entropy, which can efficiently decide to select those to-kens that contain critical information without introducing addi-tional parameters. 7 | >4. We validate the effectiveness of our method on four fine-grained visual classification benchmark datasets. The results ofthe experiments illustrate that our proposed method achieves anew state-of-the-art performance. 8 | ## Model 9 | ![model](model.png) 10 | ## Pretraines_weights 11 | download Link: https://pan.baidu.com/s/1yHWdEun9H9Uim9zzZDITQg?pwd=knjy 12 | ## Environment Requirements 13 | ``` 14 | cxxfilt>=0.2.0 15 | tqdm>=4.28.1 16 | numpy>=1.15.3 17 | PyYAML>=5.1 18 | pytest>=3.5.1 19 | packaging>=14.0 20 | flake8>=3.7.9 21 | Sphinx>=3.0.3 22 | ``` 23 | ## DataSets 24 | ### Caltech-UCSD Birds-200-2011 (CUB-200-2011) 25 | Caltech-UCSD Birds-200-2011 (CUB-200-2011) is an extended version of the CUB-200 dataset, with roughly double the number of images per class and new part location annotations.
26 | DownLoad Link: https://www.vision.caltech.edu/datasets/cub_200_2011/ 27 | >1. Number of categories: 200 28 | >2. Number of images: 11,788 29 | >3. Annotations per image: 15 Part Locations, 312 Binary Attributes, 1 Bounding Box 30 | ### Stanford Dogs Dataset 31 | The Stanford Dogs dataset contains images of 120 breeds of dogs from around the world. This dataset has been built using images and annotation from ImageNet for the task of fine-grained image categorization.
32 | Download Link: http://vision.stanford.edu/aditya86/ImageNetDogs/ 33 | >1. Number of categories: 120 34 | >2. Number of images: 20,580 35 | >3. Annotations: Class labels, Bounding boxesx 36 | ### NABirds Dataset 37 | Try out a dataset for fine-grained recognition, featuring 400 species of North America’s birds.
38 | Download Link: https://paperswithcode.com/dataset/nabirds 39 | ### IP102 Datset 40 | The IP102 datset contains more than 75,000 images belongs to 102 categories.
41 | Download Link: https://www.aliyundrive.com/s/c5G9scSGyak 42 | ## Experiments Result 43 |
44 | 45 |
46 |
47 | 48 |
49 | 50 | 51 | -------------------------------------------------------------------------------- /apex/amp/opt.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | from .scaler import LossScaler, master_params 5 | from ._amp_state import maybe_print 6 | 7 | import numpy as np 8 | 9 | class OptimWrapper(object): 10 | def __init__(self, optimizer, amp_handle, num_loss): 11 | self._optimizer = optimizer 12 | self._amp_handle = amp_handle 13 | self._num_loss = num_loss 14 | self._loss_idx = 0 15 | self._skip_next = [False] * num_loss 16 | self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)] 17 | 18 | @contextlib.contextmanager 19 | def scale_loss(self, loss): 20 | if not self._amp_handle.is_active(): 21 | yield loss 22 | return 23 | 24 | # When there are multiple losses per-optimizer, we need 25 | # to save out current grad accumulation, since we won't be 26 | # able to unscale this particulare loss once the grads are 27 | # all mixed together. 28 | cached_grads = [] 29 | if self._loss_idx > 0: 30 | for p in master_params(self._optimizer): 31 | if p.grad is not None: 32 | cached_grads.append(p.grad.data.detach().clone()) 33 | else: 34 | cached_grads.append(None) 35 | self._optimizer.zero_grad() 36 | 37 | loss_scale = self._cur_loss_scaler().loss_scale() 38 | yield loss * loss_scale 39 | 40 | self._cur_loss_scaler().clear_overflow_state() 41 | self._cur_loss_scaler().unscale( 42 | master_params(self._optimizer), 43 | master_params(self._optimizer), 44 | loss_scale) 45 | self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale() 46 | self._loss_idx += 1 47 | 48 | if len(cached_grads) > 0: 49 | for p, cached_grad in zip(master_params(self._optimizer), 50 | cached_grads): 51 | if cached_grad is not None: 52 | p.grad.data.add_(cached_grad) 53 | cached_grads = [] 54 | 55 | def _cur_loss_scaler(self): 56 | assert 0 <= self._loss_idx < self._num_loss 57 | return self._loss_scaler[self._loss_idx] 58 | 59 | def step(self, closure=None): 60 | if not self._amp_handle.is_active(): 61 | return self._optimizer.step(closure=closure) 62 | 63 | self._loss_idx = 0 64 | 65 | for group in self._optimizer.param_groups: 66 | for p in group['params']: 67 | self._amp_handle.remove_cache(p) 68 | 69 | if closure is not None: 70 | raise NotImplementedError( 71 | 'The `closure` argument is unsupported by the amp ' + 72 | 'optimizer wrapper.') 73 | if any(self._skip_next): 74 | maybe_print('Gradient overflow, skipping update') 75 | self._skip_next = [False] * self._num_loss 76 | else: 77 | return self._optimizer.step(closure=closure) 78 | 79 | # Forward any attribute lookups 80 | def __getattr__(self, attr): 81 | return getattr(self._optimizer, attr) 82 | 83 | # Forward all torch.optim.Optimizer methods 84 | def __getstate__(self): 85 | return self._optimizer.__getstate__() 86 | 87 | def __setstate__(self): 88 | return self._optimizer.__setstate__() 89 | 90 | def __repr__(self): 91 | return self._optimizer.__repr__() 92 | 93 | def state_dict(self): 94 | return self._optimizer.state_dict() 95 | 96 | def load_state_dict(self, state_dict): 97 | return self._optimizer.load_state_dict(state_dict) 98 | 99 | def zero_grad(self): 100 | return self._optimizer.zero_grad() 101 | 102 | def add_param_group(self, param_group): 103 | return self._optimizer.add_param_group(param_group) 104 | -------------------------------------------------------------------------------- /apex/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if hasattr(torch.distributed, 'ReduceOp'): 4 | ReduceOp = torch.distributed.ReduceOp 5 | elif hasattr(torch.distributed, 'reduce_op'): 6 | ReduceOp = torch.distributed.reduce_op 7 | else: 8 | ReduceOp = torch.distributed.deprecated.reduce_op 9 | 10 | from .distributed import DistributedDataParallel, Reducer 11 | # This is tricky because I'd like SyncBatchNorm to be exposed the same way 12 | # for both the cuda-enabled and python-fallback versions, and I don't want 13 | # to suppress the error information. 14 | try: 15 | import syncbn 16 | from .optimized_sync_batchnorm import SyncBatchNorm 17 | except ImportError as err: 18 | from .sync_batchnorm import SyncBatchNorm 19 | SyncBatchNorm.syncbn_import_error = err 20 | 21 | def convert_syncbn_model(module, process_group=None, channel_last=False): 22 | ''' 23 | Recursively traverse module and its children to replace all instances of 24 | ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`. 25 | 26 | All ``torch.nn.BatchNorm*N*d`` wrap around 27 | ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch 28 | to use sync BN. 29 | 30 | Args: 31 | module (torch.nn.Module): input module 32 | 33 | Example:: 34 | 35 | >>> # model is an instance of torch.nn.Module 36 | >>> import apex 37 | >>> sync_bn_model = apex.parallel.convert_syncbn_model(model) 38 | ''' 39 | mod = module 40 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 41 | return module 42 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 43 | mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last) 44 | mod.running_mean = module.running_mean 45 | mod.running_var = module.running_var 46 | mod.num_batches_tracked = module.num_batches_tracked 47 | if module.affine: 48 | mod.weight.data = module.weight.data.clone().detach() 49 | mod.bias.data = module.bias.data.clone().detach() 50 | for name, child in module.named_children(): 51 | mod.add_module(name, convert_syncbn_model(child, 52 | process_group=process_group, 53 | channel_last=channel_last)) 54 | # TODO(jie) should I delete model explicitly? 55 | del module 56 | return mod 57 | 58 | def create_syncbn_process_group(group_size): 59 | ''' 60 | Creates process groups to be used for syncbn of a give ``group_size`` and returns 61 | process group that current GPU participates in. 62 | 63 | ``group_size`` must divide the total number of GPUs (world_size). 64 | 65 | ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned. 66 | 67 | ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead. 68 | 69 | Args: 70 | group_size (int): number of GPU's to collaborate for sync bn 71 | 72 | Example:: 73 | 74 | >>> # model is an instance of torch.nn.Module 75 | >>> import apex 76 | >>> group = apex.parallel.create_syncbn_process_group(group_size) 77 | ''' 78 | 79 | if group_size==0: 80 | return None 81 | 82 | world_size = torch.distributed.get_world_size() 83 | assert(world_size >= group_size) 84 | assert(world_size % group_size == 0) 85 | 86 | group=None 87 | for group_num in (range(world_size//group_size)): 88 | group_ids = range(group_num*group_size, (group_num+1)*group_size) 89 | cur_group = torch.distributed.new_group(ranks=group_ids) 90 | if (torch.distributed.get_rank()//group_size == group_num): 91 | group = cur_group 92 | #can not drop out and return here, every process must go through creation of all subgroups 93 | 94 | assert(group is not None) 95 | return group 96 | -------------------------------------------------------------------------------- /apex/parallel/sync_batchnorm_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | from apex.parallel import ReduceOp 5 | 6 | 7 | class SyncBatchnormFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size): 11 | torch.cuda.nvtx.range_push("sync_BN_fw") 12 | # transpose it to channel last to support broadcasting for input with different rank 13 | c_last_input = input.transpose(1, -1).contiguous().clone() 14 | 15 | ctx.save_for_backward(c_last_input, weight, bias, 16 | running_mean, running_variance) 17 | ctx.eps = eps 18 | ctx.process_group = process_group 19 | ctx.world_size = world_size 20 | 21 | c_last_input = (c_last_input - running_mean) / \ 22 | torch.sqrt(running_variance + eps) 23 | 24 | if weight is not None: 25 | c_last_input = c_last_input * weight 26 | if bias is not None: 27 | c_last_input = c_last_input + bias 28 | 29 | torch.cuda.nvtx.range_pop() 30 | return c_last_input.transpose(1, -1).contiguous().clone() 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | torch.cuda.nvtx.range_push("sync_BN_bw") 35 | # mini batch mean & var are calculated by forward path. 36 | # mu = 1./N*np.sum(h, axis = 0) 37 | # var = 1./N*np.sum((h-mu)**2, axis = 0) 38 | c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors 39 | 40 | eps = ctx.eps 41 | process_group = ctx.process_group 42 | world_size = ctx.world_size 43 | grad_input = grad_weight = grad_bias = None 44 | num_features = running_mean.size()[0] 45 | 46 | # transpose it to channel last to support broadcasting for input with different rank 47 | torch.cuda.nvtx.range_push("carilli field") 48 | c_last_grad = grad_output.transpose(1, -1).contiguous() 49 | # squash non-channel dimension so we can easily calculate mean 50 | c_grad = c_last_grad.view(-1, num_features).contiguous() 51 | torch.cuda.nvtx.range_pop() 52 | 53 | # calculate grad_input 54 | if ctx.needs_input_grad[0]: 55 | # dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0) 56 | # - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0)) 57 | mean_dy = c_grad.mean(0) 58 | mean_dy_xmu = (c_last_grad * (c_last_input - 59 | running_mean)).view(-1, num_features).mean(0) 60 | if torch.distributed.is_initialized(): 61 | torch.distributed.all_reduce( 62 | mean_dy, ReduceOp.SUM, process_group) 63 | mean_dy = mean_dy / world_size 64 | torch.distributed.all_reduce( 65 | mean_dy_xmu, ReduceOp.SUM, process_group) 66 | mean_dy_xmu = mean_dy_xmu / world_size 67 | c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / ( 68 | running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps) 69 | if weight is not None: 70 | c_last_grad_input.mul_(weight) 71 | grad_input = c_last_grad_input.transpose(1, -1).contiguous() 72 | 73 | # calculate grad_weight 74 | grad_weight = None 75 | if weight is not None and ctx.needs_input_grad[1]: 76 | # dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0) 77 | grad_weight = ((c_last_input - running_mean) / torch.sqrt( 78 | running_variance + eps) * c_last_grad).view(-1, num_features).sum(0) 79 | 80 | # calculate grad_bias 81 | grad_bias = None 82 | if bias is not None and ctx.needs_input_grad[2]: 83 | # dbeta = np.sum(dy, axis=0) 84 | grad_bias = c_grad.sum(0) 85 | 86 | torch.cuda.nvtx.range_pop() 87 | return grad_input, grad_weight, grad_bias, None, None, None, None, None 88 | -------------------------------------------------------------------------------- /apex/parallel/LARC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | 5 | class LARC(object): 6 | """ 7 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 8 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 9 | local learning rate for each individual parameter. The algorithm is designed to improve 10 | convergence of large batch training. 11 | 12 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 13 | 14 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 15 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 16 | 17 | ``` 18 | model = ... 19 | optim = torch.optim.Adam(model.parameters(), lr=...) 20 | optim = LARC(optim) 21 | ``` 22 | 23 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 24 | 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | optim = apex.fp16_utils.FP16_Optimizer(optim) 30 | ``` 31 | 32 | Args: 33 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 34 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 35 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 36 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 37 | """ 38 | 39 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 40 | self.optim = optimizer 41 | self.trust_coefficient = trust_coefficient 42 | self.eps = eps 43 | self.clip = clip 44 | 45 | def __getstate__(self): 46 | return self.optim.__getstate__() 47 | 48 | def __setstate__(self, state): 49 | self.optim.__setstate__(state) 50 | 51 | @property 52 | def state(self): 53 | return self.optim.state 54 | 55 | def __repr__(self): 56 | return self.optim.__repr__() 57 | 58 | @property 59 | def param_groups(self): 60 | return self.optim.param_groups 61 | 62 | @param_groups.setter 63 | def param_groups(self, value): 64 | self.optim.param_groups = value 65 | 66 | def state_dict(self): 67 | return self.optim.state_dict() 68 | 69 | def load_state_dict(self, state_dict): 70 | self.optim.load_state_dict(state_dict) 71 | 72 | def zero_grad(self): 73 | self.optim.zero_grad() 74 | 75 | def add_param_group(self, param_group): 76 | self.optim.add_param_group( param_group) 77 | 78 | def step(self): 79 | with torch.no_grad(): 80 | weight_decays = [] 81 | for group in self.optim.param_groups: 82 | # absorb weight decay control from optimizer 83 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 84 | weight_decays.append(weight_decay) 85 | group['weight_decay'] = 0 86 | for p in group['params']: 87 | if p.grad is None: 88 | continue 89 | param_norm = torch.norm(p.data) 90 | grad_norm = torch.norm(p.grad.data) 91 | 92 | if param_norm != 0 and grad_norm != 0: 93 | # calculate adaptive lr + weight decay 94 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 95 | 96 | # clip learning rate for LARC 97 | if self.clip: 98 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 99 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 100 | 101 | p.grad.data += weight_decay * p.data 102 | p.grad.data *= adaptive_lr 103 | 104 | self.optim.step() 105 | # return weight decay control to optimizer 106 | for i, group in enumerate(self.optim.param_groups): 107 | group['weight_decay'] = weight_decays[i] 108 | -------------------------------------------------------------------------------- /apex/parallel/optimized_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | from torch.nn import functional as F 4 | 5 | import syncbn 6 | from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction 7 | 8 | 9 | class SyncBatchNorm(_BatchNorm): 10 | """ 11 | synchronized batch normalization module extented from `torch.nn.BatchNormNd` 12 | with the added stats reduction across multiple processes. 13 | :class:`apex.parallel.SyncBatchNorm` is designed to work with 14 | `DistributedDataParallel`. 15 | 16 | When running in training mode, the layer reduces stats across all processes 17 | to increase the effective batchsize for normalization layer. This is useful 18 | in applications where batch size is small on a given process that would 19 | diminish converged accuracy of the model. The model uses collective 20 | communication package from `torch.distributed`. 21 | 22 | When running in evaluation mode, the layer falls back to 23 | `torch.nn.functional.batch_norm` 24 | 25 | Args: 26 | num_features: :math:`C` from an expected input of size 27 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 28 | eps: a value added to the denominator for numerical stability. 29 | Default: 1e-5 30 | momentum: the value used for the running_mean and running_var 31 | computation. Can be set to ``None`` for cumulative moving average 32 | (i.e. simple average). Default: 0.1 33 | affine: a boolean value that when set to ``True``, this module has 34 | learnable affine parameters. Default: ``True`` 35 | track_running_stats: a boolean value that when set to ``True``, this 36 | module tracks the running mean and variance, and when set to ``False``, 37 | this module does not track such statistics and always uses batch 38 | statistics in both training and eval modes. Default: ``True`` 39 | process_group: pass in a process group within which the stats of the 40 | mini-batch is being synchronized. ``None`` for using default process 41 | group 42 | channel_last: a boolean value that when set to ``True``, this module 43 | take the last dimension of the input tensor to be the channel 44 | dimension. Default: False 45 | 46 | Examples:: 47 | >>> # channel first tensor 48 | >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() 49 | >>> inp = torch.randn(10, 100, 14, 14).cuda() 50 | >>> out = sbn(inp) 51 | >>> inp = torch.randn(3, 100, 20).cuda() 52 | >>> out = sbn(inp) 53 | >>> # channel last tensor 54 | >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda() 55 | >>> inp = torch.randn(10, 14, 14, 100).cuda() 56 | """ 57 | 58 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False): 59 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 60 | self.process_group = process_group 61 | self.channel_last = channel_last 62 | self.fuse_relu = fuse_relu 63 | 64 | def _specify_process_group(self, process_group): 65 | self.process_group = process_group 66 | 67 | def _specify_channel_last(self, channel_last): 68 | self.channel_last = channel_last 69 | 70 | def forward(self, input, z = None): 71 | # if input.dim() == 2, we switch to channel_last for efficient memory accessing 72 | channel_last = self.channel_last if input.dim() != 2 else True 73 | 74 | if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None: 75 | # fall back to pytorch implementation for inference 76 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) 77 | else: 78 | exponential_average_factor = 0.0 79 | if self.training and self.track_running_stats: 80 | self.num_batches_tracked += 1 81 | if self.momentum is None: 82 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 83 | else: 84 | exponential_average_factor = self.momentum 85 | return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu) 86 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torch.utils.data.sampler import BatchSampler 5 | from PIL import Image 6 | 7 | 8 | def default_loader(path): 9 | try: 10 | img = Image.open(path).convert('RGB') 11 | except: 12 | with open('read_error.txt', 'a') as fid: 13 | fid.write(path + '\n') 14 | return Image.new('RGB', (224, 224), 'white') 15 | return img 16 | 17 | 18 | class RandomDataset(Dataset): 19 | def __init__(self, transform=None, dataloader=default_loader): 20 | self.transform = transform 21 | self.dataloader = dataloader 22 | 23 | with open('/home/samuel/datasets/CUB_200_2011/val.txt', 'r') as fid: 24 | self.imglist = fid.readlines() 25 | 26 | def __getitem__(self, index): 27 | image_name, label = self.imglist[index].strip().split() 28 | image_path = '/home/samuel/datasets/CUB_200_2011/images/{}'.format(image_name) 29 | img = self.dataloader(image_path) 30 | img = self.transform(img) 31 | label = int(label) 32 | label = torch.LongTensor([label]) 33 | 34 | return [img, label] 35 | 36 | def __len__(self): 37 | return len(self.imglist) 38 | 39 | 40 | class BatchDataset(Dataset): 41 | def __init__(self, transform=None, dataloader=default_loader): 42 | self.transform = transform 43 | self.dataloader = dataloader 44 | 45 | # 打开train.txt, 读行 46 | with open('/home/samuel/datasets/CUB_200_2011/train.txt', 'r') as fid: 47 | self.imglist = fid.readlines() 48 | 49 | self.labels = [] 50 | # 每行读路径和标签 51 | for line in self.imglist: 52 | image_path, label = line.strip().split() 53 | self.labels.append(int(label)) 54 | self.labels = np.array(self.labels) 55 | self.labels = torch.LongTensor(self.labels) 56 | 57 | def __getitem__(self, index): 58 | # 每行读路径和标签 59 | image_name, label = self.imglist[index].strip().split() 60 | image_path = '/home/samuel/datasets/CUB_200_2011/images/{}'.format(image_name) 61 | img = self.dataloader(image_path) 62 | img = self.transform(img) 63 | label = int(label) 64 | label = torch.LongTensor([label]) 65 | 66 | return [img, label] 67 | 68 | def __len__(self): 69 | return len(self.imglist) 70 | 71 | 72 | class BalancedBatchSampler(BatchSampler): 73 | def __init__(self, dataset, n_classes, n_samples): 74 | # dataset这里指train_datasets, 75 | # labels:即之前求得的tensor([ 1, 1, 1, ..., 200, 200, 200]) 76 | self.labels = dataset.labels 77 | # 转化为list形式 78 | self.labels_set = list(set(self.labels.numpy())) 79 | # 这个操作大致是建立标签和序列对应的字典 80 | self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] 81 | for label in self.labels_set} 82 | print(self.label_to_indices) 83 | 84 | for l in self.labels_set: 85 | np.random.shuffle(self.label_to_indices[l]) 86 | 87 | self.used_label_indices_count = {label: 0 for label in self.labels_set} 88 | 89 | self.count = 0 90 | self.n_classes = n_classes 91 | self.n_samples = n_samples 92 | self.dataset = dataset 93 | # self.batch_size = 2 94 | self.batch_size = self.n_samples * self.n_classes 95 | 96 | def __iter__(self): 97 | self.count = 0 98 | while self.count + self.batch_size < len(self.dataset): 99 | # np.random.choice(a, size=None, replace=True, p=None) 100 | # 从a(只要是ndarray都可以,但必须是一维)中随机抽取数字,并组成指定大小(size)的数组 101 | # replace:True表示可以取相同数字,False表示不可以取相同数字 102 | # 数组p:与数组a相对,表示取数组a中每个元素的概率 103 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 104 | indices = [] 105 | for class_ in classes: 106 | indices.extend(self.label_to_indices[class_][ 107 | self.used_label_indices_count[class_]:self.used_label_indices_count[ 108 | class_] + self.n_samples]) 109 | self.used_label_indices_count[class_] += self.n_samples 110 | if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): 111 | np.random.shuffle(self.label_to_indices[class_]) 112 | self.used_label_indices_count[class_] = 0 113 | yield indices 114 | self.count += self.n_classes * self.n_samples 115 | 116 | def __len__(self): 117 | return len(self.dataset) // self.batch_size 118 | -------------------------------------------------------------------------------- /apex/parallel/optimized_sync_batchnorm_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | import syncbn 5 | from apex.parallel import ReduceOp 6 | 7 | class SyncBatchnormFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False): 11 | input = input.contiguous() 12 | world_size = 0 13 | 14 | mean = None 15 | var_biased = None 16 | inv_std = None 17 | var = None 18 | out = None 19 | count = None 20 | if track_running_stats: 21 | if channel_last: 22 | count = int(input.numel()/input.size(-1)) 23 | mean, var_biased = syncbn.welford_mean_var_c_last(input) 24 | num_channels = input.size(-1) 25 | else: 26 | count = int(input.numel()/input.size(1)) 27 | mean, var_biased = syncbn.welford_mean_var(input) 28 | num_channels = input.size(1) 29 | 30 | if torch.distributed.is_initialized(): 31 | if not process_group: 32 | process_group = torch.distributed.group.WORLD 33 | device = mean.device 34 | world_size = torch.distributed.get_world_size(process_group) 35 | 36 | count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count) 37 | combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0) 38 | combined_list = [torch.empty_like(combined) for k in range(world_size)] 39 | torch.distributed.all_gather(combined_list, combined, process_group) 40 | combined = torch.stack(combined_list, dim=0) 41 | mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) 42 | count_all = count_all.view(-1) 43 | mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps) 44 | else: 45 | device = mean.device 46 | count_all = torch.cuda.IntTensor([count], device=device) 47 | inv_std = 1.0 / torch.sqrt(var_biased + eps) 48 | var = var_biased * (count) / (count-1) 49 | 50 | if count == 1 and world_size < 2: 51 | raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size())) 52 | 53 | r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half() 54 | r_v_inc = var if running_variance.dtype != torch.float16 else var.half() 55 | running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc 56 | running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc 57 | else: 58 | mean = running_mean.data 59 | inv_std = 1.0 / torch.sqrt(running_variance.data + eps) 60 | 61 | ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32)) 62 | ctx.process_group = process_group 63 | ctx.channel_last = channel_last 64 | ctx.world_size = world_size 65 | ctx.fuse_relu = fuse_relu 66 | 67 | if channel_last: 68 | out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu) 69 | else: 70 | out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) 71 | 72 | return out 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | grad_output = grad_output.contiguous() 77 | # mini batch mean & var are calculated by forward path. 78 | # mu = 1./N*np.sum(h, axis = 0) 79 | # var = 1./N*np.sum((h-mu)**2, axis = 0) 80 | saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors 81 | process_group = ctx.process_group 82 | channel_last = ctx.channel_last 83 | world_size = ctx.world_size 84 | fuse_relu = ctx.fuse_relu 85 | grad_input = grad_z = grad_weight = grad_bias = None 86 | 87 | if fuse_relu: 88 | grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias) 89 | if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]: 90 | grad_z = grad_output.clone() 91 | 92 | # TODO: update kernel to not pre_divide by item_num 93 | if channel_last: 94 | sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight) 95 | else: 96 | sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight) 97 | 98 | # calculate grad_input 99 | if ctx.needs_input_grad[0]: 100 | 101 | if torch.distributed.is_initialized(): 102 | num_channels = sum_dy.shape[0] 103 | combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) 104 | torch.distributed.all_reduce( 105 | combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) 106 | sum_dy, sum_dy_xmu = torch.split(combined, num_channels) 107 | 108 | if channel_last: 109 | grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) 110 | else: 111 | grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) 112 | 113 | if weight is None or not ctx.needs_input_grad[2]: 114 | grad_weight = None 115 | 116 | if weight is None or not ctx.needs_input_grad[3]: 117 | grad_bias = None 118 | 119 | return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None 120 | -------------------------------------------------------------------------------- /apex/parallel/sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | from torch.nn import functional as F 4 | 5 | from .sync_batchnorm_kernel import SyncBatchnormFunction 6 | from apex.parallel import ReduceOp 7 | 8 | 9 | class SyncBatchNorm(_BatchNorm): 10 | """ 11 | synchronized batch normalization module extented from ``torch.nn.BatchNormNd`` 12 | with the added stats reduction across multiple processes. 13 | :class:`apex.parallel.SyncBatchNorm` is designed to work with 14 | ``DistributedDataParallel``. 15 | 16 | When running in training mode, the layer reduces stats across all processes 17 | to increase the effective batchsize for normalization layer. This is useful 18 | in applications where batch size is small on a given process that would 19 | diminish converged accuracy of the model. The model uses collective 20 | communication package from ``torch.distributed``. 21 | 22 | When running in evaluation mode, the layer falls back to 23 | ``torch.nn.functional.batch_norm``. 24 | 25 | Args: 26 | num_features: :math:`C` from an expected input of size 27 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 28 | eps: a value added to the denominator for numerical stability. 29 | Default: 1e-5 30 | momentum: the value used for the running_mean and running_var 31 | computation. Can be set to ``None`` for cumulative moving average 32 | (i.e. simple average). Default: 0.1 33 | affine: a boolean value that when set to ``True``, this module has 34 | learnable affine parameters. Default: ``True`` 35 | track_running_stats: a boolean value that when set to ``True``, this 36 | module tracks the running mean and variance, and when set to ``False``, 37 | this module does not track such statistics and always uses batch 38 | statistics in both training and eval modes. Default: ``True`` 39 | 40 | Example:: 41 | 42 | >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() 43 | >>> inp = torch.randn(10, 100, 14, 14).cuda() 44 | >>> out = sbn(inp) 45 | >>> inp = torch.randn(3, 100, 20).cuda() 46 | >>> out = sbn(inp) 47 | """ 48 | 49 | warned = False 50 | 51 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): 52 | if channel_last == True: 53 | raise AttributeError("channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.") 54 | 55 | if not SyncBatchNorm.warned: 56 | if hasattr(self, "syncbn_import_error"): 57 | print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error) 58 | else: 59 | print("Warning: using Python fallback for SyncBatchNorm") 60 | SyncBatchNorm.warned = True 61 | 62 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 63 | self.process_group = process_group 64 | 65 | def _specify_process_group(self, process_group): 66 | self.process_group = process_group 67 | 68 | def forward(self, input): 69 | torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var") 70 | mean = None 71 | var = None 72 | cast = None 73 | out = None 74 | 75 | # casting to handle mismatch input type to layer type 76 | if self.running_mean is not None: 77 | if self.running_mean.dtype != input.dtype: 78 | input = input.to(self.running_mean.dtype) 79 | cast = input.dtype 80 | elif self.weight is not None: 81 | if self.weight.dtype != input.dtype: 82 | input = input.to(self.weight.dtype) 83 | cast = input.dtype 84 | 85 | if not self.training and self.track_running_stats: 86 | # fall back to pytorch implementation for inference 87 | torch.cuda.nvtx.range_pop() 88 | out = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) 89 | else: 90 | process_group = self.process_group 91 | world_size = 1 92 | if not self.process_group: 93 | process_group = torch.distributed.group.WORLD 94 | self.num_batches_tracked += 1 95 | with torch.no_grad(): 96 | channel_first_input = input.transpose(0, 1).contiguous() 97 | squashed_input_tensor_view = channel_first_input.view( 98 | channel_first_input.size(0), -1) 99 | # total number of data points for each variance entry. Used to calculate unbiased variance estimate 100 | m = None 101 | local_m = float(squashed_input_tensor_view.size()[1]) 102 | local_mean = torch.mean(squashed_input_tensor_view, 1) 103 | local_sqr_mean = torch.pow( 104 | squashed_input_tensor_view, 2).mean(1) 105 | if torch.distributed.is_initialized(): 106 | world_size = torch.distributed.get_world_size(process_group) 107 | torch.distributed.all_reduce( 108 | local_mean, ReduceOp.SUM, process_group) 109 | mean = local_mean / world_size 110 | torch.distributed.all_reduce( 111 | local_sqr_mean, ReduceOp.SUM, process_group) 112 | sqr_mean = local_sqr_mean / world_size 113 | m = local_m * world_size 114 | else: 115 | m = local_m 116 | mean = local_mean 117 | sqr_mean = local_sqr_mean 118 | # var(x) = E (( x - mean_x ) ** 2) 119 | # = 1 / N * sum ( x - mean_x ) ** 2 120 | # = 1 / N * sum (x**2) - mean_x**2 121 | var = sqr_mean - mean.pow(2) 122 | 123 | if self.running_mean is not None: 124 | self.running_mean = self.momentum * mean + \ 125 | (1 - self.momentum) * self.running_mean 126 | if self.running_var is not None: 127 | # as noted by the paper, we used unbiased variance estimate of the mini-batch 128 | # Var[x] = m / (m-1) * Eb (sample_variance) 129 | self.running_var = m / \ 130 | (m-1) * self.momentum * var + \ 131 | (1 - self.momentum) * self.running_var 132 | torch.cuda.nvtx.range_pop() 133 | out = SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size) 134 | return out.to(cast) 135 | -------------------------------------------------------------------------------- /apex/amp/amp.py: -------------------------------------------------------------------------------- 1 | from . import compat, rnn_compat, utils, wrap 2 | from .handle import AmpHandle, NoOpHandle 3 | from .lists import functional_overrides, torch_overrides, tensor_overrides 4 | from ._amp_state import _amp_state 5 | from .frontend import * 6 | 7 | import functools 8 | import itertools 9 | 10 | import torch 11 | 12 | 13 | _DECORATOR_HANDLE = None 14 | _USER_CAST_REGISTRY = set() 15 | _USER_PROMOTE_REGISTRY = set() 16 | 17 | 18 | def _decorator_helper(orig_fn, cast_fn, wrap_fn): 19 | def wrapper(*args, **kwargs): 20 | handle = _DECORATOR_HANDLE 21 | if handle is None or not handle.is_active(): 22 | return orig_fn(*args, **kwargs) 23 | inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__, 24 | handle.verbose) 25 | return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs) 26 | return wrapper 27 | 28 | 29 | # Decorator form 30 | def half_function(fn): 31 | wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) 32 | return _decorator_helper(fn, utils.maybe_half, wrap_fn) 33 | 34 | 35 | def float_function(fn): 36 | wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False) 37 | return _decorator_helper(fn, utils.maybe_float, wrap_fn) 38 | 39 | 40 | def promote_function(fn): 41 | wrap_fn = functools.partial(wrap.make_promote_wrapper) 42 | return _decorator_helper(fn, utils.maybe_float, wrap_fn) 43 | 44 | 45 | # Registry form 46 | def register_half_function(module, name): 47 | if not hasattr(module, name): 48 | raise ValueError('No function named {} in module {}.'.format( 49 | name, module)) 50 | _USER_CAST_REGISTRY.add((module, name, utils.maybe_half)) 51 | 52 | 53 | def register_float_function(module, name): 54 | if not hasattr(module, name): 55 | raise ValueError('No function named {} in module {}.'.format( 56 | name, module)) 57 | _USER_CAST_REGISTRY.add((module, name, utils.maybe_float)) 58 | 59 | 60 | def register_promote_function(module, name): 61 | if not hasattr(module, name): 62 | raise ValueError('No function named {} in module {}.'.format( 63 | name, module)) 64 | _USER_PROMOTE_REGISTRY.add((module, name)) 65 | 66 | 67 | # Top-level function to insert _all_ the hooks. 68 | def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False): 69 | global _DECORATOR_HANDLE 70 | 71 | if not enabled: 72 | handle = NoOpHandle() 73 | _DECORATOR_HANDLE = handle 74 | return handle 75 | 76 | handle = AmpHandle(loss_scale, enable_caching, verbose) 77 | 78 | # 0) Force-{fp16, fp32} for user-annotated functions 79 | for mod, fn, cast_fn in _USER_CAST_REGISTRY: 80 | try_caching = (cast_fn == utils.maybe_half) 81 | wrap.cached_cast(mod, fn, cast_fn, handle, 82 | try_caching, verbose) 83 | _USER_CAST_REGISTRY.clear() 84 | 85 | # 0.5) Force-promote for user-annotated functions 86 | for mod, fn in _USER_PROMOTE_REGISTRY: 87 | wrap.promote(mod, fn, handle, verbose) 88 | _USER_PROMOTE_REGISTRY.clear() 89 | 90 | # 1) Force-{fp16, fp32} on white- / black-list functions 91 | override_modules = [functional_overrides, 92 | torch_overrides, 93 | tensor_overrides] 94 | cast_table = [('FP16_FUNCS', utils.maybe_half), 95 | ('FP32_FUNCS', utils.maybe_float)] 96 | for module, (list_name, cast_fn) in itertools.product(override_modules, 97 | cast_table): 98 | for fn in getattr(module, list_name): 99 | try_caching = (cast_fn == utils.maybe_half) 100 | wrap.cached_cast(module.MODULE, fn, cast_fn, handle, 101 | try_caching, verbose) 102 | 103 | # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist 104 | # methods on FloatTensor, since they're distinct types. 105 | if compat.tensor_is_float_tensor(): 106 | for fn in tensor_overrides.FP16_FUNCS: 107 | wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half, 108 | handle, try_caching=True, verbose=verbose) 109 | for fn in tensor_overrides.FP32_FUNCS: 110 | wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float, 111 | handle, try_caching=False, verbose=verbose) 112 | 113 | # 2) Enable type-promotion on multi-arg functions and methods. 114 | # NB: special handling for sequence fns (e.g. `torch.cat`). 115 | promote_modules = [torch_overrides, tensor_overrides] 116 | promote_table = [('CASTS', wrap.promote), 117 | ('SEQUENCE_CASTS', wrap.sequence_promote)] 118 | for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules, 119 | promote_table): 120 | for fn in getattr(promote_mod, list_name): 121 | promote_fn(promote_mod.MODULE, fn, handle, verbose) 122 | 123 | # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types 124 | if compat.tensor_is_float_tensor(): 125 | for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor, 126 | torch.cuda.HalfTensor], 127 | promote_table): 128 | for fn in getattr(tensor_overrides, list_name): 129 | promote_fn(cls, fn, handle, verbose) 130 | 131 | # 3) For any in-place version of a blacklist function, error if any input is fp16. 132 | # NB: this is overly conservative. 133 | for fn in utils.as_inplace(torch_overrides.FP32_FUNCS): 134 | wrap.err_if_any_half(torch_overrides.MODULE, fn, handle) 135 | 136 | # 3.5) For any in-place blacklist method, error if called on fp16 tensor 137 | for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS): 138 | wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose) 139 | if compat.tensor_is_float_tensor(): 140 | wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose) 141 | 142 | # 4) For other in-place methods, match the type of self tensor 143 | for fn in utils.as_inplace(itertools.chain( 144 | tensor_overrides.FP16_FUNCS, 145 | tensor_overrides.CASTS)): 146 | wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose) 147 | if compat.tensor_is_float_tensor(): 148 | wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose) 149 | wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose) 150 | 151 | # 5) RNNs + RNN cells are whitelisted specially 152 | if rnn_compat.has_old_rnns(): 153 | wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose) 154 | if not rnn_compat.has_old_rnns(): 155 | # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable. 156 | torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim() 157 | # Wrap all the rnns 158 | for x in rnn_compat.RNN_NAMES: 159 | wrap.new_rnn_cast(x.upper(), handle, verbose) 160 | 161 | # Wrap all the RNN cells 162 | rnn_compat.whitelist_rnn_cells(handle, verbose) 163 | 164 | # 6) Place error+print message on banned functions. 165 | # Or, if allow_banned, then cast to FP32. 166 | for fn, err_msg in functional_overrides.BANNED_FUNCS: 167 | if allow_banned: 168 | wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float, 169 | handle, try_caching=True, verbose=verbose) 170 | else: 171 | wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg) 172 | 173 | _DECORATOR_HANDLE = handle 174 | 175 | _amp_state.handle = handle 176 | 177 | return handle 178 | -------------------------------------------------------------------------------- /apex/amp/utils.py: -------------------------------------------------------------------------------- 1 | from . import compat 2 | 3 | import functools 4 | import itertools 5 | 6 | import torch 7 | 8 | def is_cuda_enabled(): 9 | return torch.version.cuda is not None 10 | 11 | def get_cuda_version(): 12 | return tuple(int(x) for x in torch.version.cuda.split('.')) 13 | 14 | def is_fp_tensor(x): 15 | if is_nested(x): 16 | # Fast-fail version of all(is_fp_tensor) 17 | for y in x: 18 | if not is_fp_tensor(y): 19 | return False 20 | return True 21 | return compat.is_tensor_like(x) and compat.is_floating_point(x) 22 | 23 | def is_nested(x): 24 | return isinstance(x, tuple) or isinstance(x, list) 25 | 26 | def should_cache(x): 27 | if is_nested(x): 28 | # Fast-fail version of all(should_cache) 29 | for y in x: 30 | if not should_cache(y): 31 | return False 32 | return True 33 | return isinstance(x, torch.nn.parameter.Parameter) and \ 34 | type_string(x) == 'FloatTensor' 35 | 36 | def collect_fp_tensor_types(args, kwargs): 37 | def collect_types(x, types): 38 | if is_nested(x): 39 | for y in x: 40 | collect_types(y, types) 41 | else: 42 | types.add(type_string(x)) 43 | 44 | all_args = itertools.chain(args, kwargs.values()) 45 | types = set() 46 | for x in all_args: 47 | if is_fp_tensor(x): 48 | collect_types(x, types) 49 | return types 50 | 51 | def type_string(x): 52 | return x.type().split('.')[-1] 53 | 54 | def maybe_half(x, name='', verbose=False): 55 | if is_nested(x): 56 | return type(x)([maybe_half(y) for y in x]) 57 | 58 | if not x.is_cuda or type_string(x) == 'HalfTensor': 59 | return x 60 | else: 61 | if verbose: 62 | print('Float->Half ({})'.format(name)) 63 | return x.half() 64 | 65 | def maybe_float(x, name='', verbose=False): 66 | if is_nested(x): 67 | return type(x)([maybe_float(y) for y in x]) 68 | 69 | if not x.is_cuda or type_string(x) == 'FloatTensor': 70 | return x 71 | else: 72 | if verbose: 73 | print('Half->Float ({})'.format(name)) 74 | return x.float() 75 | 76 | # NB: returneds casted `args`, mutates `kwargs` in-place 77 | def casted_args(cast_fn, args, kwargs): 78 | new_args = [] 79 | for x in args: 80 | if is_fp_tensor(x): 81 | new_args.append(cast_fn(x)) 82 | else: 83 | new_args.append(x) 84 | for k in kwargs: 85 | val = kwargs[k] 86 | if is_fp_tensor(val): 87 | kwargs[k] = cast_fn(val) 88 | return new_args 89 | 90 | def cached_cast(cast_fn, x, cache): 91 | if is_nested(x): 92 | return type(x)([cached_cast(y) for y in x]) 93 | if x in cache: 94 | cached_x = cache[x] 95 | if x.requires_grad and cached_x.requires_grad: 96 | # Make sure x is actually cached_x's autograd parent. 97 | if cached_x.grad_fn.next_functions[1][0].variable is not x: 98 | raise RuntimeError("x and cache[x] both require grad, but x is not " 99 | "cache[x]'s parent. This is likely an error.") 100 | # During eval, it's possible to end up caching casted weights with 101 | # requires_grad=False. On the next training iter, if cached_x is found 102 | # and reused from the cache, it will not actually have x as its parent. 103 | # Therefore, we choose to invalidate the cache (and force refreshing the cast) 104 | # if x.requires_grad and cached_x.requires_grad do not match. 105 | # 106 | # During eval (i.e. running under with torch.no_grad()) the invalidation 107 | # check would cause the cached value to be dropped every time, because 108 | # cached_x would always be created with requires_grad=False, while x would 109 | # still have requires_grad=True. This would render the cache effectively 110 | # useless during eval. Therefore, if we are running under the no_grad() 111 | # context manager (torch.is_grad_enabled=False) we elide the invalidation 112 | # check, and use the cached value even though its requires_grad flag doesn't 113 | # match. During eval, we don't care that there's no autograd-graph 114 | # connection between x and cached_x. 115 | if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad: 116 | del cache[x] 117 | else: 118 | return cached_x 119 | 120 | casted_x = cast_fn(x) 121 | cache[x] = casted_x 122 | return casted_x 123 | 124 | def verbosify(cast_fn, fn_name, verbose): 125 | if verbose: 126 | return functools.partial(cast_fn, name=fn_name, verbose=verbose) 127 | else: 128 | return cast_fn 129 | 130 | def as_inplace(fns): 131 | for x in fns: 132 | yield x + '_' 133 | 134 | def has_func(mod, fn): 135 | if isinstance(mod, dict): 136 | return fn in mod 137 | else: 138 | return hasattr(mod, fn) 139 | 140 | def get_func(mod, fn): 141 | if isinstance(mod, dict): 142 | return mod[fn] 143 | else: 144 | return getattr(mod, fn) 145 | 146 | def set_func(mod, fn, new_fn): 147 | if isinstance(mod, dict): 148 | mod[fn] = new_fn 149 | else: 150 | setattr(mod, fn, new_fn) 151 | 152 | def set_func_save(handle, mod, fn, new_fn): 153 | cur_fn = get_func(mod, fn) 154 | handle._save_func(mod, fn, cur_fn) 155 | set_func(mod, fn, new_fn) 156 | 157 | # A couple problems get solved here: 158 | # - The flat_weight buffer is disconnected from autograd graph, 159 | # so the fp16 weights need to be derived from the input weights 160 | # to this forward call, not the flat buffer. 161 | # - The ordering of weights in the flat buffer is...idiosyncratic. 162 | # First problem is solved with combination of set_ (to set up 163 | # correct storage) and copy_ (so the fp16 weight derives from the 164 | # fp32 one in autograd. 165 | # Second is solved by doing ptr arithmetic on the fp32 weights 166 | # to derive the correct offset. 167 | # 168 | # TODO: maybe this should actually use 169 | # `torch._cudnn_rnn_flatten_weight`? But then I need to call 170 | # on first iter and cache the right offsets. Ugh. 171 | def synthesize_flattened_rnn_weights(fp32_weights, 172 | fp16_flat_tensor, 173 | rnn_fn='', 174 | verbose=False): 175 | fp16_weights = [] 176 | fp32_base_ptr = fp32_weights[0][0].data_ptr() 177 | for layer_weights in fp32_weights: 178 | fp16_layer_weights = [] 179 | for w_fp32 in layer_weights: 180 | w_fp16 = w_fp32.new().half() 181 | offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() 182 | w_fp16.set_(fp16_flat_tensor.storage(), 183 | offset, 184 | w_fp32.shape) 185 | w_fp16.copy_(w_fp32) 186 | if verbose: 187 | print('Float->Half ({})'.format(rnn_fn)) 188 | fp16_layer_weights.append(w_fp16) 189 | fp16_weights.append(fp16_layer_weights) 190 | return fp16_weights 191 | 192 | # Roughly same as above, just the `fp32_weights` aren't nested. 193 | # Code kept separate for readability. 194 | def new_synthesize_flattened_rnn_weights(fp32_weights, 195 | fp16_flat_tensor, 196 | rnn_fn='', 197 | verbose=False): 198 | fp16_weights = [] 199 | fp32_base_ptr = fp32_weights[0].data_ptr() 200 | for w_fp32 in fp32_weights: 201 | w_fp16 = w_fp32.new().half() 202 | offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() 203 | w_fp16.set_(fp16_flat_tensor.storage(), 204 | offset, 205 | w_fp32.shape) 206 | w_fp16.copy_(w_fp32) 207 | if verbose: 208 | print('Float->Half ({})'.format(rnn_fn)) 209 | fp16_weights.append(w_fp16) 210 | return fp16_weights 211 | -------------------------------------------------------------------------------- /apex/amp/scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..multi_tensor_apply import multi_tensor_applier 3 | from ._amp_state import _amp_state, master_params, maybe_print 4 | from itertools import product 5 | 6 | def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False): 7 | # Exception handling for 18.04 compatibility 8 | if check_overflow: 9 | cpu_sum = float(model_grad.float().sum()) 10 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 11 | return True 12 | 13 | if master_grad is not model_grad: # copy_ probably internally short-circuits this 14 | master_grad.copy_(model_grad) 15 | if scale != 1.0: 16 | master_grad.mul_(scale) 17 | return False 18 | 19 | def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False): 20 | # Exception handling for 18.04 compatibility 21 | if check_overflow: 22 | cpu_sum = float(model_grad.float().sum()) 23 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 24 | return True 25 | 26 | # if master_grad is not model_grad: # copy_ probably internally short-circuits this 27 | # master_grad.copy_(model_grad) 28 | assert stashed_grad.dtype == master_grad.dtype 29 | converted_model_grad = model_grad.data.to(master_grad.dtype) 30 | master_grad.data = a*converted_model_grad.data + b*stashed_grad.data 31 | return False 32 | 33 | class LossScaler(object): 34 | warned_no_fused_kernel = False 35 | warned_unscaling_non_fp32_grad = False 36 | has_fused_kernel = False 37 | 38 | def __init__(self, 39 | loss_scale, 40 | init_scale=2.**16, 41 | scale_factor=2., 42 | scale_window=2000, 43 | min_loss_scale=None, 44 | max_loss_scale=2.**24): 45 | if loss_scale == "dynamic": 46 | self.dynamic = True 47 | self._loss_scale = min(max_loss_scale, init_scale) 48 | else: 49 | self.dynamic = False 50 | self._loss_scale = loss_scale 51 | self._max_loss_scale = max_loss_scale 52 | self._min_loss_scale = min_loss_scale 53 | self._scale_seq_len = scale_window 54 | self._unskipped = 0 55 | self._has_overflow = False 56 | self._overflow_buf = torch.cuda.IntTensor([0]) 57 | if multi_tensor_applier.available: 58 | import amp_C 59 | LossScaler.has_fused_kernel = multi_tensor_applier.available 60 | LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale 61 | LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby 62 | else: 63 | if not LossScaler.warned_no_fused_kernel: 64 | maybe_print( 65 | "Warning: multi_tensor_applier fused unscale kernel is unavailable, " 66 | "possibly because apex was installed without --cuda_ext --cpp_ext. " 67 | "Using Python fallback. Original ImportError was: " + 68 | repr(multi_tensor_applier.import_err), 69 | True) 70 | LossScaler.has_fused_kernel = False 71 | LossScaler.warned_no_fused_kernel = True 72 | 73 | def loss_scale(self): 74 | return self._loss_scale 75 | 76 | def unscale_python(self, model_grads, master_grads, scale): 77 | for model, master in zip(model_grads, master_grads): 78 | if model is not None: 79 | if not LossScaler.warned_unscaling_non_fp32_grad: 80 | if master.dtype != torch.float32: 81 | maybe_print( 82 | "Attempting to unscale a grad with type {} ".format(master.type()) + 83 | "Unscaling non-fp32 grads may indicate an error. " 84 | "When using Amp, you don't need to call .half() on your model.") 85 | LossScaler.warned_unscaling_non_fp32_grad = True 86 | self._has_overflow = scale_check_overflow_python(model, 87 | master, 88 | 1./scale, 89 | self.dynamic) 90 | if self._has_overflow and self.dynamic: 91 | break 92 | 93 | # unused_scale keeps some of the old API alive for hopefully a short time. 94 | def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None): 95 | if self._has_overflow: 96 | return 97 | 98 | scale = self._loss_scale 99 | if scale_override is not None: 100 | scale = scale_override 101 | 102 | if scale == 1.0 and models_are_masters and not self.dynamic: 103 | return 104 | 105 | if LossScaler.has_fused_kernel: 106 | # if (not LossScaler.warned_unscaling_non_fp32_grad 107 | # and master_grads[0].dtype == torch.float16): 108 | # print("Warning: unscaling grads that are not FP32. " 109 | # "Unscaling non-fp32 grads may indicate an error. " 110 | # "When using Amp, you don't need to call .half() on your model.") 111 | # # Setting this to True unconditionally allows the possibility of an escape 112 | # # if never-before-seen non-fp32 grads are created in some later iteration. 113 | # LossScaler.warned_unscaling_non_fp32_grad = True 114 | multi_tensor_applier(LossScaler.multi_tensor_scale_cuda, 115 | self._overflow_buf, 116 | [model_grads, master_grads], 117 | 1./scale) 118 | else: 119 | self.unscale_python(model_grads, master_grads, scale) 120 | 121 | # Defer to update_scale 122 | # If the fused kernel is available, we only need one D2H memcopy and sync. 123 | # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: 124 | # self._has_overflow = self._overflow_buf.item() 125 | 126 | def unscale_with_stashed_python(self, 127 | model_grads, 128 | stashed_master_grads, 129 | master_grads, 130 | a, 131 | b): 132 | for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads): 133 | if model is None and stashed is None: 134 | continue 135 | else: 136 | if not LossScaler.warned_unscaling_non_fp32_grad: 137 | if master.dtype != torch.float32: 138 | maybe_print( 139 | "Attempting to unscale a grad with type {} ".format(master.type()) + 140 | "Unscaling non-fp32 grads may indicate an error. " 141 | "When using Amp, you don't need to call .half() on your model.") 142 | LossScaler.warned_unscaling_non_fp32_grad = True 143 | self._has_overflow = axpby_check_overflow_python(model, 144 | stashed, 145 | master, 146 | a, 147 | b, 148 | self.dynamic) 149 | if self._has_overflow and self.dynamic: 150 | break 151 | 152 | def unscale_with_stashed(self, 153 | model_grads, 154 | stashed_master_grads, 155 | master_grads, 156 | scale_override=None): 157 | if self._has_overflow: 158 | return 159 | 160 | grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0 161 | if scale_override is not None: 162 | grads_have_scale, stashed_have_scale, out_scale = scale_override 163 | 164 | if LossScaler.has_fused_kernel: 165 | if (not LossScaler.warned_unscaling_non_fp32_grad 166 | and master_grads[0].dtype == torch.float16): 167 | print("Warning: unscaling grads that are not FP32. " 168 | "Unscaling non-fp32 grads may indicate an error. " 169 | "When using Amp, you don't need to call .half() on your model.") 170 | # Setting this to True unconditionally allows the possibility of an escape 171 | # if never-before-seen non-fp32 grads are created in some later iteration. 172 | LossScaler.warned_unscaling_non_fp32_grad = True 173 | multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda, 174 | self._overflow_buf, 175 | [model_grads, stashed_master_grads, master_grads], 176 | out_scale/grads_have_scale, # 1./scale, 177 | out_scale/stashed_have_scale, # 1.0, 178 | 0) # check only arg 0, aka the incoming model grads, for infs 179 | else: 180 | self.unscale_with_stashed_python(model_grads, 181 | stashed_master_grads, 182 | master_grads, 183 | out_scale/grads_have_scale, 184 | out_scale/stashed_have_scale) 185 | 186 | # Defer to update_scale 187 | # If the fused kernel is available, we only need one D2H memcopy and sync. 188 | # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: 189 | # self._has_overflow = self._overflow_buf.item() 190 | 191 | def clear_overflow_state(self): 192 | self._has_overflow = False 193 | if self.has_fused_kernel: 194 | self._overflow_buf.zero_() 195 | 196 | # Separate so unscale() can be called more that once before updating. 197 | def update_scale(self): 198 | # If the fused kernel is available, we only need one D2H memcopy and sync. 199 | if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: 200 | self._has_overflow = self._overflow_buf.item() 201 | 202 | if self._has_overflow and self.dynamic: 203 | should_skip = True 204 | if(self._min_loss_scale): 205 | self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.) 206 | else: 207 | self._loss_scale = self._loss_scale/2. 208 | self._unskipped = 0 209 | else: 210 | should_skip = False 211 | self._unskipped += 1 212 | 213 | if self._unskipped == self._scale_seq_len and self.dynamic: 214 | self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.) 215 | self._unskipped = 0 216 | 217 | return should_skip 218 | -------------------------------------------------------------------------------- /utils/autoaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 3 | """ 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | 9 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy'] 10 | 11 | 12 | class AutoAugImageNetPolicy(object): 13 | def __init__(self, fillcolor=(128, 128, 128)): 14 | self.policies = [ 15 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 16 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 17 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 18 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 19 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 20 | 21 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 22 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 23 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 24 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 25 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 26 | 27 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 28 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 29 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 30 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 31 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 34 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 35 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 36 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 37 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 38 | 39 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 40 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 41 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 42 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 43 | ] 44 | 45 | def __call__(self, img): 46 | policy_idx = random.randint(0, len(self.policies) - 1) 47 | return self.policies[policy_idx](img) 48 | 49 | def __repr__(self): 50 | return "AutoAugment ImageNet Policy" 51 | 52 | 53 | class AutoAugCIFAR10Policy(object): 54 | def __init__(self, fillcolor=(128, 128, 128)): 55 | self.policies = [ 56 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 57 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 58 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 59 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 60 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 61 | 62 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 63 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 64 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 65 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 66 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 67 | 68 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 69 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 70 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 71 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 72 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 73 | 74 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 75 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 76 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 77 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 78 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 79 | 80 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 81 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 82 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 83 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 84 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 85 | ] 86 | 87 | def __call__(self, img): 88 | policy_idx = random.randint(0, len(self.policies) - 1) 89 | return self.policies[policy_idx](img) 90 | 91 | def __repr__(self): 92 | return "AutoAugment CIFAR10 Policy" 93 | 94 | 95 | class AutoAugSVHNPolicy(object): 96 | def __init__(self, fillcolor=(128, 128, 128)): 97 | self.policies = [ 98 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 99 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 100 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 101 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 102 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 103 | 104 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 105 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 106 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 107 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 108 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 109 | 110 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 111 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 112 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 113 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 114 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 115 | 116 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 117 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 118 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 119 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 120 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 121 | 122 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 123 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 124 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 125 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 126 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 127 | ] 128 | 129 | def __call__(self, img): 130 | policy_idx = random.randint(0, len(self.policies) - 1) 131 | return self.policies[policy_idx](img) 132 | 133 | def __repr__(self): 134 | return "AutoAugment SVHN Policy" 135 | 136 | 137 | class SubPolicy(object): 138 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 139 | ranges = { 140 | "shearX": np.linspace(0, 0.3, 10), 141 | "shearY": np.linspace(0, 0.3, 10), 142 | "translateX": np.linspace(0, 150 / 331, 10), 143 | "translateY": np.linspace(0, 150 / 331, 10), 144 | "rotate": np.linspace(0, 30, 10), 145 | "color": np.linspace(0.0, 0.9, 10), 146 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 147 | "solarize": np.linspace(256, 0, 10), 148 | "contrast": np.linspace(0.0, 0.9, 10), 149 | "sharpness": np.linspace(0.0, 0.9, 10), 150 | "brightness": np.linspace(0.0, 0.9, 10), 151 | "autocontrast": [0] * 10, 152 | "equalize": [0] * 10, 153 | "invert": [0] * 10 154 | } 155 | 156 | def rotate_with_fill(img, magnitude): 157 | rot = img.convert("RGBA").rotate(magnitude) 158 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 159 | 160 | func = { 161 | "shearX": lambda img, magnitude: img.transform( 162 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 163 | Image.BICUBIC, fillcolor=fillcolor), 164 | "shearY": lambda img, magnitude: img.transform( 165 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 166 | Image.BICUBIC, fillcolor=fillcolor), 167 | "translateX": lambda img, magnitude: img.transform( 168 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 169 | fillcolor=fillcolor), 170 | "translateY": lambda img, magnitude: img.transform( 171 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 172 | fillcolor=fillcolor), 173 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 174 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 175 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 176 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 177 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 178 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 179 | 1 + magnitude * random.choice([-1, 1])), 180 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 181 | 1 + magnitude * random.choice([-1, 1])), 182 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 183 | 1 + magnitude * random.choice([-1, 1])), 184 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 185 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 186 | "invert": lambda img, magnitude: ImageOps.invert(img) 187 | } 188 | 189 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 190 | # operation1, ranges[operation1][magnitude_idx1], 191 | # operation2, ranges[operation2][magnitude_idx2]) 192 | self.p1 = p1 193 | self.operation1 = func[operation1] 194 | self.magnitude1 = ranges[operation1][magnitude_idx1] 195 | self.p2 = p2 196 | self.operation2 = func[operation2] 197 | self.magnitude2 = ranges[operation2][magnitude_idx2] 198 | 199 | def __call__(self, img): 200 | if random.random() < self.p1: 201 | img = self.operation1(img, self.magnitude1) 202 | if random.random() < self.p2: 203 | img = self.operation2(img, self.magnitude2) 204 | return img -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from PIL import Image 3 | import os 4 | 5 | import torch 6 | 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler, BatchSampler 9 | 10 | from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, IP102 11 | from .autoaugment import AutoAugImageNetPolicy 12 | import numpy as np 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # class BalancedBatchSampler(BatchSampler): 18 | # def __init__(self, dataset, n_classes, n_samples): 19 | # # dataset这里指train_datasets, 20 | # # labels:即之前求得的tensor([ 1, 1, 1, ..., 200, 200, 200]) 21 | # 22 | # # 转化为list形式 23 | # self.labels_set = dataset.train_label 24 | # 25 | # self.labels = torch.tensor(dataset.train_label) 26 | # 27 | # # 这个操作大致是建立标签和序列对应的字典 28 | # self.label_to_indices = {label: np.where(self.labels == label)[0] for label in self.labels_set} 29 | # # print(self.labels_set) 30 | # # print(self.labels) 31 | # 32 | # for l in self.labels_set: 33 | # np.random.shuffle(self.label_to_indices[l]) 34 | # 35 | # self.used_label_indices_count = {label: 0 for label in self.labels_set} 36 | # 37 | # self.count = 0 38 | # self.n_classes = n_classes 39 | # self.n_samples = n_samples 40 | # self.dataset = dataset 41 | # 42 | # self.batch_size = self.n_samples * self.n_classes 43 | # 44 | # def __iter__(self): 45 | # self.count = 0 46 | # while self.count + self.batch_size < len(self.dataset): 47 | # # np.random.choice(a, size=None, replace=True, p=None) 48 | # # 从a(只要是ndarray都可以,但必须是一维)中随机抽取数字,并组成指定大小(size)的数组 49 | # # replace:True表示可以取相同数字,False表示不可以取相同数字 50 | # # 数组p:与数组a相对,表示取数组a中每个元素的概率 51 | # classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 52 | # indices = [] 53 | # for class_ in classes: 54 | # indices.extend(self.label_to_indices[class_][ 55 | # self.used_label_indices_count[class_]:self.used_label_indices_count[ 56 | # class_] + self.n_samples]) 57 | # self.used_label_indices_count[class_] += self.n_samples 58 | # if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): 59 | # np.random.shuffle(self.label_to_indices[class_]) 60 | # self.used_label_indices_count[class_] = 0 61 | # yield indices 62 | # self.count += self.n_classes * self.n_samples 63 | # 64 | # def __len__(self): 65 | # return len(self.dataset) // self.batch_size 66 | 67 | 68 | def get_loader(args): 69 | if args.local_rank not in [-1, 0]: 70 | torch.distributed.barrier() 71 | 72 | if args.dataset == 'CUB_200_2011': 73 | train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 74 | transforms.RandomCrop((448, 448)), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 78 | test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 79 | transforms.CenterCrop((448, 448)), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 82 | trainset = CUB(root=args.data_root, is_train=True, transform=train_transform) 83 | testset = CUB(root=args.data_root, is_train=False, transform=test_transform) 84 | elif args.dataset == 'car': 85 | trainset = CarsDataset(os.path.join(args.data_root, 'devkit/cars_train_annos.mat'), 86 | os.path.join(args.data_root, 'cars_train'), 87 | os.path.join(args.data_root, 'devkit/cars_meta.mat'), 88 | # cleaned=os.path.join(data_dir,'cleaned.dat'), 89 | transform=transforms.Compose([ 90 | transforms.Resize((600, 600), Image.BILINEAR), 91 | transforms.RandomCrop((448, 448)), 92 | transforms.RandomHorizontalFlip(), 93 | AutoAugImageNetPolicy(), 94 | transforms.ToTensor(), 95 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 96 | ) 97 | testset = CarsDataset(os.path.join(args.data_root, 'cars_test_annos_withlabels.mat'), 98 | os.path.join(args.data_root, 'cars_test'), 99 | os.path.join(args.data_root, 'devkit/cars_meta.mat'), 100 | # cleaned=os.path.join(data_dir,'cleaned_test.dat'), 101 | transform=transforms.Compose([ 102 | transforms.Resize((600, 600), Image.BILINEAR), 103 | transforms.CenterCrop((448, 448)), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 106 | ) 107 | elif args.dataset == 'dog': 108 | train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 109 | transforms.RandomCrop((448, 448)), 110 | transforms.RandomHorizontalFlip(), 111 | transforms.ToTensor(), 112 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 113 | test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 114 | transforms.CenterCrop((448, 448)), 115 | transforms.ToTensor(), 116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 117 | trainset = dogs(root=args.data_root, 118 | train=True, 119 | cropped=False, 120 | transform=train_transform, 121 | download=False 122 | ) 123 | testset = dogs(root=args.data_root, 124 | train=False, 125 | cropped=False, 126 | transform=test_transform, 127 | download=False 128 | ) 129 | elif args.dataset == 'nabirds': 130 | train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 131 | transforms.RandomCrop((448, 448)), 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 135 | test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 136 | transforms.CenterCrop((448, 448)), 137 | transforms.ToTensor(), 138 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 139 | trainset = NABirds(root=args.data_root, train=True, transform=train_transform) 140 | testset = NABirds(root=args.data_root, train=False, transform=test_transform) 141 | elif args.dataset == 'INat2017': 142 | train_transform = transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), 143 | transforms.RandomCrop((304, 304)), 144 | transforms.RandomHorizontalFlip(), 145 | AutoAugImageNetPolicy(), 146 | transforms.ToTensor(), 147 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 148 | test_transform = transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), 149 | transforms.CenterCrop((304, 304)), 150 | transforms.ToTensor(), 151 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 152 | trainset = INat2017(args.data_root, 'train', train_transform) 153 | testset = INat2017(args.data_root, 'val', test_transform) 154 | elif args.dataset == 'ip102': 155 | train_transform = transforms.Compose([transforms.Resize((256, 256), Image.BILINEAR), 156 | transforms.RandomCrop((224, 224)), 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 160 | test_transform = transforms.Compose([transforms.Resize((256, 256), Image.BILINEAR), 161 | transforms.CenterCrop((224, 224)), 162 | transforms.ToTensor(), 163 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 164 | trainset = IP102(root=args.data_root, is_train=True, transform=train_transform) 165 | testset = IP102(root=args.data_root, is_train=False, transform=test_transform) 166 | 167 | if args.local_rank == 0: 168 | torch.distributed.barrier() 169 | 170 | train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) 171 | test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset) 172 | train_loader = DataLoader(trainset, 173 | sampler=train_sampler, 174 | batch_size=args.train_batch_size, 175 | num_workers=8, 176 | drop_last=True, 177 | pin_memory=True, 178 | ) 179 | test_loader = DataLoader(testset, 180 | sampler=test_sampler, 181 | batch_size=args.eval_batch_size, 182 | num_workers=8, 183 | pin_memory=True) if testset is not None else None 184 | 185 | return train_loader, test_loader 186 | 187 | -------------------------------------------------------------------------------- /apex/amp/wrap.py: -------------------------------------------------------------------------------- 1 | from . import compat 2 | from . import utils 3 | from ._amp_state import _amp_state 4 | from . import rnn_compat 5 | 6 | import functools 7 | 8 | import torch 9 | 10 | def make_cast_wrapper(orig_fn, cast_fn, handle, 11 | try_caching=False): 12 | @functools.wraps(orig_fn) 13 | def wrapper(*args, **kwargs): 14 | if not handle.is_active(): 15 | return orig_fn(*args, **kwargs) 16 | 17 | if try_caching and handle.has_cache: 18 | args = list(args) 19 | for i in range(len(args)): 20 | if utils.should_cache(args[i]): 21 | args[i] = utils.cached_cast(cast_fn, args[i], handle.cache) 22 | for k in kwargs: 23 | if utils.should_cache(kwargs[k]): 24 | kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache) 25 | new_args = utils.casted_args(cast_fn, 26 | args, 27 | kwargs) 28 | return orig_fn(*new_args, **kwargs) 29 | return wrapper 30 | 31 | def cached_cast(mod, fn, cast_fn, handle, 32 | try_caching=False, verbose=False): 33 | if not utils.has_func(mod, fn): 34 | return 35 | 36 | orig_fn = utils.get_func(mod, fn) 37 | cast_fn = utils.verbosify(cast_fn, fn, verbose) 38 | wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching) 39 | utils.set_func_save(handle, mod, fn, wrapper) 40 | 41 | # `handle` arg is unused, but simplifies API to make `make_cast_wrapper` 42 | # Annoyingly, make_promote_wrapper still uses the global handle. Once everyone 43 | # is on the new API and I am free to get rid of handle, I can clean this up. 44 | def make_promote_wrapper(orig_fn, cast_fn, handle=None): 45 | @functools.wraps(orig_fn) 46 | def wrapper(*args, **kwargs): 47 | if not _amp_state.handle.is_active(): 48 | return orig_fn(*args, **kwargs) 49 | 50 | types = utils.collect_fp_tensor_types(args, kwargs) 51 | 52 | if len(types) <= 1: 53 | return orig_fn(*args, **kwargs) 54 | elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): 55 | new_args = utils.casted_args(cast_fn, 56 | args, 57 | kwargs) 58 | return orig_fn(*new_args, **kwargs) 59 | else: 60 | raise NotImplementedError('Do not know how to handle ' + 61 | 'these types to promote: {}' 62 | .format(types)) 63 | return wrapper 64 | 65 | def promote(mod, fn, handle, verbose=False): 66 | orig_fn = utils.get_func(mod, fn) 67 | maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) 68 | wrapper = make_promote_wrapper(orig_fn, maybe_float) 69 | utils.set_func_save(handle, mod, fn, wrapper) 70 | 71 | def sequence_promote(mod, fn, handle, verbose=False): 72 | orig_fn = utils.get_func(mod, fn) 73 | maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) 74 | @functools.wraps(orig_fn) 75 | def wrapper(seq, *args, **kwargs): 76 | if not _amp_state.handle.is_active(): 77 | return orig_fn(seq, *args, **kwargs) 78 | 79 | types = set([utils.type_string(x) for x in seq]) 80 | if len(types) <= 1: 81 | return orig_fn(seq, *args, **kwargs) 82 | elif types == set(['HalfTensor', 'FloatTensor']): 83 | cast_seq = utils.casted_args(maybe_float, 84 | seq, {}) 85 | return orig_fn(cast_seq, *args, **kwargs) 86 | else: 87 | # TODO: other mixed-type cases aren't due to amp. 88 | # Just pass through? 89 | return orig_fn(seq, *args, **kwargs) 90 | utils.set_func_save(handle, mod, fn, wrapper) 91 | 92 | def promote_match_arg0(mod, fn, handle, verbose=False): 93 | if not utils.has_func(mod, fn): 94 | return 95 | 96 | orig_fn = utils.get_func(mod, fn) 97 | @functools.wraps(orig_fn) 98 | def wrapper(arg0, *args, **kwargs): 99 | assert compat.is_tensor_like(arg0) 100 | if not _amp_state.handle.is_active(): 101 | return orig_fn(arg0, *args, **kwargs) 102 | 103 | if utils.type_string(arg0) == 'HalfTensor': 104 | cast_fn = utils.maybe_half 105 | elif utils.type_string(arg0) == 'FloatTensor': 106 | cast_fn = utils.maybe_float 107 | else: 108 | return orig_fn(arg0, *args, **kwargs) 109 | cast_fn = utils.verbosify(cast_fn, fn, verbose) 110 | new_args = utils.casted_args(cast_fn, args, kwargs) 111 | return orig_fn(arg0, *new_args, **kwargs) 112 | utils.set_func_save(handle, mod, fn, wrapper) 113 | 114 | def err_if_any_half(mod, fn, handle, custom_err_msg=None): 115 | if not utils.has_func(mod, fn): 116 | return 117 | 118 | orig_fn = utils.get_func(mod, fn) 119 | @functools.wraps(orig_fn) 120 | def wrapper(*args, **kwargs): 121 | types = utils.collect_fp_tensor_types(args, kwargs) 122 | if 'HalfTensor' in types: 123 | if custom_err_msg: 124 | raise NotImplementedError(custom_err_msg) 125 | else: 126 | raise NotImplementedError('Cannot call in-place function ' + 127 | '{} with fp16 arguments.'.format(fn)) 128 | else: 129 | return orig_fn(*args, **kwargs) 130 | utils.set_func_save(handle, mod, fn, wrapper) 131 | 132 | def err_if_arg0_half(mod, fn, handle, verbose=False): 133 | if not utils.has_func(mod, fn): 134 | return 135 | 136 | orig_fn = utils.get_func(mod, fn) 137 | @functools.wraps(orig_fn) 138 | def wrapper(arg0, *args, **kwargs): 139 | assert compat.is_tensor_like(arg0) 140 | if utils.type_string(arg0) == 'HalfTensor': 141 | raise NotImplementedError('Cannot call in-place method ' + 142 | '{} on fp16 Tensors.'.format(fn)) 143 | else: 144 | cast_fn = utils.verbosify(utils.maybe_float, fn, verbose) 145 | new_args = utils.casted_args(cast_fn, args, kwargs) 146 | return orig_fn(arg0, *new_args, **kwargs) 147 | utils.set_func_save(handle, mod, fn, wrapper) 148 | 149 | # Current RNN approach: 150 | # - Wrap top-level `RNN` function in thnn backend 151 | # - Will call into either CudnnRNN or AutogradRNN 152 | # - Each of these are factory functions that return a per-iter 153 | # `forward` function 154 | # - We interpose on the factory function to: 155 | # 1) Interpose on the actual forward function and put in casts 156 | # 2) Insert an fp16 `flat_weight` if necessary 157 | def rnn_cast(backend, fn, handle, verbose=False): 158 | orig_rnn = utils.get_func(backend, fn) 159 | @functools.wraps(orig_rnn) 160 | def rnn_wrapper(*args, **kwargs): 161 | flat_weight = kwargs.get('flat_weight') 162 | if flat_weight is not None: 163 | # We replace `flat_weight` with an uninitialized fp16 164 | # Tensor. The "actual" weight tensors (provided in `forward`), 165 | # will then be set up as ptrs into the buffer and have the 166 | # corresponding fp32 values copied in. 167 | # We need to call `copy` on the "actual" weights so that the 168 | # autograd graph correctly backprops from the wgrads computed 169 | # inside cuDNN (on fp16 weights) into the fp32 weights. 170 | assert utils.type_string(flat_weight) == 'FloatTensor' 171 | if compat.tensor_is_float_tensor() or compat.tensor_is_variable(): 172 | # Pre-0.4. A little slower, since it zeros out memory. 173 | flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape) 174 | else: 175 | flat_weight_fp16 = torch.empty_like(flat_weight, 176 | dtype=torch.float16) 177 | kwargs['flat_weight'] = flat_weight_fp16 178 | else: 179 | flat_weight_fp16 = None 180 | 181 | forward = orig_rnn(*args, **kwargs) 182 | @functools.wraps(forward) 183 | def fwd_wrapper(*fargs, **fkwargs): 184 | assert len(fargs) == 3 or len(fargs) == 4 185 | inputs, weights, hiddens = fargs[:3] 186 | assert utils.is_fp_tensor(inputs) 187 | assert isinstance(weights, list) 188 | cast_fn = utils.verbosify(utils.maybe_half, 189 | fn, 190 | verbose) 191 | new_args = [] 192 | 193 | # 0) Inputs 194 | new_args.append(cast_fn(inputs)) 195 | 196 | # 1) Weights 197 | if flat_weight_fp16 is not None: 198 | fp16_weights = utils.synthesize_flattened_rnn_weights( 199 | weights, flat_weight_fp16, fn, verbose) 200 | else: 201 | fp16_weights = [[cast_fn(w) for w in layer] 202 | for layer in weights] 203 | new_args.append(fp16_weights) 204 | 205 | # 2) Inputs: either a tuple (for LSTM) or single tensor 206 | if isinstance(hiddens, tuple): 207 | new_args.append(tuple(cast_fn(x) for x in hiddens)) 208 | elif utils.is_fp_tensor(hiddens): 209 | new_args.append(cast_fn(hiddens)) 210 | else: 211 | # Hiddens can, in principle, be `None` -- pass through 212 | new_args.append(hiddens) 213 | 214 | # 3) Batch sizes (0.4 or later only) 215 | if len(fargs) == 4: 216 | new_args.append(fargs[3]) 217 | 218 | return forward(*new_args, **fkwargs) 219 | return fwd_wrapper 220 | utils.set_func_save(handle, backend, fn, rnn_wrapper) 221 | 222 | def new_rnn_cast(fn, handle, verbose=False): 223 | # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744 224 | # For rnn backend calls that route through _rnn_impls, we must patch the ref 225 | # that _rnn_impls stashed. For rnn backend calls that directly invoke 226 | # _VF., e.g. _VF.lstm, we can patch onto VariableFunctionsShim, 227 | # which in turn has patched the ref named "_VF" in torch.nn.modules.rnn. 228 | if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn): 229 | mod = torch.nn.modules.rnn._rnn_impls 230 | else: 231 | mod = torch.nn.modules.rnn._VF 232 | assert isinstance(mod, rnn_compat.VariableFunctionsShim) 233 | fn = fn.lower() 234 | orig_fn = utils.get_func(mod, fn) 235 | cast_fn = utils.verbosify(utils.maybe_half, fn, verbose) 236 | @functools.wraps(orig_fn) 237 | def wrapper(*args, **kwargs): 238 | # Exact call signature from modules/rnn.py 239 | assert len(args) == 9 240 | assert len(kwargs) == 0 241 | 242 | if not _amp_state.handle.is_active(): 243 | return orig_fn(*args, **kwargs) 244 | 245 | if isinstance(args[6], bool): 246 | params_idx = 2 # Not PackedSequence case 247 | else: 248 | params_idx = 3 # PackedSequence case 249 | 250 | new_args = [] 251 | for i, arg in enumerate(args): 252 | if i == params_idx: 253 | num_params = sum([x.numel() for x in arg]) 254 | fp16_weight_buf = args[0].new_empty((num_params,), 255 | dtype=torch.half) 256 | casted_weights = utils.new_synthesize_flattened_rnn_weights( 257 | arg, fp16_weight_buf, fn, verbose) 258 | new_args.append(casted_weights) 259 | elif utils.is_fp_tensor(arg): 260 | new_args.append(cast_fn(arg)) 261 | else: 262 | new_args.append(arg) 263 | 264 | return orig_fn(*new_args) 265 | utils.set_func_save(handle, mod, fn, wrapper) 266 | 267 | def disable_casts(mod, fn, handle): 268 | if not utils.has_func(mod, fn): 269 | return 270 | 271 | orig_fn = utils.get_func(mod, fn) 272 | @functools.wraps(orig_fn) 273 | def wrapper(*args, **kwargs): 274 | with handle._disable_casts(): 275 | return orig_fn(*args, **kwargs) 276 | utils.set_func_save(handle, mod, fn, wrapper) 277 | -------------------------------------------------------------------------------- /apex/amp/_initialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import string_classes 3 | import functools 4 | import numpy as np 5 | import sys 6 | from types import MethodType 7 | import warnings 8 | from ._amp_state import _amp_state, warn_or_err, container_abcs 9 | from .handle import disable_casts 10 | from .scaler import LossScaler 11 | from ._process_optimizer import _process_optimizer 12 | from apex.fp16_utils import convert_network 13 | from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general 14 | from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused 15 | 16 | if torch.distributed.is_available(): 17 | from ..parallel import DistributedDataParallel as apex_DDP 18 | from ..parallel.LARC import LARC 19 | 20 | 21 | def to_type(dtype, t): 22 | if isinstance(t, torch.Tensor): 23 | if not t.is_cuda: 24 | # This should not be a hard error, since it may be legitimate. 25 | warnings.warn("An input tensor was not cuda.") 26 | # GANs require this. 27 | # if t.requires_grad: 28 | # warn_or_err("input data requires grad. Since input data is not a model parameter,\n" 29 | # "its gradients will not be properly allreduced by DDP.") 30 | if t.is_floating_point(): 31 | return t.to(dtype) 32 | return t 33 | else: 34 | # Trust the user's custom batch type, that's all I can do here. 35 | return t.to(dtype) 36 | 37 | 38 | # Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py. 39 | def applier(value, fn): 40 | if isinstance(value, torch.Tensor): 41 | return fn(value) 42 | elif isinstance(value, string_classes): 43 | return value 44 | elif isinstance(value, np.ndarray): 45 | return value 46 | elif hasattr(value, "to"): # Allow handling of custom batch classes 47 | return fn(value) 48 | elif isinstance(value, container_abcs.Mapping): 49 | return {applier(k, fn) : applier(v, fn) for k, v in value.items()} 50 | elif isinstance(value, container_abcs.Iterable): 51 | return type(value)(applier(v, fn) for v in value) 52 | else: 53 | # Do I want this to fire off even if someone chooses to pass something ordinary like 54 | # an int or float? May be more annoying than it's worth. 55 | # print("Warning: unrecognized type in applier. If your input data is a custom class, " 56 | # "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. " 57 | # "Amp will check for your custom to() and invoke it to cast the batch's " 58 | # "floating-point Tensors to the appropriate type. " 59 | # "Also, if your data is a custom class, it is your responsibility to ensure that " 60 | # "any Tensors you want to be cuda are already cuda." 61 | return value 62 | 63 | 64 | def check_models(models): 65 | for model in models: 66 | parallel_type = None 67 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 68 | parallel_type = "torch.nn.parallel.DistributedDataParallel" 69 | if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP): 70 | parallel_type = "apex.parallel.DistributedDataParallel" 71 | if isinstance(model, torch.nn.parallel.DataParallel): 72 | parallel_type = "torch.nn.parallel.DataParallel" 73 | if parallel_type is not None: 74 | raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) + 75 | "Parallel wrappers should only be applied to the model(s) AFTER \n" 76 | "the model(s) have been returned from amp.initialize.") 77 | 78 | 79 | def check_params_fp32(models): 80 | for model in models: 81 | for name, param in model.named_parameters(): 82 | if param.is_floating_point(): 83 | if 'Half' in param.type(): 84 | warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" 85 | "When using amp.initialize, you do not need to call .half() on your model\n" 86 | "before passing it, no matter what optimization level you choose.".format( 87 | name, param.type())) 88 | elif not param.is_cuda: 89 | warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" 90 | "When using amp.initialize, you need to provide a model with parameters\n" 91 | "located on a CUDA device before passing it no matter what optimization level\n" 92 | "you chose. Use model.to('cuda') to use the default device.".format( 93 | name, param.type())) 94 | 95 | # Backward compatibility for PyTorch 0.4 96 | if hasattr(model, 'named_buffers'): 97 | buf_iter = model.named_buffers() 98 | else: 99 | buf_iter = model._buffers 100 | for obj in buf_iter: 101 | if type(obj)==tuple: 102 | name, buf = obj 103 | else: 104 | name, buf = obj, buf_iter[obj] 105 | if buf.is_floating_point(): 106 | if 'Half' in buf.type(): 107 | warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" 108 | "When using amp.initialize, you do not need to call .half() on your model\n" 109 | "before passing it, no matter what optimization level you choose.".format( 110 | name, buf.type())) 111 | elif not buf.is_cuda: 112 | warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" 113 | "When using amp.initialize, you need to provide a model with buffers\n" 114 | "located on a CUDA device before passing it no matter what optimization level\n" 115 | "you chose. Use model.to('cuda') to use the default device.".format( 116 | name, buf.type())) 117 | 118 | 119 | def check_optimizers(optimizers): 120 | for optim in optimizers: 121 | bad_optim_type = None 122 | if isinstance(optim, FP16_Optimizer_general): 123 | bad_optim_type = "apex.fp16_utils.FP16_Optimizer" 124 | if isinstance(optim, FP16_Optimizer_for_fused): 125 | bad_optim_type = "apex.optimizers.FP16_Optimizer" 126 | if bad_optim_type is not None: 127 | raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) + 128 | "The optimizer(s) passed to amp.initialize() must be bare \n" 129 | "instances of either ordinary Pytorch optimizers, or Apex fused \n" 130 | "optimizers.\n") 131 | 132 | 133 | class O2StateDictHook(object): 134 | def __init__(self, fn): 135 | self.fn = fn 136 | 137 | def __call__(self, module, state_dict, prefix, local_metadata): 138 | for key in state_dict: 139 | param = state_dict[key] 140 | if 'Half' in param.type(): 141 | param = param.to(torch.float32) 142 | state_dict[key] = param 143 | 144 | 145 | def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): 146 | from .amp import init as amp_init 147 | 148 | optimizers_was_list = False 149 | if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)): 150 | optimizers = [optimizers] 151 | elif optimizers is None: 152 | optimizers = [] 153 | elif isinstance(optimizers, list): 154 | optimizers_was_list = True 155 | check_optimizers(optimizers) 156 | else: 157 | check_optimizers([optimizers]) 158 | raise TypeError("optimizers must be either a single optimizer or a list of optimizers.") 159 | 160 | if isinstance(models, torch.nn.Module): 161 | models_was_list = False 162 | models = [models] 163 | elif isinstance(models, list): 164 | models_was_list = True 165 | else: 166 | raise TypeError("models must be either a single model or a list of models.") 167 | 168 | check_models(models) 169 | 170 | if not _amp_state.allow_incoming_model_not_fp32: 171 | check_params_fp32(models) 172 | 173 | # In the future, when FP16_Optimizer can be deprecated and master weights can 174 | # become an attribute, remember to stash master weights before casting the model. 175 | 176 | if properties.cast_model_type: 177 | if properties.keep_batchnorm_fp32: 178 | for model in models: 179 | convert_network(model, properties.cast_model_type) 180 | else: 181 | for model in models: 182 | model.to(properties.cast_model_type) 183 | 184 | input_caster = functools.partial(to_type, properties.cast_model_type) 185 | if cast_model_outputs is not None: 186 | output_caster = functools.partial(to_type, cast_model_outputs) 187 | else: 188 | output_caster = functools.partial(to_type, torch.float32) 189 | 190 | for model in models: 191 | # Patch the forward method to cast incoming data to the correct type, and 192 | # outgoing data to float32, so "the user never needs to call .half()." 193 | # I like writing things explicitly more than decorators. 194 | def patch_forward(old_fwd): 195 | def new_fwd(*args, **kwargs): 196 | output = old_fwd(*applier(args, input_caster), 197 | **applier(kwargs, input_caster)) 198 | return applier(output, output_caster) 199 | return new_fwd 200 | 201 | model.forward = patch_forward(model.forward) 202 | 203 | # State dict trick to recast any preexisting per-param state tensors 204 | for optimizer in optimizers: 205 | optimizer.load_state_dict(optimizer.state_dict()) 206 | 207 | # patch model.state_dict() to return float32 params 208 | for model in models: 209 | for module in model.modules(): 210 | module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32))) 211 | 212 | elif cast_model_outputs is not None: 213 | output_caster = functools.partial(to_type, cast_model_outputs) 214 | 215 | for model in models: 216 | def patch_forward(old_fwd): 217 | def new_fwd(*args, **kwargs): 218 | output = old_fwd(*args, **kwargs) 219 | return applier(output, output_caster) 220 | return new_fwd 221 | 222 | model.forward = patch_forward(model.forward) 223 | 224 | for i, optimizer in enumerate(optimizers): 225 | optimizers[i] = _process_optimizer(optimizer, properties) 226 | 227 | _amp_state.loss_scalers = [] 228 | for _ in range(num_losses): 229 | _amp_state.loss_scalers.append(LossScaler(properties.loss_scale, 230 | min_loss_scale=_amp_state.min_loss_scale, 231 | max_loss_scale=_amp_state.max_loss_scale)) 232 | 233 | if properties.patch_torch_functions: 234 | # handle is unused here. It's accessible later through a global value anyway. 235 | handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2)) 236 | for optimizer in optimizers: 237 | # Disable Amp casting for the optimizer step, because it should only be 238 | # applied to FP32 master params anyway. 239 | def patch_step(old_step): 240 | def new_step(self, *args, **kwargs): 241 | with disable_casts(): 242 | output = old_step(*args, **kwargs) 243 | return output 244 | return new_step 245 | 246 | optimizer.step = MethodType(patch_step(optimizer.step), optimizer) 247 | 248 | if optimizers_was_list: 249 | if models_was_list: 250 | return models, optimizers 251 | else: 252 | return models[0], optimizers 253 | else: 254 | if models_was_list: 255 | if len(optimizers) == 0: 256 | return models 257 | else: 258 | return models, optimizers[0] 259 | else: 260 | if len(optimizers) == 0: 261 | return models[0] 262 | else: 263 | return models[0], optimizers[0] 264 | -------------------------------------------------------------------------------- /apex/amp/handle.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | import sys 4 | import torch 5 | 6 | from . import utils 7 | from .opt import OptimWrapper 8 | from .scaler import LossScaler 9 | from ._amp_state import _amp_state, master_params, maybe_print 10 | 11 | if torch.distributed.is_available(): 12 | from ..parallel.LARC import LARC 13 | 14 | 15 | # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. 16 | @contextlib.contextmanager 17 | def scale_loss(loss, 18 | optimizers, 19 | loss_id=0, 20 | model=None, 21 | delay_unscale=False, 22 | delay_overflow_check=False): 23 | """ 24 | On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``. 25 | ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``:: 26 | 27 | with amp.scale_loss(loss, optimizer) as scaled_loss: 28 | scaled_loss.backward() 29 | 30 | On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs 31 | and unscaled, so that ``optimizer.step()`` can be called. 32 | 33 | .. note:: 34 | If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and 35 | can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``) 36 | any FP16 gradients are copied to FP32 master gradients before being unscaled. 37 | ``optimizer.step()`` will then apply the unscaled master gradients to the master params. 38 | 39 | .. warning:: 40 | If Amp is using explicit FP32 master params, only the FP32 master gradients will be 41 | unscaled. The direct ``.grad`` attributes of any FP16 42 | model params will remain scaled after context manager exit. 43 | This subtlety affects gradient clipping. See "Gradient clipping" under 44 | `Advanced Amp Usage`_ for best practices. 45 | 46 | Args: 47 | loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context 48 | manager yields is simply ``loss.float()*loss_scale``, so in principle 49 | ``loss`` could have more than one element, as long as you call 50 | ``backward()`` on ``scaled_loss`` appropriately within the context manager body. 51 | optimizers: All optimizer(s) for which the current backward pass is creating gradients. 52 | Must be an optimizer or list of optimizers returned from an earlier call 53 | to ``amp.initialize``. For example use with multiple optimizers, see 54 | "Multiple models/optimizers/losses" under `Advanced Amp Usage`_. 55 | loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument 56 | to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id`` 57 | must be an integer between 0 and ``num_losses`` that tells Amp which loss is 58 | being used for the current backward pass. See "Multiple models/optimizers/losses" 59 | under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp 60 | will use the default global loss scaler for this backward pass. 61 | model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future 62 | optimizations. 63 | delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and 64 | the default value of ``False`` is strongly recommended. 65 | If ``True``, Amp will not unscale the gradients or perform model->master 66 | gradient copies on context manager exit. 67 | ``delay_unscale=True`` is a minor ninja performance optimization and can result 68 | in weird gotchas (especially with multiple models/optimizers/losses), 69 | so only use it if you know what you're doing. 70 | "Gradient accumulation across iterations" under `Advanced Amp Usage`_ 71 | illustrates a situation where this CAN (but does not need to) be used. 72 | 73 | .. warning:: 74 | If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be 75 | called yet after context manager exit, and must wait for another, later backward context 76 | manager invocation with ``delay_unscale`` left to False. 77 | 78 | .. _`Advanced Amp Usage`: 79 | https://nvidia.github.io/apex/advanced.html 80 | """ 81 | if not hasattr(_amp_state, "opt_properties"): 82 | raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. " 83 | "model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called " 84 | "before `with amp.scale_loss`.") 85 | 86 | if not _amp_state.opt_properties.enabled: 87 | yield loss 88 | return 89 | 90 | if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)): 91 | optimizers = [optimizers] 92 | 93 | loss_scaler = _amp_state.loss_scalers[loss_id] 94 | loss_scale = loss_scaler.loss_scale() 95 | 96 | if ((not _amp_state.opt_properties.master_weights) 97 | and (not loss_scaler.dynamic) 98 | and loss_scale == 1.0): 99 | yield loss.float() 100 | # Needing to drop the cache here as well is an ugly gotcha. 101 | # But for now I think it's necessary to short-circuit. 102 | # Probably ok to skip this if not delay_unscale 103 | if _amp_state.opt_properties.patch_torch_functions: 104 | _amp_state.handle._clear_cache() 105 | return 106 | 107 | if not delay_unscale: 108 | if isinstance(optimizers, list): 109 | for optimizer in optimizers: 110 | if not optimizer._amp_stash.params_have_scaled_gradients: 111 | optimizer._prepare_amp_backward() 112 | 113 | yield (loss.float())*loss_scale 114 | 115 | if delay_unscale: 116 | for optimizer in optimizers: 117 | optimizer._amp_stash.params_have_scaled_gradients = True 118 | else: 119 | # FusedSGD may take care of unscaling as part of their step() methods. 120 | # if not isinstance(optimizers, FP16_Optimizer_for_fused): 121 | loss_scaler.clear_overflow_state() 122 | for optimizer in optimizers: 123 | optimizer._post_amp_backward(loss_scaler) 124 | optimizer._amp_stash.params_have_scaled_gradients = False 125 | # For future fused optimizers that enable sync-free dynamic loss scaling, 126 | # should_skip will always be False. 127 | should_skip = False if delay_overflow_check else loss_scaler.update_scale() 128 | if should_skip: 129 | for optimizer in optimizers: 130 | if not optimizer._amp_stash.already_patched: 131 | # Close on loss_scaler and loss_id as well, to be safe. Probably not 132 | # necessary because amp.scale_loss is already creating a temporary scope. 133 | def patch_step(opt, loss_scaler, loss_id): 134 | opt_step = opt.step 135 | def skip_step(closure=None): 136 | if closure is not None: 137 | raise RuntimeError("Currently, Amp does not support closure use with optimizers.") 138 | maybe_print(("Gradient overflow. Skipping step, loss scaler " + 139 | "{} reducing loss scale to {}").format(loss_id, 140 | loss_scaler.loss_scale())) 141 | # TODO: I don't like the special casing for different optimizer implementations. 142 | # Maybe skip should delegate to a method owned by the optimizers themselves. 143 | if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"): 144 | # Clear the master grads that wouldn't be zeroed by model.zero_grad() 145 | for param in opt._amp_stash.all_fp32_from_fp16_params: 146 | param.grad = None 147 | if hasattr(opt, "most_recent_scale"): 148 | opt.most_recent_scale = 1.0 149 | opt.scale_set_by_backward = False 150 | opt.step = opt_step 151 | opt._amp_stash.already_patched = False 152 | return skip_step 153 | optimizer.step = patch_step(optimizer, loss_scaler, loss_id) 154 | optimizer._amp_stash.already_patched = True 155 | 156 | # Probably ok to skip this if not delay_unscale 157 | if _amp_state.opt_properties.patch_torch_functions: 158 | _amp_state.handle._clear_cache() 159 | 160 | 161 | # Free function version of AmpHandle.disable_casts, another step on the 162 | # path to removing the concept of "AmpHandle" 163 | @contextlib.contextmanager 164 | def disable_casts(): 165 | _amp_state.handle._is_active = False 166 | yield 167 | _amp_state.handle._is_active = True 168 | 169 | 170 | class AmpHandle(object): 171 | def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False): 172 | self._enable_caching = enable_caching 173 | self._verbose = verbose 174 | self._cache = dict() 175 | self._default_scaler = LossScaler(loss_scale) 176 | self._is_active = True 177 | self._all_wrappers = [] 178 | 179 | def is_active(self): 180 | return self._is_active 181 | 182 | @contextlib.contextmanager 183 | def _disable_casts(self): 184 | self._is_active = False 185 | yield 186 | self._is_active = True 187 | 188 | def wrap_optimizer(self, optimizer, num_loss=1): 189 | self._default_scaler = None 190 | return OptimWrapper(optimizer, self, num_loss) 191 | 192 | @contextlib.contextmanager 193 | def scale_loss(self, loss, optimizer): 194 | raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, " 195 | "documented here: https://nvidia.github.io/apex/amp.html. Transition guide: " 196 | "https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users") 197 | 198 | if not self.is_active(): 199 | yield loss 200 | return 201 | 202 | if self._default_scaler is None: 203 | raise RuntimeError( 204 | 'After calling `handle.wrap_optimizer()`, you must explicitly ' + 205 | 'use `optimizer.scale_loss(loss)`.') 206 | 207 | # TODO: this code block is duplicated here and `opt.py`. Unify. 208 | loss_scale = self._default_scaler.loss_scale() 209 | yield loss * loss_scale 210 | 211 | self._default_scaler.clear_overflow_state() 212 | self._default_scaler.unscale( 213 | master_params(optimizer), 214 | master_params(optimizer), 215 | loss_scale) 216 | should_skip = self._default_scaler.update_scale() 217 | if should_skip: 218 | optimizer_step = optimizer.step 219 | def skip_step(): 220 | maybe_print('Gradient overflow, skipping update') 221 | optimizer.step = optimizer_step 222 | optimizer.step = skip_step 223 | 224 | self._clear_cache() 225 | 226 | def _clear_cache(self): 227 | self._cache.clear() 228 | 229 | # Experimental support for saving / restoring uncasted versions of functions 230 | def _save_func(self, mod, fn, func): 231 | self._all_wrappers.append((mod, fn, func)) 232 | 233 | def _deactivate(self): 234 | for mod, fn, func in self._all_wrappers: 235 | utils.set_func(mod, fn, func) 236 | self._all_wrappers = [] 237 | 238 | @property 239 | def has_cache(self): 240 | return self._enable_caching 241 | 242 | @property 243 | def cache(self): 244 | return self._cache 245 | 246 | def remove_cache(self, param): 247 | if self.has_cache and param in self.cache: 248 | del self.cache[param] 249 | 250 | @property 251 | def verbose(self): 252 | return self._verbose 253 | 254 | class NoOpHandle(object): 255 | def is_active(self): 256 | return False 257 | 258 | @contextlib.contextmanager 259 | def _disable_casts(self): 260 | yield 261 | 262 | def wrap_optimizer(self, optimizer, num_loss=1): 263 | return OptimWrapper(optimizer, self, num_loss) 264 | 265 | @contextlib.contextmanager 266 | def scale_loss(self, loss, optimizer): 267 | yield loss 268 | 269 | @property 270 | def has_cache(self): 271 | return False 272 | 273 | @property 274 | def verbose(self): 275 | return False 276 | 277 | def _clear_cache(self): 278 | pass 279 | 280 | def _deactivate(self): 281 | pass 282 | -------------------------------------------------------------------------------- /models/model_ViT.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | 21 | import models.config as config 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 26 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 27 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 28 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 29 | FC_0 = "MlpBlock_3/Dense_0" 30 | FC_1 = "MlpBlock_3/Dense_1" 31 | ATTENTION_NORM = "LayerNorm_0" 32 | MLP_NORM = "LayerNorm_2" 33 | 34 | 35 | def np2th(weights, conv=False): 36 | """Possibly convert HWIO to OIHW.""" 37 | if conv: 38 | weights = weights.transpose([3, 2, 0, 1]) 39 | return torch.from_numpy(weights) 40 | 41 | 42 | def swish(x): 43 | return x * torch.sigmoid(x) 44 | 45 | 46 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, config): 51 | super(Attention, self).__init__() 52 | self.num_attention_heads = config.transformer["num_heads"] 53 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 54 | self.all_head_size = self.num_attention_heads * self.attention_head_size 55 | 56 | self.query = Linear(config.hidden_size, self.all_head_size) 57 | self.key = Linear(config.hidden_size, self.all_head_size) 58 | self.value = Linear(config.hidden_size, self.all_head_size) 59 | 60 | self.out = Linear(config.hidden_size, config.hidden_size) 61 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 62 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 63 | 64 | self.softmax = Softmax(dim=-1) 65 | 66 | def transpose_for_scores(self, x): 67 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 68 | x = x.view(*new_x_shape) 69 | return x.permute(0, 2, 1, 3) 70 | 71 | def forward(self, hidden_states): 72 | mixed_query_layer = self.query(hidden_states) 73 | mixed_key_layer = self.key(hidden_states) 74 | mixed_value_layer = self.value(hidden_states) 75 | 76 | query_layer = self.transpose_for_scores(mixed_query_layer) 77 | key_layer = self.transpose_for_scores(mixed_key_layer) 78 | value_layer = self.transpose_for_scores(mixed_value_layer) 79 | 80 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 81 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 82 | attention_probs = self.softmax(attention_scores) 83 | weights = attention_probs 84 | attention_probs = self.attn_dropout(attention_probs) 85 | 86 | context_layer = torch.matmul(attention_probs, value_layer) 87 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 88 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 89 | context_layer = context_layer.view(*new_context_layer_shape) 90 | attention_output = self.out(context_layer) 91 | attention_output = self.proj_dropout(attention_output) 92 | return attention_output, weights 93 | 94 | 95 | class Mlp(nn.Module): 96 | def __init__(self, config): 97 | super(Mlp, self).__init__() 98 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 99 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 100 | self.act_fn = ACT2FN["gelu"] 101 | self.dropout = Dropout(config.transformer["dropout_rate"]) 102 | 103 | self._init_weights() 104 | 105 | def _init_weights(self): 106 | nn.init.xavier_uniform_(self.fc1.weight) 107 | nn.init.xavier_uniform_(self.fc2.weight) 108 | nn.init.normal_(self.fc1.bias, std=1e-6) 109 | nn.init.normal_(self.fc2.bias, std=1e-6) 110 | 111 | def forward(self, x): 112 | x = self.fc1(x) 113 | x = self.act_fn(x) 114 | x = self.dropout(x) 115 | x = self.fc2(x) 116 | x = self.dropout(x) 117 | return x 118 | 119 | 120 | class Embeddings(nn.Module): 121 | """Construct the embeddings from patch, position embeddings. 122 | """ 123 | 124 | def __init__(self, config, img_size, in_channels=3): 125 | super(Embeddings, self).__init__() 126 | self.hybrid = None 127 | img_size = _pair(img_size) 128 | 129 | if config.patches.get("grid") is not None: 130 | grid_size = config.patches["grid"] 131 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 132 | n_patches = (img_size[0] // 16) * (img_size[1] // 16) 133 | self.hybrid = True 134 | else: 135 | patch_size = _pair(config.patches["size"]) 136 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 137 | self.hybrid = False 138 | 139 | if self.hybrid: 140 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, 141 | width_factor=config.resnet.width_factor) 142 | in_channels = self.hybrid_model.width * 16 143 | self.patch_embeddings = Conv2d(in_channels=in_channels, 144 | out_channels=config.hidden_size, 145 | kernel_size=patch_size, 146 | stride=patch_size) 147 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size)) 148 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 149 | 150 | self.dropout = Dropout(config.transformer["dropout_rate"]) 151 | 152 | def forward(self, x): 153 | B = x.shape[0] 154 | cls_tokens = self.cls_token.expand(B, -1, -1) 155 | 156 | if self.hybrid: 157 | x = self.hybrid_model(x) 158 | x = self.patch_embeddings(x) 159 | x = x.flatten(2) 160 | x = x.transpose(-1, -2) 161 | x = torch.cat((cls_tokens, x), dim=1) 162 | 163 | embeddings = x + self.position_embeddings 164 | embeddings = self.dropout(embeddings) 165 | return embeddings 166 | 167 | 168 | class Block(nn.Module): 169 | def __init__(self, config): 170 | super(Block, self).__init__() 171 | self.hidden_size = config.hidden_size 172 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 173 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 174 | self.ffn = Mlp(config) 175 | self.attn = Attention(config) 176 | 177 | def forward(self, x): 178 | h = x 179 | x = self.attention_norm(x) 180 | x, weights = self.attn(x) 181 | x = x + h 182 | 183 | h = x 184 | x = self.ffn_norm(x) 185 | x = self.ffn(x) 186 | x = x + h 187 | return x, weights 188 | 189 | def load_from(self, weights, n_block): 190 | ROOT = f"Transformer/encoderblock_{n_block}" 191 | with torch.no_grad(): 192 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, 193 | self.hidden_size).t() 194 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 195 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, 196 | self.hidden_size).t() 197 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, 198 | self.hidden_size).t() 199 | 200 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 201 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 202 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 203 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 204 | 205 | self.attn.query.weight.copy_(query_weight) 206 | self.attn.key.weight.copy_(key_weight) 207 | self.attn.value.weight.copy_(value_weight) 208 | self.attn.out.weight.copy_(out_weight) 209 | self.attn.query.bias.copy_(query_bias) 210 | self.attn.key.bias.copy_(key_bias) 211 | self.attn.value.bias.copy_(value_bias) 212 | self.attn.out.bias.copy_(out_bias) 213 | 214 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 215 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 216 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 217 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 218 | 219 | self.ffn.fc1.weight.copy_(mlp_weight_0) 220 | self.ffn.fc2.weight.copy_(mlp_weight_1) 221 | self.ffn.fc1.bias.copy_(mlp_bias_0) 222 | self.ffn.fc2.bias.copy_(mlp_bias_1) 223 | 224 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 225 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 226 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 227 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 228 | 229 | 230 | class Encoder(nn.Module): 231 | def __init__(self, config): 232 | super(Encoder, self).__init__() 233 | self.layer = nn.ModuleList() 234 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 235 | for _ in range(config.transformer["num_layers"]): 236 | layer = Block(config) 237 | self.layer.append(copy.deepcopy(layer)) 238 | 239 | def forward(self, hidden_states): 240 | attn_weights = [] 241 | for layer_block in self.layer: 242 | hidden_states, weights = layer_block(hidden_states) 243 | attn_weights.append(weights) 244 | 245 | encoded = self.encoder_norm(hidden_states) 246 | return encoded, attn_weights 247 | 248 | 249 | class Transformer(nn.Module): 250 | def __init__(self, config, img_size): 251 | super(Transformer, self).__init__() 252 | self.embeddings = Embeddings(config, img_size=img_size) 253 | self.encoder = Encoder(config) 254 | 255 | def forward(self, input_ids): 256 | embedding_output = self.embeddings(input_ids) 257 | encoded, attn_weights = self.encoder(embedding_output) 258 | return encoded, attn_weights 259 | 260 | 261 | class VisionTransformer(nn.Module): 262 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False): 263 | super(VisionTransformer, self).__init__() 264 | self.num_classes = num_classes 265 | self.zero_head = zero_head 266 | self.classifier = config.classifier 267 | 268 | self.transformer = Transformer(config, img_size) 269 | self.head = Linear(config.hidden_size, num_classes) 270 | 271 | def forward(self, x, labels=None): 272 | x, attn_weights = self.transformer(x) 273 | logits = self.head(x[:, 0]) 274 | 275 | if labels is not None: 276 | loss_fct = CrossEntropyLoss() 277 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) 278 | return loss, logits 279 | else: 280 | return logits, attn_weights 281 | 282 | def load_from(self, weights): 283 | with torch.no_grad(): 284 | if self.zero_head: 285 | nn.init.zeros_(self.head.weight) 286 | nn.init.zeros_(self.head.bias) 287 | else: 288 | self.head.weight.copy_(np2th(weights["head/kernel"]).t()) 289 | self.head.bias.copy_(np2th(weights["head/bias"]).t()) 290 | 291 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 292 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 293 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 294 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 295 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 296 | 297 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 298 | posemb_new = self.transformer.embeddings.position_embeddings 299 | if posemb.size() == posemb_new.size(): 300 | self.transformer.embeddings.position_embeddings.copy_(posemb) 301 | else: 302 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 303 | ntok_new = posemb_new.size(1) 304 | 305 | if self.classifier == "token": 306 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 307 | ntok_new -= 1 308 | else: 309 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 310 | 311 | gs_old = int(np.sqrt(len(posemb_grid))) 312 | gs_new = int(np.sqrt(ntok_new)) 313 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 314 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 315 | 316 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 317 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 318 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 319 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 320 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 321 | 322 | for bname, block in self.transformer.encoder.named_children(): 323 | for uname, unit in block.named_children(): 324 | unit.load_from(weights, n_block=uname) 325 | 326 | if self.transformer.embeddings.hybrid: 327 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_( 328 | np2th(weights["conv_root/kernel"], conv=True)) 329 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 330 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 331 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 332 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 333 | 334 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 335 | for uname, unit in block.named_children(): 336 | unit.load_from(weights, n_block=bname, n_unit=uname) 337 | 338 | 339 | CONFIGS = { 340 | 'ViT-B_16': config.get_b16_config(), 341 | 'ViT-B_32': config.get_b32_config(), 342 | 'ViT-L_16': config.get_l16_config(), 343 | 'ViT-L_32': config.get_l32_config(), 344 | 'ViT-H_14': config.get_h14_config(), 345 | } 346 | -------------------------------------------------------------------------------- /models/model_TransFG.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | 21 | import models.config as configs 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 26 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 27 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 28 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 29 | FC_0 = "MlpBlock_3/Dense_0" 30 | FC_1 = "MlpBlock_3/Dense_1" 31 | ATTENTION_NORM = "LayerNorm_0" 32 | MLP_NORM = "LayerNorm_2" 33 | 34 | 35 | def np2th(weights, conv=False): 36 | """Possibly convert HWIO to OIHW.""" 37 | if conv: 38 | weights = weights.transpose([3, 2, 0, 1]) 39 | return torch.from_numpy(weights) 40 | 41 | 42 | def swish(x): 43 | return x * torch.sigmoid(x) 44 | 45 | 46 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 47 | 48 | 49 | class LabelSmoothing(nn.Module): 50 | """ 51 | NLL loss with label smoothing. 52 | """ 53 | 54 | def __init__(self, smoothing=0.0): 55 | """ 56 | Constructor for the LabelSmoothing module. 57 | :param smoothing: label smoothing factor 58 | """ 59 | super(LabelSmoothing, self).__init__() 60 | self.confidence = 1.0 - smoothing 61 | self.smoothing = smoothing 62 | 63 | def forward(self, x, target): 64 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 65 | 66 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 67 | nll_loss = nll_loss.squeeze(1) 68 | smooth_loss = -logprobs.mean(dim=-1) 69 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 70 | return loss.mean() 71 | 72 | 73 | class Attention(nn.Module): 74 | def __init__(self, config): 75 | super(Attention, self).__init__() 76 | self.num_attention_heads = config.transformer["num_heads"] 77 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 78 | self.all_head_size = self.num_attention_heads * self.attention_head_size 79 | 80 | self.query = Linear(config.hidden_size, self.all_head_size) 81 | self.key = Linear(config.hidden_size, self.all_head_size) 82 | self.value = Linear(config.hidden_size, self.all_head_size) 83 | 84 | self.out = Linear(config.hidden_size, config.hidden_size) 85 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 86 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 87 | 88 | self.softmax = Softmax(dim=-1) 89 | 90 | def transpose_for_scores(self, x): 91 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 92 | x = x.view(*new_x_shape) 93 | return x.permute(0, 2, 1, 3) 94 | 95 | def forward(self, hidden_states): 96 | mixed_query_layer = self.query(hidden_states) 97 | mixed_key_layer = self.key(hidden_states) 98 | mixed_value_layer = self.value(hidden_states) 99 | 100 | query_layer = self.transpose_for_scores(mixed_query_layer) 101 | key_layer = self.transpose_for_scores(mixed_key_layer) 102 | value_layer = self.transpose_for_scores(mixed_value_layer) 103 | 104 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 105 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 106 | attention_probs = self.softmax(attention_scores) 107 | weights = attention_probs 108 | attention_probs = self.attn_dropout(attention_probs) 109 | 110 | context_layer = torch.matmul(attention_probs, value_layer) 111 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 112 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 113 | context_layer = context_layer.view(*new_context_layer_shape) 114 | attention_output = self.out(context_layer) 115 | attention_output = self.proj_dropout(attention_output) 116 | return attention_output, weights 117 | 118 | 119 | class Mlp(nn.Module): 120 | def __init__(self, config): 121 | super(Mlp, self).__init__() 122 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 123 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 124 | self.act_fn = ACT2FN["gelu"] 125 | self.dropout = Dropout(config.transformer["dropout_rate"]) 126 | 127 | self._init_weights() 128 | 129 | def _init_weights(self): 130 | nn.init.xavier_uniform_(self.fc1.weight) 131 | nn.init.xavier_uniform_(self.fc2.weight) 132 | nn.init.normal_(self.fc1.bias, std=1e-6) 133 | nn.init.normal_(self.fc2.bias, std=1e-6) 134 | 135 | def forward(self, x): 136 | x = self.fc1(x) 137 | x = self.act_fn(x) 138 | x = self.dropout(x) 139 | x = self.fc2(x) 140 | x = self.dropout(x) 141 | return x 142 | 143 | 144 | class Embeddings(nn.Module): 145 | """Construct the embeddings from patch, position embeddings. 146 | """ 147 | 148 | def __init__(self, config, img_size, in_channels=3): 149 | super(Embeddings, self).__init__() 150 | self.hybrid = None 151 | img_size = _pair(img_size) 152 | 153 | patch_size = _pair(config.patches["size"]) 154 | if config.split == 'non-overlap': 155 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 156 | self.patch_embeddings = Conv2d(in_channels=in_channels, 157 | out_channels=config.hidden_size, 158 | kernel_size=patch_size, 159 | stride=patch_size) 160 | elif config.split == 'overlap': 161 | n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ( 162 | (img_size[1] - patch_size[1]) // config.slide_step + 1) 163 | self.patch_embeddings = Conv2d(in_channels=in_channels, 164 | out_channels=config.hidden_size, 165 | kernel_size=patch_size, 166 | stride=(config.slide_step, config.slide_step)) 167 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size)) 168 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 169 | 170 | self.dropout = Dropout(config.transformer["dropout_rate"]) 171 | 172 | def forward(self, x): 173 | B = x.shape[0] 174 | cls_tokens = self.cls_token.expand(B, -1, -1) 175 | 176 | if self.hybrid: 177 | x = self.hybrid_model(x) 178 | x = self.patch_embeddings(x) 179 | x = x.flatten(2) 180 | x = x.transpose(-1, -2) 181 | x = torch.cat((cls_tokens, x), dim=1) 182 | 183 | embeddings = x + self.position_embeddings 184 | embeddings = self.dropout(embeddings) 185 | return embeddings 186 | 187 | 188 | class Block(nn.Module): 189 | def __init__(self, config): 190 | super(Block, self).__init__() 191 | self.hidden_size = config.hidden_size 192 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 193 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 194 | self.ffn = Mlp(config) 195 | self.attn = Attention(config) 196 | 197 | def forward(self, x): 198 | h = x 199 | x = self.attention_norm(x) 200 | x, weights = self.attn(x) 201 | x = x + h 202 | 203 | h = x 204 | x = self.ffn_norm(x) 205 | x = self.ffn(x) 206 | x = x + h 207 | return x, weights 208 | 209 | def load_from(self, weights, n_block): 210 | ROOT = f"Transformer/encoderblock_{n_block}" 211 | with torch.no_grad(): 212 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, 213 | self.hidden_size).t() 214 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 215 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, 216 | self.hidden_size).t() 217 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, 218 | self.hidden_size).t() 219 | 220 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 221 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 222 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 223 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 224 | 225 | self.attn.query.weight.copy_(query_weight) 226 | self.attn.key.weight.copy_(key_weight) 227 | self.attn.value.weight.copy_(value_weight) 228 | self.attn.out.weight.copy_(out_weight) 229 | self.attn.query.bias.copy_(query_bias) 230 | self.attn.key.bias.copy_(key_bias) 231 | self.attn.value.bias.copy_(value_bias) 232 | self.attn.out.bias.copy_(out_bias) 233 | 234 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 235 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 236 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 237 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 238 | 239 | self.ffn.fc1.weight.copy_(mlp_weight_0) 240 | self.ffn.fc2.weight.copy_(mlp_weight_1) 241 | self.ffn.fc1.bias.copy_(mlp_bias_0) 242 | self.ffn.fc2.bias.copy_(mlp_bias_1) 243 | 244 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 245 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 246 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 247 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 248 | 249 | 250 | class Part_Attention(nn.Module): 251 | def __init__(self): 252 | super(Part_Attention, self).__init__() 253 | 254 | def forward(self, x): 255 | length = len(x) 256 | last_map = x[0] 257 | for i in range(1, length): 258 | last_map = torch.matmul(x[i], last_map) 259 | last_map = last_map[:, :, 0, 1:] 260 | 261 | _, max_inx = last_map.max(2) 262 | return _, max_inx 263 | 264 | 265 | class Encoder(nn.Module): 266 | def __init__(self, config): 267 | super(Encoder, self).__init__() 268 | self.layer = nn.ModuleList() 269 | for _ in range(config.transformer["num_layers"] - 1): 270 | layer = Block(config) 271 | self.layer.append(copy.deepcopy(layer)) 272 | self.part_select = Part_Attention() 273 | self.part_layer = Block(config) 274 | self.part_norm = LayerNorm(config.hidden_size, eps=1e-6) 275 | 276 | def forward(self, hidden_states): 277 | attn_weights = [] 278 | for layer in self.layer: 279 | hidden_states, weights = layer(hidden_states) 280 | attn_weights.append(weights) 281 | part_num, part_inx = self.part_select(attn_weights) 282 | part_inx = part_inx + 1 283 | parts = [] 284 | B, num = part_inx.shape 285 | for i in range(B): 286 | parts.append(hidden_states[i, part_inx[i, :]]) 287 | parts = torch.stack(parts).squeeze(1) 288 | concat = torch.cat((hidden_states[:, 0].unsqueeze(1), parts), dim=1) 289 | part_states, part_weights = self.part_layer(concat) 290 | part_encoded = self.part_norm(part_states) 291 | 292 | return part_encoded 293 | 294 | 295 | class Transformer(nn.Module): 296 | def __init__(self, config, img_size): 297 | super(Transformer, self).__init__() 298 | self.embeddings = Embeddings(config, img_size=img_size) 299 | self.encoder = Encoder(config) 300 | 301 | def forward(self, input_ids): 302 | embedding_output = self.embeddings(input_ids) 303 | part_encoded = self.encoder(embedding_output) 304 | return part_encoded 305 | 306 | 307 | class VisionTransformer(nn.Module): 308 | def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False): 309 | super(VisionTransformer, self).__init__() 310 | self.num_classes = num_classes 311 | self.smoothing_value = smoothing_value 312 | self.zero_head = zero_head 313 | self.classifier = config.classifier 314 | self.transformer = Transformer(config, img_size) 315 | self.part_head = Linear(config.hidden_size, num_classes) 316 | 317 | def forward(self, x, labels=None): 318 | part_tokens = self.transformer(x) 319 | part_logits = self.part_head(part_tokens[:, 0]) 320 | 321 | if labels is not None: 322 | if self.smoothing_value == 0: 323 | loss_fct = CrossEntropyLoss() 324 | else: 325 | loss_fct = LabelSmoothing(self.smoothing_value) 326 | part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1)) 327 | contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1)) 328 | loss = part_loss + contrast_loss 329 | return loss, part_logits 330 | else: 331 | return part_logits 332 | 333 | def load_from(self, weights): 334 | with torch.no_grad(): 335 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 336 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 337 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 338 | self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 339 | self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 340 | 341 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 342 | posemb_new = self.transformer.embeddings.position_embeddings 343 | if posemb.size() == posemb_new.size(): 344 | self.transformer.embeddings.position_embeddings.copy_(posemb) 345 | else: 346 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 347 | ntok_new = posemb_new.size(1) 348 | 349 | if self.classifier == "token": 350 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 351 | ntok_new -= 1 352 | else: 353 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 354 | 355 | gs_old = int(np.sqrt(len(posemb_grid))) 356 | gs_new = int(np.sqrt(ntok_new)) 357 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 358 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 359 | 360 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 361 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 362 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 363 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 364 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 365 | 366 | for bname, block in self.transformer.encoder.named_children(): 367 | if bname.startswith('part') == False: 368 | for uname, unit in block.named_children(): 369 | unit.load_from(weights, n_block=uname) 370 | 371 | if self.transformer.embeddings.hybrid: 372 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_( 373 | np2th(weights["conv_root/kernel"], conv=True)) 374 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 375 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 376 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 377 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 378 | 379 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 380 | for uname, unit in block.named_children(): 381 | unit.load_from(weights, n_block=bname, n_unit=uname) 382 | 383 | 384 | def con_loss(features, labels): 385 | B, _ = features.shape 386 | features = F.normalize(features) 387 | cos_matrix = features.mm(features.t()) 388 | pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float() 389 | neg_label_matrix = 1 - pos_label_matrix 390 | pos_cos_matrix = 1 - cos_matrix 391 | neg_cos_matrix = cos_matrix - 0.4 392 | neg_cos_matrix[neg_cos_matrix < 0] = 0 393 | loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum() 394 | loss /= (B * B) 395 | return loss 396 | 397 | 398 | CONFIGS = { 399 | 'ViT-B_16': configs.get_b16_config(), 400 | 'ViT-B_32': configs.get_b32_config(), 401 | 'ViT-L_16': configs.get_l16_config(), 402 | 'ViT-L_32': configs.get_l32_config(), 403 | 'ViT-H_14': configs.get_h14_config(), 404 | } 405 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import logging 5 | import argparse 6 | import os 7 | import random 8 | import numpy as np 9 | import time 10 | 11 | from datetime import timedelta 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from tqdm import tqdm 17 | from torch.utils.tensorboard import SummaryWriter 18 | from apex import amp 19 | from apex.parallel import DistributedDataParallel as DDP 20 | 21 | from models.modeling import VisionTransformer, CONFIGS 22 | from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule 23 | from utils.data_utils import get_loader 24 | from utils.dist_util import get_world_size 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | def simple_accuracy(preds, labels): 49 | return (preds == labels).mean() 50 | 51 | 52 | def reduce_mean(tensor, nprocs): 53 | rt = tensor.clone() 54 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 55 | rt /= nprocs 56 | return rt 57 | 58 | 59 | def save_model(args, model): 60 | model_to_save = model.module if hasattr(model, 'module') else model 61 | model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) 62 | if args.fp16: 63 | checkpoint = { 64 | 'model': model_to_save.state_dict(), 65 | 'amp': amp.state_dict() 66 | } 67 | else: 68 | checkpoint = { 69 | 'model': model_to_save.state_dict(), 70 | } 71 | torch.save(checkpoint, model_checkpoint) 72 | logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) 73 | 74 | 75 | def setup(args): 76 | # Prepare model 77 | config = CONFIGS[args.model_type] 78 | 79 | if args.dataset == "CUB_200_2011": 80 | num_classes = 200 81 | elif args.dataset == "car": 82 | num_classes = 196 83 | elif args.dataset == "nabirds": 84 | num_classes = 555 85 | elif args.dataset == "dog": 86 | num_classes = 120 87 | elif args.dataset == "INat2017": 88 | num_classes = 5089 89 | elif args.dataset == "ip102": 90 | num_classes = 102 91 | 92 | model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes) 93 | 94 | model.load_from(np.load(args.pretrained_dir)) 95 | if args.pretrained_model is not None: 96 | pretrained_model = torch.load(args.pretrained_model)['model'] 97 | model.load_state_dict(pretrained_model) 98 | model.to(args.device) 99 | num_params = count_parameters(model) 100 | 101 | logger.info("{}".format(config)) 102 | logger.info("Training parameters %s", args) 103 | logger.info("Total Parameter: \t%2.1fM" % num_params) 104 | return args, model 105 | 106 | 107 | def count_parameters(model): 108 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 109 | return params / 1000000 110 | 111 | 112 | def set_seed(args): 113 | random.seed(args.seed) 114 | np.random.seed(args.seed) 115 | torch.manual_seed(args.seed) 116 | if args.n_gpu > 0: 117 | torch.cuda.manual_seed_all(args.seed) 118 | 119 | 120 | def valid(args, model, writer, test_loader, global_step): 121 | # Validation! 122 | eval_losses = AverageMeter() 123 | 124 | logger.info("\n") 125 | logger.info("***** Running Validation *****") 126 | logger.info(" Num steps = %d", len(test_loader)) 127 | logger.info(" Batch size = %d", args.eval_batch_size) 128 | 129 | model.eval() 130 | all_preds, all_label = [], [] 131 | epoch_iterator = tqdm(test_loader, 132 | desc="Validating... (loss=X.X)", 133 | bar_format="{l_bar}{r_bar}", 134 | dynamic_ncols=True, 135 | disable=args.local_rank not in [-1, 0]) 136 | loss_fct = torch.nn.CrossEntropyLoss() 137 | for step, batch in enumerate(epoch_iterator): 138 | batch = tuple(t.to(args.device) for t in batch) 139 | x, y = batch 140 | with torch.no_grad(): 141 | logits = model(x) 142 | 143 | eval_loss = loss_fct(logits, y) 144 | eval_loss = eval_loss.mean() 145 | eval_losses.update(eval_loss.item()) 146 | 147 | preds = torch.argmax(logits, dim=-1) 148 | 149 | if len(all_preds) == 0: 150 | all_preds.append(preds.detach().cpu().numpy()) 151 | all_label.append(y.detach().cpu().numpy()) 152 | else: 153 | all_preds[0] = np.append( 154 | all_preds[0], preds.detach().cpu().numpy(), axis=0 155 | ) 156 | all_label[0] = np.append( 157 | all_label[0], y.detach().cpu().numpy(), axis=0 158 | ) 159 | epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) 160 | 161 | all_preds, all_label = all_preds[0], all_label[0] 162 | accuracy = simple_accuracy(all_preds, all_label) 163 | accuracy = torch.tensor(accuracy).to(args.device) 164 | dist.barrier() 165 | val_accuracy = reduce_mean(accuracy, args.nprocs) 166 | val_accuracy = val_accuracy.detach().cpu().numpy() 167 | 168 | logger.info("\n") 169 | logger.info("Validation Results") 170 | logger.info("Global Steps: %d" % global_step) 171 | logger.info("Valid Loss: %2.5f" % eval_losses.avg) 172 | logger.info("Valid Accuracy: %2.5f" % val_accuracy) 173 | if args.local_rank in [-1, 0]: 174 | writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) 175 | 176 | return val_accuracy 177 | 178 | 179 | def train(args, model): 180 | """ Train the model """ 181 | if args.local_rank in [-1, 0]: 182 | os.makedirs(args.output_dir, exist_ok=True) 183 | writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) 184 | 185 | train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 186 | 187 | # Prepare dataset 188 | train_loader, test_loader = get_loader(args) 189 | 190 | # Prepare optimizer and scheduler 191 | optimizer = torch.optim.SGD(model.parameters(), 192 | lr=args.learning_rate, 193 | momentum=0.9, 194 | weight_decay=args.weight_decay) 195 | t_total = args.num_steps 196 | if args.decay_type == "cosine": 197 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 198 | else: 199 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 200 | 201 | if args.fp16: 202 | model, optimizer = amp.initialize(models=model, 203 | optimizers=optimizer, 204 | opt_level=args.fp16_opt_level) 205 | amp._amp_state.loss_scalers[0]._loss_scale = 2 ** 20 206 | 207 | # Distributed training 208 | if args.local_rank != -1: 209 | model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) 210 | 211 | # Train! 212 | logger.info("***** Running training *****") 213 | logger.info(" Total optimization steps = %d", args.num_steps) 214 | logger.info(" Instantaneous batch size per GPU = %d", train_batch_size) 215 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 216 | train_batch_size * args.gradient_accumulation_steps * ( 217 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 218 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 219 | 220 | model.zero_grad() 221 | set_seed(args) # Added here for reproducibility (even between python 2 and 3) 222 | losses = AverageMeter() 223 | global_step, best_acc = 0, 0 224 | start_time = time.time() 225 | while True: 226 | model.train() 227 | epoch_iterator = tqdm(train_loader, 228 | desc="Training (X / X Steps) (loss=X.X)", 229 | bar_format="{l_bar}{r_bar}", 230 | dynamic_ncols=True, 231 | disable=args.local_rank not in [-1, 0]) 232 | all_preds, all_label = [], [] 233 | 234 | for step, batch in enumerate(epoch_iterator): 235 | batch = tuple(t.to(args.device) for t in batch) 236 | x, y = batch 237 | 238 | loss, logits = model(x, y) 239 | loss = loss.mean() 240 | 241 | preds = torch.argmax(logits, dim=-1) 242 | 243 | if len(all_preds) == 0: 244 | all_preds.append(preds.detach().cpu().numpy()) 245 | all_label.append(y.detach().cpu().numpy()) 246 | else: 247 | all_preds[0] = np.append( 248 | all_preds[0], preds.detach().cpu().numpy(), axis=0 249 | ) 250 | all_label[0] = np.append( 251 | all_label[0], y.detach().cpu().numpy(), axis=0 252 | ) 253 | 254 | if args.gradient_accumulation_steps > 1: 255 | loss = loss / args.gradient_accumulation_steps 256 | if args.fp16: 257 | with amp.scale_loss(loss, optimizer) as scaled_loss: 258 | scaled_loss.backward() 259 | else: 260 | loss.backward() 261 | 262 | if (step + 1) % args.gradient_accumulation_steps == 0: 263 | losses.update(loss.item() * args.gradient_accumulation_steps) 264 | if args.fp16: 265 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 266 | else: 267 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 268 | scheduler.step() 269 | optimizer.step() 270 | optimizer.zero_grad() 271 | global_step += 1 272 | 273 | epoch_iterator.set_description( 274 | "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) 275 | ) 276 | if args.local_rank in [-1, 0]: 277 | writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) 278 | writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) 279 | if global_step % args.eval_every == 0: 280 | with torch.no_grad(): 281 | accuracy = valid(args, model, writer, test_loader, global_step) 282 | if args.local_rank in [-1, 0]: 283 | if best_acc < accuracy: 284 | save_model(args, model) 285 | best_acc = accuracy 286 | logger.info("best accuracy so far: %f" % best_acc) 287 | logger.info("\n") 288 | model.train() 289 | 290 | if global_step % t_total == 0: 291 | break 292 | all_preds, all_label = all_preds[0], all_label[0] 293 | accuracy = simple_accuracy(all_preds, all_label) 294 | accuracy = torch.tensor(accuracy).to(args.device) 295 | dist.barrier() 296 | train_accuracy = reduce_mean(accuracy, args.nprocs) 297 | train_accuracy = train_accuracy.detach().cpu().numpy() 298 | logger.info("train accuracy so far: %f" % train_accuracy) 299 | losses.reset() 300 | if global_step % t_total == 0: 301 | break 302 | 303 | writer.close() 304 | logger.info("Best Accuracy: \t%f" % best_acc) 305 | logger.info("End Training!") 306 | end_time = time.time() 307 | logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) 308 | 309 | 310 | def main(): 311 | parser = argparse.ArgumentParser() 312 | # Required parameters 313 | parser.add_argument("--name", required=True, 314 | help="Name of this run. Used for monitoring.") 315 | parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "ip102"], 316 | default="CUB_200_2011", 317 | help="Which dataset.") 318 | parser.add_argument('--data_root', type=str, default='/home/samuel/datasets') 319 | parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", 320 | "ViT-L_32", "ViT-H_14"], 321 | default="ViT-B_16", 322 | help="Which variant to use.") 323 | parser.add_argument("--pretrained_dir", type=str, 324 | default="/home/samuel/projects/MyNet/pretrained_weights/vit_base_patch16_448_in21k.npz", 325 | help="Where to search for pretrained ViT models.") 326 | parser.add_argument("--pretrained_model", type=str, default=None, 327 | help="load pretrained model") 328 | parser.add_argument("--output_dir", default="/home/samuel/projects/MyNet/pretrained_weights/output", type=str, 329 | help="The output directory where checkpoints will be written.") 330 | parser.add_argument("--img_size", default=448, type=int, 331 | help="Resolution size") 332 | parser.add_argument("--train_batch_size", default=16, type=int, 333 | help="Total batch size for train.") 334 | parser.add_argument("--eval_batch_size", default=16, type=int, 335 | help="Total batch size for eval.") 336 | parser.add_argument("--eval_every", default=100, type=int, 337 | help="Run prediction on validation set every so many steps." 338 | "Will always run one evaluation at the end of training.") 339 | 340 | parser.add_argument("--learning_rate", default=3e-3, type=float, 341 | help="The initial learning rate for SGD.") 342 | parser.add_argument("--weight_decay", default=0, type=float, 343 | help="Weight deay if we apply some.") 344 | parser.add_argument("--num_steps", default=20000, type=int, 345 | help="Total number of training epochs to perform.") 346 | parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", 347 | help="How to decay the learning rate.") 348 | parser.add_argument("--warmup_steps", default=500, type=int, 349 | help="Step of training to perform learning rate warmup for.") 350 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 351 | help="Max gradient norm.") 352 | 353 | parser.add_argument("--local_rank", type=int, default=-1, 354 | help="local_rank for distributed training on gpus") 355 | parser.add_argument('--seed', type=int, default=42, 356 | help="random seed for initialization") 357 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 358 | help="Number of updates steps to accumulate before performing a backward/update pass.") 359 | parser.add_argument('--fp16', action='store_true', 360 | help="Whether to use 16-bit float precision instead of 32-bit") 361 | parser.add_argument('--fp16_opt_level', type=str, default='O2', 362 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 363 | "See details at https://nvidia.github.io/apex/amp.html") 364 | parser.add_argument('--loss_scale', type=float, default=0, 365 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 366 | "0 (default value): dynamic loss scaling.\n" 367 | "Positive power of 2: static loss scaling value.\n") 368 | parser.add_argument('--smoothing_value', type=float, default=0.0, 369 | help="Label smoothing value\n") 370 | 371 | args = parser.parse_args() 372 | 373 | # if args.fp16 and args.smoothing_value != 0: 374 | # raise NotImplementedError("label smoothing not supported for fp16 training now") 375 | args.data_root = '{}/{}'.format(args.data_root, args.dataset) 376 | # Setup CUDA, GPU & distributed training 377 | if args.local_rank == -1: 378 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 379 | args.n_gpu = torch.cuda.device_count() 380 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 381 | torch.cuda.set_device(args.local_rank) 382 | device = torch.device("cuda", args.local_rank) 383 | torch.distributed.init_process_group(backend='nccl', 384 | timeout=timedelta(minutes=60)) 385 | args.n_gpu = 1 386 | args.device = device 387 | args.nprocs = torch.cuda.device_count() 388 | 389 | # Setup logging 390 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 391 | datefmt='%m/%d/%Y %H:%M:%S', 392 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 393 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % 394 | (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) 395 | 396 | # Set seed 397 | set_seed(args) 398 | 399 | # Model & Tokenizer Setup 400 | args, model = setup(args) 401 | # Training 402 | train(args, model) 403 | 404 | 405 | if __name__ == "__main__": 406 | main() 407 | --------------------------------------------------------------------------------