├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── datasets └── sort.py ├── main.py └── models ├── approximator.py └── dab.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2020 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAB: Differentiable Approximation Bridges 2 | 3 | A simplified example demonstrating a DAB network presented in [Improving Discrete Latent Representations With Differentiable Approximation Bridges](https://arxiv.org/abs/1905.03658). 4 | 5 | ### Usage 6 | 7 | The only dependency for this demo is [pytorch](https://pytorch.org/get-started/locally/). 8 | To run the 10-sort signum-dense problem described in section 4.1 of the [paper](https://arxiv.org/abs/1905.03658) simply run: 9 | 10 | ```python 11 | python main.py 12 | ``` 13 | 14 | This should result in the following which corroborates the paper’s result of 94.2% : 15 | 16 | ```bash 17 | train[Epoch 2168][1999872.0 samples][7.79 sec]: Loss: 79.2356 DABLoss: 7.9058 Accuracy: 95.5683 18 | … 19 | test[Epoch 2168][399360.0 samples][0.91 sec]: Loss: 79.2329 DABLoss: 7.9012 Accuracy: 94.6424 20 | ``` 21 | 22 | ### Create a DAB for a custom non-differentiable function 23 | 24 | 1. Create a suitable approximation neural network. 25 | 2. Implement custom hard function similar to SignumWithMargin in models/dab.py . 26 | 3. Stack a DAB module in your neural network pipeline. 27 | 4. Add DAB loss to normal loss. 28 | 29 | 30 | ### Cite 31 | 32 | ``` 33 | @article{ 34 | dabimprovingdiscreterepr2020, 35 | title={Improving Discrete Latent Representations With Differentiable Approximation Bridges}, 36 | author={Ramapuram, Jason and Webb, Russ}, 37 | journal={IEEE WCCI}, 38 | year={2020} 39 | } 40 | -------------------------------------------------------------------------------- /datasets/sort.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.txt file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.utils.data 7 | import numpy as np 8 | 9 | 10 | def generate_samples(num_samples, seq_len, max_digit): 11 | """ Helper to generate sampels between 0 and max_digit 12 | 13 | :param num_samples: the total number of samples to generate 14 | :param seq_len: length of each sequence 15 | :param max_digit: the upper bound in the uniform distribution 16 | :returns: [B, seq_len] 17 | :rtype: torch.Tensor, torch.Tensor 18 | 19 | """ 20 | data = np.random.uniform(0, max_digit, size=[num_samples, seq_len]) 21 | labels = np.argsort(data, axis=-1) 22 | data = data.reshape(num_samples, seq_len, 1) 23 | labels = labels.reshape(num_samples, seq_len*1) 24 | print('[debug] labels = ', labels.shape, " | data = ", data.shape) 25 | return [data.astype(np.float32), labels] 26 | 27 | 28 | class SortDataset(torch.utils.data.Dataset): 29 | def __init__(self, upper_bound_unif, sequence_length, split='train', 30 | transform=None, target_transform=None, **kwargs): 31 | self.split = split 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.max_digit = upper_bound_unif # max sorting range U ~ [0, max_digit] 35 | self.sequence_length = sequence_length # set the sequence length if it isn't specified 36 | 37 | # set the number of samples to 2 million by default 38 | train_samples = kwargs.get('num_samples', 2000000) 39 | self.num_samples = train_samples if split == 'train' else int(train_samples*0.2) 40 | 41 | # load the sort dataset and labels 42 | self.data, self.labels = generate_samples(self.num_samples, 43 | self.sequence_length, 44 | self.max_digit) 45 | print("[{}] {} samples\n".format(split, len(self.labels))) 46 | 47 | def __getitem__(self, index): 48 | """ Returns a single element based on the index. 49 | Extended by pytorch to a queue based loader. 50 | 51 | :param index: the single sample id 52 | :returns: an unsorted vector and the correct sorted class target. 53 | :rtype: torch.Tensor, torch.Tensor 54 | 55 | """ 56 | target = self.labels[index] 57 | data = self.data[index] 58 | 59 | if self.transform is not None: 60 | data = self.transform(data) 61 | 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | 65 | # sanity conversions in case the data has not yet been 66 | # converted to a torch.Tensor 67 | if not isinstance(data, torch.Tensor): 68 | data = torch.from_numpy(data) 69 | 70 | if not isinstance(target, torch.Tensor): 71 | target = torch.from_numpy(target) 72 | 73 | return data, target 74 | 75 | def __len__(self): 76 | return len(self.labels) 77 | 78 | 79 | class SortLoader(object): 80 | def __init__(self, batch_size, upper_bound_unif, sequence_length, transform=None, target_transform=None, **kwargs): 81 | """ A container class that houses a train and test loader for the sort problem. 82 | 83 | :param batch_size: the minibatch size 84 | :param upper_bound_unif: the upper bound in U(0, upper_bound_unif) 85 | :param sequence_length: how many samples in input? 86 | :param transform: torchvision transforms if needed 87 | :param target_transform: torchvision target label transforme 88 | :returns: SortLoader object with .train_loader and .test_loader to iterate corresponding datasets 89 | :rtype: object 90 | 91 | """ 92 | # build the datasets that implement __getitem__ 93 | train_dataset = SortDataset(upper_bound_unif, sequence_length, split='train', 94 | transfor=transform, target_transform=target_transform, **kwargs) 95 | test_dataset = SortDataset(upper_bound_unif, sequence_length, split='test', 96 | transfor=transform, target_transform=target_transform, **kwargs) 97 | 98 | # build the dataloaders that wrap the dataset 99 | loader_args = {'num_workers': 4, 'pin_memory': True, 'batch_size': batch_size, 'drop_last': True} 100 | self.train_loader = torch.utils.data.DataLoader( 101 | train_dataset, shuffle=True, **loader_args 102 | ) 103 | self.test_loader = torch.utils.data.DataLoader( 104 | test_dataset, shuffle=True, **loader_args 105 | ) 106 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.txt file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | import time 6 | import torch 7 | import pprint 8 | import argparse 9 | import contextlib 10 | import numpy as np 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | 15 | 16 | from models.approximator import IndependentVectorizedApproximator 17 | from models.dab import DAB, SignumWithMargin, View 18 | from datasets.sort import SortLoader 19 | 20 | 21 | parser = argparse.ArgumentParser(description='DAB Sort Dense Example') 22 | 23 | # Task parameters 24 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 25 | help='input batch size for training (default: 1024)') 26 | parser.add_argument('--epochs', type=int, default=15000, 27 | help='minimum number of epochs to train (default: 10000)') 28 | parser.add_argument('--sequence-length', type=int, default=10, 29 | help='size of each sequence to use for sorting (default: 10)') 30 | 31 | # Model related 32 | parser.add_argument('--latent-size', type=int, default=256, 33 | help='latent layer size (default: 256)') 34 | parser.add_argument('--dab-gamma', type=float, default=10.0, 35 | help='the weighting for the DAB loss (default: 10)') 36 | parser.add_argument('--approximator-type', type=str, default='batch', 37 | help='batch [does all at once] or independent [elem-by-elem] approximator (default: batch)') 38 | 39 | # Optimization related 40 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', 41 | help='learning rate (default: 1e-4)') 42 | parser.add_argument('--optimizer', type=str, default="adam", 43 | help="specify optimizer (default: adam)") 44 | 45 | # Device /debug stuff 46 | parser.add_argument('--debug-step', action='store_true', default=False, 47 | help='only does one step of the execute_graph function per call instead of all minibatches') 48 | parser.add_argument('--seed', type=int, default=None, 49 | help='seed for numpy and pytorch (default: None)') 50 | parser.add_argument('--ngpu', type=int, default=1, 51 | help='number of gpus available (default: 1)') 52 | parser.add_argument('--no-cuda', action='store_true', default=False, 53 | help='disables CUDA training') 54 | args = parser.parse_args() 55 | args.cuda = not args.no_cuda and torch.cuda.is_available() 56 | if args.cuda: 57 | torch.backends.cudnn.benchmark = True 58 | 59 | # set a fixed seed for GPUs and CPU 60 | if args.seed is not None: 61 | print("setting seed %d" % args.seed) 62 | np.random.seed(args.seed) 63 | torch.manual_seed(args.seed) 64 | if args.cuda: 65 | torch.cuda.manual_seed_all(args.seed) 66 | 67 | 68 | def all_or_none_accuracy(preds, targets, dim=-1): 69 | """ Gets the accuracy of the predicted sequence. 70 | 71 | :param preds: model predictions 72 | :param targets: the true targets 73 | :param dim: dimension to operate over 74 | :returns: scalar value for all-or-none accuracy 75 | :rtype: float32 76 | 77 | """ 78 | preds_max = preds.data.max(dim=dim)[1] # get the index of the max log-probability 79 | assert targets.shape == preds_max.shape, \ 80 | "target[{}] shape does not match preds[{}]".format(targets.shape, preds_max.shape) 81 | targ = targets.data 82 | return torch.mean(preds_max.eq(targ).cpu().all(dim=dim).type(torch.float32)) 83 | 84 | 85 | def build_optimizer(model, args): 86 | """ helper to build the optimizer and wrap model 87 | 88 | :param model: the model to wrap 89 | :returns: optimizer wrapping model provided 90 | :rtype: nn.Optim 91 | 92 | """ 93 | optim_map = { 94 | "rmsprop": optim.RMSprop, 95 | "adam": optim.Adam, 96 | "adadelta": optim.Adadelta, 97 | "sgd": optim.SGD, 98 | "lbfgs": optim.LBFGS 99 | } 100 | return optim_map[args.optimizer.lower().strip()]( 101 | model.parameters(), lr=args.lr 102 | ) 103 | 104 | 105 | def build_model(args): 106 | """ Builds the approximator and the model with the DAB 107 | 108 | :param args: argparse 109 | :returns: model 110 | :rtype: nn.Sequential 111 | 112 | """ 113 | if args.approximator_type == 'batch': 114 | approximator = nn.Sequential( # moar layers 115 | nn.Linear(args.latent_size, args.latent_size // 2), 116 | nn.Tanh(), 117 | nn.Linear(args.latent_size // 2, args.latent_size // 2), 118 | nn.Tanh(), 119 | nn.Linear(args.latent_size // 2, args.latent_size) 120 | ) 121 | elif args.approximator_type == 'independent': 122 | approximator = IndependentVectorizedApproximator(args.latent_size, 123 | activation=nn.Tanh) 124 | else: 125 | raise Exception("unknown approximator type, specify independent or batch") 126 | 127 | model = nn.Sequential( # even moar layers 128 | View([-1, args.sequence_length]), 129 | nn.Linear(args.sequence_length, args.latent_size // 2), 130 | nn.Tanh(), 131 | nn.Linear(args.latent_size // 2, args.latent_size // 2), 132 | nn.Tanh(), 133 | nn.Linear(args.latent_size // 2, args.latent_size), 134 | nn.Tanh(), 135 | DAB(approximator=approximator, 136 | hard_layer=SignumWithMargin()), 137 | nn.Linear(args.latent_size, args.sequence_length * args.sequence_length), 138 | View([-1, args.sequence_length, args.sequence_length]) 139 | ) 140 | print(model) 141 | 142 | return model.cuda() if args.cuda else model 143 | 144 | 145 | def build_dataloader(args): 146 | """ Helper to build the data dataloader that houses 147 | both the train and test pytorch Dataloaders 148 | 149 | :param args: argparse 150 | :returns: SortLoader 151 | :rtype: object 152 | 153 | """ 154 | return SortLoader(batch_size=args.batch_size, 155 | upper_bound_unif=1, 156 | sequence_length=args.sequence_length, 157 | transform=None, 158 | target_transform=None, 159 | num_samples=2000000) 160 | 161 | 162 | def get_dab_loss(model): 163 | """ Simple helper to iterate a model and return the DAB loss. 164 | 165 | :param model: the full nn.Sequential or nn.Modulelist 166 | :returns: dab loss 167 | :rtype: torch.Tensor 168 | 169 | """ 170 | dab_loss, dab_count = None, 0 171 | for layer in model: 172 | if isinstance(layer, DAB): 173 | dab_count += 1 174 | if dab_loss is None: 175 | dab_loss = layer.loss_function() 176 | else: 177 | dab_loss += layer.loss_function() 178 | 179 | dab_count = 1 if dab_count == 0 else dab_count 180 | dab_loss = torch.zeros(args.batch_size) if dab_loss is None else dab_loss 181 | dab_loss = dab_loss.cuda() if args.cuda else dab_loss 182 | return dab_loss / dab_count 183 | 184 | 185 | @contextlib.contextmanager 186 | def dummy_context(): 187 | """ Simple helper to create a fake context scope. 188 | 189 | :returns: None 190 | :rtype: Scope 191 | 192 | """ 193 | yield None 194 | 195 | 196 | def execute_graph(epoch, model, loader, optimizer=None, prefix='test'): 197 | """ execute the graph; when 'train' is in the name the model runs the optimizer 198 | 199 | :param epoch: the current epoch number 200 | :param model: the torch model 201 | :param loader: the train or **TEST** loader 202 | :param optimizer: the optimizer 203 | :param prefix: 'train', 'test' or 'valid' 204 | :returns: loss scalar 205 | :rtype: float32 206 | 207 | """ 208 | start_time = time.time() 209 | model.eval() if prefix == 'test' else model.train() 210 | assert optimizer is not None if 'train' in prefix or 'valid' in prefix else optimizer is None 211 | accuracy, loss, dab_loss, num_samples = 0., 0., 0., 0. 212 | 213 | # iterate over train and valid data 214 | for minibatch, labels in loader: 215 | minibatch = minibatch.cuda() if args.cuda else minibatch 216 | labels = labels.cuda() if args.cuda else labels 217 | if 'train' in prefix: 218 | optimizer.zero_grad() # zero gradients 219 | 220 | with torch.no_grad() if prefix == 'test' else dummy_context(): 221 | pred_logits = model(minibatch) # get model predictions 222 | 223 | # classification + DAB loss 224 | dab_loss_t = get_dab_loss(model) 225 | classification_loss_t = torch.sum(F.cross_entropy(input=pred_logits, target=labels, reduction='none'), -1) 226 | loss_t = torch.mean(classification_loss_t + args.dab_gamma * dab_loss_t) 227 | 228 | loss += loss_t.item() # add to aggregate loss 229 | dab_loss += torch.mean(dab_loss_t).item() 230 | accuracy += all_or_none_accuracy(preds=F.softmax(pred_logits, dim=1), # get accuracy value 231 | targets=labels, dim=1) 232 | num_samples += minibatch.size(0) 233 | 234 | if 'train' in prefix: # compute bp and optimize 235 | loss_t.backward() 236 | optimizer.step() 237 | 238 | if args.debug_step: # for testing purposes 239 | break 240 | 241 | # debug prints for a ** SINGLE ** sample, loss above is calculated over entire minibatch 242 | print('preds[0]\t =\t ', F.softmax(pred_logits[0], dim=1).max(dim=1)[1]) 243 | print('targets[0]\t =\t ', labels[0]) 244 | print('inputs[0]\t = ', minibatch[0]) 245 | 246 | # reduce by the number of minibatches completed 247 | num_minibatches_completed = num_samples / minibatch.size(0) 248 | loss /= num_minibatches_completed 249 | dab_loss /= num_minibatches_completed 250 | accuracy /= num_minibatches_completed 251 | 252 | # print out verbose loggin 253 | print('{}[Epoch {}][{} samples][{:.2f} sec]: Loss: {:.4f}\tDABLoss: {:.4f}\tAccuracy: {:.4f}'.format( 254 | prefix, epoch, num_samples, time.time() - start_time, 255 | loss, dab_loss, accuracy * 100.0)) 256 | 257 | # return this for early stopping if used 258 | return loss 259 | 260 | 261 | def train(epoch, model, optimizer, train_loader, prefix='train'): 262 | """ Helper to run execute-graph for the train dataset 263 | 264 | :param epoch: the current epoch 265 | :param model: the model 266 | :param test_loader: the train data-loader 267 | :param prefix: the default prefix; useful if we have multiple training types 268 | :returns: mean loss value 269 | :rtype: float32 270 | 271 | """ 272 | return execute_graph(epoch, model, train_loader, optimizer, prefix='train') 273 | 274 | 275 | def test(epoch, model, test_loader, prefix='test'): 276 | """ Helper to run execute-graph for the test dataset 277 | 278 | :param epoch: the current epoch 279 | :param model: the model 280 | :param test_loader: the test data-loader 281 | :param prefix: the default prefix; useful if we have multiple test types 282 | :returns: mean loss value 283 | :rtype: float32 284 | 285 | """ 286 | return execute_graph(epoch, model, test_loader, prefix='test') 287 | 288 | 289 | def run(args): 290 | """ Main entry-point into the program 291 | 292 | :param args: argparse 293 | :returns: None 294 | :rtype: None 295 | 296 | """ 297 | loader = build_dataloader(args) # houses train and test loader 298 | model = build_model(args) # the model itself 299 | optimizer = build_optimizer(model, args) # the optimizer for the model 300 | 301 | # main training loop 302 | for epoch in range(1, args.epochs + 1): 303 | train(epoch, model, optimizer, loader.train_loader) 304 | test(epoch, model, loader.test_loader) 305 | 306 | 307 | if __name__ == "__main__": 308 | print(pprint.PrettyPrinter(indent=4).pformat(vars(args))) 309 | run(args) 310 | -------------------------------------------------------------------------------- /models/approximator.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.txt file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | import torch.nn as nn 6 | 7 | 8 | class IndependentVectorizedApproximator(nn.Module): 9 | def __init__(self, latent_size, activation=nn.Tanh): 10 | """ A vectorized approximator that takes each input (feature-elem by feature-elem) and produces an approximation. 11 | This allows for a network that shares parameters and enables an easier approximation. 12 | 13 | :param latent_size: latent size for model 14 | :returns: IndependentVectorizedApproximator 15 | :rtype: nn.Module 16 | 17 | """ 18 | super(IndependentVectorizedApproximator, self).__init__() 19 | 20 | # the actual model 21 | self.approximator = nn.Sequential( 22 | nn.Linear(1, latent_size), 23 | activation(), 24 | nn.Linear(latent_size, latent_size), 25 | activation(), 26 | nn.Linear(latent_size, 1) 27 | ) 28 | 29 | def forward(self, x): 30 | x = x.unsqueeze(-1) 31 | return self.approximator(x).squeeze(-1) 32 | -------------------------------------------------------------------------------- /models/dab.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.txt file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class View(nn.Module): 11 | def __init__(self, shape): 12 | """ Simple helper module to reshape a tensor 13 | 14 | :param shape: the desired shape, -1 for ignored dimensions 15 | :returns: reshaped tensor 16 | :rtype: torch.Tensor 17 | 18 | """ 19 | super(View, self).__init__() 20 | self.shape = shape 21 | 22 | def forward(self, input): 23 | return input.contiguous().view(*self.shape) 24 | 25 | 26 | class DAB(nn.Module): 27 | def __init__(self, approximator, hard_layer): 28 | """ DAB layer simply accepts an approximator model, a hard layer 29 | and adds syntatic sugar to return the hard output while caching 30 | the soft version. It also adds a helper fn loss_function() to 31 | return the DAB loss. 32 | 33 | :param approximator: the approximator nn.Module 34 | :param hard_layer: the hard layer nn.Module 35 | :returns: DAB Object 36 | :rtype: nn.Module 37 | 38 | """ 39 | super(DAB, self).__init__() 40 | self.loss_fn = F.mse_loss 41 | self.hard_layer = hard_layer.apply 42 | self.approximator = approximator 43 | 44 | def loss_function(self): 45 | """ Simple helper to return the cached loss 46 | 47 | :returns: loss reduced across feature dimension 48 | :rtype: torch.Tensor 49 | 50 | """ 51 | assert self.true_output.shape[0] == self.approximator_output.shape[0], "batch mismatch" 52 | batch_size = self.true_output.shape[0] 53 | return torch.sum(self.loss_fn(self.approximator_output.view(batch_size, -1), 54 | self.true_output.view(batch_size, -1), 55 | reduction='none'), dim=-1) 56 | 57 | def forward(self, x, **kwargs): 58 | """ DAB layer simply caches the true and approximator outputs 59 | and returns the hard output. 60 | 61 | :param x: the input to the DAB / hard fn 62 | :returns: hard output 63 | :rtype: torch.Tensor 64 | 65 | """ 66 | self.approximator_output = self.approximator(x, **kwargs) 67 | self.true_output = self.hard_layer(x, self.approximator_output) 68 | 69 | # sanity check and return 70 | assert self.approximator_output.shape == self.true_output.shape, \ 71 | "proxy output {} doesn't match size of hard output [{}]".format( 72 | self.approximator_output.shape, self.true_output.shape 73 | ) 74 | 75 | return self.true_output 76 | 77 | 78 | class BaseHardFn(torch.autograd.Function): 79 | @staticmethod 80 | def forward(ctx, x, soft_y, hard_fn, *args): 81 | """ Runs the hard function for forward, cache the output and returns. 82 | All hard functions should inherit from this, it implements the autograd override. 83 | 84 | :param ctx: pytorch context, automatically passed in. 85 | :param x: input tensor. 86 | :param soft_y: forward pass output (logits) of DAB approximator network. 87 | :param hard_fn: to be passed in from derived class. 88 | :param args: list of args to pass to hard function. 89 | :returns: hard_fn(tensor), backward pass using DAB. 90 | :rtype: torch.Tensor 91 | 92 | """ 93 | hard = hard_fn(x, *args) 94 | saveable_args = list([a for a in args if isinstance(a, torch.Tensor)]) 95 | ctx.save_for_backward(x, soft_y, *saveable_args) 96 | return hard 97 | 98 | @staticmethod 99 | def _hard_fn(x, *args): 100 | raise NotImplementedError("implement _hard_fn in derived class") 101 | 102 | @staticmethod 103 | def backward(ctx, grad_out): 104 | """ Returns DAB derivative. 105 | 106 | :param ctx: pytorch context, automatically passed in. 107 | :param grad_out: grads coming into layer 108 | :returns: dab_grad(tensor) 109 | :rtype: torch.Tensor 110 | 111 | """ 112 | x, soft_y, *args = ctx.saved_tensors 113 | with torch.enable_grad(): 114 | grad = torch.autograd.grad(outputs=soft_y, inputs=x, 115 | grad_outputs=grad_out, 116 | # allow_unused=True, 117 | retain_graph=True) 118 | return grad[0], None, None, None 119 | 120 | 121 | class SignumWithMargin(BaseHardFn): 122 | @staticmethod 123 | def _hard_fn(x, *args): 124 | """ x[x < -eps] = -1 125 | x[x > +eps] = 1 126 | else x = 0 127 | 128 | :param x: input tensor 129 | :param args: list of args with 0th element being eps 130 | :returns: signum(tensor) 131 | :rtype: torch.Tensor 132 | 133 | """ 134 | eps = args[0] if len(args) > 0 else 0.5 135 | sig = torch.zeros_like(x) 136 | sig[x < -eps] = -1 137 | sig[x > eps] = 1 138 | return sig 139 | 140 | @staticmethod 141 | def forward(ctx, x, soft_y, *args): 142 | return BaseHardFn.forward(ctx, x, soft_y, SignumWithMargin._hard_fn, *args) 143 | --------------------------------------------------------------------------------