├── .gitignore ├── README.md ├── architectures └── mnist │ ├── flatten.py │ └── lift2d.py ├── datasets ├── __init__.py ├── mnist.py └── utils_data.py ├── evaluate.py ├── evaluate_forward.py ├── examples ├── gaussian2d.ipynb ├── gaussian3d.ipynb └── utils_example.py ├── functional.py ├── images └── arch-redunet.jpg ├── load.py ├── loss.py ├── plot.py ├── redunet ├── __init__.py ├── layers │ ├── fourier1d.py │ ├── fourier2d.py │ ├── redulayer.py │ └── vector.py ├── modules.py ├── multichannel_weight.py ├── projections │ └── lift.py └── redunet.py ├── requirements.txt ├── train_forward.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode 4 | 5 | ### Emacs ### 6 | # -*- mode: gitignore; -*- 7 | *~ 8 | \#*\# 9 | /.emacs.desktop 10 | /.emacs.desktop.lock 11 | *.elc 12 | auto-save-list 13 | tramp 14 | .\#* 15 | 16 | # Org-mode 17 | .org-id-locations 18 | *_archive 19 | ltximg/** 20 | 21 | # flymake-mode 22 | *_flymake.* 23 | 24 | # eshell files 25 | /eshell/history 26 | /eshell/lastdir 27 | 28 | # elpa packages 29 | /elpa/ 30 | 31 | # reftex files 32 | *.rel 33 | 34 | # AUCTeX auto folder 35 | /auto/ 36 | 37 | # cask packages 38 | .cask/ 39 | dist/ 40 | 41 | # Flycheck 42 | flycheck_*.el 43 | 44 | # server auth directory 45 | /server/ 46 | 47 | # projectiles files 48 | .projectile 49 | 50 | # directory configuration 51 | .dir-locals.el 52 | 53 | # network security 54 | /network-security.data 55 | 56 | 57 | ### JupyterNotebooks ### 58 | # gitignore template for Jupyter Notebooks 59 | # website: http://jupyter.org/ 60 | 61 | .ipynb_checkpoints 62 | */.ipynb_checkpoints/* 63 | 64 | # IPython 65 | profile_default/ 66 | ipython_config.py 67 | 68 | # Remove previous ipynb_checkpoints 69 | # git rm -r .ipynb_checkpoints/ 70 | 71 | ### Python ### 72 | # Byte-compiled / optimized / DLL files 73 | __pycache__/ 74 | *.py[cod] 75 | *$py.class 76 | 77 | # C extensions 78 | *.so 79 | 80 | # Distribution / packaging 81 | .Python 82 | build/ 83 | develop-eggs/ 84 | downloads/ 85 | eggs/ 86 | .eggs/ 87 | lib/ 88 | lib64/ 89 | parts/ 90 | sdist/ 91 | var/ 92 | wheels/ 93 | pip-wheel-metadata/ 94 | share/python-wheels/ 95 | *.egg-info/ 96 | .installed.cfg 97 | *.egg 98 | MANIFEST 99 | 100 | # PyInstaller 101 | # Usually these files are written by a python script from a template 102 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 103 | *.manifest 104 | *.spec 105 | 106 | # Installer logs 107 | pip-log.txt 108 | pip-delete-this-directory.txt 109 | 110 | # Unit test / coverage reports 111 | htmlcov/ 112 | .tox/ 113 | .nox/ 114 | .coverage 115 | .coverage.* 116 | .cache 117 | nosetests.xml 118 | coverage.xml 119 | *.cover 120 | *.py,cover 121 | .hypothesis/ 122 | .pytest_cache/ 123 | pytestdebug.log 124 | 125 | # Translations 126 | *.mo 127 | *.pot 128 | 129 | # Django stuff: 130 | *.log 131 | local_settings.py 132 | db.sqlite3 133 | db.sqlite3-journal 134 | 135 | # Flask stuff: 136 | instance/ 137 | .webassets-cache 138 | 139 | # Scrapy stuff: 140 | .scrapy 141 | 142 | # Sphinx documentation 143 | docs/_build/ 144 | doc/_build/ 145 | 146 | # PyBuilder 147 | target/ 148 | 149 | # Jupyter Notebook 150 | 151 | # IPython 152 | 153 | # pyenv 154 | .python-version 155 | 156 | # pipenv 157 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 158 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 159 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 160 | # install all needed dependencies. 161 | #Pipfile.lock 162 | 163 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 164 | __pypackages__/ 165 | 166 | # Celery stuff 167 | celerybeat-schedule 168 | celerybeat.pid 169 | 170 | # SageMath parsed files 171 | *.sage.py 172 | 173 | # Environments 174 | .env 175 | .venv 176 | env/ 177 | venv/ 178 | ENV/ 179 | env.bak/ 180 | venv.bak/ 181 | pythonenv* 182 | 183 | # Spyder project settings 184 | .spyderproject 185 | .spyproject 186 | 187 | # Rope project settings 188 | .ropeproject 189 | 190 | # mkdocs documentation 191 | /site 192 | 193 | # mypy 194 | .mypy_cache/ 195 | .dmypy.json 196 | dmypy.json 197 | 198 | # Pyre type checker 199 | .pyre/ 200 | 201 | # pytype static type analyzer 202 | .pytype/ 203 | 204 | # profiling data 205 | .prof 206 | 207 | ### SublimeText ### 208 | # Cache files for Sublime Text 209 | *.tmlanguage.cache 210 | *.tmPreferences.cache 211 | *.stTheme.cache 212 | 213 | # Workspace files are user-specific 214 | *.sublime-workspace 215 | 216 | # Project files should be checked into the repository, unless a significant 217 | # proportion of contributors will probably not be using Sublime Text 218 | # *.sublime-project 219 | 220 | # SFTP configuration file 221 | sftp-config.json 222 | 223 | # Package control specific files 224 | Package Control.last-run 225 | Package Control.ca-list 226 | Package Control.ca-bundle 227 | Package Control.system-ca-bundle 228 | Package Control.cache/ 229 | Package Control.ca-certs/ 230 | Package Control.merged-ca-bundle 231 | Package Control.user-ca-bundle 232 | oscrypto-ca-bundle.crt 233 | bh_unicode_properties.cache 234 | 235 | # Sublime-github package stores a github token in this file 236 | # https://packagecontrol.io/packages/sublime-github 237 | GitHub.sublime-settings 238 | 239 | ### Vim ### 240 | # Swap 241 | [._]*.s[a-v][a-z] 242 | !*.svg # comment out if you don't need vector files 243 | [._]*.sw[a-p] 244 | [._]s[a-rt-v][a-z] 245 | [._]ss[a-gi-z] 246 | [._]sw[a-p] 247 | 248 | # Session 249 | Session.vim 250 | Sessionx.vim 251 | 252 | # Temporary 253 | .netrwhist 254 | # Auto-generated tag files 255 | tags 256 | # Persistent undo 257 | [._]*.un~ 258 | 259 | ### VisualStudioCode ### 260 | .vscode/* 261 | !.vscode/tasks.json 262 | !.vscode/launch.json 263 | *.code-workspace 264 | 265 | ### VisualStudioCode Patch ### 266 | # Ignore all local history of files 267 | .history 268 | .ionide 269 | 270 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode 271 | 272 | 273 | data/ 274 | saved_models/ 275 | transfer.sh 276 | transfer/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Networks from the Principle of Rate Reduction 2 | This repository is the official PyTorch implementation of the paper [ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction](https://arxiv.org/abs/2105.10446) (2021) 3 | 4 | by [Kwan Ho Ryan Chan](https://ryanchankh.github.io)* (UC Berkeley), [Yaodong Yu](https://yaodongyu.github.io/)* (UC Berkeley), [Chong You](https://sites.google.com/view/cyou)* (UC Berkeley), [Haozhi Qi](https://haozhi.io/) (UC Berkeley), [John Wright](http://www.columbia.edu/~jw2966/) (Columbia University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). 5 | 6 | 7 | 8 | ## What is ReduNet? 9 | ReduNet is a deep neural network construcuted naturally by deriving the gradients of the Maximal Coding Rate Reduction (MCR2) [1] objective. Every layer of this network can be interpreted based on its mathematical operations and the network collectively is trained in a feed-forward manner only. In addition, by imposing shift invariant properties to our network, the convolutional operator can be derived using only the data and MCR2 objective function, hence making our network design principled and interpretable. 10 | 11 |

12 |
13 | Figure: Weights and operations for one layer of ReduNet 14 |

15 |

