├── .gitignore ├── README.md └── selective_back_propagation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | # Edit at https://www.gitignore.io/?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # pyenv 71 | .python-version 72 | 73 | # pipenv 74 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 75 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 76 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 77 | # install all needed dependencies. 78 | #Pipfile.lock 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | .spyproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | # Mr Developer 94 | .mr.developer.cfg 95 | .project 96 | .pydevproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | .dmypy.json 104 | dmypy.json 105 | 106 | # Pyre type checker 107 | .pyre/ 108 | 109 | # End of https://www.gitignore.io/api/python 110 | 111 | 112 | 113 | 114 | 115 | 116 | # dotenv 117 | .env 118 | 119 | # virtualenv 120 | .venv 121 | venv/ 122 | ENV/ 123 | 124 | 125 | # editor, os cache directory 126 | .vscode/ 127 | .idea/ 128 | __MACOSX/ 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | # input data, saved log, checkpoints 138 | data/ 139 | saved/ 140 | 141 | # custom 142 | */*archive*/ 143 | *weigths/ 144 | 145 | # Extensions 146 | *.lprof 147 | *.weights 148 | *.conv* 149 | *.pth 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Selective_Backpropagation 2 | from paper Accelerating Deep Learning by Focusing on the Biggest Losers 3 | https://arxiv.org/abs/1910.00762v1 4 | 5 | ## Code example: 6 | ### Without selective backpropagation: 7 | ``` 8 | ... 9 | criterion = nn.CrossEntropyLoss(reduction='none') 10 | ... 11 | for x, y in data_loader: 12 | ... 13 | y_pred = model(x) 14 | loss = criterion(y_pred, y).mean() 15 | loss.backward() 16 | ... 17 | ``` 18 | ### With selective backpropagation: 19 | ``` 20 | ... 21 | criterion = nn.CrossEntropyLoss(reduction='none') 22 | selective_backprop = SelectiveBackPropagation( 23 | criterion, 24 | lambda loss : loss.mean().backward(), 25 | optimizer, 26 | model, 27 | batch_size, 28 | epoch_length=len(data_loader), 29 | loss_selection_threshold=False) 30 | ... 31 | for x, y in data_loader: 32 | ... 33 | with torch.no_grad(): 34 | y_pred = model(x) 35 | not_reduced_loss = criterion(y_pred, y) 36 | selective_backprop.selective_back_propagation(not_reduced_loss, x, y) 37 | ... 38 | ``` -------------------------------------------------------------------------------- /selective_back_propagation.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import torch 5 | # from loguru import logger 6 | 7 | 8 | class SelectiveBackPropagation: 9 | """ 10 | Selective_Backpropagation from paper Accelerating Deep Learning by Focusing on the Biggest Losers 11 | https://arxiv.org/abs/1910.00762v1 12 | Without: 13 | ... 14 | criterion = nn.CrossEntropyLoss(reduction='none') 15 | ... 16 | for x, y in data_loader: 17 | ... 18 | y_pred = model(x) 19 | loss = criterion(y_pred, y).mean() 20 | loss.backward() 21 | ... 22 | With: 23 | ... 24 | criterion = nn.CrossEntropyLoss(reduction='none') 25 | selective_backprop = SelectiveBackPropagation( 26 | criterion, 27 | lambda loss : loss.mean().backward(), 28 | optimizer, 29 | model, 30 | batch_size, 31 | epoch_length=len(data_loader), 32 | loss_selection_threshold=False) 33 | ... 34 | for x, y in data_loader: 35 | ... 36 | with torch.no_grad(): 37 | y_pred = model(x) 38 | not_reduced_loss = criterion(y_pred, y) 39 | selective_backprop.selective_back_propagation(not_reduced_loss, x, y) 40 | ... 41 | """ 42 | def __init__(self, compute_losses_func, update_weights_func, optimizer, model, 43 | batch_size, epoch_length, loss_selection_threshold=False): 44 | """ 45 | Usage: 46 | ``` 47 | criterion = nn.CrossEntropyLoss(reduction='none') 48 | selective_backprop = SelectiveBackPropagation( 49 | criterion, 50 | lambda loss : loss.mean().backward(), 51 | optimizer, 52 | model, 53 | batch_size, 54 | epoch_length=len(data_loader), 55 | loss_selection_threshold=False) 56 | ``` 57 | 58 | :param compute_losses_func: the loss function which output a tensor of dim [batch_size] (no reduction to apply). 59 | Example: `compute_losses_func = nn.CrossEntropyLoss(reduction='none')` 60 | :param update_weights_func: the reduction of the loss and backpropagation. Example: `update_weights_func = 61 | lambda loss : loss.mean().backward()` 62 | :param optimizer: your optimizer object 63 | :param model: your model object 64 | :param batch_size: number of images per batch 65 | :param epoch_length: the number of batch per epoch 66 | :param loss_selection_threshold: default to False. Set to a float value to select all images with with loss 67 | higher than loss_selection_threshold. Do not change behavior for loss below loss_selection_threshold. 68 | """ 69 | 70 | self.loss_selection_threshold = loss_selection_threshold 71 | self.compute_losses_func = compute_losses_func 72 | self.update_weights_func = update_weights_func 73 | self.batch_size = batch_size 74 | self.optimizer = optimizer 75 | self.model = model 76 | 77 | self.loss_hist = collections.deque([], maxlen=batch_size*epoch_length) 78 | self.selected_inputs, self.selected_targets = [], [] 79 | 80 | def selective_back_propagation(self, loss_per_img, data, targets): 81 | effective_batch_loss = None 82 | 83 | cpu_losses = loss_per_img.detach().clone().cpu() 84 | self.loss_hist.extend(cpu_losses.tolist()) 85 | np_cpu_losses = cpu_losses.numpy() 86 | selection_probabilities = self._get_selection_probabilities(np_cpu_losses) 87 | 88 | selection = selection_probabilities > np.random.random(*selection_probabilities.shape) 89 | 90 | if self.loss_selection_threshold: 91 | higher_thres = np_cpu_losses > self.loss_selection_threshold 92 | selection = np.logical_or(higher_thres, selection) 93 | 94 | selected_losses = [] 95 | for idx in np.argwhere(selection).flatten(): 96 | selected_losses.append(np_cpu_losses[idx]) 97 | 98 | self.selected_inputs.append(data[idx, ...].detach().clone()) 99 | self.selected_targets.append(targets[idx, ...].detach().clone()) 100 | if len(self.selected_targets) == self.batch_size: 101 | self.model.train() 102 | predictions = self.model(torch.stack(self.selected_inputs)) 103 | effective_batch_loss = self.compute_losses_func(predictions, 104 | torch.stack(self.selected_targets)) 105 | self.update_weights_func(effective_batch_loss) 106 | effective_batch_loss = effective_batch_loss.mean() 107 | self.model.eval() 108 | self.selected_inputs = [] 109 | self.selected_targets = [] 110 | 111 | # logger.info("Mean of input loss {}".format(np.array(np_cpu_losses).mean())) 112 | # logger.info("Mean of loss history {}".format(np.array(self.loss_hist).mean())) 113 | # logger.info("Mean of selected loss {}".format(np.array(selected_losses).mean())) 114 | # logger.info("Mean of effective_batch_loss {}".format(effective_batch_loss)) 115 | return effective_batch_loss 116 | 117 | def _get_selection_probabilities(self, loss): 118 | percentiles = self._percentiles(self.loss_hist, loss) 119 | return percentiles ** 2 120 | 121 | def _percentiles(self, hist_values, values_to_search): 122 | # TODO Speed up this again. There is still a visible overhead in training. 123 | hist_values, values_to_search = np.asarray(hist_values), np.asarray(values_to_search) 124 | 125 | percentiles_values = np.percentile(hist_values, range(100)) 126 | sorted_loss_idx = sorted(range(len(values_to_search)), key=lambda k: values_to_search[k]) 127 | counter = 0 128 | percentiles_by_loss = [0] * len(values_to_search) 129 | for idx, percentiles_value in enumerate(percentiles_values): 130 | while values_to_search[sorted_loss_idx[counter]] < percentiles_value: 131 | percentiles_by_loss[sorted_loss_idx[counter]] = idx 132 | counter += 1 133 | if counter == len(values_to_search) : break 134 | if counter == len(values_to_search) : break 135 | return np.array(percentiles_by_loss)/100 136 | --------------------------------------------------------------------------------