├── .gitignore ├── LICENSE ├── README.md ├── compute_codes.py ├── data ├── MNIST_train.npy └── MNIST_val.npy ├── eval_classification.py ├── eval_denoising.py ├── imagenet_LCN_patches.py ├── main.py ├── models ├── linear_classifier.py ├── linear_dictionary.py ├── lista_classifier.py ├── lista_encoder.py └── one_hidden_decoder.py ├── scripts ├── ImageNet_SDL-NL.sh ├── ImageNet_SDL.sh ├── ImageNet_VDL-NL.sh ├── ImageNet_VDL.sh ├── MNIST_SDL-NL.sh ├── MNIST_SDL.sh ├── MNIST_VDL-NL.sh ├── MNIST_VDL.sh └── build_ImageNet_LCN.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .idea 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Katrina Evtimova 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Coding with Multi-Layer Decoders using Variance Regularization 2 | This is a PyTorch implementation for the setup described in 3 | [Sparse Coding with Multi-Layer Decoders using Variance Regularization](https://arxiv.org/abs/2112.09214). 4 | 5 | ### Requirements 6 | 7 | - Python 3.7 8 | - [PyTorch](https://pytorch.org/get-started/previous-versions/) 1.6.0 with torchvision 0.7.0 9 | - Other dependencies: numpy, tensorboardX 10 | 11 | ### Datasets 12 | 13 | In our experiments, we use: 14 | - the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. We provide the train and validation splits in 15 | ```data/MNIST_train.npy``` and ```data/MNIST_val.npy```. 16 | - a custom dataset with 200,000 gray-scale natural image patches of size 28x28 extracted from 17 | [ImageNet](https://www.image-net.org/index.php). The script to generate it is 18 | [build_imagenet_LCN.sh](https://github.com/kevtimova/deep-sparse/blob/main/scripts/build_ImageNet_LCN.sh). 19 | 20 | ### Training 21 | 22 | The scripts below can be used to train sparse autoencoders with our different setups. 23 | 24 | | dataset | model | script | 25 | |------------------|----------|--------| 26 | | MNIST | SDL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/MNIST_SDL.sh) | 27 | | MNIST | SDL-NL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/MNIST_SDL-NL.sh) | 28 | | MNIST | VDL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/MNIST_VDL.sh) | 29 | | MNIST | VDL-NL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/MNIST_VDL-NL.sh) | 30 | | ImageNet_patches | SDL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/ImageNet_SDL.sh) | 31 | | ImageNet_patches | SDL-NL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/ImageNet_SDL-NL.sh) | 32 | | ImageNet_patches | VDL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/ImageNet_VDL.sh) | 33 | | Imagenet_patches | VDL-NL | [link](https://github.com/kevtimova/deep-sparse/blob/main/scripts/ImageNet_VDL-NL.sh) | 34 | 35 | ### Evaluation 36 | 37 | We evaluate our pre-trained sparse autoencoders on the downstream tasks of denoising (for MNIST and our custom 38 | ImageNet patches dataset) and classification in the low-data regime (for MNIST only). 39 | 40 | #### Denoising 41 | 42 | The denoising perfomance on the test set can be measured at the end of training by providing a list with levels of 43 | random noise (measured by std of Gaussian noise; the noise is added to the input images) in the ```noise``` argument 44 | in ```main.py```. 45 | 46 | Alternatively, ```eval_denoising.py``` can be used given a pre-trained autoencoder. 47 | 48 | #### Classification 49 | 50 | To evaluate the linear separability of codes obtained from the sparse autoencoders, we provide the steps below. 51 | 52 | Step 1: Given a pre-trained encoder, ```compute_codes.py``` can be used to create a dataset containing the codes 53 | for each MNIST image. 54 | 55 | Step 2: Using the dataset from the previous step, ```eval_classification.py``` can be used to measure classification 56 | performance with a set number of training samples per class. 57 | 58 | There are two options for the classifier - a linear classifier 59 | (located in ```modles/linear_classifier.py```) and a classifier which uses a randomly initialized LISTA encoder module 60 | followed by a linear classification layer (located in ```modles/lista_classifier.py```). 61 | -------------------------------------------------------------------------------- /compute_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import subprocess 6 | from importlib import import_module 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | torch.backends.cudnn.benchmark = True 12 | 13 | from utils import set_random_seed, get_dataset, np, my_iterator_val 14 | 15 | def define_args(): 16 | # Define arguments 17 | parser = argparse.ArgumentParser(description='Computing codes from pre-trained encoder.') 18 | parser.add_argument('--name', type=str, default='', 19 | help='Name of experiment.') 20 | parser.add_argument('--seed', type=int, default=11, metavar='S', 21 | help='Random seed.') 22 | parser.add_argument('--outdir', default='./results/', type=str, 23 | help='Path to the directory that contains the outputs.') 24 | parser.add_argument('--dataset', default='MNIST', type=str, 25 | help='Name of the dataset (options: MNIST | imagenet_LCN)') 26 | parser.add_argument('--datadir', default='../sample_data', type=str, 27 | help='Path to the directory that contains the data.') 28 | parser.add_argument('--batch_size', type=int, default=250, metavar='N', 29 | help='Batch size for training.') 30 | parser.add_argument('--num_workers', type=int, default=0, metavar='N', 31 | help='Number of workers.') 32 | parser.add_argument('--code_dim', type=int, default=128, 33 | help='Code dimension.') 34 | parser.add_argument('--im_size', type=int, default=28, 35 | help='Image input size.') 36 | parser.add_argument('--patch_size', type=int, default=0, 37 | help='Patch size to sample after rescaling to im_size (0 if no patch sampling).') 38 | parser.add_argument('--cuda', action='store_true', default=False, 39 | help='Whether to run code on GPU (default: run on CPU).') 40 | parser.add_argument('--encoder', default='lista_encoder', type=str, 41 | help='Encoder architecture.') 42 | parser.add_argument('--num_iter_LISTA', type=int, default=3, 43 | help='Number of LISTA iterations.') 44 | parser.add_argument('--pretrained_path_enc', default='', type=str, 45 | help='Path to the pre-trained encoder or decoder.') 46 | 47 | # Get arguments 48 | args = parser.parse_args() 49 | return args 50 | 51 | def compute_codes(args): 52 | # Logistics 53 | args.use_encoder = len(args.pretrained_path_enc) > 0 54 | 55 | # Experiment name 56 | if args.name == '': 57 | model_name = args.pretrained_path_enc.split('/')[-1].split('.pth')[0] 58 | args.name = f'{model_name}_codes' 59 | print('\nComputing: {}\n'.format(args.name)) 60 | print(json.dumps(args.__dict__, sort_keys=True, indent=4) + '\n') 61 | device = torch.device("cuda" if args.cuda else "cpu") 62 | 63 | # Working directory 64 | outdir = lambda dirname: os.path.join(args.outdir, dirname) 65 | if not os.path.exists(outdir('codes')): 66 | os.mkdir(outdir('codes')) 67 | 68 | # Experiment directories 69 | out_dir = os.path.join(outdir('codes'), args.name) 70 | if not os.path.exists(out_dir): 71 | os.mkdir(out_dir) 72 | 73 | # Random seed 74 | set_random_seed(args.seed, torch, np, random, args.cuda) 75 | 76 | # Training and test data 77 | dataset_train = get_dataset(args.dataset, args.datadir, 78 | train=True, im_size=args.im_size, 79 | patch_size=args.patch_size) 80 | dataset_test = get_dataset(args.dataset, args.datadir, 81 | train=False, im_size=args.im_size, 82 | patch_size=args.patch_size) 83 | 84 | # Fix order of training and test data 85 | data_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, 86 | pin_memory=args.cuda, shuffle=False, drop_last=False) 87 | data_test = DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.num_workers, 88 | pin_memory=args.cuda, shuffle=False, drop_last=False) 89 | 90 | # Get data information 91 | args.n_channels = dataset_test.__getitem__(0)[0].shape[0] 92 | 93 | # Load encoder 94 | encoder = getattr(import_module('models.{}'.format(args.encoder)), 'Encoder')(args).to(device) 95 | encoder.load_pretrained(args.pretrained_path_enc, freeze=True) 96 | encoder.eval() 97 | 98 | for data in [{'batches': data_train, 'split': 'train'}, 99 | {'batches': data_test, 'split': 'test'}]: 100 | split = data['split'] 101 | codes_to_save = None 102 | targets = None 103 | 104 | for batch, batch_info, should in my_iterator_val(args=args, data=data['batches'], log_interval=1): 105 | 106 | # Logistics 107 | y = batch['X'].to(device) 108 | target = batch['target'] 109 | 110 | # Compute the codes using amortized inference 111 | Zs = encoder(y) 112 | 113 | # Concatenate codes 114 | if codes_to_save is None: 115 | targets = target 116 | codes_to_save = Zs 117 | else: 118 | codes_to_save = torch.cat([codes_to_save, Zs], dim=0) 119 | targets = torch.cat([targets, target], dim=0) 120 | 121 | # Save codes and targets 122 | np.save(os.path.join(out_dir, f'{args.dataset}_{split}_codes.npy'), codes_to_save.cpu().numpy()) 123 | np.save(os.path.join(out_dir, f'{args.dataset}_{split}_targets.npy'), targets.cpu().numpy()) 124 | if 'train' in split: 125 | # Compute the mean and std of the codes for the training data 126 | np.save(os.path.join(out_dir, f'{args.dataset}_codes_mean.npy'), codes_to_save.mean().cpu().numpy()) 127 | np.save(os.path.join(out_dir, f'{args.dataset}_codes_std.npy'), codes_to_save.std().cpu().numpy()) 128 | 129 | # Final message 130 | final_msg = f'Finished computing {split} codes.' 131 | print(final_msg) 132 | 133 | 134 | if __name__ == '__main__': 135 | # Get arguments 136 | args = define_args() 137 | 138 | # Save git info 139 | args.git_head = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 140 | args.git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 141 | 142 | compute_codes(args) 143 | -------------------------------------------------------------------------------- /data/MNIST_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevtimova/deep-sparse/70562a3bc788f99117392b5c41d36224adee4178/data/MNIST_train.npy -------------------------------------------------------------------------------- /data/MNIST_val.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevtimova/deep-sparse/70562a3bc788f99117392b5c41d36224adee4178/data/MNIST_val.npy -------------------------------------------------------------------------------- /eval_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import subprocess 6 | import time 7 | from importlib import import_module 8 | 9 | import torch 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | from utils import set_random_seed, get_dataset, FixedSubsetSampler, np 17 | 18 | def define_args(): 19 | # Define arguments 20 | parser = argparse.ArgumentParser(description='Evaluation: classification.') 21 | parser.add_argument('--name', type=str, default='', 22 | help='Name of experiment.') 23 | parser.add_argument('--seed', type=int, default=11, metavar='S', 24 | help='Random seed.') 25 | parser.add_argument('--outdir', default='./results/', type=str, 26 | help='Path to the directory that stores the evaluation results.') 27 | parser.add_argument('--dataset', default='MNIST', type=str, 28 | help='name of the dataset (options: MNIST | codes)') 29 | parser.add_argument('--datadir', default='./data', type=str, 30 | help='Path to the directory that contains the data.') 31 | parser.add_argument('--data_splits', default='./data', type=str, 32 | help='Path to the directory that contains the data splits.') 33 | parser.add_argument('--n_classes', type=int, default=10, 34 | help='Number of classes.') 35 | parser.add_argument('--code_dim', type=int, default=128, 36 | help='Code dimension (for LISTA classifier).') 37 | parser.add_argument('--n_training_samples_per_class', type=int, default=100, metavar='N', 38 | help='Number of training samples for the model.') 39 | parser.add_argument('--n_val_samples', type=int, default=5000, metavar='N', 40 | help='Number of validation samples for the model.') 41 | parser.add_argument('--n_test_samples', type=int, default=10000, metavar='N', 42 | help='Number of test samples for the model.') 43 | parser.add_argument('--batch_size', type=int, default=250, metavar='N', 44 | help='Batch size for training.') 45 | parser.add_argument('--num_workers', type=int, default=0, metavar='N', 46 | help='Number of workers.') 47 | parser.add_argument('--classifier', default='linear_classifier', type=str, 48 | help='Architecture of classifier.') 49 | parser.add_argument('--im_size', type=int, default=28, 50 | help='Image input size.') 51 | parser.add_argument('--patch_size', type=int, default=0, 52 | help='Patch size to sample before rescaling to im_size (0 if no patch sampling).') 53 | parser.add_argument('--cuda', action='store_true', default=False, 54 | help='Whether to run code on GPU (default: run on CPU).') 55 | parser.add_argument('--n_batches_to_log', type=int, default=20, metavar='N', 56 | help='How many batches to log.') 57 | parser.add_argument('--epochs', type=int, default=100, 58 | help='Number of epochs.') 59 | parser.add_argument('--lr', type=float, default=1e-3, 60 | help='Learning rate for classifier.') 61 | parser.add_argument('--top_k', type=int, default=3, 62 | help='Track top k classification.') 63 | parser.add_argument('--L1_reg', type=float, default=0, 64 | help='Level of L1 regularization for the classifier\'s weights.') 65 | parser.add_argument('--L2_reg', type=float, default=0, 66 | help='Level of L1 regularization for the classifier\'s weights.') 67 | 68 | # Parse arguments 69 | args = parser.parse_args() 70 | return args 71 | 72 | def eval_class(args): 73 | # Logistics 74 | if args.dataset == 'codes': 75 | try: 76 | data_source = '-'.join(args.datadir.split('/')[-1].split('_s')[-1].split('_')[1:4]) 77 | except: 78 | data_source = 'codes' 79 | else: 80 | data_source = args.dataset 81 | if 'lista' in args.classifier: 82 | data_source += '-' + '-'.join(args.classifier.split('_')) 83 | if len(args.name) == 0: 84 | timestamp = str(int(time.time())) 85 | args.name = f'{data_source}_{timestamp}_s_{args.seed}_' \ 86 | f'ntr_{args.n_training_samples_per_class}_lr_{args.lr}_' \ 87 | f'L1_{args.L1_reg}_L2_{args.L2_reg}' 88 | else: 89 | timestamp = args.name.split('_s_')[0].split('_')[-1] 90 | print('\nClassificaion experiment: {}\n'.format(args.name)) 91 | print(json.dumps(args.__dict__, sort_keys=True, indent=4) + '\n') 92 | device = torch.device("cuda" if args.cuda else "cpu") 93 | 94 | # Experiment directories 95 | outdir = lambda dirname: os.path.join(args.outdir, dirname) 96 | if not os.path.exists(outdir('classify')): 97 | os.mkdir(outdir('classify')) 98 | results_file = os.path.join(outdir('classify'), 'classif_results.tsv') 99 | model_loc = outdir('checkpoints') + '/{}'.format(args.name) + '.pth' 100 | 101 | # Tensorboard support. To run: tensorboard --logdir /logs 102 | logs_dir = outdir('classify') + '/{}'.format(args.name) 103 | os.mkdir(logs_dir) 104 | writer = SummaryWriter(log_dir=logs_dir) 105 | 106 | # Random seed 107 | set_random_seed(args.seed, torch, np, random, args.cuda) 108 | 109 | # Read training data 110 | dataset_train = get_dataset(args.dataset, args.datadir, 111 | train=True, im_size=args.im_size, 112 | patch_size=args.patch_size) 113 | train_indices = np.load(os.path.join(args.data_splits, f'{args.dataset}_train.npy')) 114 | 115 | # Select pre-set number of training elements per class 116 | train_indices_selected = [] 117 | for class_id in range(args.n_classes): 118 | selected = dataset_train.targets[train_indices] == class_id 119 | train_indices_selected = train_indices_selected + list(train_indices[selected][:args.n_training_samples_per_class]) 120 | train_sampler = SubsetRandomSampler(train_indices_selected) 121 | data_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, 122 | pin_memory=args.cuda, shuffle=False, sampler=train_sampler, drop_last=False) 123 | 124 | # Read validation data 125 | val_indices = list(np.load(os.path.join(args.data_splits, f'{args.dataset}_val.npy')))[:args.n_val_samples] 126 | val_sampler = FixedSubsetSampler(val_indices) 127 | data_val = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, 128 | pin_memory=args.cuda, shuffle=False, sampler=val_sampler, drop_last=False) 129 | 130 | # Read test data 131 | dataset_test = get_dataset(args.dataset, args.datadir, 132 | train=False, im_size=args.im_size, 133 | patch_size=args.patch_size) 134 | test_indices = [*range(len(dataset_test))][:args.n_test_samples] 135 | test_sampler = FixedSubsetSampler(test_indices) 136 | data_test = DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.num_workers, 137 | pin_memory=args.cuda, shuffle=False, sampler=test_sampler, drop_last=False) 138 | 139 | # Get data information 140 | args.n_batches = len(data_train) 141 | args.log_interval = max(args.n_batches // args.n_batches_to_log, 1) if args.n_batches_to_log > 0 else 0 142 | args.input_dim = dataset_train.__getitem__(0)[0].flatten().shape[0] 143 | 144 | # Classifier 145 | classifier = getattr(import_module('models.{}'.format(args.classifier)), 'Classifier')(args).to(device) 146 | 147 | # Classifier training loop 148 | best_train = {'loss': None, 'ep': None, 'top1': None, f'top{args.top_k}': None} 149 | best_val = {'loss': None, 'ep': None, 'top1': None, f'top{args.top_k}': None} 150 | 151 | # Training 152 | for epoch in range(args.epochs): 153 | # Train classifier 154 | train_stats = classifier.train(data_train, epoch, args.top_k) 155 | print('Epoch [{}/{}] Train Loss: {:.6f} Top1: {:.3f} Top{}: {:.3f} N: {}'.format( 156 | epoch + 1, args.epochs, train_stats['loss'], train_stats['top1'], args.top_k, train_stats[f'top{args.top_k}'], 157 | train_stats['n_samples'])) 158 | writer.add_scalar('loss_train', train_stats['loss'], epoch) 159 | writer.add_scalar('top1_train', train_stats['top1'], epoch) 160 | writer.add_scalar(f'top{args.top_k}_train', train_stats[f'top{args.top_k}'], epoch) 161 | 162 | # Track best training 163 | if best_train['loss'] is not None: 164 | if best_train['loss'] > train_stats['loss']: 165 | best_train['loss'] = train_stats['loss'] 166 | best_train['top1'] = train_stats['top1'] 167 | best_train[f'top{args.top_k}'] = train_stats[f'top{args.top_k}'] 168 | best_train['ep'] = epoch 169 | else: 170 | best_train['loss'] = train_stats['loss'] 171 | best_train['top1'] = train_stats['top1'] 172 | best_train[f'top{args.top_k}'] = train_stats[f'top{args.top_k}'] 173 | best_train['ep'] = epoch 174 | 175 | # Validation 176 | val_stats = classifier.test(data_val, args.top_k) 177 | print('Epoch [{}/{}] Valid Loss: {:.6f} Top1: {:.3f} Top{}: {:.3f} N: {}'.format( 178 | epoch + 1, args.epochs, val_stats['loss'], val_stats['top1'], args.top_k, val_stats[f'top{args.top_k}'], 179 | val_stats['n_samples'])) 180 | writer.add_scalar('loss_val', val_stats['loss'], epoch) 181 | writer.add_scalar('top1_val', val_stats['top1'], epoch) 182 | writer.add_scalar(f'top{args.top_k}_val', val_stats[f'top{args.top_k}'], epoch) 183 | 184 | # Track best validation 185 | if best_val['loss'] is not None: 186 | if best_val['loss'] > val_stats['loss']: 187 | best_val['loss'] = val_stats['loss'] 188 | best_val['top1'] = val_stats['top1'] 189 | best_val[f'top{args.top_k}'] = val_stats[f'top{args.top_k}'] 190 | best_val['ep'] = epoch 191 | 192 | # Save best model 193 | best_val_cl = classifier 194 | torch.save(classifier.state_dict(), model_loc) 195 | else: 196 | best_val['loss'] = val_stats['loss'] 197 | best_val['top1'] = val_stats['top1'] 198 | best_val[f'top{args.top_k}'] = val_stats[f'top{args.top_k}'] 199 | best_val['ep'] = epoch 200 | best_val_cl = classifier 201 | 202 | # Test performance 203 | test_stats = best_val_cl.test(data_test, args.top_k) 204 | 205 | # Final message 206 | tr_msg = '{}\tBEST Training\tLoss: {:.6f}\tTop1: {:.6f}\tTop{}: {:.6f} ep: {}\n'.format( 207 | args.name, best_train['loss'], best_train['top1'], 208 | args.top_k, best_train[f'top{args.top_k}'], best_train['ep'] + 1) 209 | val_msg = '{}\tBEST Validation\tLoss: {:.6f}\tTop1: {:.6f}\tTop{}: {:.6f} ep: {}\n'.format( 210 | args.name, best_val['loss'], best_val['top1'], 211 | args.top_k, best_val[f'top{args.top_k}'], best_val['ep'] + 1) 212 | test_msg = '{}\tBEST Val TEST\tLoss: {:.6f}\tTop1: {:.6f}\tTop{}: {:.6f} ep: {}\n'.format( 213 | args.name, test_stats['loss'], test_stats['top1'], 214 | args.top_k, test_stats[f'top{args.top_k}'], best_val['ep'] + 1) 215 | final_msg = '\n' + tr_msg + val_msg + test_msg 216 | print(final_msg) 217 | 218 | # Save results 219 | final = open(results_file, 'a') 220 | head = '{}\t{}\t{}\t{}\t{}\t{}\t{}\t'.format(data_source, timestamp, args.seed, 221 | args.n_training_samples_per_class, 222 | args.lr, args.L1_reg, args.L2_reg) 223 | train_row = head + 'train\t{:.6f}\t{:.6f}\t{:.6f}\t{}\n'.format(best_train['loss'], best_train['top1'], 224 | best_train[f'top{args.top_k}'], best_train['ep'] + 1) 225 | final.write(train_row) 226 | val_row = head + 'valid\t{:.6f}\t{:.6f}\t{:.6f}\t{}\n'.format(best_val['loss'], best_val['top1'], 227 | best_val[f'top{args.top_k}'], best_val['ep'] + 1) 228 | final.write(val_row) 229 | test_row = head + 'test\t{:.6f}\t{:.6f}\t{:.6f}\t{}\n'.format(test_stats['loss'], test_stats['top1'], 230 | test_stats[f'top{args.top_k}'], best_val['ep'] + 1) 231 | final.write(test_row) 232 | 233 | final.close() 234 | writer.close() 235 | 236 | 237 | if __name__ == '__main__': 238 | # Get arguments 239 | args = define_args() 240 | 241 | # Save git info 242 | args.git_head = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 243 | args.git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 244 | 245 | eval_class(args) 246 | -------------------------------------------------------------------------------- /eval_denoising.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import subprocess 6 | import time 7 | from importlib import import_module 8 | 9 | import torch 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | from utils import get_dataset, set_random_seed, FixedSubsetSampler, my_iterator_val, np, \ 16 | add_noise_to_img, PSNR, L0, dewhiten, inverse_transform, save_img, img_grid 17 | 18 | def define_args(): 19 | # Define arguments 20 | parser = argparse.ArgumentParser(description='Evaluation: denoising.') 21 | parser.add_argument('--name', type=str, default='', 22 | help='Name of experiment.') 23 | parser.add_argument('--seed', type=int, default=11, metavar='S', 24 | help='Random seed.') 25 | parser.add_argument('--outdir', default='./results/', type=str, 26 | help='Path to the directory that contains the outputs.') 27 | parser.add_argument('--dataset', default='MNIST', type=str, 28 | help='name of the dataset (options: MNIST | CIFAR10 | CIFAR100 | SVHN, default: MNIST)') 29 | parser.add_argument('--datadir', default='../sample_data', type=str, 30 | help='Path to the directory that contains the data.') 31 | parser.add_argument('--n_test_samples', type=int, default=10000, metavar='N', 32 | help='Number of validation samples for the model.') 33 | parser.add_argument('--batch_size', type=int, default=250, metavar='N', 34 | help='Batch size for training.') 35 | parser.add_argument('--num_workers', type=int, default=0, metavar='N', 36 | help='Number of workers.') 37 | parser.add_argument('--additive_noise', type=float, default=0, 38 | help='Level of noise for input.') 39 | parser.add_argument('--decoder', default='linear_dictionary', type=str, 40 | help='Architecture of pre-trained decoder.') 41 | parser.add_argument('--code_dim', type=int, default=128, 42 | help='Code dimension.') 43 | parser.add_argument('--pretrained_path_dec', default='', type=str, 44 | help='Path to the pre-trained encoder or decoder.') 45 | parser.add_argument('--im_size', type=int, default=28, 46 | help='Image input size.') 47 | parser.add_argument('--patch_size', type=int, default=0, 48 | help='Patch size to sample before rescaling to im_size (0 if no patch sampling).') 49 | parser.add_argument('--cuda', action='store_true', default=False, 50 | help='Whether to run code on GPU (default: run on CPU).') 51 | parser.add_argument('--encoder', default='fc_encoder', type=str, 52 | help='Encoder architecture.') 53 | parser.add_argument('--num_iter_LISTA', type=int, default=0, 54 | help='Number of LISTA iterations.') 55 | parser.add_argument('--pretrained_path_enc', default='', type=str, 56 | help='Path to the pre-trained encoder or decoder.') 57 | parser.add_argument('--hidden_dim', type=int, default=128, metavar='N', 58 | help='Hidden dimension for multi-layer decoder.') 59 | 60 | # Get arguments 61 | args = parser.parse_args() 62 | return args 63 | 64 | def eval_denoising(args): 65 | # Logistics 66 | whitening = args.dataset == 'imagenet_LCN' 67 | dataset_eval_indices = args.__dict__.pop('dataset_test_indices', None) 68 | 69 | # Experiment name 70 | if args.name == '': 71 | model_name = args.pretrained_path_enc.split('/')[-1].split('.pth')[0] 72 | args.name = f'{model_name}_seed_{args.seed}_n{args.additive_noise}' 73 | args.name = '{}_{}'.format(args.name, str(int(time.time()))) 74 | print('\nEval experiment: {}\n'.format(args.name)) 75 | print(json.dumps(args.__dict__, sort_keys=True, indent=4) + '\n') 76 | device = torch.device("cuda" if args.cuda else "cpu") 77 | 78 | # Working directory 79 | outdir = lambda dirname: os.path.join(args.outdir, dirname) 80 | 81 | # Experiment directories 82 | img_dir = os.path.join(outdir('imgs'), args.name) 83 | os.mkdir(img_dir) 84 | 85 | # Tensorboard support. To run: tensorboard --logdir /logs 86 | experiment_logs_dir = outdir('logs') + '/{}'.format(args.name) 87 | os.mkdir(experiment_logs_dir) 88 | writer = SummaryWriter(log_dir=experiment_logs_dir) 89 | 90 | # Random seed 91 | set_random_seed(args.seed, torch, np, random, args.cuda) 92 | 93 | # Evaluation data 94 | dataset_eval = get_dataset(args.dataset, args.datadir, 95 | train=False, im_size=args.im_size, 96 | patch_size=args.patch_size) 97 | if dataset_eval_indices is None: 98 | dataset_eval_indices = [*range(len(dataset_eval))][:args.n_test_samples] 99 | 100 | # Order of elements in valudation set can be fixed 101 | dataset_eval_sampler = FixedSubsetSampler(dataset_eval_indices) 102 | data_eval = DataLoader(dataset_eval, batch_size=args.batch_size, num_workers=args.num_workers, 103 | pin_memory=args.cuda, shuffle=False, sampler=dataset_eval_sampler, drop_last=True) 104 | 105 | # Get data information 106 | n_batches_val = len(data_eval) 107 | log_viz_interval = 1 108 | args.n_channels = dataset_eval.__getitem__(0)[0].shape[0] 109 | 110 | # Load decoder 111 | decoder = getattr(import_module('models.{}'.format(args.decoder)), 'Decoder')(args).to(device) 112 | decoder.load_pretrained(args.pretrained_path_dec, freeze=True) 113 | decoder.eval() 114 | 115 | # Load encoder 116 | encoder = getattr(import_module('models.{}'.format(args.encoder)), 'Encoder')(args).to(device) 117 | encoder.load_pretrained(args.pretrained_path_enc, freeze=True) 118 | encoder.eval() 119 | 120 | # Evaluation data loop 121 | psnr_orig_aggr = 0 122 | psnr_noisy_img_aggr = 0 123 | psnr_noisy_rec_aggr = 0 124 | L0_orig_aggr = 0 125 | L0_noisy_aggr = 0 126 | n_samples = 0 127 | args.epoch = 0 128 | 129 | for batch, batch_info, should in my_iterator_val(args, data_eval, log_viz_interval): 130 | 131 | # Logistics 132 | batch_id = batch['batch_id'] 133 | y = batch['X'].to(device) 134 | batch_size = batch_info['size'] 135 | 136 | # Gererate noise for inputs 137 | y_noisy = add_noise_to_img(y, args.additive_noise, torch) 138 | 139 | # Whitening 140 | y_orig_mean, y_orig_std = None, None 141 | y_noisy_mean, y_noisy_std = None, None 142 | if args.dataset == 'imagenet_LCN': 143 | y_orig_mean, y_orig_std = batch['extra'] 144 | y_orig_mean, y_orig_std = y_orig_mean.to(device), y_orig_std.to(device) 145 | y_noisy_mean, y_noisy_std = y_orig_mean, y_orig_std 146 | 147 | # Compute PSNR between original image and noisy image 148 | psnr_noisy_img = PSNR(y, y_noisy, args.dataset, 149 | tar_sample_mean=y_orig_mean, tar_sample_std=y_orig_std, 150 | pred_sample_mean=y_noisy_mean, pred_sample_std=y_noisy_std) 151 | 152 | 153 | # Compute the Zs for the original and noisy data uding amortized inference 154 | Zs_orig = encoder(y) 155 | Zs_noisy = encoder(y_noisy) 156 | 157 | # L0 of codes 158 | l0_orig = L0(Zs_orig) 159 | l0_noisy = L0(Zs_noisy) 160 | 161 | # Reconstructions 162 | y_hat_orig = decoder(Zs_orig) 163 | y_hat_noisy = decoder(Zs_noisy) 164 | 165 | # PSNR 166 | psnr_orig = PSNR(y, y_hat_orig, args.dataset, 167 | tar_sample_mean=y_orig_mean, tar_sample_std=y_orig_std) 168 | psnr_noisy_rec = PSNR(y, y_hat_noisy, args.dataset, 169 | tar_sample_mean=y_orig_mean, tar_sample_std=y_orig_std, 170 | pred_sample_mean=y_noisy_mean, pred_sample_std=y_noisy_std) 171 | 172 | # De-whiten 173 | if whitening: 174 | y = dewhiten(y, y_orig_mean, y_orig_std) 175 | y_hat_orig = dewhiten(y_hat_orig, y_orig_mean, y_orig_std) 176 | y_noisy = dewhiten(y_noisy, y_noisy_mean, y_noisy_std) 177 | y_hat_noisy = dewhiten(y_hat_noisy, y_noisy_mean, y_noisy_std) 178 | 179 | # Log PSNR 180 | writer.add_scalar('psnr_orig', psnr_orig.item(), batch_id) 181 | writer.add_scalar('psnr_noisy_img', psnr_noisy_img.item(), batch_id) 182 | writer.add_scalar('psnr_noisy_rec', psnr_noisy_rec.item(), batch_id) 183 | writer.add_scalar('L0_orig', l0_orig.item(), batch_id) 184 | writer.add_scalar('L0_noisy', l0_noisy.item(), batch_id) 185 | psnr_orig_aggr += psnr_orig * batch_size 186 | psnr_noisy_img_aggr += psnr_noisy_img * batch_size 187 | psnr_noisy_rec_aggr += psnr_noisy_rec * batch_size 188 | L0_orig_aggr += l0_orig * batch_size 189 | L0_noisy_aggr += l0_noisy * batch_size 190 | n_samples += batch_size 191 | 192 | # Log targets and reconstructions 193 | if should['log_val_imgs']: 194 | y = inverse_transform(y[:16], args.dataset).clamp_(min=0, max=1) 195 | y_noisy = inverse_transform(y_noisy[:16], args.dataset).clamp_(min=0, max=1) 196 | y_hat_orig = inverse_transform(y_hat_orig[:16], args.dataset).clamp_(min=0, max=1) 197 | y_hat_noisy = inverse_transform(y_hat_noisy[:16], args.dataset).clamp_(min=0, max=1) 198 | save_img(y, f'{img_dir}/y_{(batch_id+1):04d}.png', norm=False) 199 | save_img(y_noisy, f'{img_dir}/y_noisy_{(batch_id+1):04d}.png', norm=False) 200 | save_img(y_hat_orig, f'{img_dir}/y_hat_orig_{(batch_id+1):04d}.png', norm=False) 201 | save_img(y_hat_noisy, f'{img_dir}/y_hat_noisy_{(batch_id+1):04d}.png', norm=False) 202 | writer.add_image(f'y', img_grid(y), batch_id) 203 | writer.add_image(f'y_noisy', img_grid(y_noisy), batch_id) 204 | writer.add_image(f'y_hat_orig', img_grid(y_hat_orig), batch_id) 205 | writer.add_image(f'y_hat_noisy', img_grid(y_hat_noisy), batch_id) 206 | 207 | # Log columns of linear decoder 208 | n_cols = min(128, args.code_dim) 209 | cols = decoder.viz_columns(n_cols) 210 | save_img(cols, f'{img_dir}/lin_dec_cols.png', norm=False, n_rows=int(2 ** (np.log2(n_cols) // 2))) 211 | writer.add_image(f'lin_dec_cols', img_grid(cols, norm=False), 0) 212 | 213 | # Aggregate PSNR 214 | args.psnr_orig_aggr = psnr_orig_aggr / n_samples 215 | args.psnr_noisy_img_aggr = psnr_noisy_img_aggr / n_samples 216 | args.psnr_noisy_rec_aggr = psnr_noisy_rec_aggr / n_samples 217 | args.L0_orig_aggr = L0_orig_aggr / n_samples 218 | args.L0_noisy_aggr = L0_noisy_aggr / n_samples 219 | final_msg = 'noise {}\tPSNR_orig: {:.3f}\tPSNR_noisy {:.3f}\t' \ 220 | 'PSNR_noisy_rec {:.3f}\tL0_orig {:.3f}\tL0_noisy {:.3f}'.format(args.additive_noise, 221 | args.psnr_orig_aggr.item(), 222 | args.psnr_noisy_img_aggr.item(), 223 | args.psnr_noisy_rec_aggr.item(), 224 | args.L0_orig_aggr.item(), 225 | args.L0_noisy_aggr.item()) 226 | print(final_msg) 227 | 228 | writer.close() 229 | 230 | 231 | if __name__ == '__main__': 232 | # Get arguments 233 | args = define_args() 234 | 235 | # Save git info 236 | args.git_head = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 237 | args.git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 238 | 239 | eval_denoising(args) 240 | -------------------------------------------------------------------------------- /imagenet_LCN_patches.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import subprocess 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | from utils import get_dataset, FixedSubsetSampler, get_gaussian_filter, LocalContrastNorm 12 | 13 | def define_args(): 14 | # Arguments 15 | parser = argparse.ArgumentParser(description='Generate locally contrast normalized ImageNet patches.') 16 | parser.add_argument('--datadir', type=str, default='', 17 | help='Path to original Imagenet dataset.') 18 | parser.add_argument('--outdir', type=str, default='', 19 | help='Path for new LCN dataset.') 20 | parser.add_argument('--seed', type=int, default=29, 21 | help='random seed (default: 29)') 22 | parser.add_argument('--n_train', type=int, default=200000, 23 | help='Number of training samples.') 24 | parser.add_argument('--n_val', type=int, default=20000, 25 | help='Number of validation samples.') 26 | parser.add_argument('--n_test', type=int, default=20000, 27 | help='Number of test samples.') 28 | parser.add_argument('--im_size', type=int, default=256, 29 | help='Image input size.') 30 | parser.add_argument('--patch_size', type=int, default=52, 31 | help='Patch size to sample after rescaling to im_size.') 32 | parser.add_argument('--patch_type', type=str, default='random', 33 | help='Random or center patch.') 34 | parser.add_argument('--batch_size', type=int, default=250, 35 | help='Batch size for reading data.') 36 | parser.add_argument('--num_workers', type=int, default=0, 37 | help='Number of workers.') 38 | parser.add_argument('--cuda', action='store_true', default=False, 39 | help='Whether to run code on GPU (default: run on CPU).') 40 | parser.add_argument('--gaussian_filter_radius', type=int, default=13, 41 | help='Size of Gaussian Filter.') 42 | parser.add_argument('--gaussian_filter_sigma', type=float, default=5, 43 | help='Std of Gaussian Filter.') 44 | parser.add_argument('--std_threshold', type=float, default=0.5, 45 | help='Threshold for std of selected patches.') 46 | parser.add_argument('--epochs', type=int, default=10, 47 | help='Number of passes through the data.') 48 | # Parse arguments 49 | args = parser.parse_args() 50 | return args 51 | 52 | def generate_LCN_patches(data, gaussian_filter, args): 53 | 54 | # Compute the LCN train patches 55 | lcn_patches = None 56 | lcn_patches_mean = None 57 | lcn_patches_std = None 58 | n_patches = 0 59 | break_loops = False 60 | for epoch in range(args.epochs): 61 | for img, _ in data: 62 | img = img.to(device) 63 | 64 | # Whitening 65 | img, img_mean, img_std = LocalContrastNorm(img, gaussian_filter) 66 | selected_patches = img.std((1, 2, 3)) >= args.std_threshold 67 | 68 | if lcn_patches is None: 69 | lcn_patches = img[selected_patches].cpu().numpy() 70 | lcn_patches_mean = img_mean[selected_patches].cpu().numpy() 71 | lcn_patches_std = img_std[selected_patches].cpu().numpy() 72 | else: 73 | lcn_patches = np.append(lcn_patches, img[selected_patches].cpu().numpy(), axis=0) 74 | lcn_patches_mean = np.append(lcn_patches_mean, img_mean[selected_patches].cpu().numpy(), axis=0) 75 | lcn_patches_std = np.append(lcn_patches_std, img_std[selected_patches].cpu().numpy(), axis=0) 76 | 77 | n_patches += sum(selected_patches) 78 | print(f'Epoch {epoch} selected: {sum(selected_patches)}. Total: {n_patches}.') 79 | if n_patches >= args.n_train + args.n_val + args.n_test: 80 | break_loops = True 81 | print(f'Got all the patches during epoch {epoch + 0}!') 82 | break 83 | 84 | if break_loops: 85 | break 86 | 87 | # Save train patches 88 | training = args.n_train + args.n_val 89 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_train.npy'), lcn_patches[:training]) 90 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_train_mean.npy'), lcn_patches_mean[:training]) 91 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_train_std.npy'), lcn_patches_std[:training]) 92 | print(f'Saved {len(lcn_patches[:training])} train LCN patches.') 93 | 94 | # Save training and validation splits 95 | np.save(os.path.join(args.outdir, f'imagenet_LCN_train.npy'), [*range(args.n_train)]) 96 | np.save(os.path.join(args.outdir, f'imagenet_LCN_val.npy'), [*range(args.n_train, args.n_train + args.n_val)]) 97 | 98 | # Save test patches 99 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_test.npy'), lcn_patches[training:(training + args.n_test)]) 100 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_test_mean.npy'), lcn_patches_mean[training:(training + args.n_test)]) 101 | np.save(os.path.join(args.outdir, f'imagenet_LCN_patches_test_std.npy'), lcn_patches_std[training:(training + args.n_test)]) 102 | print(f'Saved {len(lcn_patches[training:(training + args.n_test)])} test LCN patches.') 103 | 104 | 105 | if __name__ == '__main__': 106 | # Get arguments 107 | args = define_args() 108 | 109 | # Save git info 110 | args.git_head = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 111 | args.git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 112 | 113 | print(json.dumps(args.__dict__, sort_keys=True, indent=4) + '\n') 114 | 115 | # Logistics 116 | device = torch.device("cuda" if args.cuda else "cpu") 117 | if not os.path.exists(args.outdir): 118 | os.mkdir(args.outdir) 119 | 120 | # Set random seed 121 | np.random.seed(args.seed) 122 | 123 | # Read training data 124 | dataset_start_time = time.time() 125 | dataset = get_dataset(dataset_name='imagenet', datadir=args.datadir, 126 | train=True, im_size=args.im_size, 127 | patch_size=args.patch_size, patch_type=args.patch_type) 128 | n_channels = dataset.__getitem__(0)[0].shape[0] 129 | 130 | # Training and test splits 131 | dataset_permutation = np.random.permutation(len(dataset)).tolist() 132 | 133 | # Load data 134 | dataset_sampler = FixedSubsetSampler(dataset_permutation) 135 | data = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, 136 | pin_memory=args.cuda, shuffle=False, sampler=dataset_sampler, drop_last=False) 137 | dataset_end_time = time.time() 138 | print(f"Dataset loading time: {dataset_end_time - dataset_start_time:.1f} \n") 139 | 140 | # Gaussian filter 141 | gaussian_filter = get_gaussian_filter(n_channels, device, 142 | radius=args.gaussian_filter_radius, 143 | sigma=args.gaussian_filter_sigma) 144 | 145 | # Generate patches 146 | generate_LCN_patches(data, gaussian_filter, args) 147 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import os 5 | import random 6 | import string 7 | import subprocess 8 | import time 9 | from importlib import import_module 10 | 11 | import torch 12 | from tensorboardX import SummaryWriter 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.sampler import SubsetRandomSampler 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | from utils import get_dataset, PSNR, MSE, normalize_kernels, L0, FixedSubsetSampler,\ 19 | set_random_seed, np, Gauge, my_iterator, my_iterator_val, ISTA, hinge, compute_energy, \ 20 | sqrt_var, dewhiten, log_viz, anneal_learning_rate, \ 21 | print_final_training_msg, print_final_eval_msg 22 | 23 | from eval_denoising import eval_denoising 24 | 25 | def define_args(): 26 | # Define arguments 27 | parser = argparse.ArgumentParser(description='Multi-layer Sparse Coding with Variance Regularization.') 28 | 29 | # Experiment details 30 | parser.add_argument('--name', type=str, default='', 31 | help='Name of experiment.') 32 | parser.add_argument('--seed', type=int, default=11, metavar='S', 33 | help='Random seed.') 34 | parser.add_argument('--outdir', default='./results/', type=str, 35 | help='Path to the directory that contains the outputs.') 36 | parser.add_argument('--cuda', action='store_true', default=False, 37 | help='Whether to run code on GPU (default: run on CPU).') 38 | parser.add_argument('--num_workers', type=int, default=0, metavar='N', 39 | help='Number of workers.') 40 | # Data processing 41 | parser.add_argument('--batch_size', type=int, default=250, metavar='N', 42 | help='Mini-batch size.') 43 | parser.add_argument('--datadir', default='./data', type=str, 44 | help='Path to the directory that contains the data.') 45 | parser.add_argument('--dataset', default='MNIST', type=str, 46 | help='Name of the dataset (options: MNIST | imagenet_LCN).') 47 | parser.add_argument('--data_splits', default='./data', type=str, 48 | help='Path to the directory that contains the data splits.') 49 | parser.add_argument('--n_training_samples', type=int, default=55000, metavar='N', 50 | help='Number of training samples for the model.') 51 | parser.add_argument('--n_val_samples', type=int, default=5000, metavar='N', 52 | help='Number of validation samples for the model.') 53 | parser.add_argument('--n_test_samples', type=int, default=10000, metavar='N', 54 | help='Number of test samples for evaluating the model.') 55 | parser.add_argument('--im_size', type=int, default=28, 56 | help='Image input size.') 57 | parser.add_argument('--patch_size', type=int, default=0, 58 | help='Patch size to sample after rescaling to im_size (0 if no patch sampling).') 59 | parser.add_argument('--epochs', type=int, default=1, metavar='N', 60 | help='Number of times the model passes through all the training data.') 61 | # Decoder arguments 62 | parser.add_argument('--decoder', default='linear_dictionary', type=str, 63 | help='Decoder architecture.') 64 | parser.add_argument('--pretrained_path_dec', default='', type=str, 65 | help='Path to the state_dict of a pre-trained decoder.') 66 | parser.add_argument('--train_decoder', action='store_true', default=False, 67 | help='Whether to train a decoder (default is not).') 68 | parser.add_argument('--code_dim', type=int, default=128, metavar='N', 69 | help='Code dimension.') 70 | parser.add_argument('--hidden_dim', type=int, default=128, metavar='N', 71 | help='Hidden dimension for multi-layer decoder.') 72 | parser.add_argument('--norm_decoder', type=float, default=0, 73 | help='Radius of the sphere the decoder\'s columns are projected to. Default: no normalization.') 74 | parser.add_argument('--lrt_D', type=float, default=1e-4, 75 | help='Learning rate for the decoder weights.') 76 | parser.add_argument('--weight_decay_D', type=float, default=0, 77 | help='Weight decay for decoder optimizer.') 78 | parser.add_argument('--weight_decay_D_bias', type=float, default=0, 79 | help='Weight decay to use on bias term in decoder.') 80 | parser.add_argument('--anneal_lr_D_freq', type=int, default=0, 81 | help='How frequently to anneal the decoder\'s learning rate.') 82 | parser.add_argument('--anneal_lr_D_mult', type=float, default=0, 83 | help='Multiplier for annealing the decoder\'s learning rate.') 84 | # Encoder arguments 85 | parser.add_argument('--encoder', default='lista_encoder', type=str, 86 | help='Encoder architecture.') 87 | parser.add_argument('--pretrained_path_enc', default='', type=str, 88 | help='Path to the state_dict of a pre-trained encoder.') 89 | parser.add_argument('--train_encoder', action='store_true', default=False, 90 | help='Whether to train an encoder to predict codes from inference (default is not).') 91 | parser.add_argument('--num_iter_LISTA', type=int, default=0, 92 | help='Number of LISTA iterations.') 93 | parser.add_argument('--lrt_E', type=float, default=1e-4, 94 | help='Learning rate for the encoder weights.') 95 | parser.add_argument('--weight_decay_E', type=float, default=0, 96 | help='Weight decay for encoder\'s parameters.') 97 | parser.add_argument('--weight_decay_E_bias', type=float, default=0, 98 | help='Weight decay for encoder\'s bias.') 99 | parser.add_argument('--anneal_lr_E_freq', type=int, default=0, 100 | help='How frequently to anneal the encoder\'s learning rate.') 101 | parser.add_argument('--anneal_lr_E_mult', type=float, default=0, 102 | help='Multiplier for annealing the encoder\'s learning rate.') 103 | # Inference arguments 104 | parser.add_argument('--sparsity_reg', type=float, default=1e-3, 105 | help='Sparsity term for codes during training.') 106 | parser.add_argument('--lrt_Z', type=float, default=1, 107 | help='Learning rate for sparse codes calculation.') 108 | parser.add_argument('--positive_ISTA', action='store_true', default=False, 109 | help='Whether to constrain ISTA to positive values.') 110 | parser.add_argument('--FISTA', action='store_true', default=False, 111 | help='Whether to use a faster version of ISTA.') 112 | parser.add_argument('--n_steps_inf', type=int, default=200, metavar='N', 113 | help='Number of inference iterations for computing each code.') 114 | parser.add_argument('--stop_early', type=float, default=1e-3, 115 | help='Tolerance for early stopping during (F)ISTA.') 116 | parser.add_argument('--use_Zs_enc_as_init', action='store_true', default=False, 117 | help='Epoch from which to start using encoder\'s predictions as initial values for ISTA.') 118 | parser.add_argument('--Zs_init_val', type=float, default=0, 119 | help='Initial value to initialize codes for ISTA in non-linear model.') 120 | parser.add_argument('--variance_reg', type=float, default=0, 121 | help='Weight of regularization term: squared hinge on std of latent components.') 122 | parser.add_argument('--hinge_threshold', type=float, default=0.5, 123 | help='Threshold in the hinge loss.') 124 | parser.add_argument('--code_reg', type=float, default=0, 125 | help='Coefficient for energy coming from distance to the encoder\'s predictions.') 126 | # Evaluation 127 | parser.add_argument('--noise', type=str, default='[]', 128 | help='List with levels of noise for denoising evaluation.') 129 | # Parse arguments 130 | args = parser.parse_args() 131 | return args 132 | 133 | def train_decoder_step(decoder, y, y_mean, y_std, Zs, optimizer_dec, args): 134 | # Decoder 135 | decoder.train() 136 | 137 | # Reconstruction loss 138 | y_hat = decoder(Zs) 139 | rec_loss_y = MSE(y, y_hat, reduction='mean') 140 | 141 | # Backward pass 142 | optimizer_dec.zero_grad() 143 | decoder.zero_grad() 144 | rec_loss_y.backward() 145 | optimizer_dec.step() 146 | 147 | # Compute PSNR 148 | psnr = PSNR(y, y_hat, args.dataset, y_mean, y_std) 149 | 150 | # Normalize decoder weights 151 | if args.norm_decoder > 0: 152 | normalize_kernels(decoder, args.norm_decoder) 153 | 154 | # Output dictionary 155 | output = {'rec_loss_y': rec_loss_y.detach(), 156 | 'y_hat': y_hat.detach(), 157 | 'psnr': psnr} 158 | return output 159 | 160 | def train_encoder_step(encoder, optimizer_enc, Zs_enc, Zs_inf): 161 | # Loss from codes 162 | rec_loss_code = MSE(Zs_enc, Zs_inf, reduction='mean') 163 | 164 | # Update encoder 165 | encoder.zero_grad() 166 | optimizer_enc.zero_grad() 167 | rec_loss_code.backward() 168 | optimizer_enc.step() 169 | 170 | output = {'rec_loss_code': rec_loss_code.detach()} 171 | return output 172 | 173 | def main(): 174 | # Get arguments 175 | args = define_args() 176 | 177 | # Save git info 178 | args.git_head = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 179 | args.git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 180 | 181 | # Create directory structure 182 | outdir = lambda dirname: os.path.join(args.outdir, dirname) 183 | if not os.path.exists(args.outdir): 184 | os.mkdir(args.outdir) 185 | if not os.path.exists(outdir('checkpoints')): 186 | os.mkdir(outdir('checkpoints')) 187 | if not os.path.exists(outdir('logs')): 188 | os.mkdir(outdir('logs')) 189 | if not os.path.exists(outdir('imgs')): 190 | os.mkdir(outdir('imgs')) 191 | if not os.path.exists(outdir('final')): 192 | os.mkdir(outdir('final')) 193 | 194 | # Experiment name 195 | if len(args.name) == 0: 196 | timestamp = str(int(time.time())) 197 | args.unq = ''.join(random.choices(string.ascii_uppercase + string.digits, k=3)) 198 | args.name = '{}_{}_fn_{}_lrtZ_{}_lrtD_{}_ns_{}_sp_{}_s{}_{}_{}'.format( 199 | args.decoder, 200 | args.dataset, 201 | args.norm_decoder, 202 | args.lrt_Z, 203 | args.lrt_D, 204 | args.n_steps_inf, 205 | args.sparsity_reg, 206 | args.seed, 207 | timestamp, 208 | args.unq) 209 | else: 210 | timestamp = args.name.split('_')[-2] 211 | args.unq = args.name.split('_')[-1] 212 | print('\nExperiment: {}\n'.format(args.name)) 213 | 214 | # Experiment directory for saving visualizations 215 | img_dir = os.path.join(outdir('imgs'), args.name) 216 | os.mkdir(img_dir) 217 | 218 | # Print and save experiment specs 219 | print(json.dumps(args.__dict__, sort_keys=True, indent=4) + '\n') 220 | json_file = open(os.path.join(outdir('final'), timestamp + '_' + args.unq + '.json'), "w") 221 | json_file.write(json.dumps(args.__dict__, sort_keys=True, indent=4)) 222 | json_file.close() 223 | 224 | # More logistics 225 | device = torch.device("cuda" if args.cuda else "cpu") 226 | whitening = args.dataset == 'imagenet_LCN' 227 | args.noise = eval(args.noise) 228 | head = f"m\ttime\tunq\tenc\tdata\ttr_D\ttr_E\tFISTA\tmd\ts\tcd\thd\tsp\tvar\tht\tnd\twd_B\t" \ 229 | f"init\tlrt_Z\tlrt_D\tanf\tanr\twd_D\tlrt_E\titer\tuse_enc\twdE\twdE_b\tcd_reg\tsteps\t" \ 230 | f"L0_Z\tL0_H\tPSNR\tep\tev\tL0_orig\torig\tnoisy_im\tL0_noisy\tnoisy_rec" 231 | msg_pre = f"{args.decoder}\t{timestamp}\t{args.unq}\t{args.encoder}\t{args.dataset}\t" \ 232 | f"{int(args.train_decoder)}\t{int(args.train_encoder)}\t{int(args.FISTA)}" 233 | msg_post = f"{args.seed}\t{args.code_dim}\t{args.hidden_dim}\t{args.sparsity_reg:.1e}\t{args.variance_reg}\t" \ 234 | f"{args.hinge_threshold}\t{args.norm_decoder}\t{args.weight_decay_D_bias}\t{args.Zs_init_val}\t" \ 235 | f"{args.lrt_Z}\t{args.lrt_D}\t{args.anneal_lr_D_freq}\t{args.anneal_lr_D_mult}\t{args.weight_decay_D}\t" \ 236 | f"{args.lrt_E}\t{args.num_iter_LISTA}\t{int(args.use_Zs_enc_as_init)}\t{args.weight_decay_E}\t" \ 237 | f"{args.weight_decay_E_bias}\t{args.code_reg}\t" 238 | 239 | # Tensorboard support. To run: tensorboard --logdir /logs 240 | experiment_logs_dir = outdir('logs') + '/{}'.format(args.name) 241 | os.mkdir(experiment_logs_dir) 242 | writer = SummaryWriter(log_dir=experiment_logs_dir) 243 | 244 | # Random seed 245 | set_random_seed(args.seed, torch, np, random, args.cuda) 246 | 247 | # Load training and validation data for the model 248 | dataset_start_time = time.time() 249 | dataset_train = get_dataset(args.dataset, args.datadir, 250 | train=True, im_size=args.im_size, patch_size=args.patch_size) 251 | dataset_train_indices = list(np.load(os.path.join(args.data_splits, f'{args.dataset}_train.npy')))[:args.n_training_samples] 252 | dataset_val_indices = list(np.load(os.path.join(args.data_splits, f'{args.dataset}_val.npy')))[:args.n_val_samples] 253 | 254 | # Shuffle training samples 255 | dataset_train_sampler = SubsetRandomSampler(dataset_train_indices) 256 | data_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, 257 | pin_memory=args.cuda, shuffle=False, sampler=dataset_train_sampler, drop_last=True) 258 | 259 | # Fix validation samples 260 | dataset_val_sampler = FixedSubsetSampler(dataset_val_indices) 261 | data_val = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, 262 | pin_memory=args.cuda, shuffle=False, sampler=dataset_val_sampler, drop_last=True) 263 | dataset_end_time = time.time() 264 | print(f"Dataset loading time: {dataset_end_time - dataset_start_time:.1f} \n") 265 | 266 | # Logistics: logging 267 | n_epochs_to_log_imgs = min(10, args.epochs) 268 | log_viz_interval = int(args.epochs / n_epochs_to_log_imgs) 269 | 270 | # Logistics: data 271 | args.n_channels = dataset_train.__getitem__(0)[0].shape[0] 272 | 273 | # Logistics: keep track of best training and validation models 274 | best_perf_tr = collections.defaultdict(lambda: None) 275 | best_perf_val = collections.defaultdict(lambda: None) 276 | idn = "_".join(args.name.split('_')[-2:]) 277 | results_file = os.path.join(outdir('final'), idn + '.csv') 278 | if args.train_encoder + args.train_decoder == 0: 279 | results_file = os.path.join(outdir('final'), 'EVAL-only_' + idn + '.csv') 280 | 281 | # Decoder 282 | decoder = getattr(import_module('models.{}'.format(args.decoder)), 'Decoder')(args).to(device) 283 | if len(args.pretrained_path_dec) > 0: 284 | # Load pretrained decoder, turn off gradients if not training it 285 | decoder.load_pretrained(args.pretrained_path_dec, freeze=not(args.train_decoder)) 286 | if args.train_decoder: 287 | # Normalize the decoder's columns, if randomly initialized 288 | if args.norm_decoder > 0 and len(args.pretrained_path_dec) == 0: 289 | normalize_kernels(decoder, args.norm_decoder) 290 | else: 291 | # If not training the decoder, put it in eval() mode and remove gradient tracking 292 | decoder.eval() 293 | decoder.requires_grad_(False) 294 | # If not training the decoder and there is no pre-trained decoder, save the random decoder (for eval) 295 | if len(args.pretrained_path_dec) == 0: 296 | args.pretrained_path_dec = outdir('checkpoints') + f'/{args.name}_DEC_random.pth' 297 | torch.save(decoder.state_dict(), args.pretrained_path_dec) 298 | 299 | # Encoder 300 | encoder = getattr(import_module('models.{}'.format(args.encoder)), 'Encoder')(args).to(device) 301 | if len(args.pretrained_path_enc) > 0: 302 | # Load pretrained encoder, turn off gradients if not training it 303 | encoder.load_pretrained(args.pretrained_path_enc, freeze=not(args.train_encoder)) 304 | if not(args.train_encoder): 305 | # If not training the encoder, put it in eval() mode and remove gradient tracking 306 | encoder.eval() 307 | encoder.requires_grad_(False) 308 | # If not training the encoder and there is no pre-trained encoder, save the random encoder (for eval) 309 | if len(args.pretrained_path_enc) == 0: 310 | # Save the random encoder (for eval) 311 | args.pretrained_path_enc = outdir('checkpoints') + f'/{args.name}_ENC_random.pth' 312 | torch.save(encoder.state_dict(), args.pretrained_path_enc) 313 | 314 | def train(args): 315 | # Optimizer for decoder 316 | if args.train_decoder: 317 | optimizer_dec = torch.optim.Adam(decoder.parameters(), lr=args.lrt_D, weight_decay=args.weight_decay_D) 318 | if args.decoder in ['one_hidden_decoder']: 319 | param_groups = [{'params': decoder.layer1.bias, 'weight_decay': args.weight_decay_D_bias}, 320 | {'params': decoder.layer1.weight}, 321 | {'params': decoder.layer2.weight}] 322 | optimizer_dec = torch.optim.Adam(param_groups, lr=args.lrt_D, weight_decay=args.weight_decay_D) 323 | 324 | # Optimizer for encoder 325 | if args.train_encoder: 326 | optimizer_enc = torch.optim.Adam(encoder.parameters(), lr=args.lrt_E, weight_decay=args.weight_decay_E) 327 | if args.weight_decay_E_bias > 0: 328 | param_groups = [{'params': encoder.W.bias, 'weight_decay': args.weight_decay_E_bias}, 329 | {'params': encoder.W.weight}, 330 | {'params': encoder.S.weight}] 331 | optimizer_enc = torch.optim.Adam(param_groups, lr=args.lrt_E, weight_decay=args.weight_decay_E) 332 | 333 | # Gauge data 334 | gauge = Gauge() 335 | 336 | # Training loop 337 | train_iterator = my_iterator(args, data_train, log_viz_interval) 338 | for batch, batch_info, should in train_iterator: 339 | training = True 340 | epoch = batch_info['epoch'] 341 | y = batch['X'].to(device) 342 | y_mean, y_std = None, None # whitening 343 | if args.dataset == 'imagenet_LCN': 344 | y_mean, y_std = batch['extra'] 345 | y_mean, y_std = y_mean.to(device), y_std.to(device) 346 | 347 | # Track active code components 348 | if should['epoch_start']: 349 | Zs_comp_use = 0. 350 | 351 | # Encoder predictions 352 | if args.train_encoder: 353 | encoder.train() 354 | Zs_enc = encoder(y) 355 | 356 | # Inference of the codes 357 | if args.n_steps_inf > 0: 358 | # Perform inference with (F)ISTA 359 | inference_output = ISTA(decoder, y, args.positive_ISTA, args.FISTA, 360 | args.sparsity_reg, args.n_steps_inf, args.lrt_Z, 361 | args.use_Zs_enc_as_init, Zs_enc, 362 | args.variance_reg, args.hinge_threshold, args.code_reg, 363 | args.stop_early, training, args.train_decoder) 364 | Zs = inference_output['Zs'] 365 | else: 366 | # Amortized inference using the encoder's predictions 367 | Zs = Zs_enc.detach() 368 | 369 | # Decoder update 370 | if args.train_decoder: 371 | decoder_output = train_decoder_step(decoder, y, y_mean, y_std, Zs, optimizer_dec, args) 372 | y_hat = decoder_output['y_hat'] 373 | rec_loss_y = decoder_output['rec_loss_y'] 374 | psnr = decoder_output['psnr'] 375 | else: 376 | with torch.no_grad(): 377 | y_hat = decoder(Zs) 378 | rec_loss_y = MSE(y, y_hat, reduction='mean') 379 | psnr = PSNR(y, y_hat, args.dataset, y_mean, y_std) 380 | 381 | # Encoder update 382 | if args.train_encoder and args.n_steps_inf > 0: 383 | encoder_output = train_encoder_step(encoder, optimizer_enc, Zs_enc, Zs) 384 | rec_loss_code = encoder_output['rec_loss_code'] 385 | else: 386 | # Encoder will not be updated if encoder is not trained or there is no inference 387 | pass 388 | 389 | # Decoder stats 390 | gauge.add('rec_loss_y', rec_loss_y) 391 | gauge.add('PSNR', psnr) 392 | if args.decoder in ['linear_dictionary']: 393 | gauge.add('avg_col_norm', decoder.decoder.weight.data.norm(dim=0, p=2).mean()) 394 | if args.decoder in ['one_hidden_decoder']: 395 | gauge.add('layer1_avg_col_norm', decoder.layer1.weight.data.norm(dim=0, p=2).mean()) 396 | gauge.add('layer2_avg_col_norm', decoder.layer2.weight.data.norm(dim=0, p=2).mean()) 397 | gauge.add('bias_norm_D', decoder.layer1.bias.data.norm()) 398 | gauge.add('frac_0s_hidden_pre', decoder.frac_0s_hidden_pre_relu) 399 | gauge.add('frac_0s_hidden_post', decoder.frac_0s_hidden_post_relu) 400 | 401 | # Encoder stats 402 | if args.train_encoder and args.n_steps_inf > 0: 403 | gauge.add('rec_loss_Z', rec_loss_code) 404 | if args.encoder in ['lista_encoder']: 405 | gauge.add('bias_norm_E', encoder.W.bias.data.norm()) 406 | 407 | # Code stats 408 | Zs_comp_use += (Zs.detach().abs() > 0).float().mean(0) 409 | hinge_loss = hinge(input=sqrt_var(Zs.detach()), threshold=args.hinge_threshold, reduction='mean') 410 | energy = compute_energy(y, y_hat, Zs, 411 | args.sparsity_reg, args.variance_reg, args.hinge_threshold, args.code_reg, 412 | Zs_enc) 413 | gauge.add('hinge_loss_Z', hinge_loss) 414 | gauge.add('energy', energy) 415 | gauge.add('frac_0s_Z', L0(Zs)) 416 | gauge.add('max_Z', Zs.detach().abs().max()) 417 | if args.n_steps_inf > 0: 418 | gauge.add('inf_steps', inference_output['inf_steps']) 419 | gauge.add('inf_time', inference_output['inference_time']) 420 | 421 | 422 | # End of epoch 423 | if should['epoch_end']: 424 | # Compute training metrics for the epoch 425 | train_stats = {} 426 | keys = list(gauge.cache.keys()) 427 | for k in keys: 428 | vals = torch.stack(gauge.get(k, clear=True)) 429 | v = torch.max(vals) if 'max' in k else torch.mean(vals) 430 | writer.add_scalar(f'epoch_stats_training/{k}', v, epoch) 431 | train_stats[k] = v 432 | train_stats['Zs_comp_use'] = Zs_comp_use / len(data_train) 433 | 434 | # Track best training performance (energy) 435 | if best_perf_tr['energy'] is None or best_perf_tr['energy'] > train_stats['energy']: 436 | best_perf_tr['energy'] = train_stats['energy'] 437 | best_perf_tr['PSNR'] = train_stats['PSNR'] 438 | best_perf_tr['epoch'] = epoch 439 | best_perf_tr['L0_Z'] = train_stats['frac_0s_Z'] 440 | if args.decoder in ['one_hidden_decoder']: 441 | best_perf_tr['L0_H'] = train_stats['frac_0s_hidden_post'] 442 | if args.n_steps_inf > 0: 443 | best_perf_tr['inf_steps'] = train_stats['inf_steps'] 444 | 445 | # Validation 446 | run_validation(encoder, decoder, epoch, args) 447 | 448 | # De-whiten 449 | if whitening: 450 | y = dewhiten(y, y_mean, y_std) 451 | y_hat = dewhiten(y_hat.detach(), y_mean, y_std) 452 | 453 | # Log visualizations 454 | n_samples = 64 455 | log_viz(decoder, writer, n_samples, y, y_hat, Zs, train_stats, img_dir, args.decoder, args.dataset, 456 | f'ep_{epoch}_train', log_all=True) 457 | 458 | # Save best model & useful viz 459 | if best_perf_val['epoch'] == epoch: 460 | # Save encoder and decoder 461 | if args.train_encoder: 462 | torch.save(encoder.state_dict(), outdir('checkpoints') + f'/{args.name}_ENC_best.pth') 463 | if args.train_decoder: 464 | torch.save(decoder.state_dict(), outdir('checkpoints') + f'/{args.name}_DEC_best.pth') 465 | 466 | # Anneal lrt_D 467 | if args.train_decoder and args.anneal_lr_D_freq > 0: 468 | anneal_learning_rate(optimizer_dec, epoch + 1, args.lrt_D, args.anneal_lr_D_mult, 469 | args.anneal_lr_D_freq) 470 | # Anneal lrt_E 471 | if args.train_encoder and args.anneal_lr_E_freq > 0: 472 | anneal_learning_rate(optimizer_enc, epoch + 1, args.lrt_E, args.anneal_lr_E_mult, 473 | args.anneal_lr_E_freq) 474 | 475 | # Clean up memory 476 | del Zs, y, y_hat 477 | 478 | # Message (to help with analysis) 479 | print_final_training_msg(results_file, head, msg_pre, msg_post, 480 | args.noise, best_perf_tr, best_perf_val) 481 | 482 | writer.close() 483 | 484 | def run_validation(encoder, decoder, epoch, args): 485 | # The encoder's predictions are used as the Zs during validation 486 | training = False 487 | encoder.eval() 488 | decoder.eval() 489 | 490 | # Track metrics and logistics 491 | gauge = Gauge() 492 | Zs_comp_use = 0. 493 | 494 | # Data loop 495 | val_iterator = my_iterator_val(args, data_val, log_viz_interval, epoch) 496 | for batch, batch_info, should in val_iterator: 497 | y = batch['X'].to(device) 498 | y_mean, y_std = None, None # whitening 499 | if args.dataset == 'imagenet_LCN': 500 | y_mean, y_std = batch['extra'] 501 | y_mean, y_std = y_mean.to(device), y_std.to(device) 502 | 503 | # Encoder predictions 504 | with torch.no_grad(): 505 | Zs_enc = encoder(y) 506 | 507 | # Compute the Zs from inference 508 | if args.n_steps_inf > 0: 509 | inference_output = ISTA(decoder, y, args.positive_ISTA, args.FISTA, 510 | args.sparsity_reg, args.n_steps_inf, args.lrt_Z, 511 | args.use_Zs_enc_as_init, Zs_enc, 512 | args.variance_reg, args.hinge_threshold, args.code_reg, 513 | args.stop_early, training, args.train_decoder) 514 | Zs_inf = inference_output['Zs'] 515 | else: 516 | Zs_inf = Zs_enc 517 | 518 | # Validation stats 519 | with torch.no_grad(): 520 | # Decoder stats 521 | y_hat = decoder(Zs_enc) 522 | rec_loss_y = MSE(y, y_hat, reduction='mean') 523 | psnr = PSNR(y, y_hat, args.dataset, y_mean, y_std) 524 | 525 | 526 | # Encoder stats 527 | rec_loss_code = MSE(Zs_inf, Zs_enc, reduction='mean') 528 | 529 | # Code stats 530 | frac_0s = L0(Zs_enc) 531 | Zs_comp_use += (Zs_enc.detach().abs() > 0).float().mean(0) 532 | Zs_max = Zs_enc.detach().abs().max() 533 | energy = compute_energy(y, y_hat, Zs_enc, 534 | args.sparsity_reg, args.variance_reg, args.hinge_threshold, args.code_reg, 535 | Zs_enc) 536 | 537 | # Track stats 538 | gauge.add('rec_loss_Z', rec_loss_code) 539 | gauge.add('rec_loss_y', rec_loss_y) 540 | gauge.add('PSNR', psnr) 541 | gauge.add('frac_0s_Z', frac_0s) 542 | gauge.add('max_Z', Zs_max) 543 | gauge.add('energy', energy) 544 | if args.decoder in ['one_hidden_decoder']: 545 | gauge.add(f'frac_0s_hidden_pre', decoder.frac_0s_hidden_pre_relu) 546 | gauge.add(f'frac_0s_hidden_post', decoder.frac_0s_hidden_post_relu) 547 | 548 | # Log aggregate validation stats 549 | valid_stats = {} 550 | keys = list(gauge.cache.keys()) 551 | for k in keys: 552 | vals = torch.stack(gauge.get(k, clear=True)) 553 | v = torch.max(vals) if 'max' in k else torch.mean(vals) 554 | writer.add_scalar(f'epoch_stats_validation/{k}', v, epoch) 555 | valid_stats[k] = v 556 | valid_stats['Zs_comp_use'] = Zs_comp_use / len(data_val) 557 | 558 | # Track best validation performance 559 | if best_perf_val['energy'] is None or best_perf_val['energy'] > valid_stats['energy']: 560 | best_perf_val['energy'] = valid_stats['energy'] 561 | best_perf_val['PSNR'] = valid_stats['PSNR'] 562 | best_perf_val['epoch'] = epoch 563 | best_perf_val['L0_Z'] = valid_stats['frac_0s_Z'] 564 | if args.decoder in ['one_hidden_decoder']: 565 | best_perf_val[f'L0_H'] = valid_stats['frac_0s_hidden_post'] 566 | 567 | # De-whiten 568 | if whitening: 569 | y = dewhiten(y, y_mean, y_std) 570 | y_hat = dewhiten(y_hat.detach(), y_mean, y_std) 571 | 572 | # Log visualizations 573 | n_samples = 64 574 | if should['log_val_imgs']: 575 | log_viz(decoder, writer, n_samples, y, y_hat, Zs_enc, valid_stats, img_dir, args.decoder, args.dataset, 576 | viz_type=f'ep_{epoch}_val', log_all=False) 577 | if best_perf_val[f'epoch'] == epoch: 578 | log_viz(decoder, writer, n_samples, y, y_hat, Zs_enc, valid_stats, img_dir, args.decoder, args.dataset, 579 | viz_type='BEST_VAL', log_all=True) 580 | 581 | def run_eval_denoising(args): 582 | # Experiment identification 583 | idn = "_".join(args.name.split('_')[-2:]) 584 | 585 | # Generate the arguments for the eval experiment 586 | for noise in args.noise: 587 | args_copy = argparse.Namespace(**vars(args)) 588 | args_copy.additive_noise = noise 589 | args_copy.name = f'ENC_{idn}_den_{noise}_{str(int(time.time()))}' 590 | # Path to encoder 591 | if args.train_encoder: 592 | args_copy.pretrained_path_enc = outdir('checkpoints') + f'/{args.name}_ENC_best.pth' 593 | # Path to decoder 594 | if args.train_decoder: 595 | args_copy.pretrained_path_dec = outdir('checkpoints') + f'/{args.name}_DEC_best.pth' 596 | eval_denoising(args_copy) 597 | print_final_eval_msg(results_file, msg_pre, msg_post, args_copy, best_perf_val) 598 | 599 | # Training 600 | if args.train_decoder or args.train_encoder: 601 | train(args) 602 | 603 | # Evaluation 604 | run_eval_denoising(args) 605 | 606 | if __name__ == '__main__': 607 | main() 608 | -------------------------------------------------------------------------------- /models/linear_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | class Classifier(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | 9 | # Model 10 | self.input_dim = args.input_dim 11 | self.output_dim = args.n_classes 12 | self.log_interval = args.log_interval 13 | self.cuda = args.cuda 14 | self.device = torch.device("cuda" if self.cuda else "cpu") 15 | 16 | self.classifier = nn.Linear(self.input_dim, self.output_dim, bias=True) 17 | self.criterion = nn.CrossEntropyLoss().to(self.device) 18 | self.lr = args.lr 19 | self.L1_reg = args.L1_reg 20 | self.L2_reg = args.L2_reg 21 | self.optimizer = optim.Adam(params=self.classifier.parameters(), 22 | lr=self.lr, weight_decay=self.L2_reg) 23 | 24 | 25 | def forward(self, input): 26 | output = self.classifier(input) 27 | return output 28 | 29 | def train(self, data_loader, epoch, k=3): 30 | self.classifier.train() 31 | loss_aggr = 0 32 | top1_acc_aggr = 0 33 | topk_acc_aggr = 0 34 | n_samples = 0 35 | for batch_id, (input, target) in enumerate(data_loader): 36 | # Data 37 | input, target = input.to(self.device), target.to(self.device) 38 | batch_size = input.shape[0] 39 | input = input.view(batch_size, -1) 40 | n_samples += batch_size 41 | total_loss = 0 42 | 43 | # Forward 44 | self.optimizer.zero_grad() 45 | self.classifier.zero_grad() 46 | output = self.classifier(input) 47 | 48 | # Loss and update 49 | loss_CE = self.criterion(output, target) 50 | total_loss += loss_CE 51 | if self.L1_reg > 0: 52 | l1_norm = self.L1_norm() 53 | total_loss += self.L1_reg * l1_norm 54 | total_loss.backward() 55 | self.optimizer.step() 56 | loss_aggr += loss_CE.detach() * batch_size 57 | 58 | # Predictions 59 | top1_pred = output.detach().max(1, keepdim=True)[1] # get the index of the max log-probability 60 | top1_pred = top1_pred.eq(target.view_as(top1_pred)).sum().float() 61 | top1_acc_aggr += top1_pred 62 | topk_pred = output.detach().topk(k)[1] 63 | topk_pred = topk_pred.eq(target.unsqueeze(1).expand(-1, k)).sum().float() 64 | topk_acc_aggr += topk_pred 65 | if self.log_interval > 0 and batch_id % self.log_interval == 0: 66 | print('Train Epoch: {} [{}/{}] Loss: {:.6f} Top1: {:.2f} Top{}: {:.2f}'.format( 67 | epoch + 1, batch_id + 1, len(data_loader), 68 | loss_CE.detach(), top1_pred / batch_size, k, topk_pred / batch_size)) 69 | 70 | return {'loss': loss_aggr / n_samples, 71 | 'top1': top1_acc_aggr / n_samples, 72 | f'top{k}': topk_acc_aggr / n_samples, 73 | 'n_samples': n_samples} 74 | 75 | 76 | def test(self, data_loader, k=3): 77 | self.classifier.eval() 78 | with torch.no_grad(): 79 | loss_aggr = 0 80 | top1_acc_aggr = 0 81 | topk_acc_aggr = 0 82 | n_samples = 0 83 | for batch_id, (input, target) in enumerate(data_loader): 84 | # Data 85 | input, target = input.to(self.device), target.to(self.device) 86 | batch_size = input.shape[0] 87 | input = input.view(batch_size, -1) 88 | n_samples += batch_size 89 | 90 | # Forward 91 | output = self.classifier(input) 92 | 93 | # Loss 94 | loss = self.criterion(output, target) 95 | loss_aggr += loss.detach() * batch_size 96 | 97 | # Predictions 98 | top1_pred = output.detach().max(1, keepdim=True)[1] # get the index of the max log-probability 99 | top1_pred = top1_pred.eq(target.view_as(top1_pred)).sum().float() 100 | top1_acc_aggr += top1_pred 101 | topk_pred = output.detach().topk(k)[1] 102 | topk_pred = topk_pred.eq(target.unsqueeze(1).expand(-1, k)).sum().float() 103 | topk_acc_aggr += topk_pred 104 | 105 | # Logging 106 | loss_aggr = loss_aggr / n_samples 107 | top1_acc_aggr = top1_acc_aggr / n_samples 108 | topk_acc_aggr = topk_acc_aggr / n_samples 109 | 110 | return {'loss': loss_aggr, 111 | 'top1': top1_acc_aggr, 112 | f'top{k}': topk_acc_aggr, 113 | 'n_samples': n_samples} 114 | 115 | def L1_norm(self): 116 | # Compute the l1 norm of the classifier's weights 117 | l1_norm = torch.norm(self.classifier.weight, p=1) + torch.norm(self.classifier.bias, p=1) 118 | return l1_norm 119 | 120 | def load_pretrained(self, path, freeze=False): 121 | # Load pretrained model 122 | pretrained_model = torch.load(f=path, map_location="cuda" if self.cuda else "cpu") 123 | msg = self.load_state_dict(pretrained_model) 124 | print(msg) 125 | 126 | # Freeze pretrained parameters 127 | if freeze: 128 | for p in self.parameters(): 129 | p.requires_grad = False 130 | -------------------------------------------------------------------------------- /models/linear_dictionary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Decoder(nn.Module): 5 | def __init__(self, args): 6 | super().__init__() 7 | 8 | # Model 9 | self.patch_size = args.patch_size 10 | self.n_channels = args.n_channels 11 | self.code_dim = args.code_dim 12 | self.im_size = min(args.im_size, args.patch_size) if args.patch_size > 0 else args.im_size 13 | assert self.im_size > 0 14 | self.output_dim = (self.im_size ** 2) * self.n_channels 15 | self.decoder = nn.Linear(self.code_dim, self.output_dim, bias=False) 16 | self.cuda = args.cuda 17 | self.device = torch.device("cuda" if self.cuda else "cpu") 18 | 19 | def forward(self, code): 20 | output = self.decoder(code) 21 | return output.view(output.shape[0], self.n_channels, self.im_size, -1) 22 | 23 | def initZs(self, batch_size): 24 | Zs = torch.zeros(size=(batch_size, self.code_dim), device=self.device) 25 | return Zs 26 | 27 | def viz_columns(self, n_samples=24, norm_each=False): 28 | # Visualize columns of linear decoder 29 | cols = [] 30 | W = self.decoder.weight.data 31 | max_abs = W.abs().max() 32 | # Iterate over columns 33 | for c in range(n_samples): 34 | column = W[:, c].detach().clone() 35 | if norm_each: 36 | max_abs = column.abs().max() 37 | # Map values to (-0.5, 0.5) interval 38 | if max_abs > 0: 39 | column /= (2 * max_abs) 40 | # Map 0 to gray (0.5) 41 | column += 0.5 42 | # Reshape column to output shape 43 | column = column.view(self.n_channels, self.im_size, -1) 44 | cols.append(column) 45 | cols = torch.stack(cols) 46 | return cols 47 | 48 | def load_pretrained(self, path, freeze=True): 49 | # Load pretrained model 50 | pretrained_model = torch.load(f=path, map_location="cuda" if self.cuda else "cpu") 51 | msg = self.load_state_dict(pretrained_model) 52 | print(msg) 53 | 54 | # Freeze pretrained parameters 55 | if freeze: 56 | for p in self.parameters(): 57 | p.requires_grad = False 58 | -------------------------------------------------------------------------------- /models/lista_classifier.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | class Classifier(nn.Module): 8 | def __init__(self, args): 9 | super().__init__() 10 | 11 | # Model 12 | self.code_dim = args.code_dim 13 | self.output_dim = args.n_classes 14 | self.log_interval = args.log_interval 15 | self.cuda = args.cuda 16 | self.device = torch.device("cuda" if self.cuda else "cpu") 17 | 18 | # LISTA encoder 19 | args.n_channels = 1 20 | self.lista = getattr(import_module('models.{}'.format('lista_encoder')), 'Encoder')(args).to(self.device) 21 | 22 | # Linear classifier on top of codes 23 | self.classifier = nn.Linear(self.code_dim, self.output_dim, bias=True) 24 | 25 | # Optimization 26 | self.lr = args.lr 27 | self.L1_reg = args.L1_reg 28 | self.L2_reg = args.L2_reg 29 | self.criterion = nn.CrossEntropyLoss().to(self.device) 30 | param_groups = [{'params': self.lista.parameters()}, 31 | {'params': self.classifier.parameters()}] 32 | self.optimizer = optim.Adam(params=param_groups, lr=self.lr, weight_decay=self.L2_reg) 33 | 34 | def forward(self, input): 35 | code = self.lista(input) 36 | output = self.classifier(code) 37 | return output 38 | 39 | def train(self, data_loader, epoch, k=3): 40 | self.lista.train() 41 | self.classifier.train() 42 | loss_aggr = 0 43 | top1_acc_aggr = 0 44 | topk_acc_aggr = 0 45 | n_samples = 0 46 | for batch_id, (input, target) in enumerate(data_loader): 47 | # Data 48 | input, target = input.to(self.device), target.to(self.device) 49 | batch_size = input.shape[0] 50 | input = input.view(batch_size, -1) 51 | n_samples += batch_size 52 | total_loss = 0 53 | 54 | # Forward 55 | self.optimizer.zero_grad() 56 | self.lista.zero_grad() 57 | self.classifier.zero_grad() 58 | output = self.classifier(self.lista(input)) 59 | 60 | # Loss and update 61 | loss_CE = self.criterion(output, target) 62 | total_loss += loss_CE 63 | if self.L1_reg > 0: 64 | l1_norm = self.L1_norm() 65 | total_loss += self.L1_reg * l1_norm 66 | total_loss.backward() 67 | self.optimizer.step() 68 | loss_aggr += loss_CE.detach() * batch_size 69 | 70 | # Predictions 71 | top1_pred = output.detach().max(1, keepdim=True)[1] # get the index of the max log-probability 72 | top1_pred = top1_pred.eq(target.view_as(top1_pred)).sum().float() 73 | top1_acc_aggr += top1_pred 74 | topk_pred = output.detach().topk(k)[1] 75 | topk_pred = topk_pred.eq(target.unsqueeze(1).expand(-1, k)).sum().float() 76 | topk_acc_aggr += topk_pred 77 | if self.log_interval > 0 and batch_id % self.log_interval == 0: 78 | print('Train Epoch: {} [{}/{}] Loss: {:.6f} Top1: {:.2f} Top{}: {:.2f}'.format( 79 | epoch + 1, batch_id + 1, len(data_loader), 80 | loss_CE.detach(), top1_pred / batch_size, k, topk_pred / batch_size)) 81 | 82 | return {'loss': loss_aggr / n_samples, 83 | 'top1': top1_acc_aggr / n_samples, 84 | f'top{k}': topk_acc_aggr / n_samples, 85 | 'n_samples': n_samples} 86 | 87 | 88 | def test(self, data_loader, k=3): 89 | self.lista.eval() 90 | self.classifier.eval() 91 | with torch.no_grad(): 92 | loss_aggr = 0 93 | top1_acc_aggr = 0 94 | topk_acc_aggr = 0 95 | n_samples = 0 96 | for batch_id, (input, target) in enumerate(data_loader): 97 | # Data 98 | input, target = input.to(self.device), target.to(self.device) 99 | batch_size = input.shape[0] 100 | input = input.view(batch_size, -1) 101 | n_samples += batch_size 102 | 103 | # Forward 104 | output = self.classifier(self.lista(input)) 105 | 106 | # Loss 107 | loss = self.criterion(output, target) 108 | loss_aggr += loss.detach() * batch_size 109 | 110 | # Predictions 111 | top1_pred = output.detach().max(1, keepdim=True)[1] # get the index of the max log-probability 112 | top1_pred = top1_pred.eq(target.view_as(top1_pred)).sum().float() 113 | top1_acc_aggr += top1_pred 114 | topk_pred = output.detach().topk(k)[1] 115 | topk_pred = topk_pred.eq(target.unsqueeze(1).expand(-1, k)).sum().float() 116 | topk_acc_aggr += topk_pred 117 | 118 | # Logging 119 | loss_aggr = loss_aggr / n_samples 120 | top1_acc_aggr = top1_acc_aggr / n_samples 121 | topk_acc_aggr = topk_acc_aggr / n_samples 122 | 123 | return {'loss': loss_aggr, 124 | 'top1': top1_acc_aggr, 125 | f'top{k}': topk_acc_aggr, 126 | 'n_samples': n_samples} 127 | 128 | def L1_norm(self): 129 | # Compute the l1 norm of the classifier's weights 130 | l1_norm = torch.norm(self.classifier.weight, p=1) + torch.norm(self.classifier.bias, p=1) 131 | l1_norm += torch.norm(self.lista.W.weight, p=1) + torch.norm(self.lista.W.bias, p=1) + torch.norm(self.lista.S.weight, p=1) 132 | return l1_norm 133 | 134 | def load_pretrained(self, path, freeze=False): 135 | # Load pretrained model 136 | pretrained_model = torch.load(f=path, map_location="cuda" if self.cuda else "cpu") 137 | msg = self.load_state_dict(pretrained_model) 138 | print(msg) 139 | 140 | # Freeze pretrained parameters 141 | if freeze: 142 | for p in self.parameters(): 143 | p.requires_grad = False 144 | -------------------------------------------------------------------------------- /models/lista_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Encoder(nn.Module): 5 | def __init__(self, args): 6 | super().__init__() 7 | 8 | # Logistics 9 | self.im_size = min(args.im_size, args.patch_size) if args.patch_size > 0 else args.im_size 10 | assert self.im_size > 0 11 | self.n_channels = args.n_channels 12 | self.code_dim = args.code_dim 13 | self.cuda = args.cuda 14 | self.device = torch.device("cuda" if self.cuda else "cpu") 15 | 16 | # Architecture 17 | self.T = args.num_iter_LISTA 18 | self.W = nn.Linear(self.n_channels * self.im_size ** 2, self.code_dim, bias=True) 19 | self.W.bias.data.fill_(0) 20 | self.S = nn.Linear(self.code_dim, self.code_dim, bias=False) # no bias in second layer 21 | self.relu = nn.ReLU() 22 | 23 | def forward(self, y): 24 | B = self.W(y.view(y.shape[0], -1)) 25 | Z = self.relu(B) 26 | # LISTA loop 27 | for t in range(self.T): 28 | C = B + self.S(Z) 29 | Z = self.relu(C) 30 | return Z.view(Z.shape[0], -1) 31 | 32 | def load_pretrained(self, path, freeze=False): 33 | # Load pretrained model 34 | pretrained_model = torch.load(f=path, map_location="cuda" if self.cuda else "cpu") 35 | msg = self.load_state_dict(pretrained_model) 36 | print(msg) 37 | 38 | # Freeze pretrained parameters 39 | if freeze: 40 | for p in self.parameters(): 41 | p.requires_grad = False 42 | -------------------------------------------------------------------------------- /models/one_hidden_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def L0(z, reduction='mean'): 5 | """ 6 | :param z: (B, C) or (B, C, W, H) tensor 7 | :return: average of proportion of zero elements in each element in batch 8 | """ 9 | with torch.no_grad(): 10 | assert (len(z.shape) == 2 or len(z.shape) == 4) 11 | dims = 1 if len(z.shape) == 2 else (1, 2, 3) 12 | prop_0s_each_sample = (z.abs() == 0).float().mean(dims) 13 | if reduction == 'sum': 14 | return prop_0s_each_sample.sum() 15 | if reduction == 'mean': 16 | return prop_0s_each_sample.mean() 17 | 18 | class Decoder(nn.Module): 19 | def __init__(self, args): 20 | super().__init__() 21 | 22 | # Model 23 | self.im_size = min(args.im_size, args.patch_size) if args.patch_size > 0 else args.im_size 24 | self.n_channels = args.n_channels 25 | self.code_dim = args.code_dim 26 | self.output_dim = (self.im_size ** 2) * self.n_channels 27 | self.hidden_dim = args.hidden_dim 28 | self.layer1 = nn.Linear(self.code_dim, self.hidden_dim, bias=True) 29 | self.layer1.bias.data.fill_(0) # Fill bias with 0s 30 | self.layer2 = nn.Linear(self.hidden_dim, self.output_dim, bias=False) 31 | self.relu = nn.ReLU() 32 | self.Zs_init_val = args.Zs_init_val 33 | self.cuda = args.cuda 34 | self.device = torch.device("cuda" if self.cuda else "cpu") 35 | 36 | def forward(self, code): 37 | output = self.layer1(code) 38 | with torch.no_grad(): 39 | self.frac_0s_hidden_pre_relu = L0(output, 'mean') 40 | output = self.relu(output) 41 | with torch.no_grad(): 42 | self.frac_0s_hidden_post_relu = L0(output, 'mean') 43 | output = self.layer2(output) 44 | return output.view(output.shape[0], self.n_channels, self.im_size, -1) 45 | 46 | def initZs(self, batch_size): 47 | Zs = torch.zeros(size=(batch_size, self.code_dim), device=self.device).fill_(self.Zs_init_val) 48 | return Zs 49 | 50 | def viz_columns(self, n_samples=24, norm_each=False): 51 | # Visualize columns of the linear layer closest to reconstruction 52 | cols = [] 53 | W = self.layer2.weight.data 54 | max_abs = W.abs().max() 55 | # Iterate over columns 56 | for c in range(n_samples): 57 | column = W[:, c].clone().detach() 58 | if norm_each: 59 | max_abs = column.abs().max() 60 | # Map values to (-0.5, 0.5) interval 61 | if max_abs > 0: 62 | column /= (2 * max_abs) 63 | # Map 0 to gray (0.5) 64 | column += 0.5 65 | # Reshape column to output shape 66 | column = column.view(self.n_channels, self.im_size, -1) 67 | cols.append(column) 68 | cols = torch.stack(cols) 69 | return cols 70 | 71 | def viz_codes(self, fill_vals, n_samples=24): 72 | # Visualize reconstructions from singe active code componnet 73 | codes = torch.zeros(n_samples, self.code_dim).to(self.device) 74 | # Visualize reconstuctions from bias (if there is one) 75 | with torch.no_grad(): 76 | recs_bias = self.forward(codes) 77 | # Generate codes with a single active component 78 | for c in range(n_samples): 79 | codes[c, c] = fill_vals[c] 80 | # Reconstructions from codes with a single active component 81 | with torch.no_grad(): 82 | recs = self.forward(codes) 83 | return recs - recs_bias 84 | 85 | def load_pretrained(self, path, freeze=False): 86 | # Load pretrained model 87 | pretrained_model = torch.load(f=path, map_location="cuda" if self.cuda else "cpu") 88 | msg = self.load_state_dict(pretrained_model) 89 | print(msg) 90 | 91 | # Freeze pretrained parameters 92 | if freeze: 93 | for p in self.parameters(): 94 | p.requires_grad = False 95 | -------------------------------------------------------------------------------- /scripts/ImageNet_SDL-NL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=ImageNet_SDL-NL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "256" \ 16 | --code_reg "1" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "imagenet_LCN" \ 21 | --decoder "one_hidden_decoder" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "100" \ 24 | --FISTA \ 25 | --hidden_dim "2048" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.001" \ 29 | --lrt_E "0.0001" \ 30 | --lrt_Z "0.5" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "20000" \ 33 | --n_training_samples "200000" \ 34 | --n_val_samples "20000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "1" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.01" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "0" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_D_bias "0.01" \ 52 | --weight_decay_E_bias "0.01" 53 | -------------------------------------------------------------------------------- /scripts/ImageNet_SDL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=ImageNet_SDL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "256" \ 16 | --code_reg "1" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "imagenet_LCN" \ 21 | --decoder "linear_decoder" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "100" \ 24 | --FISTA \ 25 | --hidden_dim "0" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.001" \ 29 | --lrt_E "0.0001" \ 30 | --lrt_Z "0.5" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "20000" \ 33 | --n_training_samples "200000" \ 34 | --n_val_samples "20000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "1" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.005" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "0" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_D_bias "0" \ 52 | --weight_decay_E_bias "0.01" 53 | -------------------------------------------------------------------------------- /scripts/ImageNet_VDL-NL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=ImageNet_VDL-NL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --anneal_lr_D_freq "30" \ 15 | --anneal_lr_D_mult "0.5" \ 16 | --batch_size "250" \ 17 | --code_dim "256" \ 18 | --code_reg "40" \ 19 | --cuda \ 20 | --data_splits $DATASET_SPLITS_PATH \ 21 | --datadir $DATASET_PATH \ 22 | --dataset "imagenet_LCN" \ 23 | --decoder "one_hidden_decoder" \ 24 | --encoder "lista_encoder" \ 25 | --epochs "100" \ 26 | --FISTA \ 27 | --hidden_dim "2048" \ 28 | --hinge_threshold "0.5" \ 29 | --im_size "28" \ 30 | --lrt_D "5e-05" \ 31 | --lrt_E "0.0001" \ 32 | --lrt_Z "0.5" \ 33 | --n_steps_inf "200" \ 34 | --n_test_samples "20000" \ 35 | --n_training_samples "200000" \ 36 | --n_val_samples "20000" \ 37 | --noise "[1]" \ 38 | --norm_decoder "0" \ 39 | --num_iter_LISTA "3" \ 40 | --num_workers "4" \ 41 | --outdir $OUTPUT_PATH \ 42 | --patch_size "0" \ 43 | --positive_ISTA \ 44 | --seed "31" \ 45 | --sparsity_reg "0.01" \ 46 | --stop_early "0.001" \ 47 | --train_decoder \ 48 | --train_encoder \ 49 | --use_Zs_enc_as_init \ 50 | --variance_reg "10" \ 51 | --weight_decay_D "0" \ 52 | --weight_decay_E "0" \ 53 | --weight_decay_D_bias "0.1" \ 54 | --weight_decay_E_bias "0.01" 55 | -------------------------------------------------------------------------------- /scripts/ImageNet_VDL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=ImageNet_VDL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "256" \ 16 | --code_reg "5" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "imagenet_LCN" \ 21 | --decoder "linear_decoder" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "100" \ 24 | --FISTA \ 25 | --hidden_dim "0" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.0003" \ 29 | --lrt_E "0.0001" \ 30 | --lrt_Z "0.5" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "20000" \ 33 | --n_training_samples "200000" \ 34 | --n_val_samples "20000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "0" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.01" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "10" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_D_bias "0" \ 52 | --weight_decay_E_bias "0.01" 53 | -------------------------------------------------------------------------------- /scripts/MNIST_SDL-NL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=MNIST_SDL-NL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "128" \ 16 | --code_reg "1" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "MNIST" \ 21 | --decoder "one_hidden_decoder" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "200" \ 24 | --FISTA \ 25 | --hidden_dim "1024" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.001" \ 29 | --lrt_E "0.0001" \ 30 | --lrt_Z "1" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "10000" \ 33 | --n_training_samples "55000" \ 34 | --n_val_samples "5000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "1" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.01" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "0" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_D_bias "0.001" \ 52 | --weight_decay_E_bias "0" 53 | -------------------------------------------------------------------------------- /scripts/MNIST_SDL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=MNIST_SDL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "128" \ 16 | --code_reg "1" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "MNIST" \ 21 | --decoder "linear_dictionary" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "200" \ 24 | --FISTA \ 25 | --hidden_dim "0" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.001" \ 29 | --lrt_E "0.0003" \ 30 | --lrt_Z "1" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "10000" \ 33 | --n_training_samples "55000" \ 34 | --n_val_samples "5000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "1" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.005" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "0" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_E_bias "0" \ 52 | --weight_decay_D_bias "0" 53 | -------------------------------------------------------------------------------- /scripts/MNIST_VDL-NL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=MNIST_VDL-NL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --anneal_lr_D_freq "30" \ 15 | --anneal_lr_D_mult "0.5" \ 16 | --batch_size "250" \ 17 | --code_dim "128" \ 18 | --code_reg "100" \ 19 | --cuda \ 20 | --data_splits $DATASET_SPLITS_PATH \ 21 | --datadir $DATASET_PATH \ 22 | --dataset "MNIST" \ 23 | --decoder "one_hidden_decoder" \ 24 | --encoder "lista_encoder" \ 25 | --epochs "200" \ 26 | --FISTA \ 27 | --hidden_dim "1024" \ 28 | --hinge_threshold "0.5" \ 29 | --im_size "28" \ 30 | --lrt_D "0.0003" \ 31 | --lrt_E "0.0001" \ 32 | --lrt_Z "0.5" \ 33 | --n_steps_inf "200" \ 34 | --n_test_samples "10000" \ 35 | --n_training_samples "55000" \ 36 | --n_val_samples "5000" \ 37 | --noise "[1]" \ 38 | --norm_decoder "0" \ 39 | --num_iter_LISTA "3" \ 40 | --num_workers "4" \ 41 | --outdir $OUTPUT_PATH \ 42 | --patch_size "0" \ 43 | --positive_ISTA \ 44 | --seed "31" \ 45 | --sparsity_reg "0.01" \ 46 | --stop_early "0.001" \ 47 | --train_decoder \ 48 | --train_encoder \ 49 | --use_Zs_enc_as_init \ 50 | --variance_reg "10" \ 51 | --weight_decay_D "0" \ 52 | --weight_decay_E "0" \ 53 | --weight_decay_D_bias "0.001" \ 54 | --weight_decay_E_bias "0" 55 | -------------------------------------------------------------------------------- /scripts/MNIST_VDL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=MNIST_VDL.out 8 | 9 | DATASET_PATH="/path/to/data/directory" 10 | DATASET_SPLITS_PATH="/path/to/data/train-val-splits" 11 | OUTPUT_PATH="/path/to/output/directory" 12 | 13 | python -u main.py \ 14 | --batch_size "250" \ 15 | --code_dim "128" \ 16 | --code_reg "5" \ 17 | --cuda \ 18 | --data_splits $DATASET_SPLITS_PATH \ 19 | --datadir $DATASET_PATH \ 20 | --dataset "MNIST" \ 21 | --decoder "linear_dictionary" \ 22 | --encoder "lista_encoder" \ 23 | --epochs "200" \ 24 | --FISTA \ 25 | --hidden_dim "0" \ 26 | --hinge_threshold "0.5" \ 27 | --im_size "28" \ 28 | --lrt_D "0.0003" \ 29 | --lrt_E "0.0001" \ 30 | --lrt_Z "0.5" \ 31 | --n_steps_inf "200" \ 32 | --n_test_samples "10000" \ 33 | --n_training_samples "55000" \ 34 | --n_val_samples "5000" \ 35 | --noise "[1]" \ 36 | --norm_decoder "0" \ 37 | --num_iter_LISTA "3" \ 38 | --num_workers "4" \ 39 | --outdir $OUTPUT_PATH \ 40 | --patch_size "0" \ 41 | --positive_ISTA \ 42 | --seed "31" \ 43 | --sparsity_reg "0.02" \ 44 | --stop_early "0.001" \ 45 | --train_decoder \ 46 | --train_encoder \ 47 | --use_Zs_enc_as_init \ 48 | --variance_reg "10" \ 49 | --weight_decay_D "0" \ 50 | --weight_decay_E "0" \ 51 | --weight_decay_D_bias "0" \ 52 | --weight_decay_E_bias "0" 53 | -------------------------------------------------------------------------------- /scripts/build_ImageNet_LCN.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=4 4 | #SBATCH --time=12:00:00 5 | #SBATCH --mem=20GB 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=build_imagenet_LCN.out 8 | 9 | INPUT_DATASET_PATH="/path/to/imagenet/directory" 10 | OUTPUT_DATASET_PATH="/path/to/output/directory" 11 | 12 | python -u imagenet_LCN_patches.py \ 13 | --batch_size "250" \ 14 | --cuda \ 15 | --datadir $INPUT_DATASET_PATH \ 16 | --im_size "256" \ 17 | --n_test "20000" \ 18 | --n_train "200000" \ 19 | --n_val "20000" \ 20 | --num_workers "4" \ 21 | --outdir $OUTPUT_DATASET_PATH \ 22 | --patch_size "52" \ 23 | --patch_type "random" \ 24 | --seed "31" \ 25 | --std_threshold "0.5" 26 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import Sampler, Dataset 11 | from torchvision import datasets as datasets_torch 12 | from torchvision.transforms import ToTensor, Resize, Compose, Normalize, RandomCrop, Grayscale, CenterCrop 13 | from torchvision.utils import save_image, make_grid 14 | 15 | MEAN_MNIST = [0.1307] 16 | STD_MNIST = [0.3081] 17 | MEAN_IMAGENET = [0.485, 0.456, 0.406] 18 | STD_IMAGENET = [0.229, 0.224, 0.225] 19 | MEAN_IMAGENET_GRAY = [0.457] 20 | STD_IMAGENET_GRAY = [0.259] 21 | 22 | class Gauge: 23 | def __init__(self): 24 | self.cache = collections.defaultdict(list) 25 | 26 | def add(self, k, v): 27 | self.cache[k].append(v) 28 | 29 | def get(self, k, clear=False): 30 | # Get values for key k and delete them 31 | res = self.cache[k] 32 | if clear: 33 | del self.cache[k] 34 | return res 35 | 36 | class FixedSubsetSampler(Sampler): 37 | r"""Gives a sampler that yields the same set of indices. 38 | 39 | Arguments: 40 | indices (sequence): a sequence of indices 41 | """ 42 | def __init__(self, indices): 43 | self.idx = indices 44 | 45 | def __iter__(self): 46 | return iter(self.idx) 47 | 48 | def __len__(self): 49 | return len(self.idx) 50 | 51 | def get_dataset(dataset_name, datadir, train, im_size, patch_size, patch_type='random'): 52 | 53 | # Transformations 54 | transforms = [] 55 | 56 | # Resize & patch 57 | if dataset_name not in ['codes']: 58 | transforms.append(Resize(im_size)) 59 | if patch_size > 0: 60 | if 'random' in patch_type: 61 | transforms.append(RandomCrop(patch_size)) 62 | elif 'center' in patch_type: 63 | transforms.append(CenterCrop(patch_size)) 64 | else: 65 | raise NotImplementedError 66 | 67 | # Normalize 68 | if dataset_name == 'MNIST': 69 | transforms.append(ToTensor()) 70 | transforms.append(Normalize(mean=MEAN_MNIST, std=STD_MNIST)) 71 | elif dataset_name == 'imagenet': 72 | transforms.append(Grayscale()) 73 | transforms.append(ToTensor()) 74 | transforms.append(Normalize(mean=MEAN_IMAGENET_GRAY, std=STD_IMAGENET_GRAY)) 75 | elif dataset_name == 'imagenet_LCN': 76 | pass 77 | elif dataset_name == 'codes': 78 | pass 79 | else: 80 | raise NotImplementedError 81 | 82 | # Compose transformations 83 | transforms = Compose(transforms) 84 | 85 | # Read dataset 86 | if dataset_name == 'MNIST': 87 | dataset = getattr(datasets_torch, dataset_name) 88 | dataset = dataset(root=datadir, train=train, download=True, transform=transforms) 89 | elif dataset_name == 'imagenet': 90 | split = 'train' 91 | dataset = datasets_torch.ImageFolder(root=f'{datadir}/{split}', transform=transforms) 92 | elif dataset_name == 'imagenet_LCN': 93 | # Read dataset 94 | split = 'train' if train else 'test' 95 | img = np.load(os.path.join(datadir, f'imagenet_LCN_patches_{split}.npy')) 96 | img_mean = np.load(os.path.join(datadir, f'imagenet_LCN_patches_{split}_mean.npy')) 97 | img_std = np.load(os.path.join(datadir, f'imagenet_LCN_patches_{split}_std.npy')) 98 | dataset = ImageNetLCN((img, img_mean, img_std)) 99 | elif dataset_name == 'codes': 100 | # Read dataset 101 | split = 'train' if train else 'test' 102 | codes = np.load(os.path.join(datadir, f'MNIST_{split}_codes.npy')) 103 | targets = np.load(os.path.join(datadir, f'MNIST_{split}_targets.npy')) 104 | dataset = Codes((codes, targets)) 105 | else: 106 | raise NotImplementedError 107 | return dataset 108 | 109 | def inverse_transform(X, dataset_name): 110 | if dataset_name == 'MNIST': 111 | mean = MEAN_MNIST[0] 112 | std = STD_MNIST[0] 113 | elif dataset_name == 'imagenet_LCN': 114 | mean = MEAN_IMAGENET_GRAY[0] 115 | std = STD_IMAGENET_GRAY[0] 116 | else: 117 | raise NotImplementedError 118 | return X * std + mean 119 | 120 | def ISTA_step(x, alpha, step_size, positive, stop_early): 121 | z_prox = x.detach().clone() 122 | # ISTA gradient step followed by a shrinkage step 123 | with torch.no_grad(): 124 | z_prox.data = soft_threshold(x.detach() - (1 - stop_early) * step_size * x.grad.data, 125 | threshold=(1 - stop_early) * alpha * step_size, positive=positive) 126 | return nn.Parameter(z_prox) 127 | 128 | def soft_threshold(x, threshold, positive): 129 | # Function which shrinks input by a given threshold 130 | result = x.sign() * F.relu(x.abs() - threshold, inplace=True) 131 | if positive: 132 | return F.relu(result) 133 | return result 134 | 135 | def sqrt_var(x): 136 | # Computes the unbiased sample variances of input samples of shape (N, d) across the N dimension 137 | mean_x = x.mean(0) 138 | v = torch.norm(x - mean_x, p=2, dim=0) / ((x.shape[0] - 1) ** 0.5) 139 | return v 140 | 141 | def ISTA(decoder, y, positive_ISTA, FISTA, 142 | sparsity_reg, n_steps_inf, lrt_Z, 143 | use_Zs_enc_as_init, Zs_enc, 144 | variance_reg, hinge_threshold, code_reg, 145 | tolerance, training, train_decoder): 146 | # Housekeeping 147 | start_time = time.time() 148 | batch_size = y.shape[0] 149 | 150 | # Turn off gradient for decoder 151 | decoder.requires_grad_(False) 152 | decoder.eval() 153 | 154 | # Generate codes 155 | if use_Zs_enc_as_init: 156 | Zs = nn.Parameter(Zs_enc.detach().clone()) 157 | else: 158 | Zs = nn.Parameter(decoder.initZs(batch_size)) 159 | if FISTA: 160 | aux = nn.Parameter(Zs.detach().clone()) 161 | t_old = 1 162 | 163 | # Auxiliary variables for early stopping 164 | stop_early_dummies = torch.zeros((batch_size, 1), device=Zs.device) 165 | stop_early_step = torch.zeros((batch_size, 1), device=Zs.device) 166 | 167 | # Inference iterations 168 | for step in range(n_steps_inf): 169 | trainable_param = aux if FISTA else Zs 170 | loss_dict = loss_f(trainable_param, decoder, y, 171 | variance_reg, hinge_threshold, 172 | code_reg, Zs_enc) 173 | total_loss = loss_dict['total_loss'] 174 | 175 | # Gradient computation for the codes 176 | trainable_param.grad = None 177 | total_loss.backward() 178 | 179 | # Keep track of the codes from the previous iteration 180 | Zs_old = Zs.detach().clone() 181 | 182 | # Gradient and shrinkage step 183 | Zs = ISTA_step(x=trainable_param, 184 | alpha=sparsity_reg, 185 | step_size=lrt_Z, 186 | positive=positive_ISTA, 187 | stop_early=stop_early_dummies) 188 | 189 | # FISTA 190 | if FISTA: 191 | t_new = 0.5 * (1 + np.sqrt(1 + 4 * t_old ** 2)) 192 | aux = nn.Parameter(Zs.detach() + ((t_old - 1) / t_new) * (Zs.detach() - Zs_old)) 193 | t_old = t_new 194 | 195 | # Log metrics 196 | stop_early_dummies = stop_early(Zs_old, Zs.detach(), tolerance) 197 | 198 | # Stop early 199 | stop_early_step += (1 - stop_early_dummies) 200 | if step < n_steps_inf - 1: 201 | if stop_early_dummies.sum() == batch_size: 202 | break 203 | # Track time 204 | elapsed_time = time.time() - start_time 205 | 206 | # Count number of total steps 207 | Zs_steps_mean = stop_early_step.mean() 208 | 209 | # Remove gradient 210 | Zs = Zs.detach() 211 | Zs.requires_grad = False 212 | 213 | # Turn on gradient for decoder 214 | if training and train_decoder: 215 | decoder.requires_grad_(True) 216 | decoder.train() 217 | 218 | output = {'Zs': Zs, 219 | 'inf_steps': torch.FloatTensor([Zs_steps_mean]).to(Zs.device), 220 | 'inference_time': torch.FloatTensor([elapsed_time]).to(Zs.device)} 221 | return output 222 | 223 | def loss_f(Zs, decoder, y, variance_reg, hinge_threshold, code_reg, Zs_enc): 224 | total_loss = 0 225 | 226 | # Reconstruction loss 227 | y_hat = decoder(Zs) 228 | rec_loss = MSE(y, y_hat, reduction='sum') 229 | total_loss += rec_loss 230 | 231 | # Hinge regularization 232 | if variance_reg > 0: 233 | hinge_loss = hinge(input=sqrt_var(Zs), threshold=hinge_threshold, reduction='sum') 234 | total_loss += variance_reg * hinge_loss 235 | else: 236 | hinge_loss = hinge(input=sqrt_var(Zs.detach()), threshold=hinge_threshold, reduction='sum') 237 | 238 | # Distance to the encoder's predictions 239 | code_loss = None 240 | if code_reg > 0: 241 | code_loss = MSE(Zs, Zs_enc.detach(), reduction='sum') 242 | total_loss += code_reg * code_loss 243 | 244 | output = {'total_loss': total_loss, 'rec_loss': rec_loss.detach(), 'y_hat': y_hat.detach(), 245 | 'hinge_loss': hinge_loss.detach()} 246 | if code_loss is not None: 247 | output['code_loss'] = code_loss.detach() 248 | 249 | return output 250 | 251 | def MSE(target, pred, reduction='sum'): 252 | assert target.shape == pred.shape 253 | dims = (1, 2, 3) if len(target.shape) == 4 else 1 254 | mean_sq_diff = ((target - pred) ** 2).mean(dims) 255 | if reduction == 'sum': 256 | return mean_sq_diff.sum() 257 | elif reduction == 'mean': 258 | return mean_sq_diff.mean() 259 | elif reduction == 'none': 260 | return mean_sq_diff 261 | 262 | def PSNR(target, pred, dataset, tar_sample_mean=None, tar_sample_std=None, pred_sample_mean=None, pred_sample_std=None, 263 | R=1, dummy=1e-4, reduction='mean'): 264 | with torch.no_grad(): 265 | # Map inputs back to image space 266 | if tar_sample_mean is not None: 267 | target = (target * tar_sample_std) + tar_sample_mean 268 | if pred_sample_mean is not None: 269 | # Prediction comes from sample different from the target (e.g. in the case of denoising) 270 | pred = (pred * pred_sample_std) + pred_sample_mean 271 | else: 272 | pred = (pred * tar_sample_std) + tar_sample_mean 273 | target = inverse_transform(target, dataset) 274 | pred = inverse_transform(pred, dataset) 275 | 276 | # Compute the PSNR 277 | dims = (1, 2, 3) if len(target.shape) == 4 else 1 278 | mean_sq_err = ((target - pred)**2).mean(dims) 279 | mean_sq_err = mean_sq_err + (mean_sq_err == 0).float() * dummy # if 0, fill with dummy -> PSNR of 40 by default 280 | output = 10*torch.log10(R**2/mean_sq_err) 281 | if reduction == 'mean': 282 | return output.mean() 283 | elif reduction == 'none': 284 | return output 285 | 286 | def L0(z, reduction='mean', grad=False): 287 | """ 288 | :param z: (B, C) or (B, C, W, H) tensor 289 | :return: average of proportion of zero elements in each element in batch 290 | """ 291 | if not(grad): 292 | z = z.detach() 293 | assert (len(z.shape) == 2 or len(z.shape) == 4) 294 | dims = 1 if len(z.shape) == 2 else (1, 2, 3) 295 | prop_0s_each_sample = (z.abs() == 0).float().mean(dims) 296 | if reduction == 'sum': 297 | return prop_0s_each_sample.sum() 298 | if reduction == 'mean': 299 | return prop_0s_each_sample.mean() 300 | 301 | def L1(z, reduction='mean', grad=False): 302 | if not(grad): 303 | z = z.detach() 304 | if reduction == 'sum': 305 | return torch.norm(z, p=1, dim=1).sum() 306 | elif reduction == 'mean': 307 | return torch.norm(z, p=1, dim=1).mean() 308 | 309 | def stop_early(z_old, z_new, tolerance, absolute=False): 310 | if tolerance == 0: 311 | device = torch.device("cuda" if z_old.is_cuda else "cpu") 312 | shape = (z_old.shape[0], 1) if len(z_old.shape) == 2 else (z_old.shape[0], 1, 1, 1) 313 | return torch.zeros(size=shape, device=device) 314 | with torch.no_grad(): 315 | code_dim = 1 if len(z_old.shape) == 2 else (1, 2, 3) 316 | if absolute: 317 | diff = torch.norm(z_old - z_new, p=2, dim=code_dim) / z_old[0].numel() 318 | else: 319 | diff = torch.norm(z_old - z_new, p=2, dim=code_dim) / torch.norm(z_old, p=2, dim=code_dim) 320 | if len(z_old.shape) == 2: 321 | return (diff < tolerance).float().unsqueeze(-1) 322 | else: 323 | return (diff < tolerance).float().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 324 | 325 | def set_random_seed(seed, torch, np, random, cuda): 326 | torch.manual_seed(seed) 327 | np.random.seed(seed) 328 | random.seed(seed) 329 | if cuda: 330 | torch.cuda.manual_seed_all(seed) 331 | 332 | def normalize_kernels(net, radius): 333 | # Set kernels to have fixed norm equal to radius 334 | if radius > 0 and net.training: 335 | for _, module in net.named_modules(): 336 | if type(module) == nn.Linear: 337 | with torch.no_grad(): 338 | W = module.weight.data 339 | norms = W.norm(p=2, dim=0) 340 | mask = norms / radius 341 | module.weight.data /= mask 342 | if type(module) == nn.Conv2d: 343 | with torch.no_grad(): 344 | W = module.weight.data 345 | norm = W.norm(p=2, dim=[0, 2, 3]) 346 | module.weight.data /= norm.unsqueeze(0).unsqueeze(2).unsqueeze(2) * (1 / radius) 347 | 348 | def save_img(tensor, name, norm, n_rows=16, scale_each=False): 349 | save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each) 350 | 351 | def img_grid(tensor, norm=False, scale_each=False, n_rows=16): 352 | return make_grid(tensor, nrow=n_rows, padding=5, normalize=norm, range=None, scale_each=scale_each, pad_value=1) 353 | 354 | def my_iterator(args, data, log_interval): 355 | for epoch in range(args.epochs): 356 | for batch_id, (X, extra) in enumerate(data): 357 | batch = {} 358 | batch['batch_id'] = batch_id 359 | batch['X'] = X 360 | if args.dataset == 'imagenet_LCN': 361 | batch['extra'] = extra 362 | if args.dataset == 'MNIST': 363 | batch['target'] = extra 364 | 365 | batch_info = {} 366 | batch_info['epoch'] = epoch 367 | batch_info['size'] = X.shape[0] 368 | 369 | should = {} 370 | should['epoch_start'] = batch_id == 0 # first batch of epoch 371 | should['epoch_end'] = batch_id == len(data) - 1 # last batch of epoch 372 | should['log_train_imgs'] = epoch % log_interval == 0 or epoch == args.epochs - 1 373 | yield batch, batch_info, should 374 | 375 | def my_iterator_val(args, data, log_interval, epoch=0): 376 | for batch_id, (X, extra) in enumerate(data): 377 | batch = {} 378 | batch['batch_id'] = batch_id 379 | batch['X'] = X 380 | if args.dataset == 'imagenet_LCN': 381 | batch['extra'] = extra 382 | if args.dataset == 'MNIST': 383 | batch['target'] = extra 384 | 385 | batch_info = {} 386 | batch_info['size'] = X.shape[0] 387 | 388 | should = {} 389 | should['val_start'] = batch_id == 0 # first batch of epoch 390 | should['val_end'] = batch_id == len(data) - 1 # last batch of epoch 391 | should['log_val_imgs'] = epoch % log_interval == 0 or epoch == args.epochs - 1 392 | yield batch, batch_info, should 393 | 394 | def hinge(input, threshold=1.0, reduction='sum'): 395 | # Hinge loss implementation 396 | diff = F.relu(threshold - input) 397 | diff = diff**2 398 | if reduction == 'sum': 399 | loss = diff.sum() 400 | elif reduction == 'mean': 401 | loss = diff.mean() 402 | return loss 403 | 404 | def add_noise_to_img(y, noise_level, torch): 405 | # Noise 406 | noise = noise_level * torch.randn(y.shape, device=y.device) 407 | 408 | # Add noise to input 409 | y_noisy = y + noise 410 | 411 | # Normalize noisy image 412 | return y_noisy 413 | 414 | def log_viz(decoder, writer, n_samples, y, y_hat, Zs, stats, 415 | img_dir, decoder_arch, dataset, viz_type, log_all=False): 416 | # Log target 417 | y_img = inverse_transform(y[:n_samples], dataset) 418 | save_img(y_img, f'{img_dir}/{viz_type}_X.png', norm=False) 419 | writer.add_image(f'{viz_type}/X', img_grid(y_img)) 420 | 421 | # Log reconstructions 422 | y_hat_img = inverse_transform(y_hat[:n_samples], dataset).clamp_(min=0, max=1) 423 | save_img(y_hat_img, f'{img_dir}/{viz_type}_X_rec.png', norm=False) 424 | writer.add_image(f'{viz_type}/X_rec', img_grid(y_hat_img)) 425 | 426 | if log_all: 427 | n_samples = min(256, decoder.code_dim) 428 | # Log decoder columns 429 | cols = decoder.viz_columns(n_samples, norm_each=True) 430 | save_img(cols, f'{img_dir}/{viz_type}_top_layer_norm_each.png', 431 | norm=False, n_rows=int(2 ** (np.log2(n_samples) // 2))) 432 | writer.add_image(f'{viz_type}/top_layer_norm_each', img_grid(cols)) 433 | cols = decoder.viz_columns(n_samples, norm_each=False) 434 | save_img(cols, f'{img_dir}/{viz_type}_top_layer_norm_all.png', 435 | norm=False, n_rows=int(2 ** (np.log2(n_samples) // 2))) 436 | writer.add_image(f'{viz_type}/top_layer_norm_all', img_grid(cols)) 437 | 438 | # Log code activations 439 | if decoder_arch in ['one_hidden_decoder']: 440 | recs = inverse_transform(decoder.viz_codes(Zs.detach().max(0)[0], n_samples), 441 | dataset).clamp_(min=0, max=1) 442 | save_img(recs, f'{img_dir}/{viz_type}_code_act.png', 443 | norm=False, n_rows=int(2 ** (np.log2(n_samples) // 2))) 444 | writer.add_image(f'{viz_type}/code_act', img_grid(recs)) 445 | 446 | # Save codes for histogram 447 | np.save(f'{img_dir}/{viz_type}_codes.npy', Zs.detach().cpu().numpy()) 448 | np.save(f'{img_dir}/{viz_type}_Zs_comp_use.npy', stats['Zs_comp_use'].cpu().numpy()) 449 | 450 | def anneal_learning_rate(optimizer, epoch, lrt, ratio=0.9, frequency=2): 451 | """Sets the learning rate to the initial LR multiplied by {ratio} every {frequency} epochs""" 452 | lrt = lrt * (ratio ** (epoch // frequency)) # adjusted lrt 453 | for param_group in optimizer.param_groups: 454 | param_group['lr'] = lrt 455 | 456 | def print_final_training_msg(results_file, head, msg_pre, msg_post, noise, 457 | best_perf_tr, best_perf_val): 458 | # Save results to file 459 | final_file = open(results_file, 'w') 460 | final_file.write(head) 461 | msg_eval = f"{str(noise)}\t" \ 462 | f"NA\tNA\tNA\tNA\tNA" 463 | best_tr = f"{msg_pre}\tBEST TRAIN\t{msg_post}" \ 464 | f"{best_perf_tr.get('inf_steps', -1):.0f}\t" \ 465 | f"{best_perf_tr.get('L0_Z', -1):.3f}\t" \ 466 | f"{best_perf_tr.get('L0_H', -1):.3f}\t" \ 467 | f"{best_perf_tr.get('PSNR', -1):.3f}\t" \ 468 | f"{best_perf_tr.get('epoch', -1)}\t" \ 469 | f"{msg_eval}" 470 | final_file.write(best_tr + '\n') 471 | best_val = f"{msg_pre}\tBEST VAL\t{msg_post}" \ 472 | f"{best_perf_val.get('inf_steps', -1):.0f}\t" \ 473 | f"{best_perf_val.get('L0_Z', -1):.3f}\t" \ 474 | f"{best_perf_val.get('L0_H', -1):.3f}\t" \ 475 | f"{best_perf_val.get('PSNR', -1):.3f}\t" \ 476 | f"{best_perf_val.get('epoch', -1)}\t" \ 477 | f"{msg_eval}" 478 | final_file.write(best_val + '\n') 479 | final_file.close() 480 | 481 | # Print final message 482 | final_msg_trn = f"BEST TRAIN\t" \ 483 | f"inf_steps: {best_perf_tr.get('inf_steps', -1):.0f}\t" \ 484 | f"L0_Z: {best_perf_tr.get('L0_Z', -1):.2f}\t" \ 485 | f"L0_H: {best_perf_tr.get('L0_H', -1):.2f}\t" \ 486 | f"PSNR: {best_perf_tr.get('PSNR', -1):.2f}\t" \ 487 | f"epoch: {best_perf_tr.get('epoch', -1)}" 488 | final_msg_val = f"BEST VALID\t" \ 489 | f"inf_steps: {best_perf_val.get('inf_steps', -1):.0f}\t" \ 490 | f"L0_Z: {best_perf_val.get('L0_Z', -1):.2f}\t" \ 491 | f"L0_H: {best_perf_val.get('L0_H', -1):.2f}\t" \ 492 | f"PSNR: {best_perf_val.get('PSNR', -1):.2f}\t" \ 493 | f"epoch: {best_perf_val.get('epoch', -1)}" 494 | print(final_msg_trn + '\n' + final_msg_val) 495 | 496 | def print_final_eval_msg(results_file, msg_pre, msg_post, args_eval, best_perf_val): 497 | # Save results to file 498 | final_file = open(results_file, 'a') 499 | eval_stats = f"{args_eval.additive_noise}\t" \ 500 | f"{args_eval.L0_orig_aggr:.3f}\t" \ 501 | f"{args_eval.psnr_orig_aggr:.3f}\t" \ 502 | f"{args_eval.psnr_noisy_img_aggr:.3f}\t" \ 503 | f"{args_eval.L0_noisy_aggr:.3f}\t" \ 504 | f"{args_eval.psnr_noisy_rec_aggr:.3f}" 505 | msg_eval = f"{msg_pre}\tFINAL denoising \t{msg_post}" \ 506 | f"{best_perf_val.get('inf_steps', -1):.0f}\t" \ 507 | f"{best_perf_val.get('L0_Z', -1):.3f}\t" \ 508 | f"{best_perf_val.get('L0_H', -1):.3f}\t" \ 509 | f"{best_perf_val.get('PSNR', -1):.3f}\t" \ 510 | f"{best_perf_val.get('epoch', -1)}\t" \ 511 | f"{eval_stats}" 512 | final_file.write(msg_eval + '\n') 513 | final_file.close() 514 | 515 | def dewhiten(y, y_mean, y_std): 516 | return y * y_std + y_mean 517 | 518 | def compute_energy(y, y_hat, Zs, sparsity_reg, variance_reg, hinge_threshold, code_reg, Zs_enc): 519 | # Function computing the energy minimized during inference 520 | with torch.no_grad(): 521 | # Reconstruction + L1 norm energy 522 | energy = MSE(y, y_hat, reduction='sum') + sparsity_reg * L1(Zs, reduction='sum') 523 | # Variance regularization energy 524 | if variance_reg > 0: 525 | variance_term = hinge(input=sqrt_var(Zs.detach()), threshold=hinge_threshold, reduction='sum') 526 | energy += variance_reg * variance_term 527 | # Encoder code regularization energy 528 | if code_reg > 0: 529 | enc_code_term = MSE(Zs, Zs_enc.detach(), reduction='sum') 530 | energy += code_reg * enc_code_term 531 | return energy 532 | 533 | def get_gaussian_filter(channels, device, radius, sigma, dim=2): 534 | radius = [radius] * dim 535 | sigma = [sigma] * dim 536 | kernel = 1 537 | meshgrids = torch.meshgrid( 538 | [ 539 | torch.arange(size, device=device) 540 | for size in radius 541 | ] 542 | ) 543 | for size, std, mgrid in zip(radius, sigma, meshgrids): 544 | mean = (size - 1) / 2 545 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 546 | torch.exp(-((mgrid - mean) / std) ** 2 / 2) 547 | 548 | # Make sure sum of values in gaussian kernel equals 1. 549 | kernel = kernel / torch.sum(kernel) 550 | 551 | # Reshape to depthwise convolutional weight 552 | kernel = kernel.view(1, 1, *kernel.size()) 553 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 554 | 555 | return kernel 556 | 557 | def LocalContrastNorm(image, gaussian_filter, padding=False): 558 | """ 559 | INPUTS 560 | images: torch.Tensor of shape (N, ch, h, w) 561 | gaussian_filter: gaussian filter of size (ch,radius,radius) 562 | radius: Gaussian filter size (int), odd 563 | OUTPUT 564 | locally contrast normalized images of shape (N, ch, h - 2*(radius -1), w - 2*(radius -1)) or (N, ch, h, m) 565 | depending on whether padding is used 566 | Modified from: https://github.com/dibyadas/Visualize-Normalizations/blob/master/LocalContrastNorm.ipynb 567 | """ 568 | _, ch, radius, _ = gaussian_filter.shape 569 | if radius % 2 == 0: 570 | radius = radius + 1 571 | pad = radius // 2 572 | 573 | # Apply Gaussian filter to original patch 574 | if padding: 575 | # (N, ch, h, w) 576 | filter_out = F.conv2d(input=image, weight=gaussian_filter, padding=radius - 1)[:, :, pad:-pad, pad:-pad] 577 | else: 578 | # (N, ch, h - r + 1, w - r + 1) 579 | filter_out = F.conv2d(input=image, weight=gaussian_filter, padding=0) 580 | 581 | # Center 582 | if padding: 583 | # (N, ch, h, w) 584 | centered_image = image - filter_out 585 | else: 586 | # (N, ch, h - r + 1, w - r + 1) 587 | centered_image = image[:, :, pad:-pad, pad:-pad] - filter_out 588 | 589 | # Variance 590 | if padding: 591 | var = F.conv2d(centered_image.pow(2), gaussian_filter, padding=radius - 1)[:, :, pad:-pad, pad:-pad] 592 | else: 593 | # (N, ch, h - 2*(r - 1), w - 2*(r - 1)) 594 | var = F.conv2d(centered_image.pow(2), gaussian_filter, padding=0) 595 | var_pos = var >= 0 596 | var = var * var_pos + 0 * (var_pos == False) 597 | 598 | # Standard deviation 599 | st_dev = var.sqrt() 600 | st_dev_mean = st_dev.mean() 601 | gr_than_mean = st_dev > st_dev_mean 602 | st_dev = st_dev * gr_than_mean + st_dev_mean * (gr_than_mean == False) 603 | gr_than_min = st_dev > 1e-4 604 | st_dev = st_dev * gr_than_min + 1e-4 * (gr_than_min == False) 605 | 606 | # Divide by std 607 | if padding: 608 | new_image = centered_image / st_dev 609 | else: 610 | new_image = centered_image[:, :, pad:-pad, pad:-pad] / st_dev 611 | 612 | # Return normalized input and stats 613 | if padding: 614 | return new_image, filter_out, st_dev 615 | else: 616 | return new_image, filter_out[:, :, pad:-pad, pad:-pad], st_dev 617 | 618 | class ImageNetLCN(Dataset): 619 | 620 | def __init__(self, dataset): 621 | self.img, self.mean, self.std = dataset 622 | 623 | def __getitem__(self, index): 624 | return self.img[index], (self.mean[index], self.std[index]) 625 | 626 | def __len__(self): 627 | return self.img.shape[0] 628 | 629 | class Codes(Dataset): 630 | 631 | def __init__(self, dataset): 632 | self.codes, self.targets = dataset 633 | 634 | def __getitem__(self, index): 635 | return self.codes[index], self.targets[index] 636 | 637 | def __len__(self): 638 | return self.codes.shape[0] 639 | --------------------------------------------------------------------------------