16 | 17 | [1] Yu, Yaodong, Kwan Ho Ryan Chan, Chong You, Chaobing Song, and Yi Ma. "[Learning diverse and discriminative representations via the principle of maximal coding rate reduction](https://proceedings.neurips.cc/paper/2020/file/6ad4174eba19ecb5fed17411a34ff5e6-Paper.pdf)" Advances in Neural Information Processing Systems 33 (2020). 18 | 19 | ## Requirements 20 | This codebase is written for `python3`. To install necessary python packages, run `conda create --name redunet_official --file requirements.txt`. 21 | 22 | ## Demo 23 | For a quick demonstration of ReduNet on Gaussian 2D or 3D cases, please visit the notebook by running one of the two commands: 24 | 25 | ``` 26 | $ jupyter notebook ./examples/gaussian2d.ipynb 27 | $ jupyter notebook ./examples/gaussian3d.ipynb 28 | ``` 29 | 30 | ## Core Usage and Design 31 | The design of this repository aims to be easy-to-use and easy-to-intergrate to the current framework of your experiment, as long as it uses PyTorch. The `ReduNet` object inherents from `nn.Sequential`, and layers `ReduLayers`, such as `Vector`, `Fourier1D` and `Fourier2D` inherent from `nn.Module`. Loss functions are implemented in `loss.py`. Architectures and Dataset options are located in `load.py` file. Data objects and pre-set architectures are loaded in folders `dataset` and `architectures`. Feel free to add more based on the experiments you want to run. We have provided basic experiment setups, located in `train_.py` and `evaluate_.py`, where `` is the type of experiment. For utility functions, please check out `functional.py` or `utils.py`. Feel free to email us if there are any issues or suggestions. 32 | 33 | 34 | ## Example: Forward Construction 35 | To train a ReduNet using forward construction, please checkout `train_forward.py`. For evaluating, please checkout `evaluate_forward.py`. For example, to train on 40-layer ReduNet on MNIST using 1000 samples per class, run: 36 | 37 | ``` 38 | $ python3 train_forward.py --data mnistvector --arch layers50 --samples 1000 39 | ``` 40 | After training, you can evaluate the trained model using `evaluate_forward.py`, by running: 41 | 42 | ``` 43 | $ python3 evaluate_forward.py --model_dir ./saved_models/forward/mnistvector+layers50/samples1000 44 | ``` 45 | , which will evaluate using all available training samples and testing samples. For more training and testing options, please checkout the file `train_forward.py` and `evaluate_forward.py`. 46 | 47 | ### Experiments in Paper 48 | For code used to generate experimental empirical results listed in our paper, please visit our other repository: [https://github.com/ryanchankh/redunet_paper](https://github.com/ryanchankh/redunet_paper) 49 | 50 | ## Reference 51 | For technical details and full experimental results, please check the [paper](https://arxiv.org/abs/2105.10446). Please consider citing our work if you find it helpful to yours: 52 | 53 | ``` 54 | @article{chan2021redunet, 55 | title={ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction}, 56 | author={Chan, Kwan Ho Ryan and Yu, Yaodong and You, Chong and Qi, Haozhi and Wright, John and Ma, Yi}, 57 | journal={arXiv preprint arXiv:2105.10446}, 58 | year={2021} 59 | } 60 | ``` 61 | 62 | ## License and Contributing 63 | - This README is formatted based on [paperswithcode](https://github.com/paperswithcode/releasing-research-code). 64 | - Feel free to post issues via Github. 65 | 66 | ## Contact 67 | Please contact [ryanchankh@berkeley.edu](ryanchankh@berkeley.edu) and [yyu@eecs.berkeley.edu](yyu@eecs.berkeley.edu) if you have any question on the codes. 68 | -------------------------------------------------------------------------------- /architectures/mnist/flatten.py: -------------------------------------------------------------------------------- 1 | from redunet import * 2 | 3 | 4 | 5 | def flatten(layers, num_classes): 6 | net = ReduNet( 7 | *[Vector(eta=0.5, 8 | eps=0.1, 9 | lmbda=500, 10 | num_classes=num_classes, 11 | dimensions=784 12 | ) for _ in range(layers)], 13 | ) 14 | return net -------------------------------------------------------------------------------- /architectures/mnist/lift2d.py: -------------------------------------------------------------------------------- 1 | from redunet import * 2 | 3 | 4 | 5 | def lift2d(channels, layers, num_classes, seed=0): 6 | net = ReduNet( 7 | Lift2D(1, channels, 9, seed=seed), 8 | *[Fourier2D(eta=0.5, 9 | eps=0.1, 10 | lmbda=500, 11 | num_classes=num_classes, 12 | dimensions=(channels, 28, 28) 13 | ) for _ in range(layers)], 14 | ) 15 | return net 16 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/ReduNet/f67a348e5b54b9e60b783ceec6fe0e379bc3b96f/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader 4 | from .utils_data import filter_class 5 | 6 | 7 | 8 | 9 | 10 | 11 | def mnist2d_10class(data_dir): 12 | transform = transforms.Compose([ 13 | transforms.ToTensor(), 14 | ]) 15 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 16 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 17 | num_classes = 10 18 | return trainset, testset, num_classes 19 | 20 | def mnist2d_5class(data_dir): 21 | transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | ]) 24 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 25 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 26 | trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4]) 27 | testset, _ = filter_class(testset, [0, 1, 2, 3, 4]) 28 | num_classes = 5 29 | return trainset, testset, num_classes 30 | 31 | def mnist2d_2class(data_dir): 32 | transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | ]) 35 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 36 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 37 | trainset, num_classes = filter_class(trainset, [0, 1]) 38 | testset, _ = filter_class(testset, [0, 1]) 39 | return trainset, testset, num_classes 40 | 41 | def mnistvector_10class(data_dir): 42 | transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Lambda(lambda x: x.flatten()) 45 | ]) 46 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 47 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 48 | num_classes = 10 49 | return trainset, testset, num_classes 50 | 51 | def mnistvector_5class(data_dir): 52 | transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Lambda(lambda x: x.flatten()) 55 | ]) 56 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 57 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 58 | trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4]) 59 | testset, _ = filter_class(testset, [0, 1, 2, 3, 4]) 60 | return trainset, testset, num_classes 61 | 62 | def mnistvector_2class(data_dir): 63 | transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Lambda(lambda x: x.flatten()) 66 | ]) 67 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True) 68 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True) 69 | trainset, num_classes = filter_class(trainset, [0, 1]) 70 | testset, _ = filter_class(testset, [0, 1]) 71 | return trainset, testset, num_classes 72 | 73 | 74 | if __name__ == '__main__': 75 | trainset, testset, num_classes = mnist2d_2class('./data/') 76 | trainloader = DataLoader(trainset, batch_size=trainset.data.shape[0]) 77 | print(trainset) 78 | print(testset) 79 | print(num_classes) 80 | 81 | batch_imgs, batch_lbls = next(iter(trainloader)) 82 | print(batch_imgs.shape, batch_lbls.shape) 83 | print(batch_lbls.unique(return_counts=True)) 84 | -------------------------------------------------------------------------------- /datasets/utils_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | 6 | def filter_class(dataset, classes): 7 | data, labels = dataset.data, dataset.targets 8 | if type(labels) == list: 9 | labels = torch.tensor(labels) 10 | data_filter = [] 11 | labels_filter = [] 12 | for _class in classes: 13 | idx = labels == _class 14 | data_filter.append(data[idx]) 15 | labels_filter.append(labels[idx]) 16 | if type(dataset.data) == np.ndarray: 17 | dataset.data = np.vstack(data_filter) 18 | dataset.targets = np.hstack(labels_filter) 19 | elif type(dataset.data) == torch.Tensor: 20 | dataset.data = torch.cat(data_filter) 21 | dataset.targets = torch.cat(labels_filter) 22 | else: 23 | raise TypeError('dataset.data type neither np.ndarray nor torch.Tensor') 24 | return dataset, len(classes) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as sps 3 | import torch 4 | 5 | from sklearn.svm import LinearSVC 6 | from sklearn.decomposition import PCA 7 | from sklearn.decomposition import TruncatedSVD 8 | from sklearn.linear_model import SGDClassifier 9 | from sklearn.svm import LinearSVC, SVC 10 | from sklearn.tree import DecisionTreeClassifier 11 | from sklearn.ensemble import RandomForestClassifier 12 | 13 | import functional as F 14 | import utils 15 | 16 | 17 | 18 | def evaluate(eval_dir, method, train_features, train_labels, test_features, test_labels, **kwargs): 19 | if method == 'svm': 20 | acc_train, acc_test = svm(train_features, train_labels, test_features, test_labels) 21 | elif method == 'knn': 22 | acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs) 23 | elif method == 'nearsub': 24 | acc_train, acc_test = nearsub(train_features, train_labels, test_features, test_labels, **kwargs) 25 | elif method == 'nearsub_pca': 26 | acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs) 27 | acc_dict = {'train': acc_train, 'test': acc_test} 28 | utils.save_params(eval_dir, acc_dict, name=f'acc_{method}') 29 | 30 | def svm(train_features, train_labels, test_features, test_labels): 31 | svm = LinearSVC(verbose=0, random_state=10) 32 | svm.fit(train_features, train_labels) 33 | acc_train = svm.score(train_features, train_labels) 34 | acc_test = svm.score(test_features, test_labels) 35 | print("SVM: {}, {}".format(acc_train, acc_test)) 36 | return acc_train, acc_test 37 | 38 | # def knn(train_features, train_labels, test_features, test_labels, k=5): 39 | # sim_mat = train_features @ train_features.T 40 | # topk = torch.from_numpy(sim_mat).topk(k=k, dim=0) 41 | # topk_pred = train_labels[topk.indices] 42 | # test_pred = torch.tensor(topk_pred).mode(0).values.detach() 43 | # acc_train = compute_accuracy(test_pred.numpy(), train_labels) 44 | 45 | # sim_mat = train_features @ test_features.T 46 | # topk = torch.from_numpy(sim_mat).topk(k=k, dim=0) 47 | # topk_pred = train_labels[topk.indices] 48 | # test_pred = torch.tensor(topk_pred).mode(0).values.detach() 49 | # acc_test = compute_accuracy(test_pred.numpy(), test_labels) 50 | # print("kNN: {}, {}".format(acc_train, acc_test)) 51 | # return acc_train, acc_test 52 | 53 | def knn(train_features, train_labels, test_features, test_labels, k=5): 54 | sim_mat = train_features @ train_features.T 55 | topk = sim_mat.topk(k=k, dim=0) 56 | topk_pred = train_labels[topk.indices] 57 | test_pred = topk_pred.mode(0).values.detach() 58 | acc_train = compute_accuracy(test_pred, train_labels) 59 | 60 | sim_mat = train_features @ test_features.T 61 | topk = sim_mat.topk(k=k, dim=0) 62 | topk_pred = train_labels[topk.indices] 63 | test_pred = topk_pred.mode(0).values.detach() 64 | acc_test = compute_accuracy(test_pred, test_labels) 65 | print("kNN: {}, {}".format(acc_train, acc_test)) 66 | return acc_train, acc_test 67 | 68 | # # TODO: 1. implement pytorch version 2. suport batches 69 | # def nearsub(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10, return_pred=False): 70 | # train_scores, test_scores = [], [] 71 | # classes = np.arange(num_classes) 72 | # features_sort, _ = utils.sort_dataset(train_features, train_labels, 73 | # classes=classes, stack=False) 74 | # fd = features_sort[0].shape[1] 75 | # if n_comp >= fd: 76 | # n_comp = fd - 1 77 | # for j in classes: 78 | # svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j]) 79 | # subspace_j = np.eye(fd) - svd.components_.T @ svd.components_ 80 | # train_j = subspace_j @ train_features.T 81 | # test_j = subspace_j @ test_features.T 82 | # train_scores_j = np.linalg.norm(train_j, ord=2, axis=0) 83 | # test_scores_j = np.linalg.norm(test_j, ord=2, axis=0) 84 | # train_scores.append(train_scores_j) 85 | # test_scores.append(test_scores_j) 86 | # train_pred = np.argmin(train_scores, axis=0) 87 | # test_pred = np.argmin(test_scores, axis=0) 88 | # if return_pred: 89 | # return train_pred.tolist(), test_pred.tolist() 90 | # train_acc = compute_accuracy(classes[train_pred], train_labels) 91 | # test_acc = compute_accuracy(classes[test_pred], test_labels) 92 | # print('SVD: {}, {}'.format(train_acc, test_acc)) 93 | # return train_acc, test_acc 94 | 95 | def nearsub(train_features, train_labels, test_features, test_labels, 96 | num_classes, n_comp=10, return_pred=False): 97 | train_scores, test_scores = [], [] 98 | classes = np.arange(num_classes) 99 | features_sort, _ = utils.sort_dataset(train_features, train_labels, 100 | classes=classes, stack=False) 101 | fd = features_sort[0].shape[1] 102 | for j in classes: 103 | _, _, V = torch.svd(features_sort[j]) 104 | components = V[:, :n_comp].T 105 | subspace_j = torch.eye(fd) - components.T @ components 106 | train_j = subspace_j @ train_features.T 107 | test_j = subspace_j @ test_features.T 108 | train_scores_j = torch.linalg.norm(train_j, ord=2, axis=0) 109 | test_scores_j = torch.linalg.norm(test_j, ord=2, axis=0) 110 | train_scores.append(train_scores_j) 111 | test_scores.append(test_scores_j) 112 | train_pred = torch.stack(train_scores).argmin(0) 113 | test_pred = torch.stack(test_scores).argmin(0) 114 | if return_pred: 115 | return train_pred.numpy(), test_pred.numpy() 116 | train_acc = compute_accuracy(classes[train_pred], train_labels.numpy()) 117 | test_acc = compute_accuracy(classes[test_pred], test_labels.numpy()) 118 | print('SVD: {}, {}'.format(train_acc, test_acc)) 119 | return train_acc, test_acc 120 | 121 | def nearsub_pca(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10): 122 | scores_pca = [] 123 | classes = np.arange(num_classes) 124 | features_sort, _ = utils.sort_dataset(train_features, train_labels, classes=classes, stack=False) 125 | fd = features_sort[0].shape[1] 126 | if n_comp >= fd: 127 | n_comp = fd - 1 128 | for j in np.arange(len(classes)): 129 | pca = PCA(n_components=n_comp).fit(features_sort[j]) 130 | pca_subspace = pca.components_.T 131 | mean = np.mean(features_sort[j], axis=0) 132 | pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \ 133 | @ (test_features - mean).T 134 | score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0) 135 | scores_pca.append(score_pca_j) 136 | test_predict_pca = np.argmin(scores_pca, axis=0) 137 | acc_pca = compute_accuracy(classes[test_predict_pca], test_labels) 138 | print('PCA: {}'.format(acc_pca)) 139 | return acc_pca 140 | 141 | def argmax(train_features, train_labels, test_features, test_labels): 142 | train_pred = train_features.argmax(1) 143 | train_acc = compute_accuracy(train_pred, train_labels) 144 | test_pred = test_features.argmax(1) 145 | test_acc = compute_accuracy(test_pred, test_labels) 146 | return train_acc, test_acc 147 | 148 | def compute_accuracy(y_pred, y_true): 149 | """Compute accuracy by counting correct classification. """ 150 | assert y_pred.shape == y_true.shape 151 | if type(y_pred) == torch.Tensor: 152 | n_wrong = torch.count_nonzero(y_pred - y_true).item() 153 | elif type(y_pred) == np.ndarray: 154 | n_wrong = np.count_nonzero(y_pred - y_true) 155 | else: 156 | raise TypeError("Not Tensor nor Array type.") 157 | n_samples = len(y_pred) 158 | return 1 - n_wrong / n_samples 159 | 160 | def baseline(train_features, train_labels, test_features, test_labels): 161 | test_models = {'log_l2': SGDClassifier(loss='log', max_iter=10000, random_state=42), 162 | 'SVM_linear': LinearSVC(max_iter=10000, random_state=42), 163 | 'SVM_RBF': SVC(kernel='rbf', random_state=42), 164 | 'DecisionTree': DecisionTreeClassifier(), 165 | 'RandomForrest': RandomForestClassifier()} 166 | for model_name in test_models: 167 | test_model = test_models[model_name] 168 | test_model.fit(train_features, train_labels) 169 | score = test_model.score(test_features, test_labels) 170 | print(f"{model_name}: {score}") 171 | 172 | def majority_vote(pred, true): 173 | pred_majority = sps.mode(pred, axis=0)[0].squeeze() 174 | return compute_accuracy(pred_majority, true) 175 | -------------------------------------------------------------------------------- /evaluate_forward.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from redunet import * 8 | import evaluate 9 | import functional as F 10 | import load as L 11 | import utils 12 | import plot 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model_dir', type=str, help='model directory') 18 | parser.add_argument('--loss', default=False, action='store_true', help='set to True if plot loss') 19 | parser.add_argument('--trainsamples', type=int, default=None, help="number of train samples in each class") 20 | parser.add_argument('--testsamples', type=int, default=None, help="number of train samples in each class") 21 | parser.add_argument('--translatetrain', default=False, action='store_true', help='set to True if translation train samples') 22 | parser.add_argument('--translatetest', default=False, action='store_true', help='set to True if translation test samples') 23 | parser.add_argument('--batch_size', type=int, default=100, help='batch size for evaluation') 24 | args = parser.parse_args() 25 | 26 | ## CUDA 27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 28 | 29 | ## Setup 30 | eval_dir = os.path.join(args.model_dir, 31 | f'trainsamples{args.trainsamples}' 32 | f'_testsamples{args.testsamples}' 33 | f'_translatetrain{args.translatetrain}' 34 | f'_translatetest{args.translatetest}') 35 | params = utils.load_params(args.model_dir) 36 | 37 | ## Data 38 | trainset, testset, num_classes = L.load_dataset(params['data'], data_dir=params['data_dir']) 39 | X_train, y_train = F.get_samples(trainset, args.trainsamples) 40 | X_test, y_test = F.get_samples(testset, args.testsamples) 41 | if args.translatetrain: 42 | X_train, y_train = F.translate(X_train, y_train, stride=7) 43 | if args.translatetest: 44 | X_test, y_test = F.translate(X_test, y_test, stride=7) 45 | X_train, y_train = X_train.to(device), y_train.to(device) 46 | X_test, y_test = X_test.to(device), y_test.to(device) 47 | 48 | ## Architecture 49 | net = L.load_architecture(params['data'], params['arch']) 50 | net = utils.load_ckpt(args.model_dir, 'model', net) 51 | net = net.to(device) 52 | 53 | ## Forward 54 | with torch.no_grad(): 55 | print('train') 56 | Z_train = net.batch_forward(X_train, batch_size=args.batch_size, loss=args.loss, device=device) 57 | X_train, y_train, Z_train = F.to_cpu(X_train, y_train, Z_train) 58 | utils.save_loss(eval_dir, f'train', net.get_loss()) 59 | 60 | print('test') 61 | Z_test = net.batch_forward(X_test, batch_size=args.batch_size, loss=args.loss, device=device) 62 | X_test, y_test, Z_test = F.to_cpu(X_test, y_test, Z_test) 63 | utils.save_loss(eval_dir, f'test', net.get_loss()) 64 | 65 | ## Normalize 66 | X_train = F.normalize(X_train.flatten(1)) 67 | X_test = F.normalize(X_test.flatten(1)) 68 | Z_train = F.normalize(Z_train.flatten(1)) 69 | Z_test = F.normalize(Z_test.flatten(1)) 70 | 71 | # Evaluate 72 | evaluate.evaluate(eval_dir, 'knn', Z_train, y_train, Z_test, y_test) 73 | #evaluate.evaluate(eval_dir, 'nearsub', Z_train, y_train, Z_test, y_test, num_classes=num_classes, n_comp=10) 74 | 75 | # Plot 76 | plot.plot_loss_mcr(eval_dir, 'train') 77 | plot.plot_loss_mcr(eval_dir, 'test') 78 | plot.plot_heatmap(eval_dir, 'X_train', X_train, y_train, num_classes) 79 | plot.plot_heatmap(eval_dir, 'X_test', X_test, y_test, num_classes) 80 | plot.plot_heatmap(eval_dir, 'Z_train', Z_train, y_train, num_classes) 81 | plot.plot_heatmap(eval_dir, 'Z_test', Z_test, y_test, num_classes) 82 | 83 | -------------------------------------------------------------------------------- /examples/gaussian2d.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.chdir('../')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from redunet import ReduNetVector\n", 20 | "import utils_example as ue\n", 21 | "import plot" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Hyperparameters" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "## Data\n", 38 | "dataset = 1 # can be 1 to 8\n", 39 | "train_noise = 0.1\n", 40 | "test_noise = 0.1\n", 41 | "train_samples = 100\n", 42 | "test_samples = 100\n", 43 | "\n", 44 | "## Model\n", 45 | "num_layers = 200 # number of redunet layers\n", 46 | "eta = 0.5\n", 47 | "eps = 0.1\n", 48 | "lmbda = 200" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Data" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "scrolled": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "X_train, y_train, num_classes = ue.generate_2d(dataset, train_noise, train_samples) # train\n", 67 | "X_test, y_test, num_classes = ue.generate_2d(dataset, test_noise, test_samples) # test" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "net = ReduNetVector(num_classes, num_layers, X_train.shape[1], eta=eta, eps=eps, lmbda=lmbda)\n", 84 | "Z_train = net.init(X_train, y_train)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "ue.plot_loss_mcr(net.get_loss())\n", 94 | "ue.plot_2d(X_train, y_train, Z_train) " 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "scrolled": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "Z_test = net(X_test).detach()\n", 106 | "ue.plot_2d(X_test, y_test, Z_test)" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "redunet", 113 | "language": "python", 114 | "name": "redunet" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.7.9" 127 | }, 128 | "latex_envs": { 129 | "LaTeX_envs_menu_present": true, 130 | "autoclose": false, 131 | "autocomplete": true, 132 | "bibliofile": "biblio.bib", 133 | "cite_by": "apalike", 134 | "current_citInitial": 1, 135 | "eqLabelWithNumbers": true, 136 | "eqNumInitial": 1, 137 | "hotkeys": { 138 | "equation": "Ctrl-E", 139 | "itemize": "Ctrl-I" 140 | }, 141 | "labels_anchors": false, 142 | "latex_user_defs": false, 143 | "report_style_numbering": false, 144 | "user_envs_cfg": false 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 2 149 | } 150 | -------------------------------------------------------------------------------- /examples/gaussian3d.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.chdir('../')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from redunet import ReduNetVector\n", 20 | "import utils_example as ue\n", 21 | "import plot" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Hyperparameters" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "## Data\n", 38 | "dataset = 1 # can be 1 to 8\n", 39 | "train_noise = 0.1\n", 40 | "test_noise = 0.1\n", 41 | "train_samples = 100\n", 42 | "test_samples = 100\n", 43 | "\n", 44 | "## Model\n", 45 | "num_layers = 200 # number of redunet layers\n", 46 | "eta = 0.5\n", 47 | "eps = 0.1\n", 48 | "lmbda = 200" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Data" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "scrolled": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "X_train, y_train, num_classes = ue.generate_3d(dataset, train_noise, train_samples) # train\n", 67 | "X_test, y_test, num_classes = ue.generate_3d(dataset, test_noise, test_samples) # test" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "net = ReduNetVector(num_classes, num_layers, X_train.shape[1], eta=eta, eps=eps, lmbda=lmbda)\n", 84 | "Z_train = net.init(X_train, y_train)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "scrolled": false 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "ue.plot_loss_mcr(net.get_loss())\n", 96 | "ue.plot_3d(X_train, y_train, 'X_train') \n", 97 | "ue.plot_3d(Z_train, y_train, 'Z_train') " 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "scrolled": false 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "Z_test = net(X_test).detach()\n", 109 | "ue.plot_3d(X_test, y_test, 'X_test')\n", 110 | "ue.plot_3d(X_test, y_test, 'Z_test')" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "redunet", 124 | "language": "python", 125 | "name": "redunet" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.7.9" 138 | }, 139 | "latex_envs": { 140 | "LaTeX_envs_menu_present": true, 141 | "autoclose": false, 142 | "autocomplete": true, 143 | "bibliofile": "biblio.bib", 144 | "cite_by": "apalike", 145 | "current_citInitial": 1, 146 | "eqLabelWithNumbers": true, 147 | "eqNumInitial": 1, 148 | "hotkeys": { 149 | "equation": "Ctrl-E", 150 | "itemize": "Ctrl-I" 151 | }, 152 | "labels_anchors": false, 153 | "latex_user_defs": false, 154 | "report_style_numbering": false, 155 | "user_envs_cfg": false 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 2 160 | } 161 | -------------------------------------------------------------------------------- /examples/utils_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | def generate_2d(data, noise, samples, shuffle=False): 6 | if data == 1: 7 | centers = [(1, 0), (0, 1)] 8 | elif data == 2: 9 | centers = [(np.cos(np.pi/3), np.sin(np.pi/3)), (1 ,0)] 10 | elif data == 3: 11 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4)), (1 ,0)] 12 | elif data == 4: 13 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)), (1 ,0)] 14 | elif data == 5: 15 | centers = [(np.cos(2*np.pi/3), np.sin(2*np.pi/3)), (1 ,0)] 16 | elif data == 6: 17 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)), (np.cos(4*np.pi/3), np.sin(4*np.pi/3)), (1 ,0)] 18 | elif data == 7: 19 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)), 20 | (np.cos(4*np.pi/3), np.sin(4*np.pi/3)), 21 | (np.cos(np.pi/4), np.sin(np.pi/4))] 22 | elif data == 8: 23 | centers = [(np.cos(np.pi/6), np.sin(np.pi/6)), 24 | (np.cos(np.pi/2), np.sin(np.pi/2)), 25 | (np.cos(3*np.pi/4), np.sin(3*np.pi/4)), 26 | (np.cos(5*np.pi/4), np.sin(5*np.pi/4)), 27 | (np.cos(7*np.pi/4), np.sin(7*np.pi/4)), 28 | (np.cos(3*np.pi/2), np.sin(3*np.pi/2))] 29 | else: 30 | raise NameError('data not found.') 31 | 32 | data = [] 33 | targets = [] 34 | for lbl, center in enumerate(centers): 35 | X = np.random.normal(loc=center, scale=noise, size=(samples, 2)) 36 | y = np.repeat(lbl, samples).tolist() 37 | data.append(X) 38 | targets += y 39 | data = np.concatenate(data) 40 | data = data / np.linalg.norm(data, axis=1, ord=2, keepdims=True) 41 | targets = np.array(targets) 42 | 43 | if shuffle: 44 | idx_arr = np.random.choice(np.arange(len(data)), len(data), replace=False) 45 | data, targets = data[idx_arr], targets[idx_arr] 46 | 47 | data = torch.tensor(data).float() 48 | targets = torch.tensor(targets).long() 49 | return data, targets, len(centers) 50 | 51 | def generate_3d(data, noise, samples, shuffle=False): 52 | if data == 1: 53 | centers = [(1, 0, 0), 54 | (0, 1, 0), 55 | (0, 0, 1)] 56 | elif data == 2: 57 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4), 1), 58 | (np.cos(2*np.pi/3), np.sin(2*np.pi/3), 1), 59 | (np.cos(np.pi), np.sin(np.pi), 1)] 60 | elif data == 3: 61 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4), 1), 62 | (np.cos(2*np.pi/3), np.sin(2*np.pi/3), 1), 63 | (np.cos(5*np.pi/6), np.cos(5*np.pi/6), 1)] 64 | else: 65 | raise NameError('Data not found.') 66 | 67 | X, Y = [], [] 68 | for c, center in enumerate(centers): 69 | _X = np.random.normal(center, scale=(noise, noise, noise), size=(samples, 3)) 70 | _Y = np.ones(samples, dtype=np.int32) * c 71 | X.append(_X) 72 | Y.append(_Y) 73 | X = np.vstack(X) 74 | X = X / np.linalg.norm(X, axis=1, ord=2, keepdims=True) 75 | Y = np.hstack(Y) 76 | 77 | if shuffle: 78 | idx_arr = np.random.choice(np.arange(len(X)), len(X), replace=False) 79 | X, Y = X[idx_arr], Y[idx_arr] 80 | 81 | X = torch.tensor(X).float() 82 | Y = torch.tensor(Y).long() 83 | return X, Y, 3 84 | 85 | 86 | def plot_2d(inputs, labels, outputs): 87 | fig, ax = plt.subplots(ncols=2, figsize=(8, 4)) 88 | for c in labels.unique(): 89 | ax[0].scatter(inputs[:, 0], inputs[:, 1], c=labels) 90 | ax[0].set_ylim([-1.1, 1.1]) 91 | ax[0].set_xlim([-1.1, 1.1]) 92 | ax[0].set_title('X') 93 | ax[1].scatter(outputs[:, 0], outputs[:, 1], c=labels) 94 | ax[1].set_ylim([-1.1, 1.1]) 95 | ax[1].set_xlim([-1.1, 1.1]) 96 | ax[1].set_title('Z') 97 | fig.tight_layout() 98 | plt.show() 99 | plt.close() 100 | 101 | 102 | def plot_3d(Z, y, title=''): 103 | colors = np.array(['green', 'blue', 'red']) 104 | colors = np.array(['forestgreen', 'royalblue', 'brown']) 105 | fig = plt.figure(figsize=(5, 5)) 106 | ax = fig.add_subplot(111, projection='3d') 107 | ax.scatter(Z[:, 0], Z[:, 1], Z[:, 2], c=colors[y], cmap=plt.cm.Spectral, s=200.0) 108 | # Z, _ = F.get_n_each(Z, y, 1) 109 | # for c in np.unique(y): 110 | # ax.quiver(0.0, 0.0, 0.0, Z[c, 0], Z[c, 1], Z[c, 2], length=1.0, normalize=True, arrow_length_ratio=0.05, color='black') 111 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j] 112 | x = np.cos(u)*np.sin(v) 113 | y = np.sin(u)*np.sin(v) 114 | z = np.cos(v) 115 | ax.plot_wireframe(x, y, z, color="gray", alpha=0.5) 116 | ax.xaxis._axinfo["grid"]['color'] = (0,0,0,0.1) 117 | ax.yaxis._axinfo["grid"]['color'] = (0,0,0,0.1) 118 | ax.zaxis._axinfo["grid"]['color'] = (0,0,0,0.1) 119 | ax.set_title(title) 120 | # [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()] 121 | # [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()] 122 | # [tick.label.set_fontsize(24) for tick in ax.zaxis.get_major_ticks()] 123 | ax.view_init(20, 15) 124 | plt.tight_layout() 125 | plt.show() 126 | plt.close() 127 | 128 | 129 | 130 | def plot_loss_mcr(data): 131 | loss_total = data['loss_total'] 132 | loss_expd = data['loss_expd'] 133 | loss_comp = data['loss_comp'] 134 | num_iter = np.arange(len(loss_total)) 135 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True) 136 | ax.plot(num_iter, loss_total, label=r'$\Delta R$', 137 | color='green', linewidth=1.0, alpha=0.8) 138 | ax.plot(num_iter, loss_expd, label=r'$R$', 139 | color='royalblue', linewidth=1.0, alpha=0.8) 140 | ax.plot(num_iter, loss_comp, label=r'$R^c$', 141 | color='coral', linewidth=1.0, alpha=0.8) 142 | ax.set_ylabel('Loss', fontsize=10) 143 | ax.set_xlabel('Number of iterations', fontsize=10) 144 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5) 145 | fig.tight_layout() 146 | plt.show() 147 | plt.close() -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.datasets as datasets 9 | import torch.distributions as distributions 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as TF 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data._utils.collate import default_collate 14 | 15 | 16 | 17 | def get_features(net, trainloader, verbose=True, n=None, device='cpu'): 18 | '''Extract all features out into one single batch. 19 | 20 | Parameters: 21 | net (torch.nn.Module): get features using this model 22 | trainloader (torchvision.dataloader): dataloader for loading data 23 | up_to (int): pass through network up to a certain layer 24 | flatten (bool): if True, flatten output 25 | verbose (bool): shows loading staus bar 26 | Returns: 27 | features (torch.tensor): with dimension (num_samples, feature_dimension) 28 | labels (torch.tensor): with dimension (num_samples, ) 29 | ''' 30 | features = [] 31 | labels = [] 32 | if verbose: 33 | train_bar = tqdm(trainloader, desc="extracting all features from dataset") 34 | else: 35 | train_bar = trainloader 36 | total = 0 37 | with torch.no_grad(): 38 | for _, (batch_imgs, batch_lbls) in enumerate(train_bar): 39 | batch_imgs = batch_imgs.to(device) 40 | batch_features = net(batch_imgs) 41 | features.append(batch_features.cpu().detach()) 42 | labels.append(batch_lbls) 43 | total += len(batch_features) 44 | if n is not None and total > n: 45 | break 46 | return torch.cat(features)[:n], torch.cat(labels)[:n] 47 | 48 | def get_samples(dataset, num_samples, shuffle=False, batch_idx=0, seed=0, method='uniform'): 49 | if method == 'uniform': 50 | np.random.seed(seed) 51 | dataloader = DataLoader(dataset, batch_size=dataset.data.shape[0]) 52 | X, y = next(iter(dataloader)) 53 | if shuffle: # ensure you sample different samples 54 | idx_arr = np.random.choice(X.shape[0], y.shape[0], replace=False) 55 | X, y = X[idx_arr], y[idx_arr] 56 | if num_samples is not None: 57 | X, y = get_n_each(X, y, num_samples, batch_idx) 58 | X, y = X.float(), y.long() 59 | if len(X.shape) == 3: 60 | X = X.unsqueeze(1) 61 | return X.float(), y.long() 62 | elif method == 'first': 63 | num_classes = torch.unique(dataset.targets).size()[0] 64 | return next(iter(DataLoader(dataset, batch_size=num_classes*num_samples))) 65 | 66 | def normalize(X, p=2): 67 | if isinstance(X, torch.Tensor): 68 | norm = torch.linalg.norm(X.flatten(1), ord=p, axis=1) 69 | norm = norm.clip(min=1e-8) 70 | for _ in range(len(X.shape)-1): 71 | norm = norm.unsqueeze(-1) 72 | return X / norm 73 | elif isinstance(X, np.ndarray): 74 | norm = np.linalg.norm(X.reshape(X.shape[0], -1), ord=p, axis=1) 75 | norm = np.clip(norm, a_min=1e-8, a_max=None) 76 | for _ in range(len(X.shape)-1): 77 | norm = np.expand_dims(norm, -1) 78 | return X / norm 79 | else: 80 | raise TypeError('Input array not instances of torch.Tensor or np.ndarray') 81 | 82 | def get_n_each(X, y, n=None, batch_idx=0): 83 | classes = torch.unique(y) 84 | _X, _y = [], [] 85 | for c in classes: 86 | idx_class = (y == c) 87 | X_class, y_class = X[idx_class], y[idx_class] 88 | if n is not None: 89 | X_class = torch.roll(X_class, -batch_idx*n, dims=0) 90 | y_class = torch.roll(y_class, -batch_idx*n, dims=0) 91 | _X.append(X_class[:n]) 92 | _y.append(y_class[:n]) 93 | return torch.cat(_X), torch.cat(_y) 94 | 95 | def translate(data, labels, stride=7, n=None): 96 | if len(data.shape) == 3: 97 | return translate1d(data, labels, n=n, stride=stride) 98 | if len(data.shape) == 4: 99 | return translate2d(data, labels, n=n, stride=stride) 100 | raise ValueError('translate not available.') 101 | 102 | 103 | def translate1d(data, labels, n=None, stride=1): 104 | m, _, T = data.shape 105 | data_new = [] 106 | if n is None: 107 | shifts = range(0, T, stride) 108 | else: 109 | shifts = range(-n*stride, (n+1)*stride, stride) 110 | for t in shifts: 111 | data_new.append(torch.roll(data, t, dims=(2))) 112 | nrepeats = len(range(0, T, stride)) 113 | return (torch.cat(data_new), 114 | labels.repeat(nrepeats)) 115 | 116 | def translate2d(data, labels, n=None, stride=1): 117 | m, _, H, W = data.shape 118 | if n is None: 119 | shifts_horizontal = range(0, H, stride) 120 | shifts_vertical = range(0, H, stride) 121 | else: 122 | shifts_horizontal = range(-n*stride, (n+1)*stride, stride) 123 | shifts_vertical = range(-n*stride, (n+1)*stride, stride) 124 | data_new = [] 125 | for h in shifts_horizontal: 126 | for w in shifts_vertical: 127 | data_new.append(torch.roll(data, (h, w), dims=(2, 3))) 128 | nrepeats = len(shifts_vertical) * len(shifts_horizontal) 129 | return (torch.cat(data_new), 130 | labels.repeat(nrepeats)) 131 | 132 | def cart2polar(images_cart, channels, timesteps): 133 | m, C, H, W = images_cart.shape 134 | mid_pt = int(H // 2) 135 | R = torch.linspace(0, mid_pt, channels).long() 136 | thetas = torch.linspace(0, 360, timesteps).float() 137 | images_polar = [] 138 | for theta in thetas: 139 | image_rotated = TF.rotate(images_cart, theta.item()) 140 | images_polar.append(image_rotated[:, :, mid_pt, R]) 141 | return torch.cat(images_polar, axis=1).transpose(1, 2) 142 | 143 | def step_lr(epochs, init, gamma, steps): 144 | """learning rate decay 145 | epochs: total number of epochs 146 | gamma: multiplicative decay 147 | step: decay at which steps 148 | init: initial learning rate 149 | """ 150 | rates = np.ones(epochs) * init 151 | for step in steps: 152 | rates[step:] = rates[step:] * gamma 153 | return rates 154 | 155 | def corrupt_labels(trainset, num_classes, ratio, seed): 156 | """Corrupt labels in trainset. 157 | 158 | Parameters: 159 | trainset (torch.data.dataset): trainset where labels is stored 160 | ratio (float): ratio of labels to be corrupted. 0 to corrupt no labels; 161 | 1 to corrupt all labels 162 | seed (int): random seed for reproducibility 163 | 164 | Returns: 165 | trainset (torch.data.dataset): trainset with updated corrupted labels 166 | 167 | """ 168 | 169 | np.random.seed(seed) 170 | train_labels = np.asarray(trainset.targets) 171 | n_train = len(train_labels) 172 | n_rand = int(len(trainset.data)*ratio) 173 | randomize_indices = np.random.choice(range(n_train), size=n_rand, replace=False) 174 | train_labels[randomize_indices] = np.random.choice(np.arange(num_classes), size=n_rand, replace=True) 175 | trainset.targets = torch.tensor(train_labels).int() 176 | return trainset 177 | 178 | def to_cpu(*gpu_vars): 179 | cpu_vars = [] 180 | for var in gpu_vars: 181 | cpu_vars.append(var.detach().cpu()) 182 | return cpu_vars 183 | -------------------------------------------------------------------------------- /images/arch-redunet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/ReduNet/f67a348e5b54b9e60b783ceec6fe0e379bc3b96f/images/arch-redunet.jpg -------------------------------------------------------------------------------- /load.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from os import listdir 4 | from os.path import join, isfile, isdir, expanduser 5 | from tqdm import tqdm 6 | 7 | import pandas as pd 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | import torchvision.datasets as datasets 15 | import torchvision.transforms as transforms 16 | 17 | import functional as F 18 | from redunet import * 19 | 20 | 21 | def load_architecture(data, arch, seed=0): 22 | if data == 'mnist2d': 23 | if arch == 'lift2d_channels35_layers5': 24 | from architectures.mnist.lift2d import lift2d 25 | return lift2d(channels=35, layers=5, num_classes=10, seed=seed) 26 | if arch == 'lift2d_channels35_layers10': 27 | from architectures.mnist.lift2d import lift2d 28 | return lift2d(channels=35, layers=5, num_classes=10, seed=seed) 29 | if arch == 'lift2d_channels35_layers20': 30 | from architectures.mnist.lift2d import lift2d 31 | return lift2d(channels=35, layers=20, num_classes=10, seed=seed) 32 | if arch == 'lift2d_channels55_layers5': 33 | from architectures.mnist.lift2d import lift2d 34 | return lift2d(channels=55, layers=5, num_classes=10, seed=seed) 35 | if arch == 'lift2d_channels55_layers10': 36 | from architectures.mnist.lift2d import lift2d 37 | return lift2d(channels=55, layers=5, num_classes=10, seed=seed) 38 | if arch == 'lift2d_channels55_layers20': 39 | from architectures.mnist.lift2d import lift2d 40 | return lift2d(channels=55, layers=20, num_classes=10, seed=seed) 41 | if data == 'mnist2d+2class': 42 | if arch == 'lift2d_channels35_layers5': 43 | from architectures.mnist.lift2d import lift2d 44 | return lift2d(channels=35, layers=5, num_classes=2, seed=seed) 45 | if arch == 'lift2d_channels35_layers10': 46 | from architectures.mnist.lift2d import lift2d 47 | return lift2d(channels=35, layers=5, num_classes=2, seed=seed) 48 | if arch == 'lift2d_channels35_layers20': 49 | from architectures.mnist.lift2d import lift2d 50 | return lift2d(channels=35, layers=20, num_classes=2, seed=seed) 51 | if arch == 'lift2d_channels55_layers5': 52 | from architectures.mnist.lift2d import lift2d 53 | return lift2d(channels=55, layers=5, num_classes=2, seed=seed) 54 | if arch == 'lift2d_channels55_layers10': 55 | from architectures.mnist.lift2d import lift2d 56 | return lift2d(channels=55, layers=5, num_classes=2, seed=seed) 57 | if arch == 'lift2d_channels55_layers20': 58 | from architectures.mnist.lift2d import lift2d 59 | return lift2d(channels=55, layers=20, num_classes=2, seed=seed) 60 | if data == 'mnistvector': 61 | if arch == 'layers50': 62 | from architectures.mnist.flatten import flatten 63 | return flatten(layers=50, num_classes=10) 64 | if arch == 'layers20': 65 | from architectures.mnist.flatten import flatten 66 | return flatten(layers=20, num_classes=10) 67 | if arch == 'layers10': 68 | from architectures.mnist.flatten import flatten 69 | return flatten(layers=10, num_classes=10) 70 | if arch == 'layers5': 71 | from architectures.mnist.flatten import flatten 72 | return flatten(layers=5, num_classes=10) 73 | if data == 'mnistvector_2class': 74 | if arch == 'layers50': 75 | from architectures.mnist.flatten import flatten 76 | return flatten(layers=50, num_classes=2) 77 | if arch == 'layers20': 78 | from architectures.mnist.flatten import flatten 79 | return flatten(layers=20, num_classes=2) 80 | if arch == 'layers10': 81 | from architectures.mnist.flatten import flatten 82 | return flatten(layers=10, num_classes=2) 83 | if arch == 'layers5': 84 | from architectures.mnist.flatten import flatten 85 | return flatten(layers=5, num_classes=2) 86 | raise NameError('Cannot find architecture: {}.') 87 | 88 | def load_dataset(choice, data_dir='./data/'): 89 | if choice == 'mnist2d': 90 | from datasets.mnist import mnist2d_10class 91 | return mnist2d_10class(data_dir) 92 | if choice == 'mnist2d_2class': 93 | from datasets.mnist import mnist2d_2class 94 | return mnist2d_2class(data_dir) 95 | if choice =='mnistvector': 96 | from datasets.mnist import mnistvector_10class 97 | return mnistvector_10class(data_dir) 98 | raise NameError(f'Dataset {choice} not found.') 99 | 100 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from opt_einsum import contract 4 | 5 | 6 | class MaximalCodingRateReduction(nn.Module): #TODO: fix this 7 | def __init__(self, eps=0.1, gam=1.): 8 | super(MaximalCodingRateReduction, self).__init__() 9 | self.eps = eps 10 | self.gam = gam 11 | 12 | def discrimn_loss(self, Z): 13 | m, d = Z.shape 14 | I = torch.eye(d).to(Z.device) 15 | c = d / (m * self.eps) 16 | return logdet(c * covariance(Z) + I) / 2. 17 | 18 | def compress_loss(self, Z, Pi): 19 | loss_comp = 0. 20 | for j in y.unique(): 21 | Z_j = Z[(y == int(j))[:, 0]] 22 | m_j = Z_j.shape[0] 23 | c_j = d / (m_j * eps) 24 | logdet_j = logdet(I + c_j * Z_j.T @ Z_j) 25 | loss_comp += logdet_j * m_j / (2 * m) 26 | return loss_comp 27 | 28 | def forward(self, Z, y): 29 | Pi = y # TODO: change this to prob distribution 30 | loss_discrimn = self.discrimn_loss(Z) 31 | loss_compress = self.compress_loss(Z, Pi) 32 | return loss_discrimn - self.gam * loss_compress 33 | 34 | 35 | def compute_mcr2(Z, y, eps): 36 | if len(Z.shape) == 2: 37 | loss_func = compute_loss_vec 38 | elif len(Z.shape) == 3: 39 | loss_func = compute_loss_1d 40 | elif len(Z.shape) == 4: 41 | loss_func = compute_loss_2d 42 | return loss_func(Z, y, eps) 43 | 44 | def compute_loss_vec(Z, y, eps): 45 | m, d = Z.shape 46 | I = torch.eye(d).to(Z.device) 47 | c = d / (m * eps) 48 | loss_expd = logdet(c * covariance(Z) + I) / 2. 49 | loss_comp = 0. 50 | for j in y.unique(): 51 | Z_j = Z[(y == int(j))[:, 0]] 52 | m_j = Z_j.shape[0] 53 | c_j = d / (m_j * eps) 54 | logdet_j = logdet(I + c_j * Z_j.T @ Z_j) 55 | loss_comp += logdet_j * m_j / (2 * m) 56 | loss_expd, loss_comp = loss_expd.item(), loss_comp.item() 57 | return loss_expd - loss_comp, loss_expd, loss_comp 58 | 59 | def compute_loss_1d(V, y, eps): 60 | m, C, T = V.shape 61 | I = torch.eye(C).unsqueeze(-1).to(V.device) 62 | alpha = C / (m * eps) 63 | cov = alpha * covariance(V) + I 64 | loss_expd = logdet(cov.permute(2, 0, 1)).sum() / (2 * T) 65 | loss_comp = 0. 66 | for j in y.unique(): 67 | V_j = V[y==int(j)] 68 | m_j = V_j.shape[0] 69 | alpha_j = C / (m_j * eps) 70 | cov_j = alpha_j * covariance(V_j) + I 71 | loss_comp += m_j / m * logdet(cov_j.permute(2, 0, 1)).sum() / (2 * T) 72 | loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item() 73 | return loss_expd - loss_comp, loss_expd, loss_comp 74 | 75 | def compute_loss_2d(V, y, eps): 76 | m, C, H, W = V.shape 77 | I = torch.eye(C).unsqueeze(-1).unsqueeze(-1).to(V.device) 78 | alpha = C / (m * eps) 79 | cov = alpha * covariance(V) + I 80 | loss_expd = logdet(cov.permute(2, 3, 0, 1)).sum() / (2 * H * W) 81 | loss_comp = 0. 82 | for j in y.unique(): 83 | V_j = V[(y==int(j))[:, 0]] 84 | m_j = V_j.shape[0] 85 | alpha_j = C / (m_j * eps) 86 | cov_j = alpha_j * covariance(V_j) + I 87 | loss_comp += m_j / m * logdet(cov_j.permute(2, 3, 0, 1)).sum() / (2 * H * W) 88 | loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item() 89 | return loss_expd - loss_comp, loss_expd, loss_comp 90 | 91 | def covariance(X): 92 | return contract('ji...,jk...->ik...', X, X.conj()) 93 | 94 | def logdet(X): 95 | sgn, logdet = torch.linalg.slogdet(X) 96 | return sgn * logdet -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from torch.utils.data import DataLoader 8 | 9 | import functional as F 10 | import utils 11 | 12 | def plot_loss_mcr(model_dir, name): 13 | file_dir = os.path.join(model_dir, 'loss', f'{name}.csv') 14 | data = pd.read_csv(file_dir) 15 | loss_total = data['loss_total'].ravel() 16 | loss_expd = data['loss_expd'].ravel() 17 | loss_comp = data['loss_comp'].ravel() 18 | num_iter = np.arange(len(loss_total)) 19 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True) 20 | ax.plot(num_iter, loss_total, label=r'$\Delta R$', 21 | color='green', linewidth=1.0, alpha=0.8) 22 | ax.plot(num_iter, loss_expd, label=r'$R$', 23 | color='royalblue', linewidth=1.0, alpha=0.8) 24 | ax.plot(num_iter, loss_comp, label=r'$R^c$', 25 | color='coral', linewidth=1.0, alpha=0.8) 26 | ax.set_ylabel('Loss', fontsize=10) 27 | ax.set_xlabel('Number of iterations', fontsize=10) 28 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5) 29 | fig.tight_layout() 30 | 31 | loss_dir = os.path.join(model_dir, 'figures', 'loss_mcr') 32 | os.makedirs(loss_dir, exist_ok=True) 33 | file_name = os.path.join(loss_dir, f'{name}.png') 34 | plt.savefig(file_name, dpi=400) 35 | plt.close() 36 | print("Plot saved to: {}".format(file_name)) 37 | 38 | def plot_loss(model_dir): 39 | """Plot cross entropy loss. """ 40 | ## extract loss from csv 41 | file_dir = os.path.join(model_dir, 'losses.csv') 42 | data = pd.read_csv(file_dir) 43 | epochs = data['epoch'].ravel() 44 | loss = data['loss'].ravel() 45 | 46 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True, dpi=400) 47 | ax.plot(epochs, loss, #label=r'Loss', 48 | color='green', linewidth=1.0, alpha=0.8) 49 | ax.set_ylabel('Loss', fontsize=10) 50 | ax.set_xlabel('Number of iterations', fontsize=10) 51 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5) 52 | ax.set_title("Loss") 53 | ax.spines['top'].set_visible(False) 54 | ax.spines['right'].set_visible(False) 55 | plt.tight_layout() 56 | 57 | ## create saving directory 58 | loss_dir = os.path.join(model_dir, 'figures', 'loss') 59 | os.makedirs(loss_dir, exist_ok=True) 60 | file_name = os.path.join(loss_dir, 'loss.png') 61 | plt.savefig(file_name, dpi=400) 62 | print("Plot saved to: {}".format(file_name)) 63 | file_name = os.path.join(loss_dir, 'loss.pdf') 64 | plt.savefig(file_name, dpi=400) 65 | plt.close() 66 | print("Plot saved to: {}".format(file_name)) 67 | 68 | def plot_csv(model_dir, filename): 69 | df = pd.read_csv(os.path.join(model_dir, f'{filename}.csv')) 70 | colnames = df.columns 71 | 72 | fig, ax = plt.subplots(1, 1, figsize=(7, 5)) 73 | for colname in colnames[1:]: 74 | ax.plot(df[colnames[0]], df[colname], marker='x', label=colname) 75 | ax.set_xlabel(colnames[0]) 76 | ax.set_ylabel(filename) 77 | ax.legend() 78 | 79 | csv_dir = os.path.join(model_dir, 'figures', 'csv') 80 | os.makedirs(csv_dir, exist_ok=True) 81 | savepath = os.path.join(csv_dir, f'{filename}.png') 82 | fig.savefig(savepath) 83 | print('Plot saved to: {}'.format(savepath)) 84 | 85 | def plot_loss_ce(model_dir, filename='loss_ce'): 86 | df = pd.read_csv(os.path.join(model_dir, f'{filename}.csv')) 87 | colnames = df.columns 88 | 89 | fig, ax = plt.subplots(1, 1, figsize=(7, 5)) 90 | for colname in colnames[1:]: 91 | ax.plot(np.arange(df.shape[0]), df['loss_ce'], label=colname) 92 | ax.set_xlabel(colnames[0]) 93 | ax.set_ylabel(filename) 94 | ax.legend() 95 | 96 | csv_dir = os.path.join(model_dir, 'figures', 'loss_ce') 97 | os.makedirs(csv_dir, exist_ok=True) 98 | savepath = os.path.join(csv_dir, f'loss_ce.png') 99 | fig.savefig(savepath) 100 | print('Plot saved to: {}'.format(savepath)) 101 | 102 | 103 | def plot_acc(model_dir): 104 | """Plot training and testing accuracy""" 105 | ## extract loss from csv 106 | file_dir = os.path.join(model_dir, 'acc.csv') 107 | data = pd.read_csv(file_dir) 108 | epochs = data['epoch'].ravel() 109 | acc_train = data['acc_train'].ravel() 110 | acc_test = data['acc_test'].ravel() 111 | # epoch,acc_train,acc_test 112 | 113 | ## Theoretical Loss 114 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True, dpi=400) 115 | ax.plot(epochs, acc_train, label='train', color='green', alpha=0.8) 116 | ax.plot(epochs, acc_test, label='test', color='red', alpha=0.8) 117 | ax.set_ylabel('Accuracy', fontsize=10) 118 | ax.set_xlabel('Epoch', fontsize=10) 119 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5) 120 | ax.spines['top'].set_visible(False) 121 | ax.spines['right'].set_visible(False) 122 | plt.tight_layout() 123 | 124 | ## create saving directory 125 | acc_dir = os.path.join(model_dir, 'figures', 'acc') 126 | os.makedirs(acc_dir, exist_ok=True) 127 | file_name = os.path.join(acc_dir, 'accuracy.png') 128 | plt.savefig(file_name, dpi=400) 129 | print("Plot saved to: {}".format(file_name)) 130 | file_name = os.path.join(acc_dir, 'accuracy.pdf') 131 | plt.savefig(file_name, dpi=400) 132 | plt.close() 133 | print("Plot saved to: {}".format(file_name)) 134 | 135 | def plot_heatmap(model_dir, name, features, labels, num_classes): 136 | """Plot heatmap of cosine simliarity for all features. """ 137 | features_sort, _ = utils.sort_dataset(features, labels, 138 | classes=num_classes, stack=False) 139 | features_sort_ = np.vstack(features_sort) 140 | sim_mat = np.abs(features_sort_ @ features_sort_.T) 141 | 142 | # plt.rc('text', usetex=False) 143 | # plt.rcParams['font.family'] = 'serif' 144 | # plt.rcParams['font.serif'] = ['Times New Roman'] #+ plt.rcParams['font.serif'] 145 | 146 | fig, ax = plt.subplots(figsize=(7, 5), sharey=True, sharex=True) 147 | im = ax.imshow(sim_mat, cmap='Blues') 148 | fig.colorbar(im, pad=0.02, drawedges=0, ticks=[0, 0.5, 1]) 149 | ax.set_xticks(np.linspace(0, len(labels), num_classes+1)) 150 | ax.set_yticks(np.linspace(0, len(labels), num_classes+1)) 151 | [tick.label.set_fontsize(10) for tick in ax.xaxis.get_major_ticks()] 152 | [tick.label.set_fontsize(10) for tick in ax.yaxis.get_major_ticks()] 153 | fig.tight_layout() 154 | 155 | save_dir = os.path.join(model_dir, 'figures', 'heatmaps') 156 | os.makedirs(save_dir, exist_ok=True) 157 | file_name = os.path.join(save_dir, f"{name}.png") 158 | fig.savefig(file_name) 159 | print("Plot saved to: {}".format(file_name)) 160 | plt.close() 161 | 162 | def plot_transform(model_dir, inputs, outputs, name): 163 | fig, ax = plt.subplots(ncols=2) 164 | inputs = inputs.permute(1, 2, 0) 165 | outputs = outputs.permute(1, 2, 0) 166 | outputs = (outputs - outputs.min()) / (outputs.max() - outputs.min()) 167 | ax[0].imshow(inputs) 168 | ax[0].set_title('inputs') 169 | ax[1].imshow(outputs) 170 | ax[1].set_title('outputs') 171 | save_dir = os.path.join(model_dir, 'figures', 'images') 172 | os.makedirs(save_dir, exist_ok=True) 173 | file_name = os.path.join(save_dir, f'{name}.png') 174 | fig.savefig(file_name) 175 | print("Plot saved to: {}".format(file_name)) 176 | plt.close() 177 | 178 | def plot_channel_image(model_dir, features, name): 179 | def normalize(x): 180 | out = x - x.min() 181 | out = out / (out.max() - out.min()) 182 | return out 183 | fig, ax = plt.subplots() 184 | ax.imshow(normalize(features), cmap='gray') 185 | save_dir = os.path.join(model_dir, 'figures', 'images') 186 | os.makedirs(save_dir, exist_ok=True) 187 | file_name = os.path.join(save_dir, f'{name}.png') 188 | fig.savefig(file_name) 189 | print("Plot saved to: {}".format(file_name)) 190 | plt.close() 191 | 192 | 193 | def plot_nearest_image(model_dir, image, nearest_images, values, name, grid_size=(4, 4)): 194 | fig, ax = plt.subplots(*grid_size, figsize=(10, 10)) 195 | idx = 1 196 | for i in range(grid_size[0]): 197 | for j in range(grid_size[1]): 198 | if i == 0 and j == 0: 199 | ax[i, j].imshow(image) 200 | else: 201 | ax[i, j].set_title(values[idx-1]) 202 | ax[i, j].imshow(nearest_images[idx-1]) 203 | idx += 1 204 | ax[i, j].set_xticks([]) 205 | ax[i, j].set_yticks([]) 206 | plt.setp(ax[0, 0].spines.values(), color='red', linewidth=2) 207 | fig.tight_layout() 208 | 209 | save_dir = os.path.join(model_dir, 'figures', 'nearest_image') 210 | os.makedirs(save_dir, exist_ok=True) 211 | save_path = os.path.join(save_dir, f'{name}.png') 212 | fig.savefig(save_path) 213 | print(f"Plot saved to: {save_path}") 214 | plt.close() 215 | 216 | 217 | def plot_image(model_dir, image, name): 218 | fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 219 | if image.shape[2] == 1: 220 | ax.imshow(image, cmap='gray') 221 | else: 222 | ax.imshow(image) 223 | ax.set_xticks([]) 224 | ax.set_yticks([]) 225 | fig.tight_layout() 226 | 227 | save_dir = os.path.join(model_dir, 'figures', 'image') 228 | os.makedirs(save_dir, exist_ok=True) 229 | save_path = os.path.join(save_dir, f'{name}.png') 230 | fig.savefig(save_path) 231 | print(f"Plot saved to: {save_path}") 232 | plt.close() 233 | 234 | def save_image(image, save_path): 235 | fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 236 | if image.shape[2] == 1: 237 | ax.imshow(image, cmap='gray') 238 | else: 239 | ax.imshow(image) 240 | ax.set_xticks([]) 241 | ax.set_yticks([]) 242 | fig.tight_layout() 243 | fig.savefig(save_path) 244 | print(f"Plot saved to: {save_path}") 245 | plt.close() -------------------------------------------------------------------------------- /redunet/__init__.py: -------------------------------------------------------------------------------- 1 | # Layers 2 | from .layers.redulayer import ReduLayer 3 | from .layers.vector import Vector 4 | from .layers.fourier1d import Fourier1D 5 | from .layers.fourier2d import Fourier2D 6 | 7 | # Modules 8 | from .modules import ( 9 | ReduNetVector, 10 | ReduNet1D, 11 | ReduNet2D 12 | ) 13 | 14 | # Projections 15 | from .projections.lift import Lift1D 16 | from .projections.lift import Lift2D 17 | 18 | 19 | # Others 20 | from .redunet import ReduNet 21 | from .multichannel_weight import MultichannelWeight 22 | 23 | 24 | __all__ = [ 25 | 'Fourier1D', 26 | 'Fourier2D', 27 | 'Lift1D', 28 | 'Lift2D', 29 | 'MultichannelWeight', 30 | 'ReduNet', 31 | 'ReduLayer', 32 | 'ReduNetVector', 33 | 'ReduNet1D', 34 | 'ReduNet2D', 35 | 'Vector' 36 | ] -------------------------------------------------------------------------------- /redunet/layers/fourier1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as tF 4 | from torch.fft import fft, ifft 5 | from opt_einsum import contract 6 | import numpy as np 7 | 8 | import functional as F 9 | from ..multichannel_weight import MultichannelWeight 10 | from .vector import Vector 11 | 12 | 13 | 14 | class Fourier1D(Vector): 15 | def __init__(self, eta, eps, lmbda, num_classes, dimensions): 16 | super(Fourier1D, self).__init__(eta, eps, lmbda, num_classes) 17 | assert len(dimensions) == 2, 'dimensions should have tensor dim = 2' 18 | self.channels, self.timesteps = dimensions 19 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes, requires_grad=False) 20 | self.E = MultichannelWeight(self.channels, self.timesteps, dtype=torch.complex64) 21 | self.Cs = nn.ModuleList([MultichannelWeight(self.channels, self.timesteps, dtype=torch.complex64) 22 | for _ in range(num_classes)]) 23 | 24 | def nonlinear(self, Bz): 25 | norm = torch.linalg.norm(Bz.reshape(Bz.shape[0], Bz.shape[1], -1), axis=2) 26 | norm = torch.clamp(norm, min=1e-8) 27 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2) 28 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case 29 | gam = self.gam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 30 | pred = pred.unsqueeze(-1) 31 | out = torch.sum(gam * Bz * pred, axis=0) 32 | return out, y 33 | 34 | def compute_E(self, V): 35 | m, C, T = V.shape 36 | alpha = C / (m * self.eps) 37 | I = torch.eye(C, device=V.device).unsqueeze(-1) 38 | pre_inv = I + alpha * contract('ji...,jk...->ik...', V, V.conj()) 39 | E = torch.empty_like(pre_inv) 40 | for t in range(T): 41 | E[:, :, t] = alpha * torch.inverse(pre_inv[:, :, t]) 42 | return E 43 | 44 | def compute_Cs(self, V, y): 45 | m, C, T = V.shape 46 | I = torch.eye(C, device=V.device).unsqueeze(-1) 47 | Cs = torch.empty((self.num_classes, C, C, T), dtype=torch.complex64) 48 | for j in range(self.num_classes): 49 | V_j = V[(y == int(j))] 50 | m_j = V_j.shape[0] 51 | alpha_j = C / (m_j * self.eps) 52 | pre_inv = I + alpha_j * contract('ji...,jk...->ik...', V_j, V_j.conj()) 53 | for t in range(T): 54 | Cs[j, :, :, t] = alpha_j * torch.inverse(pre_inv[:, :, t]) 55 | return Cs 56 | 57 | def preprocess(self, X): 58 | Z = F.normalize(X) 59 | return fft(Z, norm='ortho', dim=2) 60 | 61 | def postprocess(self, X): 62 | Z = ifft(X, norm='ortho', dim=2) 63 | return F.normalize(Z).real 64 | -------------------------------------------------------------------------------- /redunet/layers/fourier2d.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as tF 5 | from itertools import product 6 | from torch.fft import fft2, ifft2 7 | import numpy as np 8 | from opt_einsum import contract 9 | 10 | import functional as F 11 | from ..multichannel_weight import MultichannelWeight 12 | from .vector import Vector 13 | 14 | 15 | 16 | class Fourier2D(Vector): 17 | def __init__(self, eta, eps, lmbda, num_classes, dimensions): 18 | super(Fourier2D, self).__init__(eta, eps, lmbda, num_classes) 19 | assert len(dimensions) == 3, 'dimensions should have tensor dim = 3' 20 | self.channels, self.height, self.width = dimensions 21 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes, requires_grad=False) 22 | self.E = MultichannelWeight(*dimensions, dtype=torch.complex64) 23 | self.Cs = nn.ModuleList([MultichannelWeight(*dimensions, dtype=torch.complex64) 24 | for _ in range(num_classes)]) 25 | 26 | def nonlinear(self, Cz): 27 | norm = torch.linalg.norm(Cz.flatten(start_dim=2), axis=2).clamp(min=1e-8) 28 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2) 29 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case 30 | gam = self.gam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 31 | pred = pred.unsqueeze(-1).unsqueeze(-1) 32 | out = torch.sum(gam * Cz * pred, axis=0) 33 | return out, y 34 | 35 | def compute_E(self, V): 36 | m, C, H, W = V.shape 37 | alpha = C / (m * self.eps) 38 | I = torch.eye(C, device=V.device).unsqueeze(-1).unsqueeze(-1) 39 | pre_inv = I + alpha * contract('ji...,jk...->ik...', V, V.conj()) 40 | E = torch.empty_like(pre_inv, dtype=torch.complex64) 41 | for h, w in product(range(H), range(W)): 42 | E[:, :, h, w] = alpha * torch.inverse(pre_inv[:, :, h, w]) 43 | return E 44 | 45 | def compute_Cs(self, V, y): 46 | m, C, H, W = V.shape 47 | I = torch.eye(C, device=V.device).unsqueeze(-1).unsqueeze(-1) 48 | Cs = torch.empty((self.num_classes, C, C, H, W), dtype=torch.complex64) 49 | for j in range(self.num_classes): 50 | V_j = V[(y == int(j))] 51 | m_j = V_j.shape[0] 52 | alpha_j = C / (m_j * self.eps) 53 | pre_inv = I + alpha_j * contract('ji...,jk...->ik...', V_j, V_j.conj()) 54 | for h, w in product(range(H), range(W)): 55 | Cs[j, :, :, h, w] = alpha_j * torch.inverse(pre_inv[:, :, h, w]) 56 | return Cs 57 | 58 | def preprocess(self, X): 59 | Z = F.normalize(X) 60 | return fft2(Z, norm='ortho', dim=(2, 3)) 61 | 62 | def postprocess(self, X): 63 | Z = ifft2(X, norm='ortho', dim=(2, 3)) 64 | return F.normalize(Z).real 65 | -------------------------------------------------------------------------------- /redunet/layers/redulayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class ReduLayer(nn.Module): 7 | def __init__(self): 8 | super(ReduLayer, self).__init__() 9 | 10 | def __name__(self): 11 | return "ReduNet" 12 | 13 | def forward(self, Z): 14 | raise NotImplementedError 15 | 16 | def zero(self): 17 | state_dict = self.state_dict() 18 | state_dict['E.weight'] = torch.zeros_like(self.E.weight) 19 | for j in range(self.num_classes): 20 | state_dict[f'Cs.{j}.weight'] = torch.zeros_like(self.Cs[j].weight) 21 | self.load_state_dict(state_dict) 22 | 23 | def init(self, X, y): 24 | gam = self.compute_gam(X, y) 25 | E = self.compute_E(X) 26 | Cs = self.compute_Cs(X, y) 27 | self.set_params(E, Cs, gam) 28 | 29 | def update_old(self, X, y, tau): 30 | E = self.compute_E(X).to(X.device) 31 | Cs = self.compute_Cs(X, y).to(X.device) 32 | state_dict = self.state_dict() 33 | ref_E = self.E.weight 34 | ref_Cs = [self.Cs[j].weight for j in range(self.num_classes)] 35 | new_E = ref_E + tau * (E - ref_E) 36 | new_Cs = [ref_Cs[j] + tau * (Cs[j] - ref_Cs[j]) for j in range(self.num_classes)] 37 | state_dict['E.weight'] = new_E 38 | for j in range(self.num_classes): 39 | state_dict[f'Cs.{j}.weight'] = new_Cs[j] 40 | self.load_state_dict(state_dict) 41 | 42 | def update(self, X, y, tau): 43 | E_ref, Cs_ref = self.get_params() 44 | # gam = self.init_gam(X, y) 45 | E_new = self.compute_E(X).to(X.device) 46 | Cs_new = self.compute_Cs(X, y).to(X.device) 47 | E_update = E_ref + tau * (E_new - E_ref) 48 | Cs_update = [Cs_ref[j] + tau * (Cs_new[j] - Cs_ref[j]) for j in range(self.num_classes)] 49 | self.set_params(E_update, Cs_update) 50 | 51 | def set_params(self, E, Cs, gam=None): 52 | state_dict = self.state_dict() 53 | assert self.E.weight.shape == E.shape, f'E shape does not match: {self.E.weight.shape} and {E.shape}' 54 | state_dict['E.weight'] = E 55 | for j in range(self.num_classes): 56 | assert self.Cs[j].weight.shape == Cs[j].shape, f'Cj shape does not match' 57 | state_dict[f'Cs.{j}.weight'] = Cs[j] 58 | if gam is not None: 59 | assert self.gam.shape == gam.shape, 'gam shape does not match' 60 | state_dict['gam'] = gam 61 | self.load_state_dict(state_dict) 62 | 63 | def get_params(self): 64 | E = self.E.weight 65 | Cs = [self.Cs[j].weight for j in range(self.num_classes)] 66 | return E, Cs 67 | 68 | -------------------------------------------------------------------------------- /redunet/layers/vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as tF 4 | 5 | import functional as F 6 | from .redulayer import ReduLayer 7 | 8 | 9 | class Vector(ReduLayer): 10 | def __init__(self, eta, eps, lmbda, num_classes, dimensions=None): 11 | super(Vector, self).__init__() 12 | self.eta = eta 13 | self.eps = eps 14 | self.lmbda = lmbda 15 | self.num_classes = num_classes 16 | self.d = dimensions 17 | 18 | if self.d is not None: #NOTE: initilaized in child objectss 19 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes) 20 | self.E = nn.Linear(self.d, self.d, bias=False) 21 | self.Cs = nn.ModuleList([nn.Linear(self.d, self.d, bias=False) for _ in range(num_classes)]) 22 | 23 | def forward(self, Z, return_y=False): 24 | expd = self.E(Z) 25 | comp = torch.stack([C(Z) for C in self.Cs]) 26 | clus, y_approx = self.nonlinear(comp) 27 | Z = Z + self.eta * (expd - clus) 28 | Z = F.normalize(Z) 29 | if return_y: 30 | return Z, y_approx 31 | return Z 32 | 33 | def nonlinear(self, Cz): 34 | norm = torch.linalg.norm(Cz.reshape(Cz.shape[0], Cz.shape[1], -1), axis=2) 35 | norm = torch.clamp(norm, min=1e-8) 36 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2) 37 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case 38 | gam = self.gam.unsqueeze(-1).unsqueeze(-1) 39 | out = torch.sum(gam * Cz * pred, axis=0) 40 | return out, y 41 | 42 | def compute_gam(self, X, y): 43 | m = X.shape[0] 44 | m_j = [torch.nonzero(y==j).size()[0] for j in range(self.num_classes)] 45 | gam = (torch.tensor(m_j).float() / m).flatten() 46 | return gam 47 | 48 | def compute_E(self, X): 49 | m, d = X.shape 50 | Z = X.T 51 | I = torch.eye(d, device=X.device) 52 | c = d / (m * self.eps) 53 | E = c * torch.inverse(I + c * Z @ Z.T) 54 | return E 55 | 56 | def compute_Cs(self, X, y): 57 | m, d = X.shape 58 | Z = X.T 59 | I = torch.eye(d, device=X.device) 60 | Cs = torch.zeros((self.num_classes, d, d)) 61 | for j in range(self.num_classes): 62 | idx = (y == int(j)) 63 | Z_j = Z[:, idx] 64 | m_j = Z_j.shape[1] 65 | c_j = d / (m_j * self.eps) 66 | Cs[j] = c_j * torch.inverse(I + c_j * Z_j @ Z_j.T) 67 | return Cs 68 | 69 | def preprocess(self, X): 70 | return F.normalize(X) 71 | 72 | def postprocess(self, X): 73 | return F.normalize(X) 74 | 75 | -------------------------------------------------------------------------------- /redunet/modules.py: -------------------------------------------------------------------------------- 1 | from .layers.vector import Vector 2 | from .layers.fourier1d import Fourier1D 3 | from .layers.fourier2d import Fourier2D 4 | from .redunet import ReduNet 5 | 6 | 7 | 8 | 9 | def ReduNetVector(num_classes, num_layers, d, eta, eps, lmbda): 10 | redunet = ReduNet( 11 | *[Vector(eta, eps, lmbda, num_classes, d) for _ in range(num_layers)] 12 | ) 13 | return redunet 14 | 15 | def ReduNet1D(num_classes, num_layers, channels, timesteps, eta, eps, lmbda): 16 | redunet = ReduNet( 17 | *[Fourier1D(eta, eps, lmbda, num_classes, (channels, timesteps)) for _ in range(num_layers)] 18 | ) 19 | return redunet 20 | 21 | def ReduNet2D(num_classes, num_layers, channels, height, width, eta, eps, lmbda): 22 | redunet = ReduNet( 23 | *[Fourier2D(eta, eps, lmbda, num_classes, (channels, height, width)) for _ in range(num_layers)] 24 | ) 25 | return redunet -------------------------------------------------------------------------------- /redunet/multichannel_weight.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | from opt_einsum import contract 5 | 6 | 7 | 8 | # A Weight matrix with multichannel capabilities. Inputs that multiply this weight 9 | # should have channel dimension at the end. There is no limit to the number of channels 10 | class MultichannelWeight(nn.Module): 11 | def __init__(self, channels, *dimension, dtype=torch.complex64): 12 | super(MultichannelWeight, self).__init__() 13 | self.weight = nn.Parameter(torch.randn(channels, channels, *dimension, dtype=dtype)) 14 | self.shape = self.weight.shape 15 | self.dtype = dtype 16 | 17 | def __getitem__(self, item): 18 | return self.weight[item] 19 | 20 | def forward(self, V): 21 | return contract("bi...,ih...->bh...", V.type(self.dtype), self.weight.conj()) 22 | -------------------------------------------------------------------------------- /redunet/projections/lift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | 6 | 7 | 8 | class Lift(nn.Module): 9 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0): 10 | super(Lift, self).__init__() 11 | self.in_channel = in_channel 12 | self.out_channel = out_channel 13 | self.kernel_size = kernel_size 14 | self.init_mode = init_mode 15 | self.stride = stride 16 | self.trainable = trainable 17 | self.relu = relu 18 | self.seed = seed 19 | 20 | def set_weight(self, init_mode, size, trainable): 21 | torch.manual_seed(self.seed) 22 | if init_mode == 'gaussian0.1': 23 | p = distributions.normal.Normal(loc=0, scale=0.1) 24 | elif init_mode == 'gaussian1.0': 25 | p = distributions.normal.Normal(loc=0, scale=1.) 26 | elif init_mode == 'gaussian5.0': 27 | p = distributions.normal.Normal(loc=0, scale=5.) 28 | elif init_mode == 'uniform0.1': 29 | p = distributions.uniform.Uniform(-0.1, 0.1) 30 | elif init_mode == 'uniform0.5': 31 | p = distributions.uniform.Uniform(-0.5, 0.5) 32 | else: 33 | raise NameError(f'No such kernel: {init_mode}') 34 | kernel = p.sample(size).type(torch.float) 35 | self.kernel = nn.Parameter(kernel, requires_grad=trainable) 36 | 37 | 38 | class Lift1D(Lift): 39 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0): 40 | super(Lift1D, self).__init__(in_channel, out_channel, kernel_size, init_mode, stride, trainable, relu, seed) 41 | self.size = (out_channel, in_channel, kernel_size) 42 | self.set_weight(init_mode, self.size, trainable) 43 | 44 | def forward(self, Z): 45 | Z = F.pad(Z, (0, self.kernel_size-1), 'circular') 46 | out = F.conv1d(Z, self.kernel, stride=self.stride) 47 | if self.relu: 48 | return F.relu(out) 49 | return out 50 | 51 | 52 | class Lift2D(Lift): 53 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0): 54 | super(Lift2D, self).__init__(in_channel, out_channel, kernel_size, init_mode, stride, trainable, relu, seed) 55 | self.size = (out_channel, in_channel, kernel_size, kernel_size) 56 | self.set_weight(init_mode, self.size, trainable) 57 | 58 | def forward(self, Z): 59 | kernel = self.kernel.to(Z.device) 60 | Z = F.pad(Z, (0, self.kernel_size-1, 0, self.kernel_size-1), 'circular') 61 | out = F.conv2d(Z, kernel, stride=self.stride) 62 | if self.relu: 63 | return F.relu(out) 64 | return out 65 | -------------------------------------------------------------------------------- /redunet/redunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss import compute_mcr2 4 | from .layers.redulayer import ReduLayer 5 | 6 | 7 | 8 | class ReduNet(nn.Sequential): 9 | # ReduNet Architecture. This class inherited from nn.Seqeuential class, 10 | # hence can be used for stacking layers of torch.nn.Modules. 11 | 12 | def __init__(self, *modules): 13 | super(ReduNet, self).__init__(*modules) 14 | self._init_loss() 15 | 16 | def init(self, inputs, labels): 17 | # Initialize the network. Using inputs and labels, it constructs 18 | # the parameters E and Cs throughout each ReduLayer. 19 | with torch.no_grad(): 20 | return self.forward(inputs, 21 | labels, 22 | init=True, 23 | loss=True) 24 | 25 | def update(self, inputs, labels, tau=0.1): 26 | # Update the network parameters E and Cs by 27 | # performing a moving average. 28 | with torch.no_grad(): 29 | return self.forward(inputs, 30 | labels, 31 | tau=tau, 32 | update=True, 33 | loss=True) 34 | 35 | def zero(self): 36 | # Set every network parameters E and Cs to a zero matrix. 37 | with torch.no_grad(): 38 | for module in self: 39 | if isinstance(module, ReduLayer): 40 | module.zero() 41 | return self 42 | 43 | def batch_forward(self, inputs, batch_size=1000, loss=False, cuda=True, device=None): 44 | # Perform forward pass in batches. 45 | outputs = [] 46 | for i in range(0, inputs.shape[0], batch_size): 47 | print('batch:', i, end='\r') 48 | batch_inputs = inputs[i:i+batch_size] 49 | if device is not None: 50 | batch_inputs = batch_inputs.to(device) 51 | elif cuda: 52 | batch_inputs = batch_inputs.cuda() 53 | batch_outputs = self.forward(batch_inputs, loss=loss) 54 | outputs.append(batch_outputs.cpu()) 55 | return torch.cat(outputs) 56 | 57 | def forward(self, 58 | inputs, 59 | labels=None, 60 | tau=0.1, 61 | init=False, 62 | update=False, 63 | loss=False): 64 | 65 | self._init_loss() 66 | self._inReduBlock = False 67 | 68 | for layer_i, module in enumerate(self): 69 | # preprocess for redunet layers 70 | if self._isEnterReduBlock(layer_i, module): 71 | inputs = module.preprocess(inputs) 72 | self._inReduBlock = True 73 | 74 | # If init is set to True, then initialize 75 | # layer using inputs and labels 76 | if init and self._isReduLayer(module): 77 | module.init(inputs, labels) 78 | 79 | # If update is set to True, then initialize 80 | # layer using inputs and labels 81 | if update and self._isReduLayer(module): 82 | module.update(inputs, labels, tau) 83 | 84 | # Perform a forward pass 85 | if self._isReduLayer(module): 86 | inputs, preds = module(inputs, return_y=True) 87 | else: 88 | inputs = module(inputs) 89 | 90 | # compute loss for redunet layer 91 | if loss and isinstance(module, ReduLayer): 92 | losses = compute_mcr2(inputs, preds, module.eps) 93 | self._append_loss(layer_i, *losses) 94 | 95 | # postprocess for redunet layers 96 | if self._isExitReduBlock(layer_i, module): 97 | inputs = module.postprocess(inputs) 98 | self._inReduBlock = False 99 | return inputs 100 | 101 | 102 | def get_loss(self): 103 | return self.losses 104 | 105 | def _init_loss(self): 106 | self.losses = {'layer': [], 'loss_total':[], 'loss_expd': [], 'loss_comp': []} 107 | 108 | def _append_loss(self, layer_i, loss_total, loss_expd, loss_comp): 109 | self.losses['layer'].append(layer_i) 110 | self.losses['loss_total'].append(loss_total) 111 | self.losses['loss_expd'].append(loss_expd) 112 | self.losses['loss_comp'].append(loss_comp) 113 | print(f"{layer_i} | {loss_total:.6f} {loss_expd:.6f} {loss_comp:.6f}") 114 | 115 | def _isReduLayer(self, module): 116 | return isinstance(module, ReduLayer) 117 | 118 | def _isEnterReduBlock(self, _, module): 119 | # my first encounter of ReduLayer 120 | if not self._inReduBlock and self._isReduLayer(module): 121 | return True 122 | return False 123 | 124 | def _isExitReduBlock(self, layer_i, _): 125 | # I am in ReduBlock and I am the last layer of the network 126 | if len(self) - 1 == layer_i and self._inReduBlock: \ 127 | return True 128 | # I am in ReduBlock and I am the last ReduLayer 129 | if self._inReduBlock and not self._isReduLayer(self[layer_i+1]): 130 | return True 131 | return False 132 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: osx-64 4 | appnope=0.1.0=py37_0 5 | backcall=0.2.0=py_0 6 | blas=1.0=mkl 7 | brotlipy=0.7.0=py37h9ed2024_1003 8 | bzip2=1.0.8=h1de35cc_0 9 | ca-certificates=2021.1.19=hecd8cb5_1 10 | certifi=2020.12.5=py37hecd8cb5_0 11 | cffi=1.14.4=py37h2125817_0 12 | chardet=4.0.0=py37hecd8cb5_1003 13 | cryptography=3.3.1=py37hbcfaee0_0 14 | cycler=0.10.0=py37_0 15 | decorator=4.4.2=py_0 16 | ffmpeg=4.2.2=h97e5cf8_0 17 | freetype=2.10.4=ha233b18_0 18 | gettext=0.19.8.1=hb0f4f8b_2 19 | gmp=6.1.2=hb37e062_1 20 | gnutls=3.6.5=h91ad68e_1002 21 | idna=2.10=pyhd3eb1b0_0 22 | intel-openmp=2019.4=233 23 | ipykernel=5.3.4=py37h5ca1d4c_0 24 | ipython=7.18.1=py37h5ca1d4c_0 25 | ipython_genutils=0.2.0=py37_0 26 | jedi=0.17.2=py37_0 27 | joblib=0.17.0=py_0 28 | jpeg=9b=he5867d9_2 29 | jupyter_client=6.1.7=py_0 30 | jupyter_core=4.6.3=py37_0 31 | kiwisolver=1.2.0=py37h04f5b5a_0 32 | lame=3.100=h1de35cc_0 33 | lcms2=2.11=h92f6f08_0 34 | libcxx=10.0.0=1 35 | libedit=3.1.20191231=h1de35cc_1 36 | libffi=3.3=hb1e8313_2 37 | libgfortran=3.0.1=h93005f0_2 38 | libiconv=1.16=h1de35cc_0 39 | libopus=1.3.1=h1de35cc_0 40 | libpng=1.6.37=ha441bb4_0 41 | libsodium=1.0.18=h1de35cc_0 42 | libtiff=4.1.0=hcb84e12_1 43 | libuv=1.40.0=haf1e3a3_0 44 | libvpx=1.7.0=h378b8a2_0 45 | llvm-openmp=10.0.0=h28b9765_0 46 | lz4-c=1.9.2=h79c402e_3 47 | matplotlib=3.3.2=0 48 | matplotlib-base=3.3.2=py37h181983e_0 49 | mkl=2019.4=233 50 | mkl-service=2.3.0=py37hfbe908c_0 51 | mkl_fft=1.2.0=py37hc64f4ea_0 52 | mkl_random=1.1.1=py37h959d312_0 53 | ncurses=6.2=h0a44026_1 54 | nettle=3.4.1=h3018a27_0 55 | ninja=1.10.1=py37h879752b_0 56 | numpy=1.19.1=py37h3b9f5b6_0 57 | numpy-base=1.19.1=py37hcfb5961_0 58 | olefile=0.46=py37_0 59 | opencv-python=4.4.0.44=pypi_0 60 | openh264=2.1.0=hd9629dc_0 61 | openssl=1.1.1k=h9ed2024_0 62 | opt_einsum=3.1.0=py_0 63 | pandas=1.1.3=py37hb1e8313_0 64 | parso=0.7.0=py_0 65 | pexpect=4.8.0=py37_0 66 | pickleshare=0.7.5=py37_0 67 | pillow=8.0.0=py37h1a82f1a_0 68 | pip=20.2.4=py37_0 69 | prompt-toolkit=3.0.8=py_0 70 | ptyprocess=0.6.0=py37_0 71 | pycparser=2.20=py_2 72 | pygments=2.7.1=py_0 73 | pyopenssl=20.0.1=pyhd3eb1b0_1 74 | pyparsing=2.4.7=py_0 75 | pysocks=1.7.1=py37hecd8cb5_0 76 | python=3.7.9=h26836e1_0 77 | python-dateutil=2.8.1=py_0 78 | pytorch=1.8.0=py3.7_0 79 | pytz=2020.1=py_0 80 | pyzmq=19.0.2=py37hb1e8313_1 81 | readline=8.0=h1de35cc_0 82 | requests=2.25.1=pyhd3eb1b0_0 83 | scikit-learn=0.23.2=py37h959d312_0 84 | scipy=1.5.2=py37h912ce22_0 85 | setuptools=50.3.0=py37h0dc7051_1 86 | six=1.15.0=py_0 87 | sqlite=3.33.0=hffcf06c_0 88 | threadpoolctl=2.1.0=pyh5ca1d4c_0 89 | tk=8.6.10=hb0a8c7a_0 90 | torchaudio=0.8.0=py37 91 | torchvision=0.9.0=py37_cpu 92 | tornado=6.0.4=py37h1de35cc_1 93 | tqdm=4.50.2=py_0 94 | traitlets=5.0.5=py_0 95 | typing_extensions=3.7.4.3=py_0 96 | urllib3=1.26.3=pyhd3eb1b0_0 97 | wcwidth=0.2.5=py_0 98 | wheel=0.35.1=py_0 99 | x264=1!157.20191217=h1de35cc_0 100 | xlrd=1.2.0=py37_0 101 | xz=5.2.5=h1de35cc_0 102 | zeromq=4.3.3=hb1e8313_3 103 | zlib=1.2.11=h1de35cc_3 104 | zstd=1.4.5=h41d2c2f_0 105 | -------------------------------------------------------------------------------- /train_forward.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | from redunet import * 9 | import evaluate 10 | import load as L 11 | import functional as F 12 | import utils 13 | import plot 14 | 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data', type=str, required=True, help='choice of dataset') 19 | parser.add_argument('--arch', type=str, required=True, help='choice of architecture') 20 | parser.add_argument('--samples', type=int, required=True, help="number of samples per update") 21 | parser.add_argument('--tail', type=str, default='', help='extra information to add to folder name') 22 | parser.add_argument('--save_dir', type=str, default='./saved_models/', help='base directory for saving.') 23 | parser.add_argument('--data_dir', type=str, default='./data/', help='base directory for saving.') 24 | args = parser.parse_args() 25 | 26 | ## CUDA 27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 28 | 29 | ## Model Directory 30 | model_dir = os.path.join(args.save_dir, 31 | 'forward', 32 | f'{args.data}+{args.arch}', 33 | f'samples{args.samples}' 34 | f'{args.tail}') 35 | os.makedirs(model_dir, exist_ok=True) 36 | utils.save_params(model_dir, vars(args)) 37 | print(model_dir) 38 | 39 | ## Data 40 | trainset, testset, num_classes = L.load_dataset(args.data, data_dir=args.data_dir) 41 | X_train, y_train = F.get_samples(trainset, args.samples) 42 | X_train, y_train = X_train.to(device), y_train.to(device) 43 | 44 | ## Architecture 45 | net = L.load_architecture(args.data, args.arch) 46 | net = net.to(device) 47 | 48 | ## Training 49 | with torch.no_grad(): 50 | Z_train = net.init(X_train, y_train) 51 | losses_train = net.get_loss() 52 | X_train, Z_train = F.to_cpu(X_train, Z_train) 53 | 54 | ## Saving 55 | utils.save_loss(model_dir, 'train', losses_train) 56 | utils.save_ckpt(model_dir, 'model', net) 57 | 58 | ## Plotting 59 | plot.plot_loss_mcr(model_dir, 'train') 60 | 61 | print(model_dir) 62 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | 8 | 9 | def sort_dataset(data, labels, classes, stack=False): 10 | """Sort dataset based on classes. 11 | 12 | Parameters: 13 | data (np.ndarray): data array 14 | labels (np.ndarray): one dimensional array of class labels 15 | classes (int): number of classes 16 | stack (bol): combine sorted data into one numpy array 17 | 18 | Return: 19 | sorted data (np.ndarray), sorted_labels (np.ndarray) 20 | 21 | """ 22 | if type(classes) == int: 23 | classes = np.arange(classes) 24 | sorted_data = [] 25 | sorted_labels = [] 26 | for c in classes: 27 | idx = (labels == c) 28 | data_c = data[idx] 29 | labels_c = labels[idx] 30 | sorted_data.append(data_c) 31 | sorted_labels.append(labels_c) 32 | if stack: 33 | if isinstance(data, np.ndarray): 34 | sorted_data = np.vstack(sorted_data) 35 | sorted_labels = np.hstack(sorted_labels) 36 | else: 37 | sorted_data = torch.stack(sorted_data) 38 | sorted_labels = torch.cat(sorted_labels) 39 | return sorted_data, sorted_labels 40 | 41 | def save_params(model_dir, params, name='params', name_prefix=None): 42 | """Save params to a .json file. Params is a dictionary of parameters.""" 43 | if name_prefix: 44 | model_dir = os.path.join(model_dir, name_prefix) 45 | os.makedirs(model_dir, exist_ok=True) 46 | path = os.path.join(model_dir, f'{name}.json') 47 | with open(path, 'w') as f: 48 | json.dump(params, f, indent=2, sort_keys=True) 49 | 50 | def load_params(model_dir): 51 | """Load params.json file in model directory and return dictionary.""" 52 | path = os.path.join(model_dir, "params.json") 53 | with open(path, 'r') as f: 54 | _dict = json.load(f) 55 | return _dict 56 | 57 | def update_params(model_dir, dict_): 58 | params = load_params(model_dir) 59 | for key in dict_.keys(): 60 | params[key] = dict_[key] 61 | save_params(model_dir, params) 62 | return params 63 | 64 | def create_csv(model_dir, filename, headers): 65 | """Create .csv file with filename in model_dir, with headers as the first line 66 | of the csv. """ 67 | csv_path = os.path.join(model_dir, f'{filename}.csv') 68 | if os.path.exists(csv_path): 69 | os.remove(csv_path) 70 | with open(csv_path, 'w+') as f: 71 | f.write(','.join(map(str, headers))) 72 | return csv_path 73 | 74 | def append_csv(model_dir, filename, entries): 75 | """Save entries to csv. Entries is list of numbers. """ 76 | csv_path = os.path.join(model_dir, f'{filename}.csv') 77 | assert os.path.exists(csv_path), 'CSV file is missing in project directory.' 78 | with open(csv_path, 'a') as f: 79 | f.write('\n'+','.join(map(str, entries))) 80 | 81 | def save_loss(model_dir, name, loss_dict): 82 | save_dir = os.path.join(model_dir, "loss") 83 | os.makedirs(save_dir, exist_ok=True) 84 | file_path = os.path.join(save_dir, "{}.csv".format(name)) 85 | pd.DataFrame(loss_dict).to_csv(file_path) 86 | 87 | def save_features(model_dir, name, features, labels, layer=None): 88 | save_dir = os.path.join(model_dir, "features") 89 | os.makedirs(save_dir, exist_ok=True) 90 | np.save(os.path.join(save_dir, f"{name}_features.npy"), features) 91 | np.save(os.path.join(save_dir, f"{name}_labels.npy"), labels) 92 | 93 | def save_ckpt(model_dir, name, net): 94 | """Save PyTorch checkpoint to model_dir/checkpoints/ directory in model directory. """ 95 | os.makedirs(os.path.join(model_dir, 'checkpoints'), exist_ok=True) 96 | torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints', 97 | '{}.pt'.format(name))) 98 | 99 | def load_ckpt(model_dir, name, net, eval_=True): 100 | """Load checkpoint from model directory. Checkpoints should be stored in 101 | `model_dir/checkpoints/'. 102 | """ 103 | ckpt_path = os.path.join(model_dir, 'checkpoints', f'{name}.pt') 104 | print('Loading checkpoint: {}'.format(ckpt_path)) 105 | state_dict = torch.load(ckpt_path) 106 | net.load_state_dict(state_dict) 107 | del state_dict 108 | if eval_: 109 | net.eval() 110 | return net 111 | 112 | --------------------------------------------------------------------------------