├── .gitignore
├── README.md
├── architectures
└── mnist
│ ├── flatten.py
│ └── lift2d.py
├── datasets
├── __init__.py
├── mnist.py
└── utils_data.py
├── evaluate.py
├── evaluate_forward.py
├── examples
├── gaussian2d.ipynb
├── gaussian3d.ipynb
└── utils_example.py
├── functional.py
├── images
└── arch-redunet.jpg
├── load.py
├── loss.py
├── plot.py
├── redunet
├── __init__.py
├── layers
│ ├── fourier1d.py
│ ├── fourier2d.py
│ ├── redulayer.py
│ └── vector.py
├── modules.py
├── multichannel_weight.py
├── projections
│ └── lift.py
└── redunet.py
├── requirements.txt
├── train_forward.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode
4 |
5 | ### Emacs ###
6 | # -*- mode: gitignore; -*-
7 | *~
8 | \#*\#
9 | /.emacs.desktop
10 | /.emacs.desktop.lock
11 | *.elc
12 | auto-save-list
13 | tramp
14 | .\#*
15 |
16 | # Org-mode
17 | .org-id-locations
18 | *_archive
19 | ltximg/**
20 |
21 | # flymake-mode
22 | *_flymake.*
23 |
24 | # eshell files
25 | /eshell/history
26 | /eshell/lastdir
27 |
28 | # elpa packages
29 | /elpa/
30 |
31 | # reftex files
32 | *.rel
33 |
34 | # AUCTeX auto folder
35 | /auto/
36 |
37 | # cask packages
38 | .cask/
39 | dist/
40 |
41 | # Flycheck
42 | flycheck_*.el
43 |
44 | # server auth directory
45 | /server/
46 |
47 | # projectiles files
48 | .projectile
49 |
50 | # directory configuration
51 | .dir-locals.el
52 |
53 | # network security
54 | /network-security.data
55 |
56 |
57 | ### JupyterNotebooks ###
58 | # gitignore template for Jupyter Notebooks
59 | # website: http://jupyter.org/
60 |
61 | .ipynb_checkpoints
62 | */.ipynb_checkpoints/*
63 |
64 | # IPython
65 | profile_default/
66 | ipython_config.py
67 |
68 | # Remove previous ipynb_checkpoints
69 | # git rm -r .ipynb_checkpoints/
70 |
71 | ### Python ###
72 | # Byte-compiled / optimized / DLL files
73 | __pycache__/
74 | *.py[cod]
75 | *$py.class
76 |
77 | # C extensions
78 | *.so
79 |
80 | # Distribution / packaging
81 | .Python
82 | build/
83 | develop-eggs/
84 | downloads/
85 | eggs/
86 | .eggs/
87 | lib/
88 | lib64/
89 | parts/
90 | sdist/
91 | var/
92 | wheels/
93 | pip-wheel-metadata/
94 | share/python-wheels/
95 | *.egg-info/
96 | .installed.cfg
97 | *.egg
98 | MANIFEST
99 |
100 | # PyInstaller
101 | # Usually these files are written by a python script from a template
102 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
103 | *.manifest
104 | *.spec
105 |
106 | # Installer logs
107 | pip-log.txt
108 | pip-delete-this-directory.txt
109 |
110 | # Unit test / coverage reports
111 | htmlcov/
112 | .tox/
113 | .nox/
114 | .coverage
115 | .coverage.*
116 | .cache
117 | nosetests.xml
118 | coverage.xml
119 | *.cover
120 | *.py,cover
121 | .hypothesis/
122 | .pytest_cache/
123 | pytestdebug.log
124 |
125 | # Translations
126 | *.mo
127 | *.pot
128 |
129 | # Django stuff:
130 | *.log
131 | local_settings.py
132 | db.sqlite3
133 | db.sqlite3-journal
134 |
135 | # Flask stuff:
136 | instance/
137 | .webassets-cache
138 |
139 | # Scrapy stuff:
140 | .scrapy
141 |
142 | # Sphinx documentation
143 | docs/_build/
144 | doc/_build/
145 |
146 | # PyBuilder
147 | target/
148 |
149 | # Jupyter Notebook
150 |
151 | # IPython
152 |
153 | # pyenv
154 | .python-version
155 |
156 | # pipenv
157 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
158 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
159 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
160 | # install all needed dependencies.
161 | #Pipfile.lock
162 |
163 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
164 | __pypackages__/
165 |
166 | # Celery stuff
167 | celerybeat-schedule
168 | celerybeat.pid
169 |
170 | # SageMath parsed files
171 | *.sage.py
172 |
173 | # Environments
174 | .env
175 | .venv
176 | env/
177 | venv/
178 | ENV/
179 | env.bak/
180 | venv.bak/
181 | pythonenv*
182 |
183 | # Spyder project settings
184 | .spyderproject
185 | .spyproject
186 |
187 | # Rope project settings
188 | .ropeproject
189 |
190 | # mkdocs documentation
191 | /site
192 |
193 | # mypy
194 | .mypy_cache/
195 | .dmypy.json
196 | dmypy.json
197 |
198 | # Pyre type checker
199 | .pyre/
200 |
201 | # pytype static type analyzer
202 | .pytype/
203 |
204 | # profiling data
205 | .prof
206 |
207 | ### SublimeText ###
208 | # Cache files for Sublime Text
209 | *.tmlanguage.cache
210 | *.tmPreferences.cache
211 | *.stTheme.cache
212 |
213 | # Workspace files are user-specific
214 | *.sublime-workspace
215 |
216 | # Project files should be checked into the repository, unless a significant
217 | # proportion of contributors will probably not be using Sublime Text
218 | # *.sublime-project
219 |
220 | # SFTP configuration file
221 | sftp-config.json
222 |
223 | # Package control specific files
224 | Package Control.last-run
225 | Package Control.ca-list
226 | Package Control.ca-bundle
227 | Package Control.system-ca-bundle
228 | Package Control.cache/
229 | Package Control.ca-certs/
230 | Package Control.merged-ca-bundle
231 | Package Control.user-ca-bundle
232 | oscrypto-ca-bundle.crt
233 | bh_unicode_properties.cache
234 |
235 | # Sublime-github package stores a github token in this file
236 | # https://packagecontrol.io/packages/sublime-github
237 | GitHub.sublime-settings
238 |
239 | ### Vim ###
240 | # Swap
241 | [._]*.s[a-v][a-z]
242 | !*.svg # comment out if you don't need vector files
243 | [._]*.sw[a-p]
244 | [._]s[a-rt-v][a-z]
245 | [._]ss[a-gi-z]
246 | [._]sw[a-p]
247 |
248 | # Session
249 | Session.vim
250 | Sessionx.vim
251 |
252 | # Temporary
253 | .netrwhist
254 | # Auto-generated tag files
255 | tags
256 | # Persistent undo
257 | [._]*.un~
258 |
259 | ### VisualStudioCode ###
260 | .vscode/*
261 | !.vscode/tasks.json
262 | !.vscode/launch.json
263 | *.code-workspace
264 |
265 | ### VisualStudioCode Patch ###
266 | # Ignore all local history of files
267 | .history
268 | .ionide
269 |
270 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vim,emacs,sublimetext,visualstudiocode
271 |
272 |
273 | data/
274 | saved_models/
275 | transfer.sh
276 | transfer/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Networks from the Principle of Rate Reduction
2 | This repository is the official PyTorch implementation of the paper [ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction](https://arxiv.org/abs/2105.10446) (2021)
3 |
4 | by [Kwan Ho Ryan Chan](https://ryanchankh.github.io)* (UC Berkeley), [Yaodong Yu](https://yaodongyu.github.io/)* (UC Berkeley), [Chong You](https://sites.google.com/view/cyou)* (UC Berkeley), [Haozhi Qi](https://haozhi.io/) (UC Berkeley), [John Wright](http://www.columbia.edu/~jw2966/) (Columbia University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley).
5 |
6 |
7 |
8 | ## What is ReduNet?
9 | ReduNet is a deep neural network construcuted naturally by deriving the gradients of the Maximal Coding Rate Reduction (MCR2) [1] objective. Every layer of this network can be interpreted based on its mathematical operations and the network collectively is trained in a feed-forward manner only. In addition, by imposing shift invariant properties to our network, the convolutional operator can be derived using only the data and MCR2 objective function, hence making our network design principled and interpretable.
10 |
11 |
12 | 
13 | Figure: Weights and operations for one layer of ReduNet
14 |
15 |
16 |
17 | [1] Yu, Yaodong, Kwan Ho Ryan Chan, Chong You, Chaobing Song, and Yi Ma. "[Learning diverse and discriminative representations via the principle of maximal coding rate reduction](https://proceedings.neurips.cc/paper/2020/file/6ad4174eba19ecb5fed17411a34ff5e6-Paper.pdf)" Advances in Neural Information Processing Systems 33 (2020).
18 |
19 | ## Requirements
20 | This codebase is written for `python3`. To install necessary python packages, run `conda create --name redunet_official --file requirements.txt`.
21 |
22 | ## Demo
23 | For a quick demonstration of ReduNet on Gaussian 2D or 3D cases, please visit the notebook by running one of the two commands:
24 |
25 | ```
26 | $ jupyter notebook ./examples/gaussian2d.ipynb
27 | $ jupyter notebook ./examples/gaussian3d.ipynb
28 | ```
29 |
30 | ## Core Usage and Design
31 | The design of this repository aims to be easy-to-use and easy-to-intergrate to the current framework of your experiment, as long as it uses PyTorch. The `ReduNet` object inherents from `nn.Sequential`, and layers `ReduLayers`, such as `Vector`, `Fourier1D` and `Fourier2D` inherent from `nn.Module`. Loss functions are implemented in `loss.py`. Architectures and Dataset options are located in `load.py` file. Data objects and pre-set architectures are loaded in folders `dataset` and `architectures`. Feel free to add more based on the experiments you want to run. We have provided basic experiment setups, located in `train_.py` and `evaluate_.py`, where `` is the type of experiment. For utility functions, please check out `functional.py` or `utils.py`. Feel free to email us if there are any issues or suggestions.
32 |
33 |
34 | ## Example: Forward Construction
35 | To train a ReduNet using forward construction, please checkout `train_forward.py`. For evaluating, please checkout `evaluate_forward.py`. For example, to train on 40-layer ReduNet on MNIST using 1000 samples per class, run:
36 |
37 | ```
38 | $ python3 train_forward.py --data mnistvector --arch layers50 --samples 1000
39 | ```
40 | After training, you can evaluate the trained model using `evaluate_forward.py`, by running:
41 |
42 | ```
43 | $ python3 evaluate_forward.py --model_dir ./saved_models/forward/mnistvector+layers50/samples1000
44 | ```
45 | , which will evaluate using all available training samples and testing samples. For more training and testing options, please checkout the file `train_forward.py` and `evaluate_forward.py`.
46 |
47 | ### Experiments in Paper
48 | For code used to generate experimental empirical results listed in our paper, please visit our other repository: [https://github.com/ryanchankh/redunet_paper](https://github.com/ryanchankh/redunet_paper)
49 |
50 | ## Reference
51 | For technical details and full experimental results, please check the [paper](https://arxiv.org/abs/2105.10446). Please consider citing our work if you find it helpful to yours:
52 |
53 | ```
54 | @article{chan2021redunet,
55 | title={ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction},
56 | author={Chan, Kwan Ho Ryan and Yu, Yaodong and You, Chong and Qi, Haozhi and Wright, John and Ma, Yi},
57 | journal={arXiv preprint arXiv:2105.10446},
58 | year={2021}
59 | }
60 | ```
61 |
62 | ## License and Contributing
63 | - This README is formatted based on [paperswithcode](https://github.com/paperswithcode/releasing-research-code).
64 | - Feel free to post issues via Github.
65 |
66 | ## Contact
67 | Please contact [ryanchankh@berkeley.edu](ryanchankh@berkeley.edu) and [yyu@eecs.berkeley.edu](yyu@eecs.berkeley.edu) if you have any question on the codes.
68 |
--------------------------------------------------------------------------------
/architectures/mnist/flatten.py:
--------------------------------------------------------------------------------
1 | from redunet import *
2 |
3 |
4 |
5 | def flatten(layers, num_classes):
6 | net = ReduNet(
7 | *[Vector(eta=0.5,
8 | eps=0.1,
9 | lmbda=500,
10 | num_classes=num_classes,
11 | dimensions=784
12 | ) for _ in range(layers)],
13 | )
14 | return net
--------------------------------------------------------------------------------
/architectures/mnist/lift2d.py:
--------------------------------------------------------------------------------
1 | from redunet import *
2 |
3 |
4 |
5 | def lift2d(channels, layers, num_classes, seed=0):
6 | net = ReduNet(
7 | Lift2D(1, channels, 9, seed=seed),
8 | *[Fourier2D(eta=0.5,
9 | eps=0.1,
10 | lmbda=500,
11 | num_classes=num_classes,
12 | dimensions=(channels, 28, 28)
13 | ) for _ in range(layers)],
14 | )
15 | return net
16 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ma-Lab-Berkeley/ReduNet/f67a348e5b54b9e60b783ceec6fe0e379bc3b96f/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import torchvision.datasets as datasets
2 | import torchvision.transforms as transforms
3 | from torch.utils.data import DataLoader
4 | from .utils_data import filter_class
5 |
6 |
7 |
8 |
9 |
10 |
11 | def mnist2d_10class(data_dir):
12 | transform = transforms.Compose([
13 | transforms.ToTensor(),
14 | ])
15 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
16 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
17 | num_classes = 10
18 | return trainset, testset, num_classes
19 |
20 | def mnist2d_5class(data_dir):
21 | transform = transforms.Compose([
22 | transforms.ToTensor(),
23 | ])
24 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
25 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
26 | trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4])
27 | testset, _ = filter_class(testset, [0, 1, 2, 3, 4])
28 | num_classes = 5
29 | return trainset, testset, num_classes
30 |
31 | def mnist2d_2class(data_dir):
32 | transform = transforms.Compose([
33 | transforms.ToTensor(),
34 | ])
35 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
36 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
37 | trainset, num_classes = filter_class(trainset, [0, 1])
38 | testset, _ = filter_class(testset, [0, 1])
39 | return trainset, testset, num_classes
40 |
41 | def mnistvector_10class(data_dir):
42 | transform = transforms.Compose([
43 | transforms.ToTensor(),
44 | transforms.Lambda(lambda x: x.flatten())
45 | ])
46 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
47 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
48 | num_classes = 10
49 | return trainset, testset, num_classes
50 |
51 | def mnistvector_5class(data_dir):
52 | transform = transforms.Compose([
53 | transforms.ToTensor(),
54 | transforms.Lambda(lambda x: x.flatten())
55 | ])
56 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
57 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
58 | trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4])
59 | testset, _ = filter_class(testset, [0, 1, 2, 3, 4])
60 | return trainset, testset, num_classes
61 |
62 | def mnistvector_2class(data_dir):
63 | transform = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Lambda(lambda x: x.flatten())
66 | ])
67 | trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
68 | testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
69 | trainset, num_classes = filter_class(trainset, [0, 1])
70 | testset, _ = filter_class(testset, [0, 1])
71 | return trainset, testset, num_classes
72 |
73 |
74 | if __name__ == '__main__':
75 | trainset, testset, num_classes = mnist2d_2class('./data/')
76 | trainloader = DataLoader(trainset, batch_size=trainset.data.shape[0])
77 | print(trainset)
78 | print(testset)
79 | print(num_classes)
80 |
81 | batch_imgs, batch_lbls = next(iter(trainloader))
82 | print(batch_imgs.shape, batch_lbls.shape)
83 | print(batch_lbls.unique(return_counts=True))
84 |
--------------------------------------------------------------------------------
/datasets/utils_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 |
6 | def filter_class(dataset, classes):
7 | data, labels = dataset.data, dataset.targets
8 | if type(labels) == list:
9 | labels = torch.tensor(labels)
10 | data_filter = []
11 | labels_filter = []
12 | for _class in classes:
13 | idx = labels == _class
14 | data_filter.append(data[idx])
15 | labels_filter.append(labels[idx])
16 | if type(dataset.data) == np.ndarray:
17 | dataset.data = np.vstack(data_filter)
18 | dataset.targets = np.hstack(labels_filter)
19 | elif type(dataset.data) == torch.Tensor:
20 | dataset.data = torch.cat(data_filter)
21 | dataset.targets = torch.cat(labels_filter)
22 | else:
23 | raise TypeError('dataset.data type neither np.ndarray nor torch.Tensor')
24 | return dataset, len(classes)
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats as sps
3 | import torch
4 |
5 | from sklearn.svm import LinearSVC
6 | from sklearn.decomposition import PCA
7 | from sklearn.decomposition import TruncatedSVD
8 | from sklearn.linear_model import SGDClassifier
9 | from sklearn.svm import LinearSVC, SVC
10 | from sklearn.tree import DecisionTreeClassifier
11 | from sklearn.ensemble import RandomForestClassifier
12 |
13 | import functional as F
14 | import utils
15 |
16 |
17 |
18 | def evaluate(eval_dir, method, train_features, train_labels, test_features, test_labels, **kwargs):
19 | if method == 'svm':
20 | acc_train, acc_test = svm(train_features, train_labels, test_features, test_labels)
21 | elif method == 'knn':
22 | acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs)
23 | elif method == 'nearsub':
24 | acc_train, acc_test = nearsub(train_features, train_labels, test_features, test_labels, **kwargs)
25 | elif method == 'nearsub_pca':
26 | acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs)
27 | acc_dict = {'train': acc_train, 'test': acc_test}
28 | utils.save_params(eval_dir, acc_dict, name=f'acc_{method}')
29 |
30 | def svm(train_features, train_labels, test_features, test_labels):
31 | svm = LinearSVC(verbose=0, random_state=10)
32 | svm.fit(train_features, train_labels)
33 | acc_train = svm.score(train_features, train_labels)
34 | acc_test = svm.score(test_features, test_labels)
35 | print("SVM: {}, {}".format(acc_train, acc_test))
36 | return acc_train, acc_test
37 |
38 | # def knn(train_features, train_labels, test_features, test_labels, k=5):
39 | # sim_mat = train_features @ train_features.T
40 | # topk = torch.from_numpy(sim_mat).topk(k=k, dim=0)
41 | # topk_pred = train_labels[topk.indices]
42 | # test_pred = torch.tensor(topk_pred).mode(0).values.detach()
43 | # acc_train = compute_accuracy(test_pred.numpy(), train_labels)
44 |
45 | # sim_mat = train_features @ test_features.T
46 | # topk = torch.from_numpy(sim_mat).topk(k=k, dim=0)
47 | # topk_pred = train_labels[topk.indices]
48 | # test_pred = torch.tensor(topk_pred).mode(0).values.detach()
49 | # acc_test = compute_accuracy(test_pred.numpy(), test_labels)
50 | # print("kNN: {}, {}".format(acc_train, acc_test))
51 | # return acc_train, acc_test
52 |
53 | def knn(train_features, train_labels, test_features, test_labels, k=5):
54 | sim_mat = train_features @ train_features.T
55 | topk = sim_mat.topk(k=k, dim=0)
56 | topk_pred = train_labels[topk.indices]
57 | test_pred = topk_pred.mode(0).values.detach()
58 | acc_train = compute_accuracy(test_pred, train_labels)
59 |
60 | sim_mat = train_features @ test_features.T
61 | topk = sim_mat.topk(k=k, dim=0)
62 | topk_pred = train_labels[topk.indices]
63 | test_pred = topk_pred.mode(0).values.detach()
64 | acc_test = compute_accuracy(test_pred, test_labels)
65 | print("kNN: {}, {}".format(acc_train, acc_test))
66 | return acc_train, acc_test
67 |
68 | # # TODO: 1. implement pytorch version 2. suport batches
69 | # def nearsub(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10, return_pred=False):
70 | # train_scores, test_scores = [], []
71 | # classes = np.arange(num_classes)
72 | # features_sort, _ = utils.sort_dataset(train_features, train_labels,
73 | # classes=classes, stack=False)
74 | # fd = features_sort[0].shape[1]
75 | # if n_comp >= fd:
76 | # n_comp = fd - 1
77 | # for j in classes:
78 | # svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j])
79 | # subspace_j = np.eye(fd) - svd.components_.T @ svd.components_
80 | # train_j = subspace_j @ train_features.T
81 | # test_j = subspace_j @ test_features.T
82 | # train_scores_j = np.linalg.norm(train_j, ord=2, axis=0)
83 | # test_scores_j = np.linalg.norm(test_j, ord=2, axis=0)
84 | # train_scores.append(train_scores_j)
85 | # test_scores.append(test_scores_j)
86 | # train_pred = np.argmin(train_scores, axis=0)
87 | # test_pred = np.argmin(test_scores, axis=0)
88 | # if return_pred:
89 | # return train_pred.tolist(), test_pred.tolist()
90 | # train_acc = compute_accuracy(classes[train_pred], train_labels)
91 | # test_acc = compute_accuracy(classes[test_pred], test_labels)
92 | # print('SVD: {}, {}'.format(train_acc, test_acc))
93 | # return train_acc, test_acc
94 |
95 | def nearsub(train_features, train_labels, test_features, test_labels,
96 | num_classes, n_comp=10, return_pred=False):
97 | train_scores, test_scores = [], []
98 | classes = np.arange(num_classes)
99 | features_sort, _ = utils.sort_dataset(train_features, train_labels,
100 | classes=classes, stack=False)
101 | fd = features_sort[0].shape[1]
102 | for j in classes:
103 | _, _, V = torch.svd(features_sort[j])
104 | components = V[:, :n_comp].T
105 | subspace_j = torch.eye(fd) - components.T @ components
106 | train_j = subspace_j @ train_features.T
107 | test_j = subspace_j @ test_features.T
108 | train_scores_j = torch.linalg.norm(train_j, ord=2, axis=0)
109 | test_scores_j = torch.linalg.norm(test_j, ord=2, axis=0)
110 | train_scores.append(train_scores_j)
111 | test_scores.append(test_scores_j)
112 | train_pred = torch.stack(train_scores).argmin(0)
113 | test_pred = torch.stack(test_scores).argmin(0)
114 | if return_pred:
115 | return train_pred.numpy(), test_pred.numpy()
116 | train_acc = compute_accuracy(classes[train_pred], train_labels.numpy())
117 | test_acc = compute_accuracy(classes[test_pred], test_labels.numpy())
118 | print('SVD: {}, {}'.format(train_acc, test_acc))
119 | return train_acc, test_acc
120 |
121 | def nearsub_pca(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10):
122 | scores_pca = []
123 | classes = np.arange(num_classes)
124 | features_sort, _ = utils.sort_dataset(train_features, train_labels, classes=classes, stack=False)
125 | fd = features_sort[0].shape[1]
126 | if n_comp >= fd:
127 | n_comp = fd - 1
128 | for j in np.arange(len(classes)):
129 | pca = PCA(n_components=n_comp).fit(features_sort[j])
130 | pca_subspace = pca.components_.T
131 | mean = np.mean(features_sort[j], axis=0)
132 | pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \
133 | @ (test_features - mean).T
134 | score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0)
135 | scores_pca.append(score_pca_j)
136 | test_predict_pca = np.argmin(scores_pca, axis=0)
137 | acc_pca = compute_accuracy(classes[test_predict_pca], test_labels)
138 | print('PCA: {}'.format(acc_pca))
139 | return acc_pca
140 |
141 | def argmax(train_features, train_labels, test_features, test_labels):
142 | train_pred = train_features.argmax(1)
143 | train_acc = compute_accuracy(train_pred, train_labels)
144 | test_pred = test_features.argmax(1)
145 | test_acc = compute_accuracy(test_pred, test_labels)
146 | return train_acc, test_acc
147 |
148 | def compute_accuracy(y_pred, y_true):
149 | """Compute accuracy by counting correct classification. """
150 | assert y_pred.shape == y_true.shape
151 | if type(y_pred) == torch.Tensor:
152 | n_wrong = torch.count_nonzero(y_pred - y_true).item()
153 | elif type(y_pred) == np.ndarray:
154 | n_wrong = np.count_nonzero(y_pred - y_true)
155 | else:
156 | raise TypeError("Not Tensor nor Array type.")
157 | n_samples = len(y_pred)
158 | return 1 - n_wrong / n_samples
159 |
160 | def baseline(train_features, train_labels, test_features, test_labels):
161 | test_models = {'log_l2': SGDClassifier(loss='log', max_iter=10000, random_state=42),
162 | 'SVM_linear': LinearSVC(max_iter=10000, random_state=42),
163 | 'SVM_RBF': SVC(kernel='rbf', random_state=42),
164 | 'DecisionTree': DecisionTreeClassifier(),
165 | 'RandomForrest': RandomForestClassifier()}
166 | for model_name in test_models:
167 | test_model = test_models[model_name]
168 | test_model.fit(train_features, train_labels)
169 | score = test_model.score(test_features, test_labels)
170 | print(f"{model_name}: {score}")
171 |
172 | def majority_vote(pred, true):
173 | pred_majority = sps.mode(pred, axis=0)[0].squeeze()
174 | return compute_accuracy(pred_majority, true)
175 |
--------------------------------------------------------------------------------
/evaluate_forward.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from redunet import *
8 | import evaluate
9 | import functional as F
10 | import load as L
11 | import utils
12 | import plot
13 |
14 |
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--model_dir', type=str, help='model directory')
18 | parser.add_argument('--loss', default=False, action='store_true', help='set to True if plot loss')
19 | parser.add_argument('--trainsamples', type=int, default=None, help="number of train samples in each class")
20 | parser.add_argument('--testsamples', type=int, default=None, help="number of train samples in each class")
21 | parser.add_argument('--translatetrain', default=False, action='store_true', help='set to True if translation train samples')
22 | parser.add_argument('--translatetest', default=False, action='store_true', help='set to True if translation test samples')
23 | parser.add_argument('--batch_size', type=int, default=100, help='batch size for evaluation')
24 | args = parser.parse_args()
25 |
26 | ## CUDA
27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
28 |
29 | ## Setup
30 | eval_dir = os.path.join(args.model_dir,
31 | f'trainsamples{args.trainsamples}'
32 | f'_testsamples{args.testsamples}'
33 | f'_translatetrain{args.translatetrain}'
34 | f'_translatetest{args.translatetest}')
35 | params = utils.load_params(args.model_dir)
36 |
37 | ## Data
38 | trainset, testset, num_classes = L.load_dataset(params['data'], data_dir=params['data_dir'])
39 | X_train, y_train = F.get_samples(trainset, args.trainsamples)
40 | X_test, y_test = F.get_samples(testset, args.testsamples)
41 | if args.translatetrain:
42 | X_train, y_train = F.translate(X_train, y_train, stride=7)
43 | if args.translatetest:
44 | X_test, y_test = F.translate(X_test, y_test, stride=7)
45 | X_train, y_train = X_train.to(device), y_train.to(device)
46 | X_test, y_test = X_test.to(device), y_test.to(device)
47 |
48 | ## Architecture
49 | net = L.load_architecture(params['data'], params['arch'])
50 | net = utils.load_ckpt(args.model_dir, 'model', net)
51 | net = net.to(device)
52 |
53 | ## Forward
54 | with torch.no_grad():
55 | print('train')
56 | Z_train = net.batch_forward(X_train, batch_size=args.batch_size, loss=args.loss, device=device)
57 | X_train, y_train, Z_train = F.to_cpu(X_train, y_train, Z_train)
58 | utils.save_loss(eval_dir, f'train', net.get_loss())
59 |
60 | print('test')
61 | Z_test = net.batch_forward(X_test, batch_size=args.batch_size, loss=args.loss, device=device)
62 | X_test, y_test, Z_test = F.to_cpu(X_test, y_test, Z_test)
63 | utils.save_loss(eval_dir, f'test', net.get_loss())
64 |
65 | ## Normalize
66 | X_train = F.normalize(X_train.flatten(1))
67 | X_test = F.normalize(X_test.flatten(1))
68 | Z_train = F.normalize(Z_train.flatten(1))
69 | Z_test = F.normalize(Z_test.flatten(1))
70 |
71 | # Evaluate
72 | evaluate.evaluate(eval_dir, 'knn', Z_train, y_train, Z_test, y_test)
73 | #evaluate.evaluate(eval_dir, 'nearsub', Z_train, y_train, Z_test, y_test, num_classes=num_classes, n_comp=10)
74 |
75 | # Plot
76 | plot.plot_loss_mcr(eval_dir, 'train')
77 | plot.plot_loss_mcr(eval_dir, 'test')
78 | plot.plot_heatmap(eval_dir, 'X_train', X_train, y_train, num_classes)
79 | plot.plot_heatmap(eval_dir, 'X_test', X_test, y_test, num_classes)
80 | plot.plot_heatmap(eval_dir, 'Z_train', Z_train, y_train, num_classes)
81 | plot.plot_heatmap(eval_dir, 'Z_test', Z_test, y_test, num_classes)
82 |
83 |
--------------------------------------------------------------------------------
/examples/gaussian2d.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.chdir('../')"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from redunet import ReduNetVector\n",
20 | "import utils_example as ue\n",
21 | "import plot"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "## Hyperparameters"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "## Data\n",
38 | "dataset = 1 # can be 1 to 8\n",
39 | "train_noise = 0.1\n",
40 | "test_noise = 0.1\n",
41 | "train_samples = 100\n",
42 | "test_samples = 100\n",
43 | "\n",
44 | "## Model\n",
45 | "num_layers = 200 # number of redunet layers\n",
46 | "eta = 0.5\n",
47 | "eps = 0.1\n",
48 | "lmbda = 200"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Data"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {
62 | "scrolled": true
63 | },
64 | "outputs": [],
65 | "source": [
66 | "X_train, y_train, num_classes = ue.generate_2d(dataset, train_noise, train_samples) # train\n",
67 | "X_test, y_test, num_classes = ue.generate_2d(dataset, test_noise, test_samples) # test"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Model"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "net = ReduNetVector(num_classes, num_layers, X_train.shape[1], eta=eta, eps=eps, lmbda=lmbda)\n",
84 | "Z_train = net.init(X_train, y_train)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "ue.plot_loss_mcr(net.get_loss())\n",
94 | "ue.plot_2d(X_train, y_train, Z_train) "
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {
101 | "scrolled": true
102 | },
103 | "outputs": [],
104 | "source": [
105 | "Z_test = net(X_test).detach()\n",
106 | "ue.plot_2d(X_test, y_test, Z_test)"
107 | ]
108 | }
109 | ],
110 | "metadata": {
111 | "kernelspec": {
112 | "display_name": "redunet",
113 | "language": "python",
114 | "name": "redunet"
115 | },
116 | "language_info": {
117 | "codemirror_mode": {
118 | "name": "ipython",
119 | "version": 3
120 | },
121 | "file_extension": ".py",
122 | "mimetype": "text/x-python",
123 | "name": "python",
124 | "nbconvert_exporter": "python",
125 | "pygments_lexer": "ipython3",
126 | "version": "3.7.9"
127 | },
128 | "latex_envs": {
129 | "LaTeX_envs_menu_present": true,
130 | "autoclose": false,
131 | "autocomplete": true,
132 | "bibliofile": "biblio.bib",
133 | "cite_by": "apalike",
134 | "current_citInitial": 1,
135 | "eqLabelWithNumbers": true,
136 | "eqNumInitial": 1,
137 | "hotkeys": {
138 | "equation": "Ctrl-E",
139 | "itemize": "Ctrl-I"
140 | },
141 | "labels_anchors": false,
142 | "latex_user_defs": false,
143 | "report_style_numbering": false,
144 | "user_envs_cfg": false
145 | }
146 | },
147 | "nbformat": 4,
148 | "nbformat_minor": 2
149 | }
150 |
--------------------------------------------------------------------------------
/examples/gaussian3d.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.chdir('../')"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from redunet import ReduNetVector\n",
20 | "import utils_example as ue\n",
21 | "import plot"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "## Hyperparameters"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "## Data\n",
38 | "dataset = 1 # can be 1 to 8\n",
39 | "train_noise = 0.1\n",
40 | "test_noise = 0.1\n",
41 | "train_samples = 100\n",
42 | "test_samples = 100\n",
43 | "\n",
44 | "## Model\n",
45 | "num_layers = 200 # number of redunet layers\n",
46 | "eta = 0.5\n",
47 | "eps = 0.1\n",
48 | "lmbda = 200"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Data"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {
62 | "scrolled": true
63 | },
64 | "outputs": [],
65 | "source": [
66 | "X_train, y_train, num_classes = ue.generate_3d(dataset, train_noise, train_samples) # train\n",
67 | "X_test, y_test, num_classes = ue.generate_3d(dataset, test_noise, test_samples) # test"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Model"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "net = ReduNetVector(num_classes, num_layers, X_train.shape[1], eta=eta, eps=eps, lmbda=lmbda)\n",
84 | "Z_train = net.init(X_train, y_train)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {
91 | "scrolled": false
92 | },
93 | "outputs": [],
94 | "source": [
95 | "ue.plot_loss_mcr(net.get_loss())\n",
96 | "ue.plot_3d(X_train, y_train, 'X_train') \n",
97 | "ue.plot_3d(Z_train, y_train, 'Z_train') "
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "scrolled": false
105 | },
106 | "outputs": [],
107 | "source": [
108 | "Z_test = net(X_test).detach()\n",
109 | "ue.plot_3d(X_test, y_test, 'X_test')\n",
110 | "ue.plot_3d(X_test, y_test, 'Z_test')"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": []
119 | }
120 | ],
121 | "metadata": {
122 | "kernelspec": {
123 | "display_name": "redunet",
124 | "language": "python",
125 | "name": "redunet"
126 | },
127 | "language_info": {
128 | "codemirror_mode": {
129 | "name": "ipython",
130 | "version": 3
131 | },
132 | "file_extension": ".py",
133 | "mimetype": "text/x-python",
134 | "name": "python",
135 | "nbconvert_exporter": "python",
136 | "pygments_lexer": "ipython3",
137 | "version": "3.7.9"
138 | },
139 | "latex_envs": {
140 | "LaTeX_envs_menu_present": true,
141 | "autoclose": false,
142 | "autocomplete": true,
143 | "bibliofile": "biblio.bib",
144 | "cite_by": "apalike",
145 | "current_citInitial": 1,
146 | "eqLabelWithNumbers": true,
147 | "eqNumInitial": 1,
148 | "hotkeys": {
149 | "equation": "Ctrl-E",
150 | "itemize": "Ctrl-I"
151 | },
152 | "labels_anchors": false,
153 | "latex_user_defs": false,
154 | "report_style_numbering": false,
155 | "user_envs_cfg": false
156 | }
157 | },
158 | "nbformat": 4,
159 | "nbformat_minor": 2
160 | }
161 |
--------------------------------------------------------------------------------
/examples/utils_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 |
5 | def generate_2d(data, noise, samples, shuffle=False):
6 | if data == 1:
7 | centers = [(1, 0), (0, 1)]
8 | elif data == 2:
9 | centers = [(np.cos(np.pi/3), np.sin(np.pi/3)), (1 ,0)]
10 | elif data == 3:
11 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4)), (1 ,0)]
12 | elif data == 4:
13 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)), (1 ,0)]
14 | elif data == 5:
15 | centers = [(np.cos(2*np.pi/3), np.sin(2*np.pi/3)), (1 ,0)]
16 | elif data == 6:
17 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)), (np.cos(4*np.pi/3), np.sin(4*np.pi/3)), (1 ,0)]
18 | elif data == 7:
19 | centers = [(np.cos(3*np.pi/4), np.sin(3*np.pi/4)),
20 | (np.cos(4*np.pi/3), np.sin(4*np.pi/3)),
21 | (np.cos(np.pi/4), np.sin(np.pi/4))]
22 | elif data == 8:
23 | centers = [(np.cos(np.pi/6), np.sin(np.pi/6)),
24 | (np.cos(np.pi/2), np.sin(np.pi/2)),
25 | (np.cos(3*np.pi/4), np.sin(3*np.pi/4)),
26 | (np.cos(5*np.pi/4), np.sin(5*np.pi/4)),
27 | (np.cos(7*np.pi/4), np.sin(7*np.pi/4)),
28 | (np.cos(3*np.pi/2), np.sin(3*np.pi/2))]
29 | else:
30 | raise NameError('data not found.')
31 |
32 | data = []
33 | targets = []
34 | for lbl, center in enumerate(centers):
35 | X = np.random.normal(loc=center, scale=noise, size=(samples, 2))
36 | y = np.repeat(lbl, samples).tolist()
37 | data.append(X)
38 | targets += y
39 | data = np.concatenate(data)
40 | data = data / np.linalg.norm(data, axis=1, ord=2, keepdims=True)
41 | targets = np.array(targets)
42 |
43 | if shuffle:
44 | idx_arr = np.random.choice(np.arange(len(data)), len(data), replace=False)
45 | data, targets = data[idx_arr], targets[idx_arr]
46 |
47 | data = torch.tensor(data).float()
48 | targets = torch.tensor(targets).long()
49 | return data, targets, len(centers)
50 |
51 | def generate_3d(data, noise, samples, shuffle=False):
52 | if data == 1:
53 | centers = [(1, 0, 0),
54 | (0, 1, 0),
55 | (0, 0, 1)]
56 | elif data == 2:
57 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4), 1),
58 | (np.cos(2*np.pi/3), np.sin(2*np.pi/3), 1),
59 | (np.cos(np.pi), np.sin(np.pi), 1)]
60 | elif data == 3:
61 | centers = [(np.cos(np.pi/4), np.sin(np.pi/4), 1),
62 | (np.cos(2*np.pi/3), np.sin(2*np.pi/3), 1),
63 | (np.cos(5*np.pi/6), np.cos(5*np.pi/6), 1)]
64 | else:
65 | raise NameError('Data not found.')
66 |
67 | X, Y = [], []
68 | for c, center in enumerate(centers):
69 | _X = np.random.normal(center, scale=(noise, noise, noise), size=(samples, 3))
70 | _Y = np.ones(samples, dtype=np.int32) * c
71 | X.append(_X)
72 | Y.append(_Y)
73 | X = np.vstack(X)
74 | X = X / np.linalg.norm(X, axis=1, ord=2, keepdims=True)
75 | Y = np.hstack(Y)
76 |
77 | if shuffle:
78 | idx_arr = np.random.choice(np.arange(len(X)), len(X), replace=False)
79 | X, Y = X[idx_arr], Y[idx_arr]
80 |
81 | X = torch.tensor(X).float()
82 | Y = torch.tensor(Y).long()
83 | return X, Y, 3
84 |
85 |
86 | def plot_2d(inputs, labels, outputs):
87 | fig, ax = plt.subplots(ncols=2, figsize=(8, 4))
88 | for c in labels.unique():
89 | ax[0].scatter(inputs[:, 0], inputs[:, 1], c=labels)
90 | ax[0].set_ylim([-1.1, 1.1])
91 | ax[0].set_xlim([-1.1, 1.1])
92 | ax[0].set_title('X')
93 | ax[1].scatter(outputs[:, 0], outputs[:, 1], c=labels)
94 | ax[1].set_ylim([-1.1, 1.1])
95 | ax[1].set_xlim([-1.1, 1.1])
96 | ax[1].set_title('Z')
97 | fig.tight_layout()
98 | plt.show()
99 | plt.close()
100 |
101 |
102 | def plot_3d(Z, y, title=''):
103 | colors = np.array(['green', 'blue', 'red'])
104 | colors = np.array(['forestgreen', 'royalblue', 'brown'])
105 | fig = plt.figure(figsize=(5, 5))
106 | ax = fig.add_subplot(111, projection='3d')
107 | ax.scatter(Z[:, 0], Z[:, 1], Z[:, 2], c=colors[y], cmap=plt.cm.Spectral, s=200.0)
108 | # Z, _ = F.get_n_each(Z, y, 1)
109 | # for c in np.unique(y):
110 | # ax.quiver(0.0, 0.0, 0.0, Z[c, 0], Z[c, 1], Z[c, 2], length=1.0, normalize=True, arrow_length_ratio=0.05, color='black')
111 | u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
112 | x = np.cos(u)*np.sin(v)
113 | y = np.sin(u)*np.sin(v)
114 | z = np.cos(v)
115 | ax.plot_wireframe(x, y, z, color="gray", alpha=0.5)
116 | ax.xaxis._axinfo["grid"]['color'] = (0,0,0,0.1)
117 | ax.yaxis._axinfo["grid"]['color'] = (0,0,0,0.1)
118 | ax.zaxis._axinfo["grid"]['color'] = (0,0,0,0.1)
119 | ax.set_title(title)
120 | # [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()]
121 | # [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()]
122 | # [tick.label.set_fontsize(24) for tick in ax.zaxis.get_major_ticks()]
123 | ax.view_init(20, 15)
124 | plt.tight_layout()
125 | plt.show()
126 | plt.close()
127 |
128 |
129 |
130 | def plot_loss_mcr(data):
131 | loss_total = data['loss_total']
132 | loss_expd = data['loss_expd']
133 | loss_comp = data['loss_comp']
134 | num_iter = np.arange(len(loss_total))
135 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True)
136 | ax.plot(num_iter, loss_total, label=r'$\Delta R$',
137 | color='green', linewidth=1.0, alpha=0.8)
138 | ax.plot(num_iter, loss_expd, label=r'$R$',
139 | color='royalblue', linewidth=1.0, alpha=0.8)
140 | ax.plot(num_iter, loss_comp, label=r'$R^c$',
141 | color='coral', linewidth=1.0, alpha=0.8)
142 | ax.set_ylabel('Loss', fontsize=10)
143 | ax.set_xlabel('Number of iterations', fontsize=10)
144 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5)
145 | fig.tight_layout()
146 | plt.show()
147 | plt.close()
--------------------------------------------------------------------------------
/functional.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from tqdm import tqdm
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torchvision.datasets as datasets
9 | import torch.distributions as distributions
10 | import torchvision.transforms as transforms
11 | import torchvision.transforms.functional as TF
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data._utils.collate import default_collate
14 |
15 |
16 |
17 | def get_features(net, trainloader, verbose=True, n=None, device='cpu'):
18 | '''Extract all features out into one single batch.
19 |
20 | Parameters:
21 | net (torch.nn.Module): get features using this model
22 | trainloader (torchvision.dataloader): dataloader for loading data
23 | up_to (int): pass through network up to a certain layer
24 | flatten (bool): if True, flatten output
25 | verbose (bool): shows loading staus bar
26 | Returns:
27 | features (torch.tensor): with dimension (num_samples, feature_dimension)
28 | labels (torch.tensor): with dimension (num_samples, )
29 | '''
30 | features = []
31 | labels = []
32 | if verbose:
33 | train_bar = tqdm(trainloader, desc="extracting all features from dataset")
34 | else:
35 | train_bar = trainloader
36 | total = 0
37 | with torch.no_grad():
38 | for _, (batch_imgs, batch_lbls) in enumerate(train_bar):
39 | batch_imgs = batch_imgs.to(device)
40 | batch_features = net(batch_imgs)
41 | features.append(batch_features.cpu().detach())
42 | labels.append(batch_lbls)
43 | total += len(batch_features)
44 | if n is not None and total > n:
45 | break
46 | return torch.cat(features)[:n], torch.cat(labels)[:n]
47 |
48 | def get_samples(dataset, num_samples, shuffle=False, batch_idx=0, seed=0, method='uniform'):
49 | if method == 'uniform':
50 | np.random.seed(seed)
51 | dataloader = DataLoader(dataset, batch_size=dataset.data.shape[0])
52 | X, y = next(iter(dataloader))
53 | if shuffle: # ensure you sample different samples
54 | idx_arr = np.random.choice(X.shape[0], y.shape[0], replace=False)
55 | X, y = X[idx_arr], y[idx_arr]
56 | if num_samples is not None:
57 | X, y = get_n_each(X, y, num_samples, batch_idx)
58 | X, y = X.float(), y.long()
59 | if len(X.shape) == 3:
60 | X = X.unsqueeze(1)
61 | return X.float(), y.long()
62 | elif method == 'first':
63 | num_classes = torch.unique(dataset.targets).size()[0]
64 | return next(iter(DataLoader(dataset, batch_size=num_classes*num_samples)))
65 |
66 | def normalize(X, p=2):
67 | if isinstance(X, torch.Tensor):
68 | norm = torch.linalg.norm(X.flatten(1), ord=p, axis=1)
69 | norm = norm.clip(min=1e-8)
70 | for _ in range(len(X.shape)-1):
71 | norm = norm.unsqueeze(-1)
72 | return X / norm
73 | elif isinstance(X, np.ndarray):
74 | norm = np.linalg.norm(X.reshape(X.shape[0], -1), ord=p, axis=1)
75 | norm = np.clip(norm, a_min=1e-8, a_max=None)
76 | for _ in range(len(X.shape)-1):
77 | norm = np.expand_dims(norm, -1)
78 | return X / norm
79 | else:
80 | raise TypeError('Input array not instances of torch.Tensor or np.ndarray')
81 |
82 | def get_n_each(X, y, n=None, batch_idx=0):
83 | classes = torch.unique(y)
84 | _X, _y = [], []
85 | for c in classes:
86 | idx_class = (y == c)
87 | X_class, y_class = X[idx_class], y[idx_class]
88 | if n is not None:
89 | X_class = torch.roll(X_class, -batch_idx*n, dims=0)
90 | y_class = torch.roll(y_class, -batch_idx*n, dims=0)
91 | _X.append(X_class[:n])
92 | _y.append(y_class[:n])
93 | return torch.cat(_X), torch.cat(_y)
94 |
95 | def translate(data, labels, stride=7, n=None):
96 | if len(data.shape) == 3:
97 | return translate1d(data, labels, n=n, stride=stride)
98 | if len(data.shape) == 4:
99 | return translate2d(data, labels, n=n, stride=stride)
100 | raise ValueError('translate not available.')
101 |
102 |
103 | def translate1d(data, labels, n=None, stride=1):
104 | m, _, T = data.shape
105 | data_new = []
106 | if n is None:
107 | shifts = range(0, T, stride)
108 | else:
109 | shifts = range(-n*stride, (n+1)*stride, stride)
110 | for t in shifts:
111 | data_new.append(torch.roll(data, t, dims=(2)))
112 | nrepeats = len(range(0, T, stride))
113 | return (torch.cat(data_new),
114 | labels.repeat(nrepeats))
115 |
116 | def translate2d(data, labels, n=None, stride=1):
117 | m, _, H, W = data.shape
118 | if n is None:
119 | shifts_horizontal = range(0, H, stride)
120 | shifts_vertical = range(0, H, stride)
121 | else:
122 | shifts_horizontal = range(-n*stride, (n+1)*stride, stride)
123 | shifts_vertical = range(-n*stride, (n+1)*stride, stride)
124 | data_new = []
125 | for h in shifts_horizontal:
126 | for w in shifts_vertical:
127 | data_new.append(torch.roll(data, (h, w), dims=(2, 3)))
128 | nrepeats = len(shifts_vertical) * len(shifts_horizontal)
129 | return (torch.cat(data_new),
130 | labels.repeat(nrepeats))
131 |
132 | def cart2polar(images_cart, channels, timesteps):
133 | m, C, H, W = images_cart.shape
134 | mid_pt = int(H // 2)
135 | R = torch.linspace(0, mid_pt, channels).long()
136 | thetas = torch.linspace(0, 360, timesteps).float()
137 | images_polar = []
138 | for theta in thetas:
139 | image_rotated = TF.rotate(images_cart, theta.item())
140 | images_polar.append(image_rotated[:, :, mid_pt, R])
141 | return torch.cat(images_polar, axis=1).transpose(1, 2)
142 |
143 | def step_lr(epochs, init, gamma, steps):
144 | """learning rate decay
145 | epochs: total number of epochs
146 | gamma: multiplicative decay
147 | step: decay at which steps
148 | init: initial learning rate
149 | """
150 | rates = np.ones(epochs) * init
151 | for step in steps:
152 | rates[step:] = rates[step:] * gamma
153 | return rates
154 |
155 | def corrupt_labels(trainset, num_classes, ratio, seed):
156 | """Corrupt labels in trainset.
157 |
158 | Parameters:
159 | trainset (torch.data.dataset): trainset where labels is stored
160 | ratio (float): ratio of labels to be corrupted. 0 to corrupt no labels;
161 | 1 to corrupt all labels
162 | seed (int): random seed for reproducibility
163 |
164 | Returns:
165 | trainset (torch.data.dataset): trainset with updated corrupted labels
166 |
167 | """
168 |
169 | np.random.seed(seed)
170 | train_labels = np.asarray(trainset.targets)
171 | n_train = len(train_labels)
172 | n_rand = int(len(trainset.data)*ratio)
173 | randomize_indices = np.random.choice(range(n_train), size=n_rand, replace=False)
174 | train_labels[randomize_indices] = np.random.choice(np.arange(num_classes), size=n_rand, replace=True)
175 | trainset.targets = torch.tensor(train_labels).int()
176 | return trainset
177 |
178 | def to_cpu(*gpu_vars):
179 | cpu_vars = []
180 | for var in gpu_vars:
181 | cpu_vars.append(var.detach().cpu())
182 | return cpu_vars
183 |
--------------------------------------------------------------------------------
/images/arch-redunet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ma-Lab-Berkeley/ReduNet/f67a348e5b54b9e60b783ceec6fe0e379bc3b96f/images/arch-redunet.jpg
--------------------------------------------------------------------------------
/load.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from os import listdir
4 | from os.path import join, isfile, isdir, expanduser
5 | from tqdm import tqdm
6 |
7 | import pandas as pd
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from PIL import Image
13 | from torch.utils.data import Dataset
14 | import torchvision.datasets as datasets
15 | import torchvision.transforms as transforms
16 |
17 | import functional as F
18 | from redunet import *
19 |
20 |
21 | def load_architecture(data, arch, seed=0):
22 | if data == 'mnist2d':
23 | if arch == 'lift2d_channels35_layers5':
24 | from architectures.mnist.lift2d import lift2d
25 | return lift2d(channels=35, layers=5, num_classes=10, seed=seed)
26 | if arch == 'lift2d_channels35_layers10':
27 | from architectures.mnist.lift2d import lift2d
28 | return lift2d(channels=35, layers=5, num_classes=10, seed=seed)
29 | if arch == 'lift2d_channels35_layers20':
30 | from architectures.mnist.lift2d import lift2d
31 | return lift2d(channels=35, layers=20, num_classes=10, seed=seed)
32 | if arch == 'lift2d_channels55_layers5':
33 | from architectures.mnist.lift2d import lift2d
34 | return lift2d(channels=55, layers=5, num_classes=10, seed=seed)
35 | if arch == 'lift2d_channels55_layers10':
36 | from architectures.mnist.lift2d import lift2d
37 | return lift2d(channels=55, layers=5, num_classes=10, seed=seed)
38 | if arch == 'lift2d_channels55_layers20':
39 | from architectures.mnist.lift2d import lift2d
40 | return lift2d(channels=55, layers=20, num_classes=10, seed=seed)
41 | if data == 'mnist2d+2class':
42 | if arch == 'lift2d_channels35_layers5':
43 | from architectures.mnist.lift2d import lift2d
44 | return lift2d(channels=35, layers=5, num_classes=2, seed=seed)
45 | if arch == 'lift2d_channels35_layers10':
46 | from architectures.mnist.lift2d import lift2d
47 | return lift2d(channels=35, layers=5, num_classes=2, seed=seed)
48 | if arch == 'lift2d_channels35_layers20':
49 | from architectures.mnist.lift2d import lift2d
50 | return lift2d(channels=35, layers=20, num_classes=2, seed=seed)
51 | if arch == 'lift2d_channels55_layers5':
52 | from architectures.mnist.lift2d import lift2d
53 | return lift2d(channels=55, layers=5, num_classes=2, seed=seed)
54 | if arch == 'lift2d_channels55_layers10':
55 | from architectures.mnist.lift2d import lift2d
56 | return lift2d(channels=55, layers=5, num_classes=2, seed=seed)
57 | if arch == 'lift2d_channels55_layers20':
58 | from architectures.mnist.lift2d import lift2d
59 | return lift2d(channels=55, layers=20, num_classes=2, seed=seed)
60 | if data == 'mnistvector':
61 | if arch == 'layers50':
62 | from architectures.mnist.flatten import flatten
63 | return flatten(layers=50, num_classes=10)
64 | if arch == 'layers20':
65 | from architectures.mnist.flatten import flatten
66 | return flatten(layers=20, num_classes=10)
67 | if arch == 'layers10':
68 | from architectures.mnist.flatten import flatten
69 | return flatten(layers=10, num_classes=10)
70 | if arch == 'layers5':
71 | from architectures.mnist.flatten import flatten
72 | return flatten(layers=5, num_classes=10)
73 | if data == 'mnistvector_2class':
74 | if arch == 'layers50':
75 | from architectures.mnist.flatten import flatten
76 | return flatten(layers=50, num_classes=2)
77 | if arch == 'layers20':
78 | from architectures.mnist.flatten import flatten
79 | return flatten(layers=20, num_classes=2)
80 | if arch == 'layers10':
81 | from architectures.mnist.flatten import flatten
82 | return flatten(layers=10, num_classes=2)
83 | if arch == 'layers5':
84 | from architectures.mnist.flatten import flatten
85 | return flatten(layers=5, num_classes=2)
86 | raise NameError('Cannot find architecture: {}.')
87 |
88 | def load_dataset(choice, data_dir='./data/'):
89 | if choice == 'mnist2d':
90 | from datasets.mnist import mnist2d_10class
91 | return mnist2d_10class(data_dir)
92 | if choice == 'mnist2d_2class':
93 | from datasets.mnist import mnist2d_2class
94 | return mnist2d_2class(data_dir)
95 | if choice =='mnistvector':
96 | from datasets.mnist import mnistvector_10class
97 | return mnistvector_10class(data_dir)
98 | raise NameError(f'Dataset {choice} not found.')
99 |
100 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from opt_einsum import contract
4 |
5 |
6 | class MaximalCodingRateReduction(nn.Module): #TODO: fix this
7 | def __init__(self, eps=0.1, gam=1.):
8 | super(MaximalCodingRateReduction, self).__init__()
9 | self.eps = eps
10 | self.gam = gam
11 |
12 | def discrimn_loss(self, Z):
13 | m, d = Z.shape
14 | I = torch.eye(d).to(Z.device)
15 | c = d / (m * self.eps)
16 | return logdet(c * covariance(Z) + I) / 2.
17 |
18 | def compress_loss(self, Z, Pi):
19 | loss_comp = 0.
20 | for j in y.unique():
21 | Z_j = Z[(y == int(j))[:, 0]]
22 | m_j = Z_j.shape[0]
23 | c_j = d / (m_j * eps)
24 | logdet_j = logdet(I + c_j * Z_j.T @ Z_j)
25 | loss_comp += logdet_j * m_j / (2 * m)
26 | return loss_comp
27 |
28 | def forward(self, Z, y):
29 | Pi = y # TODO: change this to prob distribution
30 | loss_discrimn = self.discrimn_loss(Z)
31 | loss_compress = self.compress_loss(Z, Pi)
32 | return loss_discrimn - self.gam * loss_compress
33 |
34 |
35 | def compute_mcr2(Z, y, eps):
36 | if len(Z.shape) == 2:
37 | loss_func = compute_loss_vec
38 | elif len(Z.shape) == 3:
39 | loss_func = compute_loss_1d
40 | elif len(Z.shape) == 4:
41 | loss_func = compute_loss_2d
42 | return loss_func(Z, y, eps)
43 |
44 | def compute_loss_vec(Z, y, eps):
45 | m, d = Z.shape
46 | I = torch.eye(d).to(Z.device)
47 | c = d / (m * eps)
48 | loss_expd = logdet(c * covariance(Z) + I) / 2.
49 | loss_comp = 0.
50 | for j in y.unique():
51 | Z_j = Z[(y == int(j))[:, 0]]
52 | m_j = Z_j.shape[0]
53 | c_j = d / (m_j * eps)
54 | logdet_j = logdet(I + c_j * Z_j.T @ Z_j)
55 | loss_comp += logdet_j * m_j / (2 * m)
56 | loss_expd, loss_comp = loss_expd.item(), loss_comp.item()
57 | return loss_expd - loss_comp, loss_expd, loss_comp
58 |
59 | def compute_loss_1d(V, y, eps):
60 | m, C, T = V.shape
61 | I = torch.eye(C).unsqueeze(-1).to(V.device)
62 | alpha = C / (m * eps)
63 | cov = alpha * covariance(V) + I
64 | loss_expd = logdet(cov.permute(2, 0, 1)).sum() / (2 * T)
65 | loss_comp = 0.
66 | for j in y.unique():
67 | V_j = V[y==int(j)]
68 | m_j = V_j.shape[0]
69 | alpha_j = C / (m_j * eps)
70 | cov_j = alpha_j * covariance(V_j) + I
71 | loss_comp += m_j / m * logdet(cov_j.permute(2, 0, 1)).sum() / (2 * T)
72 | loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item()
73 | return loss_expd - loss_comp, loss_expd, loss_comp
74 |
75 | def compute_loss_2d(V, y, eps):
76 | m, C, H, W = V.shape
77 | I = torch.eye(C).unsqueeze(-1).unsqueeze(-1).to(V.device)
78 | alpha = C / (m * eps)
79 | cov = alpha * covariance(V) + I
80 | loss_expd = logdet(cov.permute(2, 3, 0, 1)).sum() / (2 * H * W)
81 | loss_comp = 0.
82 | for j in y.unique():
83 | V_j = V[(y==int(j))[:, 0]]
84 | m_j = V_j.shape[0]
85 | alpha_j = C / (m_j * eps)
86 | cov_j = alpha_j * covariance(V_j) + I
87 | loss_comp += m_j / m * logdet(cov_j.permute(2, 3, 0, 1)).sum() / (2 * H * W)
88 | loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item()
89 | return loss_expd - loss_comp, loss_expd, loss_comp
90 |
91 | def covariance(X):
92 | return contract('ji...,jk...->ik...', X, X.conj())
93 |
94 | def logdet(X):
95 | sgn, logdet = torch.linalg.slogdet(X)
96 | return sgn * logdet
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from torch.utils.data import DataLoader
8 |
9 | import functional as F
10 | import utils
11 |
12 | def plot_loss_mcr(model_dir, name):
13 | file_dir = os.path.join(model_dir, 'loss', f'{name}.csv')
14 | data = pd.read_csv(file_dir)
15 | loss_total = data['loss_total'].ravel()
16 | loss_expd = data['loss_expd'].ravel()
17 | loss_comp = data['loss_comp'].ravel()
18 | num_iter = np.arange(len(loss_total))
19 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True)
20 | ax.plot(num_iter, loss_total, label=r'$\Delta R$',
21 | color='green', linewidth=1.0, alpha=0.8)
22 | ax.plot(num_iter, loss_expd, label=r'$R$',
23 | color='royalblue', linewidth=1.0, alpha=0.8)
24 | ax.plot(num_iter, loss_comp, label=r'$R^c$',
25 | color='coral', linewidth=1.0, alpha=0.8)
26 | ax.set_ylabel('Loss', fontsize=10)
27 | ax.set_xlabel('Number of iterations', fontsize=10)
28 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5)
29 | fig.tight_layout()
30 |
31 | loss_dir = os.path.join(model_dir, 'figures', 'loss_mcr')
32 | os.makedirs(loss_dir, exist_ok=True)
33 | file_name = os.path.join(loss_dir, f'{name}.png')
34 | plt.savefig(file_name, dpi=400)
35 | plt.close()
36 | print("Plot saved to: {}".format(file_name))
37 |
38 | def plot_loss(model_dir):
39 | """Plot cross entropy loss. """
40 | ## extract loss from csv
41 | file_dir = os.path.join(model_dir, 'losses.csv')
42 | data = pd.read_csv(file_dir)
43 | epochs = data['epoch'].ravel()
44 | loss = data['loss'].ravel()
45 |
46 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True, dpi=400)
47 | ax.plot(epochs, loss, #label=r'Loss',
48 | color='green', linewidth=1.0, alpha=0.8)
49 | ax.set_ylabel('Loss', fontsize=10)
50 | ax.set_xlabel('Number of iterations', fontsize=10)
51 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5)
52 | ax.set_title("Loss")
53 | ax.spines['top'].set_visible(False)
54 | ax.spines['right'].set_visible(False)
55 | plt.tight_layout()
56 |
57 | ## create saving directory
58 | loss_dir = os.path.join(model_dir, 'figures', 'loss')
59 | os.makedirs(loss_dir, exist_ok=True)
60 | file_name = os.path.join(loss_dir, 'loss.png')
61 | plt.savefig(file_name, dpi=400)
62 | print("Plot saved to: {}".format(file_name))
63 | file_name = os.path.join(loss_dir, 'loss.pdf')
64 | plt.savefig(file_name, dpi=400)
65 | plt.close()
66 | print("Plot saved to: {}".format(file_name))
67 |
68 | def plot_csv(model_dir, filename):
69 | df = pd.read_csv(os.path.join(model_dir, f'{filename}.csv'))
70 | colnames = df.columns
71 |
72 | fig, ax = plt.subplots(1, 1, figsize=(7, 5))
73 | for colname in colnames[1:]:
74 | ax.plot(df[colnames[0]], df[colname], marker='x', label=colname)
75 | ax.set_xlabel(colnames[0])
76 | ax.set_ylabel(filename)
77 | ax.legend()
78 |
79 | csv_dir = os.path.join(model_dir, 'figures', 'csv')
80 | os.makedirs(csv_dir, exist_ok=True)
81 | savepath = os.path.join(csv_dir, f'{filename}.png')
82 | fig.savefig(savepath)
83 | print('Plot saved to: {}'.format(savepath))
84 |
85 | def plot_loss_ce(model_dir, filename='loss_ce'):
86 | df = pd.read_csv(os.path.join(model_dir, f'{filename}.csv'))
87 | colnames = df.columns
88 |
89 | fig, ax = plt.subplots(1, 1, figsize=(7, 5))
90 | for colname in colnames[1:]:
91 | ax.plot(np.arange(df.shape[0]), df['loss_ce'], label=colname)
92 | ax.set_xlabel(colnames[0])
93 | ax.set_ylabel(filename)
94 | ax.legend()
95 |
96 | csv_dir = os.path.join(model_dir, 'figures', 'loss_ce')
97 | os.makedirs(csv_dir, exist_ok=True)
98 | savepath = os.path.join(csv_dir, f'loss_ce.png')
99 | fig.savefig(savepath)
100 | print('Plot saved to: {}'.format(savepath))
101 |
102 |
103 | def plot_acc(model_dir):
104 | """Plot training and testing accuracy"""
105 | ## extract loss from csv
106 | file_dir = os.path.join(model_dir, 'acc.csv')
107 | data = pd.read_csv(file_dir)
108 | epochs = data['epoch'].ravel()
109 | acc_train = data['acc_train'].ravel()
110 | acc_test = data['acc_test'].ravel()
111 | # epoch,acc_train,acc_test
112 |
113 | ## Theoretical Loss
114 | fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True, dpi=400)
115 | ax.plot(epochs, acc_train, label='train', color='green', alpha=0.8)
116 | ax.plot(epochs, acc_test, label='test', color='red', alpha=0.8)
117 | ax.set_ylabel('Accuracy', fontsize=10)
118 | ax.set_xlabel('Epoch', fontsize=10)
119 | ax.legend(loc='lower right', prop={"size": 15}, ncol=3, framealpha=0.5)
120 | ax.spines['top'].set_visible(False)
121 | ax.spines['right'].set_visible(False)
122 | plt.tight_layout()
123 |
124 | ## create saving directory
125 | acc_dir = os.path.join(model_dir, 'figures', 'acc')
126 | os.makedirs(acc_dir, exist_ok=True)
127 | file_name = os.path.join(acc_dir, 'accuracy.png')
128 | plt.savefig(file_name, dpi=400)
129 | print("Plot saved to: {}".format(file_name))
130 | file_name = os.path.join(acc_dir, 'accuracy.pdf')
131 | plt.savefig(file_name, dpi=400)
132 | plt.close()
133 | print("Plot saved to: {}".format(file_name))
134 |
135 | def plot_heatmap(model_dir, name, features, labels, num_classes):
136 | """Plot heatmap of cosine simliarity for all features. """
137 | features_sort, _ = utils.sort_dataset(features, labels,
138 | classes=num_classes, stack=False)
139 | features_sort_ = np.vstack(features_sort)
140 | sim_mat = np.abs(features_sort_ @ features_sort_.T)
141 |
142 | # plt.rc('text', usetex=False)
143 | # plt.rcParams['font.family'] = 'serif'
144 | # plt.rcParams['font.serif'] = ['Times New Roman'] #+ plt.rcParams['font.serif']
145 |
146 | fig, ax = plt.subplots(figsize=(7, 5), sharey=True, sharex=True)
147 | im = ax.imshow(sim_mat, cmap='Blues')
148 | fig.colorbar(im, pad=0.02, drawedges=0, ticks=[0, 0.5, 1])
149 | ax.set_xticks(np.linspace(0, len(labels), num_classes+1))
150 | ax.set_yticks(np.linspace(0, len(labels), num_classes+1))
151 | [tick.label.set_fontsize(10) for tick in ax.xaxis.get_major_ticks()]
152 | [tick.label.set_fontsize(10) for tick in ax.yaxis.get_major_ticks()]
153 | fig.tight_layout()
154 |
155 | save_dir = os.path.join(model_dir, 'figures', 'heatmaps')
156 | os.makedirs(save_dir, exist_ok=True)
157 | file_name = os.path.join(save_dir, f"{name}.png")
158 | fig.savefig(file_name)
159 | print("Plot saved to: {}".format(file_name))
160 | plt.close()
161 |
162 | def plot_transform(model_dir, inputs, outputs, name):
163 | fig, ax = plt.subplots(ncols=2)
164 | inputs = inputs.permute(1, 2, 0)
165 | outputs = outputs.permute(1, 2, 0)
166 | outputs = (outputs - outputs.min()) / (outputs.max() - outputs.min())
167 | ax[0].imshow(inputs)
168 | ax[0].set_title('inputs')
169 | ax[1].imshow(outputs)
170 | ax[1].set_title('outputs')
171 | save_dir = os.path.join(model_dir, 'figures', 'images')
172 | os.makedirs(save_dir, exist_ok=True)
173 | file_name = os.path.join(save_dir, f'{name}.png')
174 | fig.savefig(file_name)
175 | print("Plot saved to: {}".format(file_name))
176 | plt.close()
177 |
178 | def plot_channel_image(model_dir, features, name):
179 | def normalize(x):
180 | out = x - x.min()
181 | out = out / (out.max() - out.min())
182 | return out
183 | fig, ax = plt.subplots()
184 | ax.imshow(normalize(features), cmap='gray')
185 | save_dir = os.path.join(model_dir, 'figures', 'images')
186 | os.makedirs(save_dir, exist_ok=True)
187 | file_name = os.path.join(save_dir, f'{name}.png')
188 | fig.savefig(file_name)
189 | print("Plot saved to: {}".format(file_name))
190 | plt.close()
191 |
192 |
193 | def plot_nearest_image(model_dir, image, nearest_images, values, name, grid_size=(4, 4)):
194 | fig, ax = plt.subplots(*grid_size, figsize=(10, 10))
195 | idx = 1
196 | for i in range(grid_size[0]):
197 | for j in range(grid_size[1]):
198 | if i == 0 and j == 0:
199 | ax[i, j].imshow(image)
200 | else:
201 | ax[i, j].set_title(values[idx-1])
202 | ax[i, j].imshow(nearest_images[idx-1])
203 | idx += 1
204 | ax[i, j].set_xticks([])
205 | ax[i, j].set_yticks([])
206 | plt.setp(ax[0, 0].spines.values(), color='red', linewidth=2)
207 | fig.tight_layout()
208 |
209 | save_dir = os.path.join(model_dir, 'figures', 'nearest_image')
210 | os.makedirs(save_dir, exist_ok=True)
211 | save_path = os.path.join(save_dir, f'{name}.png')
212 | fig.savefig(save_path)
213 | print(f"Plot saved to: {save_path}")
214 | plt.close()
215 |
216 |
217 | def plot_image(model_dir, image, name):
218 | fig, ax = plt.subplots(1, 1, figsize=(10, 10))
219 | if image.shape[2] == 1:
220 | ax.imshow(image, cmap='gray')
221 | else:
222 | ax.imshow(image)
223 | ax.set_xticks([])
224 | ax.set_yticks([])
225 | fig.tight_layout()
226 |
227 | save_dir = os.path.join(model_dir, 'figures', 'image')
228 | os.makedirs(save_dir, exist_ok=True)
229 | save_path = os.path.join(save_dir, f'{name}.png')
230 | fig.savefig(save_path)
231 | print(f"Plot saved to: {save_path}")
232 | plt.close()
233 |
234 | def save_image(image, save_path):
235 | fig, ax = plt.subplots(1, 1, figsize=(10, 10))
236 | if image.shape[2] == 1:
237 | ax.imshow(image, cmap='gray')
238 | else:
239 | ax.imshow(image)
240 | ax.set_xticks([])
241 | ax.set_yticks([])
242 | fig.tight_layout()
243 | fig.savefig(save_path)
244 | print(f"Plot saved to: {save_path}")
245 | plt.close()
--------------------------------------------------------------------------------
/redunet/__init__.py:
--------------------------------------------------------------------------------
1 | # Layers
2 | from .layers.redulayer import ReduLayer
3 | from .layers.vector import Vector
4 | from .layers.fourier1d import Fourier1D
5 | from .layers.fourier2d import Fourier2D
6 |
7 | # Modules
8 | from .modules import (
9 | ReduNetVector,
10 | ReduNet1D,
11 | ReduNet2D
12 | )
13 |
14 | # Projections
15 | from .projections.lift import Lift1D
16 | from .projections.lift import Lift2D
17 |
18 |
19 | # Others
20 | from .redunet import ReduNet
21 | from .multichannel_weight import MultichannelWeight
22 |
23 |
24 | __all__ = [
25 | 'Fourier1D',
26 | 'Fourier2D',
27 | 'Lift1D',
28 | 'Lift2D',
29 | 'MultichannelWeight',
30 | 'ReduNet',
31 | 'ReduLayer',
32 | 'ReduNetVector',
33 | 'ReduNet1D',
34 | 'ReduNet2D',
35 | 'Vector'
36 | ]
--------------------------------------------------------------------------------
/redunet/layers/fourier1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as tF
4 | from torch.fft import fft, ifft
5 | from opt_einsum import contract
6 | import numpy as np
7 |
8 | import functional as F
9 | from ..multichannel_weight import MultichannelWeight
10 | from .vector import Vector
11 |
12 |
13 |
14 | class Fourier1D(Vector):
15 | def __init__(self, eta, eps, lmbda, num_classes, dimensions):
16 | super(Fourier1D, self).__init__(eta, eps, lmbda, num_classes)
17 | assert len(dimensions) == 2, 'dimensions should have tensor dim = 2'
18 | self.channels, self.timesteps = dimensions
19 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes, requires_grad=False)
20 | self.E = MultichannelWeight(self.channels, self.timesteps, dtype=torch.complex64)
21 | self.Cs = nn.ModuleList([MultichannelWeight(self.channels, self.timesteps, dtype=torch.complex64)
22 | for _ in range(num_classes)])
23 |
24 | def nonlinear(self, Bz):
25 | norm = torch.linalg.norm(Bz.reshape(Bz.shape[0], Bz.shape[1], -1), axis=2)
26 | norm = torch.clamp(norm, min=1e-8)
27 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2)
28 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case
29 | gam = self.gam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
30 | pred = pred.unsqueeze(-1)
31 | out = torch.sum(gam * Bz * pred, axis=0)
32 | return out, y
33 |
34 | def compute_E(self, V):
35 | m, C, T = V.shape
36 | alpha = C / (m * self.eps)
37 | I = torch.eye(C, device=V.device).unsqueeze(-1)
38 | pre_inv = I + alpha * contract('ji...,jk...->ik...', V, V.conj())
39 | E = torch.empty_like(pre_inv)
40 | for t in range(T):
41 | E[:, :, t] = alpha * torch.inverse(pre_inv[:, :, t])
42 | return E
43 |
44 | def compute_Cs(self, V, y):
45 | m, C, T = V.shape
46 | I = torch.eye(C, device=V.device).unsqueeze(-1)
47 | Cs = torch.empty((self.num_classes, C, C, T), dtype=torch.complex64)
48 | for j in range(self.num_classes):
49 | V_j = V[(y == int(j))]
50 | m_j = V_j.shape[0]
51 | alpha_j = C / (m_j * self.eps)
52 | pre_inv = I + alpha_j * contract('ji...,jk...->ik...', V_j, V_j.conj())
53 | for t in range(T):
54 | Cs[j, :, :, t] = alpha_j * torch.inverse(pre_inv[:, :, t])
55 | return Cs
56 |
57 | def preprocess(self, X):
58 | Z = F.normalize(X)
59 | return fft(Z, norm='ortho', dim=2)
60 |
61 | def postprocess(self, X):
62 | Z = ifft(X, norm='ortho', dim=2)
63 | return F.normalize(Z).real
64 |
--------------------------------------------------------------------------------
/redunet/layers/fourier2d.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as tF
5 | from itertools import product
6 | from torch.fft import fft2, ifft2
7 | import numpy as np
8 | from opt_einsum import contract
9 |
10 | import functional as F
11 | from ..multichannel_weight import MultichannelWeight
12 | from .vector import Vector
13 |
14 |
15 |
16 | class Fourier2D(Vector):
17 | def __init__(self, eta, eps, lmbda, num_classes, dimensions):
18 | super(Fourier2D, self).__init__(eta, eps, lmbda, num_classes)
19 | assert len(dimensions) == 3, 'dimensions should have tensor dim = 3'
20 | self.channels, self.height, self.width = dimensions
21 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes, requires_grad=False)
22 | self.E = MultichannelWeight(*dimensions, dtype=torch.complex64)
23 | self.Cs = nn.ModuleList([MultichannelWeight(*dimensions, dtype=torch.complex64)
24 | for _ in range(num_classes)])
25 |
26 | def nonlinear(self, Cz):
27 | norm = torch.linalg.norm(Cz.flatten(start_dim=2), axis=2).clamp(min=1e-8)
28 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2)
29 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case
30 | gam = self.gam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
31 | pred = pred.unsqueeze(-1).unsqueeze(-1)
32 | out = torch.sum(gam * Cz * pred, axis=0)
33 | return out, y
34 |
35 | def compute_E(self, V):
36 | m, C, H, W = V.shape
37 | alpha = C / (m * self.eps)
38 | I = torch.eye(C, device=V.device).unsqueeze(-1).unsqueeze(-1)
39 | pre_inv = I + alpha * contract('ji...,jk...->ik...', V, V.conj())
40 | E = torch.empty_like(pre_inv, dtype=torch.complex64)
41 | for h, w in product(range(H), range(W)):
42 | E[:, :, h, w] = alpha * torch.inverse(pre_inv[:, :, h, w])
43 | return E
44 |
45 | def compute_Cs(self, V, y):
46 | m, C, H, W = V.shape
47 | I = torch.eye(C, device=V.device).unsqueeze(-1).unsqueeze(-1)
48 | Cs = torch.empty((self.num_classes, C, C, H, W), dtype=torch.complex64)
49 | for j in range(self.num_classes):
50 | V_j = V[(y == int(j))]
51 | m_j = V_j.shape[0]
52 | alpha_j = C / (m_j * self.eps)
53 | pre_inv = I + alpha_j * contract('ji...,jk...->ik...', V_j, V_j.conj())
54 | for h, w in product(range(H), range(W)):
55 | Cs[j, :, :, h, w] = alpha_j * torch.inverse(pre_inv[:, :, h, w])
56 | return Cs
57 |
58 | def preprocess(self, X):
59 | Z = F.normalize(X)
60 | return fft2(Z, norm='ortho', dim=(2, 3))
61 |
62 | def postprocess(self, X):
63 | Z = ifft2(X, norm='ortho', dim=(2, 3))
64 | return F.normalize(Z).real
65 |
--------------------------------------------------------------------------------
/redunet/layers/redulayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class ReduLayer(nn.Module):
7 | def __init__(self):
8 | super(ReduLayer, self).__init__()
9 |
10 | def __name__(self):
11 | return "ReduNet"
12 |
13 | def forward(self, Z):
14 | raise NotImplementedError
15 |
16 | def zero(self):
17 | state_dict = self.state_dict()
18 | state_dict['E.weight'] = torch.zeros_like(self.E.weight)
19 | for j in range(self.num_classes):
20 | state_dict[f'Cs.{j}.weight'] = torch.zeros_like(self.Cs[j].weight)
21 | self.load_state_dict(state_dict)
22 |
23 | def init(self, X, y):
24 | gam = self.compute_gam(X, y)
25 | E = self.compute_E(X)
26 | Cs = self.compute_Cs(X, y)
27 | self.set_params(E, Cs, gam)
28 |
29 | def update_old(self, X, y, tau):
30 | E = self.compute_E(X).to(X.device)
31 | Cs = self.compute_Cs(X, y).to(X.device)
32 | state_dict = self.state_dict()
33 | ref_E = self.E.weight
34 | ref_Cs = [self.Cs[j].weight for j in range(self.num_classes)]
35 | new_E = ref_E + tau * (E - ref_E)
36 | new_Cs = [ref_Cs[j] + tau * (Cs[j] - ref_Cs[j]) for j in range(self.num_classes)]
37 | state_dict['E.weight'] = new_E
38 | for j in range(self.num_classes):
39 | state_dict[f'Cs.{j}.weight'] = new_Cs[j]
40 | self.load_state_dict(state_dict)
41 |
42 | def update(self, X, y, tau):
43 | E_ref, Cs_ref = self.get_params()
44 | # gam = self.init_gam(X, y)
45 | E_new = self.compute_E(X).to(X.device)
46 | Cs_new = self.compute_Cs(X, y).to(X.device)
47 | E_update = E_ref + tau * (E_new - E_ref)
48 | Cs_update = [Cs_ref[j] + tau * (Cs_new[j] - Cs_ref[j]) for j in range(self.num_classes)]
49 | self.set_params(E_update, Cs_update)
50 |
51 | def set_params(self, E, Cs, gam=None):
52 | state_dict = self.state_dict()
53 | assert self.E.weight.shape == E.shape, f'E shape does not match: {self.E.weight.shape} and {E.shape}'
54 | state_dict['E.weight'] = E
55 | for j in range(self.num_classes):
56 | assert self.Cs[j].weight.shape == Cs[j].shape, f'Cj shape does not match'
57 | state_dict[f'Cs.{j}.weight'] = Cs[j]
58 | if gam is not None:
59 | assert self.gam.shape == gam.shape, 'gam shape does not match'
60 | state_dict['gam'] = gam
61 | self.load_state_dict(state_dict)
62 |
63 | def get_params(self):
64 | E = self.E.weight
65 | Cs = [self.Cs[j].weight for j in range(self.num_classes)]
66 | return E, Cs
67 |
68 |
--------------------------------------------------------------------------------
/redunet/layers/vector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as tF
4 |
5 | import functional as F
6 | from .redulayer import ReduLayer
7 |
8 |
9 | class Vector(ReduLayer):
10 | def __init__(self, eta, eps, lmbda, num_classes, dimensions=None):
11 | super(Vector, self).__init__()
12 | self.eta = eta
13 | self.eps = eps
14 | self.lmbda = lmbda
15 | self.num_classes = num_classes
16 | self.d = dimensions
17 |
18 | if self.d is not None: #NOTE: initilaized in child objectss
19 | self.gam = nn.Parameter(torch.ones(num_classes) / num_classes)
20 | self.E = nn.Linear(self.d, self.d, bias=False)
21 | self.Cs = nn.ModuleList([nn.Linear(self.d, self.d, bias=False) for _ in range(num_classes)])
22 |
23 | def forward(self, Z, return_y=False):
24 | expd = self.E(Z)
25 | comp = torch.stack([C(Z) for C in self.Cs])
26 | clus, y_approx = self.nonlinear(comp)
27 | Z = Z + self.eta * (expd - clus)
28 | Z = F.normalize(Z)
29 | if return_y:
30 | return Z, y_approx
31 | return Z
32 |
33 | def nonlinear(self, Cz):
34 | norm = torch.linalg.norm(Cz.reshape(Cz.shape[0], Cz.shape[1], -1), axis=2)
35 | norm = torch.clamp(norm, min=1e-8)
36 | pred = tF.softmax(-self.lmbda * norm, dim=0).unsqueeze(2)
37 | y = torch.argmax(pred, axis=0) #TODO: for non argmax case
38 | gam = self.gam.unsqueeze(-1).unsqueeze(-1)
39 | out = torch.sum(gam * Cz * pred, axis=0)
40 | return out, y
41 |
42 | def compute_gam(self, X, y):
43 | m = X.shape[0]
44 | m_j = [torch.nonzero(y==j).size()[0] for j in range(self.num_classes)]
45 | gam = (torch.tensor(m_j).float() / m).flatten()
46 | return gam
47 |
48 | def compute_E(self, X):
49 | m, d = X.shape
50 | Z = X.T
51 | I = torch.eye(d, device=X.device)
52 | c = d / (m * self.eps)
53 | E = c * torch.inverse(I + c * Z @ Z.T)
54 | return E
55 |
56 | def compute_Cs(self, X, y):
57 | m, d = X.shape
58 | Z = X.T
59 | I = torch.eye(d, device=X.device)
60 | Cs = torch.zeros((self.num_classes, d, d))
61 | for j in range(self.num_classes):
62 | idx = (y == int(j))
63 | Z_j = Z[:, idx]
64 | m_j = Z_j.shape[1]
65 | c_j = d / (m_j * self.eps)
66 | Cs[j] = c_j * torch.inverse(I + c_j * Z_j @ Z_j.T)
67 | return Cs
68 |
69 | def preprocess(self, X):
70 | return F.normalize(X)
71 |
72 | def postprocess(self, X):
73 | return F.normalize(X)
74 |
75 |
--------------------------------------------------------------------------------
/redunet/modules.py:
--------------------------------------------------------------------------------
1 | from .layers.vector import Vector
2 | from .layers.fourier1d import Fourier1D
3 | from .layers.fourier2d import Fourier2D
4 | from .redunet import ReduNet
5 |
6 |
7 |
8 |
9 | def ReduNetVector(num_classes, num_layers, d, eta, eps, lmbda):
10 | redunet = ReduNet(
11 | *[Vector(eta, eps, lmbda, num_classes, d) for _ in range(num_layers)]
12 | )
13 | return redunet
14 |
15 | def ReduNet1D(num_classes, num_layers, channels, timesteps, eta, eps, lmbda):
16 | redunet = ReduNet(
17 | *[Fourier1D(eta, eps, lmbda, num_classes, (channels, timesteps)) for _ in range(num_layers)]
18 | )
19 | return redunet
20 |
21 | def ReduNet2D(num_classes, num_layers, channels, height, width, eta, eps, lmbda):
22 | redunet = ReduNet(
23 | *[Fourier2D(eta, eps, lmbda, num_classes, (channels, height, width)) for _ in range(num_layers)]
24 | )
25 | return redunet
--------------------------------------------------------------------------------
/redunet/multichannel_weight.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | from opt_einsum import contract
5 |
6 |
7 |
8 | # A Weight matrix with multichannel capabilities. Inputs that multiply this weight
9 | # should have channel dimension at the end. There is no limit to the number of channels
10 | class MultichannelWeight(nn.Module):
11 | def __init__(self, channels, *dimension, dtype=torch.complex64):
12 | super(MultichannelWeight, self).__init__()
13 | self.weight = nn.Parameter(torch.randn(channels, channels, *dimension, dtype=dtype))
14 | self.shape = self.weight.shape
15 | self.dtype = dtype
16 |
17 | def __getitem__(self, item):
18 | return self.weight[item]
19 |
20 | def forward(self, V):
21 | return contract("bi...,ih...->bh...", V.type(self.dtype), self.weight.conj())
22 |
--------------------------------------------------------------------------------
/redunet/projections/lift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.distributions as distributions
5 |
6 |
7 |
8 | class Lift(nn.Module):
9 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0):
10 | super(Lift, self).__init__()
11 | self.in_channel = in_channel
12 | self.out_channel = out_channel
13 | self.kernel_size = kernel_size
14 | self.init_mode = init_mode
15 | self.stride = stride
16 | self.trainable = trainable
17 | self.relu = relu
18 | self.seed = seed
19 |
20 | def set_weight(self, init_mode, size, trainable):
21 | torch.manual_seed(self.seed)
22 | if init_mode == 'gaussian0.1':
23 | p = distributions.normal.Normal(loc=0, scale=0.1)
24 | elif init_mode == 'gaussian1.0':
25 | p = distributions.normal.Normal(loc=0, scale=1.)
26 | elif init_mode == 'gaussian5.0':
27 | p = distributions.normal.Normal(loc=0, scale=5.)
28 | elif init_mode == 'uniform0.1':
29 | p = distributions.uniform.Uniform(-0.1, 0.1)
30 | elif init_mode == 'uniform0.5':
31 | p = distributions.uniform.Uniform(-0.5, 0.5)
32 | else:
33 | raise NameError(f'No such kernel: {init_mode}')
34 | kernel = p.sample(size).type(torch.float)
35 | self.kernel = nn.Parameter(kernel, requires_grad=trainable)
36 |
37 |
38 | class Lift1D(Lift):
39 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0):
40 | super(Lift1D, self).__init__(in_channel, out_channel, kernel_size, init_mode, stride, trainable, relu, seed)
41 | self.size = (out_channel, in_channel, kernel_size)
42 | self.set_weight(init_mode, self.size, trainable)
43 |
44 | def forward(self, Z):
45 | Z = F.pad(Z, (0, self.kernel_size-1), 'circular')
46 | out = F.conv1d(Z, self.kernel, stride=self.stride)
47 | if self.relu:
48 | return F.relu(out)
49 | return out
50 |
51 |
52 | class Lift2D(Lift):
53 | def __init__(self, in_channel, out_channel, kernel_size, init_mode='gaussian1.0', stride=1, trainable=False, relu=True, seed=0):
54 | super(Lift2D, self).__init__(in_channel, out_channel, kernel_size, init_mode, stride, trainable, relu, seed)
55 | self.size = (out_channel, in_channel, kernel_size, kernel_size)
56 | self.set_weight(init_mode, self.size, trainable)
57 |
58 | def forward(self, Z):
59 | kernel = self.kernel.to(Z.device)
60 | Z = F.pad(Z, (0, self.kernel_size-1, 0, self.kernel_size-1), 'circular')
61 | out = F.conv2d(Z, kernel, stride=self.stride)
62 | if self.relu:
63 | return F.relu(out)
64 | return out
65 |
--------------------------------------------------------------------------------
/redunet/redunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from loss import compute_mcr2
4 | from .layers.redulayer import ReduLayer
5 |
6 |
7 |
8 | class ReduNet(nn.Sequential):
9 | # ReduNet Architecture. This class inherited from nn.Seqeuential class,
10 | # hence can be used for stacking layers of torch.nn.Modules.
11 |
12 | def __init__(self, *modules):
13 | super(ReduNet, self).__init__(*modules)
14 | self._init_loss()
15 |
16 | def init(self, inputs, labels):
17 | # Initialize the network. Using inputs and labels, it constructs
18 | # the parameters E and Cs throughout each ReduLayer.
19 | with torch.no_grad():
20 | return self.forward(inputs,
21 | labels,
22 | init=True,
23 | loss=True)
24 |
25 | def update(self, inputs, labels, tau=0.1):
26 | # Update the network parameters E and Cs by
27 | # performing a moving average.
28 | with torch.no_grad():
29 | return self.forward(inputs,
30 | labels,
31 | tau=tau,
32 | update=True,
33 | loss=True)
34 |
35 | def zero(self):
36 | # Set every network parameters E and Cs to a zero matrix.
37 | with torch.no_grad():
38 | for module in self:
39 | if isinstance(module, ReduLayer):
40 | module.zero()
41 | return self
42 |
43 | def batch_forward(self, inputs, batch_size=1000, loss=False, cuda=True, device=None):
44 | # Perform forward pass in batches.
45 | outputs = []
46 | for i in range(0, inputs.shape[0], batch_size):
47 | print('batch:', i, end='\r')
48 | batch_inputs = inputs[i:i+batch_size]
49 | if device is not None:
50 | batch_inputs = batch_inputs.to(device)
51 | elif cuda:
52 | batch_inputs = batch_inputs.cuda()
53 | batch_outputs = self.forward(batch_inputs, loss=loss)
54 | outputs.append(batch_outputs.cpu())
55 | return torch.cat(outputs)
56 |
57 | def forward(self,
58 | inputs,
59 | labels=None,
60 | tau=0.1,
61 | init=False,
62 | update=False,
63 | loss=False):
64 |
65 | self._init_loss()
66 | self._inReduBlock = False
67 |
68 | for layer_i, module in enumerate(self):
69 | # preprocess for redunet layers
70 | if self._isEnterReduBlock(layer_i, module):
71 | inputs = module.preprocess(inputs)
72 | self._inReduBlock = True
73 |
74 | # If init is set to True, then initialize
75 | # layer using inputs and labels
76 | if init and self._isReduLayer(module):
77 | module.init(inputs, labels)
78 |
79 | # If update is set to True, then initialize
80 | # layer using inputs and labels
81 | if update and self._isReduLayer(module):
82 | module.update(inputs, labels, tau)
83 |
84 | # Perform a forward pass
85 | if self._isReduLayer(module):
86 | inputs, preds = module(inputs, return_y=True)
87 | else:
88 | inputs = module(inputs)
89 |
90 | # compute loss for redunet layer
91 | if loss and isinstance(module, ReduLayer):
92 | losses = compute_mcr2(inputs, preds, module.eps)
93 | self._append_loss(layer_i, *losses)
94 |
95 | # postprocess for redunet layers
96 | if self._isExitReduBlock(layer_i, module):
97 | inputs = module.postprocess(inputs)
98 | self._inReduBlock = False
99 | return inputs
100 |
101 |
102 | def get_loss(self):
103 | return self.losses
104 |
105 | def _init_loss(self):
106 | self.losses = {'layer': [], 'loss_total':[], 'loss_expd': [], 'loss_comp': []}
107 |
108 | def _append_loss(self, layer_i, loss_total, loss_expd, loss_comp):
109 | self.losses['layer'].append(layer_i)
110 | self.losses['loss_total'].append(loss_total)
111 | self.losses['loss_expd'].append(loss_expd)
112 | self.losses['loss_comp'].append(loss_comp)
113 | print(f"{layer_i} | {loss_total:.6f} {loss_expd:.6f} {loss_comp:.6f}")
114 |
115 | def _isReduLayer(self, module):
116 | return isinstance(module, ReduLayer)
117 |
118 | def _isEnterReduBlock(self, _, module):
119 | # my first encounter of ReduLayer
120 | if not self._inReduBlock and self._isReduLayer(module):
121 | return True
122 | return False
123 |
124 | def _isExitReduBlock(self, layer_i, _):
125 | # I am in ReduBlock and I am the last layer of the network
126 | if len(self) - 1 == layer_i and self._inReduBlock: \
127 | return True
128 | # I am in ReduBlock and I am the last ReduLayer
129 | if self._inReduBlock and not self._isReduLayer(self[layer_i+1]):
130 | return True
131 | return False
132 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: osx-64
4 | appnope=0.1.0=py37_0
5 | backcall=0.2.0=py_0
6 | blas=1.0=mkl
7 | brotlipy=0.7.0=py37h9ed2024_1003
8 | bzip2=1.0.8=h1de35cc_0
9 | ca-certificates=2021.1.19=hecd8cb5_1
10 | certifi=2020.12.5=py37hecd8cb5_0
11 | cffi=1.14.4=py37h2125817_0
12 | chardet=4.0.0=py37hecd8cb5_1003
13 | cryptography=3.3.1=py37hbcfaee0_0
14 | cycler=0.10.0=py37_0
15 | decorator=4.4.2=py_0
16 | ffmpeg=4.2.2=h97e5cf8_0
17 | freetype=2.10.4=ha233b18_0
18 | gettext=0.19.8.1=hb0f4f8b_2
19 | gmp=6.1.2=hb37e062_1
20 | gnutls=3.6.5=h91ad68e_1002
21 | idna=2.10=pyhd3eb1b0_0
22 | intel-openmp=2019.4=233
23 | ipykernel=5.3.4=py37h5ca1d4c_0
24 | ipython=7.18.1=py37h5ca1d4c_0
25 | ipython_genutils=0.2.0=py37_0
26 | jedi=0.17.2=py37_0
27 | joblib=0.17.0=py_0
28 | jpeg=9b=he5867d9_2
29 | jupyter_client=6.1.7=py_0
30 | jupyter_core=4.6.3=py37_0
31 | kiwisolver=1.2.0=py37h04f5b5a_0
32 | lame=3.100=h1de35cc_0
33 | lcms2=2.11=h92f6f08_0
34 | libcxx=10.0.0=1
35 | libedit=3.1.20191231=h1de35cc_1
36 | libffi=3.3=hb1e8313_2
37 | libgfortran=3.0.1=h93005f0_2
38 | libiconv=1.16=h1de35cc_0
39 | libopus=1.3.1=h1de35cc_0
40 | libpng=1.6.37=ha441bb4_0
41 | libsodium=1.0.18=h1de35cc_0
42 | libtiff=4.1.0=hcb84e12_1
43 | libuv=1.40.0=haf1e3a3_0
44 | libvpx=1.7.0=h378b8a2_0
45 | llvm-openmp=10.0.0=h28b9765_0
46 | lz4-c=1.9.2=h79c402e_3
47 | matplotlib=3.3.2=0
48 | matplotlib-base=3.3.2=py37h181983e_0
49 | mkl=2019.4=233
50 | mkl-service=2.3.0=py37hfbe908c_0
51 | mkl_fft=1.2.0=py37hc64f4ea_0
52 | mkl_random=1.1.1=py37h959d312_0
53 | ncurses=6.2=h0a44026_1
54 | nettle=3.4.1=h3018a27_0
55 | ninja=1.10.1=py37h879752b_0
56 | numpy=1.19.1=py37h3b9f5b6_0
57 | numpy-base=1.19.1=py37hcfb5961_0
58 | olefile=0.46=py37_0
59 | opencv-python=4.4.0.44=pypi_0
60 | openh264=2.1.0=hd9629dc_0
61 | openssl=1.1.1k=h9ed2024_0
62 | opt_einsum=3.1.0=py_0
63 | pandas=1.1.3=py37hb1e8313_0
64 | parso=0.7.0=py_0
65 | pexpect=4.8.0=py37_0
66 | pickleshare=0.7.5=py37_0
67 | pillow=8.0.0=py37h1a82f1a_0
68 | pip=20.2.4=py37_0
69 | prompt-toolkit=3.0.8=py_0
70 | ptyprocess=0.6.0=py37_0
71 | pycparser=2.20=py_2
72 | pygments=2.7.1=py_0
73 | pyopenssl=20.0.1=pyhd3eb1b0_1
74 | pyparsing=2.4.7=py_0
75 | pysocks=1.7.1=py37hecd8cb5_0
76 | python=3.7.9=h26836e1_0
77 | python-dateutil=2.8.1=py_0
78 | pytorch=1.8.0=py3.7_0
79 | pytz=2020.1=py_0
80 | pyzmq=19.0.2=py37hb1e8313_1
81 | readline=8.0=h1de35cc_0
82 | requests=2.25.1=pyhd3eb1b0_0
83 | scikit-learn=0.23.2=py37h959d312_0
84 | scipy=1.5.2=py37h912ce22_0
85 | setuptools=50.3.0=py37h0dc7051_1
86 | six=1.15.0=py_0
87 | sqlite=3.33.0=hffcf06c_0
88 | threadpoolctl=2.1.0=pyh5ca1d4c_0
89 | tk=8.6.10=hb0a8c7a_0
90 | torchaudio=0.8.0=py37
91 | torchvision=0.9.0=py37_cpu
92 | tornado=6.0.4=py37h1de35cc_1
93 | tqdm=4.50.2=py_0
94 | traitlets=5.0.5=py_0
95 | typing_extensions=3.7.4.3=py_0
96 | urllib3=1.26.3=pyhd3eb1b0_0
97 | wcwidth=0.2.5=py_0
98 | wheel=0.35.1=py_0
99 | x264=1!157.20191217=h1de35cc_0
100 | xlrd=1.2.0=py37_0
101 | xz=5.2.5=h1de35cc_0
102 | zeromq=4.3.3=hb1e8313_3
103 | zlib=1.2.11=h1de35cc_3
104 | zstd=1.4.5=h41d2c2f_0
105 |
--------------------------------------------------------------------------------
/train_forward.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | from redunet import *
9 | import evaluate
10 | import load as L
11 | import functional as F
12 | import utils
13 | import plot
14 |
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--data', type=str, required=True, help='choice of dataset')
19 | parser.add_argument('--arch', type=str, required=True, help='choice of architecture')
20 | parser.add_argument('--samples', type=int, required=True, help="number of samples per update")
21 | parser.add_argument('--tail', type=str, default='', help='extra information to add to folder name')
22 | parser.add_argument('--save_dir', type=str, default='./saved_models/', help='base directory for saving.')
23 | parser.add_argument('--data_dir', type=str, default='./data/', help='base directory for saving.')
24 | args = parser.parse_args()
25 |
26 | ## CUDA
27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
28 |
29 | ## Model Directory
30 | model_dir = os.path.join(args.save_dir,
31 | 'forward',
32 | f'{args.data}+{args.arch}',
33 | f'samples{args.samples}'
34 | f'{args.tail}')
35 | os.makedirs(model_dir, exist_ok=True)
36 | utils.save_params(model_dir, vars(args))
37 | print(model_dir)
38 |
39 | ## Data
40 | trainset, testset, num_classes = L.load_dataset(args.data, data_dir=args.data_dir)
41 | X_train, y_train = F.get_samples(trainset, args.samples)
42 | X_train, y_train = X_train.to(device), y_train.to(device)
43 |
44 | ## Architecture
45 | net = L.load_architecture(args.data, args.arch)
46 | net = net.to(device)
47 |
48 | ## Training
49 | with torch.no_grad():
50 | Z_train = net.init(X_train, y_train)
51 | losses_train = net.get_loss()
52 | X_train, Z_train = F.to_cpu(X_train, Z_train)
53 |
54 | ## Saving
55 | utils.save_loss(model_dir, 'train', losses_train)
56 | utils.save_ckpt(model_dir, 'model', net)
57 |
58 | ## Plotting
59 | plot.plot_loss_mcr(model_dir, 'train')
60 |
61 | print(model_dir)
62 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 |
7 |
8 |
9 | def sort_dataset(data, labels, classes, stack=False):
10 | """Sort dataset based on classes.
11 |
12 | Parameters:
13 | data (np.ndarray): data array
14 | labels (np.ndarray): one dimensional array of class labels
15 | classes (int): number of classes
16 | stack (bol): combine sorted data into one numpy array
17 |
18 | Return:
19 | sorted data (np.ndarray), sorted_labels (np.ndarray)
20 |
21 | """
22 | if type(classes) == int:
23 | classes = np.arange(classes)
24 | sorted_data = []
25 | sorted_labels = []
26 | for c in classes:
27 | idx = (labels == c)
28 | data_c = data[idx]
29 | labels_c = labels[idx]
30 | sorted_data.append(data_c)
31 | sorted_labels.append(labels_c)
32 | if stack:
33 | if isinstance(data, np.ndarray):
34 | sorted_data = np.vstack(sorted_data)
35 | sorted_labels = np.hstack(sorted_labels)
36 | else:
37 | sorted_data = torch.stack(sorted_data)
38 | sorted_labels = torch.cat(sorted_labels)
39 | return sorted_data, sorted_labels
40 |
41 | def save_params(model_dir, params, name='params', name_prefix=None):
42 | """Save params to a .json file. Params is a dictionary of parameters."""
43 | if name_prefix:
44 | model_dir = os.path.join(model_dir, name_prefix)
45 | os.makedirs(model_dir, exist_ok=True)
46 | path = os.path.join(model_dir, f'{name}.json')
47 | with open(path, 'w') as f:
48 | json.dump(params, f, indent=2, sort_keys=True)
49 |
50 | def load_params(model_dir):
51 | """Load params.json file in model directory and return dictionary."""
52 | path = os.path.join(model_dir, "params.json")
53 | with open(path, 'r') as f:
54 | _dict = json.load(f)
55 | return _dict
56 |
57 | def update_params(model_dir, dict_):
58 | params = load_params(model_dir)
59 | for key in dict_.keys():
60 | params[key] = dict_[key]
61 | save_params(model_dir, params)
62 | return params
63 |
64 | def create_csv(model_dir, filename, headers):
65 | """Create .csv file with filename in model_dir, with headers as the first line
66 | of the csv. """
67 | csv_path = os.path.join(model_dir, f'{filename}.csv')
68 | if os.path.exists(csv_path):
69 | os.remove(csv_path)
70 | with open(csv_path, 'w+') as f:
71 | f.write(','.join(map(str, headers)))
72 | return csv_path
73 |
74 | def append_csv(model_dir, filename, entries):
75 | """Save entries to csv. Entries is list of numbers. """
76 | csv_path = os.path.join(model_dir, f'{filename}.csv')
77 | assert os.path.exists(csv_path), 'CSV file is missing in project directory.'
78 | with open(csv_path, 'a') as f:
79 | f.write('\n'+','.join(map(str, entries)))
80 |
81 | def save_loss(model_dir, name, loss_dict):
82 | save_dir = os.path.join(model_dir, "loss")
83 | os.makedirs(save_dir, exist_ok=True)
84 | file_path = os.path.join(save_dir, "{}.csv".format(name))
85 | pd.DataFrame(loss_dict).to_csv(file_path)
86 |
87 | def save_features(model_dir, name, features, labels, layer=None):
88 | save_dir = os.path.join(model_dir, "features")
89 | os.makedirs(save_dir, exist_ok=True)
90 | np.save(os.path.join(save_dir, f"{name}_features.npy"), features)
91 | np.save(os.path.join(save_dir, f"{name}_labels.npy"), labels)
92 |
93 | def save_ckpt(model_dir, name, net):
94 | """Save PyTorch checkpoint to model_dir/checkpoints/ directory in model directory. """
95 | os.makedirs(os.path.join(model_dir, 'checkpoints'), exist_ok=True)
96 | torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints',
97 | '{}.pt'.format(name)))
98 |
99 | def load_ckpt(model_dir, name, net, eval_=True):
100 | """Load checkpoint from model directory. Checkpoints should be stored in
101 | `model_dir/checkpoints/'.
102 | """
103 | ckpt_path = os.path.join(model_dir, 'checkpoints', f'{name}.pt')
104 | print('Loading checkpoint: {}'.format(ckpt_path))
105 | state_dict = torch.load(ckpt_path)
106 | net.load_state_dict(state_dict)
107 | del state_dict
108 | if eval_:
109 | net.eval()
110 | return net
111 |
112 |
--------------------------------------------------------------------------------