├── bispectral_networks ├── __init__.py ├── data │ ├── __init__.py │ ├── utils.py │ ├── data_loader.py │ ├── datasets.py │ └── transforms.py ├── nn │ ├── __init__.py │ ├── __pycache__ │ │ ├── layers.cpython-310.pyc │ │ ├── layers.cpython-39.pyc │ │ ├── model.cpython-310.pyc │ │ ├── model.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── functional.cpython-310.pyc │ │ └── functional.cpython-39.pyc │ ├── functional.py │ ├── model.py │ └── layers.py ├── analysis │ ├── __init__.py │ ├── knn.py │ ├── plotting.py │ └── adversary.py ├── config.py ├── normalizer.py ├── loss.py ├── logger.py └── trainer.py ├── logs ├── rotation_model │ ├── config.pt │ ├── checkpoints │ │ └── checkpoint_551.pt │ ├── events.out.tfevents.1653530727.kzbdd6.9000.0 │ ├── val_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.5 │ ├── train_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.1 │ ├── val_rep_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.6 │ ├── train_rep_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.2 │ ├── val_recon_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.7 │ ├── val_total_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.8 │ ├── train_recon_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.3 │ └── train_total_loss │ │ └── events.out.tfevents.1653530885.kzbdd6.9000.4 └── translation_model │ ├── config.pt │ ├── checkpoints │ └── checkpoint_232.pt │ ├── events.out.tfevents.1653534595.kzbdd6.24358.0 │ ├── val_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.5 │ ├── train_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.1 │ ├── val_rep_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.6 │ ├── train_rep_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.2 │ ├── val_recon_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.7 │ ├── val_total_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.8 │ ├── train_recon_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.3 │ └── train_total_loss │ └── events.out.tfevents.1653534682.kzbdd6.24358.4 ├── notebooks └── figs │ ├── rotation │ ├── W_imag.pdf │ ├── W_real.pdf │ ├── adversary.pdf │ ├── equivariance-0.pdf │ ├── equivariance-1.pdf │ ├── equivariance-2.pdf │ ├── invariance-0.pdf │ ├── invariance-1.pdf │ ├── invariance-2.pdf │ ├── test_examples.pdf │ └── test_distance_matrix.pdf │ └── translation │ ├── W_imag.pdf │ ├── W_real.pdf │ ├── adversary.pdf │ ├── invariance-0.pdf │ ├── invariance-1.pdf │ ├── invariance-2.pdf │ ├── test_examples.pdf │ ├── equivariance-0.pdf │ ├── equivariance-1.pdf │ ├── equivariance-2.pdf │ └── test_distance_matrix.pdf ├── requirements.txt ├── setup.py ├── train.py ├── LICENSE ├── README.md └── configs ├── translation_experiment.py └── rotation_experiment.py /bispectral_networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bispectral_networks/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bispectral_networks/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bispectral_networks/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/rotation_model/config.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/config.pt -------------------------------------------------------------------------------- /logs/translation_model/config.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/config.pt -------------------------------------------------------------------------------- /notebooks/figs/rotation/W_imag.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/W_imag.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/W_real.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/W_real.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/adversary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/adversary.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/W_imag.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/W_imag.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/W_real.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/W_real.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/equivariance-0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/equivariance-0.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/equivariance-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/equivariance-1.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/equivariance-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/equivariance-2.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/invariance-0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/invariance-0.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/invariance-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/invariance-1.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/invariance-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/invariance-2.pdf -------------------------------------------------------------------------------- /notebooks/figs/rotation/test_examples.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/test_examples.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/adversary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/adversary.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/invariance-0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/invariance-0.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/invariance-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/invariance-1.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/invariance-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/invariance-2.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/test_examples.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/test_examples.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/equivariance-0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/equivariance-0.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/equivariance-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/equivariance-1.pdf -------------------------------------------------------------------------------- /notebooks/figs/translation/equivariance-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/equivariance-2.pdf -------------------------------------------------------------------------------- /logs/rotation_model/checkpoints/checkpoint_551.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/checkpoints/checkpoint_551.pt -------------------------------------------------------------------------------- /notebooks/figs/rotation/test_distance_matrix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/rotation/test_distance_matrix.pdf -------------------------------------------------------------------------------- /logs/translation_model/checkpoints/checkpoint_232.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/checkpoints/checkpoint_232.pt -------------------------------------------------------------------------------- /notebooks/figs/translation/test_distance_matrix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/notebooks/figs/translation/test_distance_matrix.pdf -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/layers.cpython-310.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | cplxmodule 4 | matplotlib 5 | jupyter 6 | tensorboard 7 | scikit-image 8 | pytorch_metric_learning 9 | pandas 10 | scipy 11 | seaborn -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/functional.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/functional.cpython-310.pyc -------------------------------------------------------------------------------- /bispectral_networks/nn/__pycache__/functional.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/bispectral_networks/nn/__pycache__/functional.cpython-39.pyc -------------------------------------------------------------------------------- /logs/rotation_model/events.out.tfevents.1653530727.kzbdd6.9000.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/events.out.tfevents.1653530727.kzbdd6.9000.0 -------------------------------------------------------------------------------- /logs/translation_model/events.out.tfevents.1653534595.kzbdd6.24358.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/events.out.tfevents.1653534595.kzbdd6.24358.0 -------------------------------------------------------------------------------- /logs/rotation_model/val_loss/events.out.tfevents.1653530885.kzbdd6.9000.5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/val_loss/events.out.tfevents.1653530885.kzbdd6.9000.5 -------------------------------------------------------------------------------- /logs/rotation_model/train_loss/events.out.tfevents.1653530885.kzbdd6.9000.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/train_loss/events.out.tfevents.1653530885.kzbdd6.9000.1 -------------------------------------------------------------------------------- /logs/rotation_model/val_rep_loss/events.out.tfevents.1653530885.kzbdd6.9000.6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/val_rep_loss/events.out.tfevents.1653530885.kzbdd6.9000.6 -------------------------------------------------------------------------------- /logs/translation_model/val_loss/events.out.tfevents.1653534682.kzbdd6.24358.5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/val_loss/events.out.tfevents.1653534682.kzbdd6.24358.5 -------------------------------------------------------------------------------- /logs/rotation_model/train_rep_loss/events.out.tfevents.1653530885.kzbdd6.9000.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/train_rep_loss/events.out.tfevents.1653530885.kzbdd6.9000.2 -------------------------------------------------------------------------------- /logs/rotation_model/val_recon_loss/events.out.tfevents.1653530885.kzbdd6.9000.7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/val_recon_loss/events.out.tfevents.1653530885.kzbdd6.9000.7 -------------------------------------------------------------------------------- /logs/rotation_model/val_total_loss/events.out.tfevents.1653530885.kzbdd6.9000.8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/val_total_loss/events.out.tfevents.1653530885.kzbdd6.9000.8 -------------------------------------------------------------------------------- /logs/translation_model/train_loss/events.out.tfevents.1653534682.kzbdd6.24358.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/train_loss/events.out.tfevents.1653534682.kzbdd6.24358.1 -------------------------------------------------------------------------------- /logs/rotation_model/train_recon_loss/events.out.tfevents.1653530885.kzbdd6.9000.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/train_recon_loss/events.out.tfevents.1653530885.kzbdd6.9000.3 -------------------------------------------------------------------------------- /logs/rotation_model/train_total_loss/events.out.tfevents.1653530885.kzbdd6.9000.4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/rotation_model/train_total_loss/events.out.tfevents.1653530885.kzbdd6.9000.4 -------------------------------------------------------------------------------- /logs/translation_model/val_rep_loss/events.out.tfevents.1653534682.kzbdd6.24358.6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/val_rep_loss/events.out.tfevents.1653534682.kzbdd6.24358.6 -------------------------------------------------------------------------------- /logs/translation_model/train_rep_loss/events.out.tfevents.1653534682.kzbdd6.24358.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/train_rep_loss/events.out.tfevents.1653534682.kzbdd6.24358.2 -------------------------------------------------------------------------------- /logs/translation_model/val_recon_loss/events.out.tfevents.1653534682.kzbdd6.24358.7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/val_recon_loss/events.out.tfevents.1653534682.kzbdd6.24358.7 -------------------------------------------------------------------------------- /logs/translation_model/val_total_loss/events.out.tfevents.1653534682.kzbdd6.24358.8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/val_total_loss/events.out.tfevents.1653534682.kzbdd6.24358.8 -------------------------------------------------------------------------------- /logs/translation_model/train_recon_loss/events.out.tfevents.1653534682.kzbdd6.24358.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/train_recon_loss/events.out.tfevents.1653534682.kzbdd6.24358.3 -------------------------------------------------------------------------------- /logs/translation_model/train_total_loss/events.out.tfevents.1653534682.kzbdd6.24358.4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sophiaas/bispectral-networks/HEAD/logs/translation_model/train_total_loss/events.out.tfevents.1653534682.kzbdd6.24358.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="bispectral-networks", 5 | version="0.0.1", 6 | packages=setuptools.find_packages(), 7 | classifiers=[ 8 | "Programming Language :: Python :: 3", 9 | "Operating System :: OS Independent", 10 | ], 11 | python_requires='>=3.6', 12 | ) -------------------------------------------------------------------------------- /bispectral_networks/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from cplxmodule import Cplx 4 | from cplxmodule.nn import CplxParameter 5 | 6 | 7 | def linear(x, W, b=None): 8 | if type(x) == Cplx or type(x) == CplxParameter or type(W) == Cplx or type(W) == CplxParameter: 9 | re = F.linear(x.real, W.real) - F.linear(x.imag, W.imag) 10 | if b is not None: 11 | re = re + b.real 12 | im = F.linear(x.real, W.imag) + F.linear(x.imag, W.real) 13 | if b is not None: 14 | im = im + b.imag 15 | out = Cplx(re, im) 16 | else: 17 | out = F.linear(x, W, b) 18 | return out 19 | 20 | 21 | def linear_conjtx(x, W, b=None): 22 | re = F.linear(x.real.T, W.real) - F.linear(-x.imag.T, W.imag) 23 | if b is not None: 24 | re = re + b.real 25 | im = F.linear(x.real.T, W.imag) + F.linear(-x.imag.T, W.real) 26 | if b is not None: 27 | im = im + b.imag 28 | out = Cplx(re, im) 29 | return out 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from bispectral_networks.trainer import run_trainer 3 | 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | "--config", 8 | type=str, 9 | help="Name of .py config file with no extension.", 10 | default="translation_experiment", 11 | ) 12 | parser.add_argument("--device", type=int, help="device to run on, -1 for cpu", default=-1) 13 | parser.add_argument( 14 | "--n_examples", type=int, help="number of data examples", default=5e6 15 | ) 16 | parser.add_argument("--seed", type=int, default=None) 17 | 18 | 19 | args = parser.parse_args() 20 | if args.device == -1: 21 | args.device = 'cpu' 22 | 23 | print("Running experiment on device {}...".format(args.device)) 24 | exec("from configs.{} import master_config, logger_config".format(args.config)) 25 | 26 | def run_wrapper(): 27 | run_trainer( 28 | master_config=master_config, 29 | logger_config=logger_config, 30 | device=args.device, 31 | n_examples=args.n_examples, 32 | seed=args.seed 33 | ) 34 | 35 | run_wrapper() 36 | -------------------------------------------------------------------------------- /bispectral_networks/config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | def _get_default_args(func): 4 | signature = inspect.signature(func) 5 | return { 6 | k: v.default 7 | for k, v in signature.parameters.items() 8 | if v.default is not inspect.Parameter.empty 9 | } 10 | 11 | 12 | class Config(dict): 13 | def __init__(self, config): 14 | """ 15 | Takes in a dictionary config of the following form: 16 | config = { 17 | "type": Class, 18 | "params": { 19 | "param1": val, 20 | "param2": val 21 | } 22 | } 23 | """ 24 | config = self.fill_defaults(config) 25 | super().__init__(**config) 26 | self.__dict__ = self 27 | 28 | def fill_defaults(self, config): 29 | defaults = _get_default_args(config["type"]) 30 | for k, v in defaults.items(): 31 | if k not in config["params"]: 32 | config["params"][k] = v 33 | return config 34 | 35 | def build(self): 36 | return self["type"](**self["params"]) -------------------------------------------------------------------------------- /bispectral_networks/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Normalizer(torch.nn.Module): 5 | def __init__(self, variables): 6 | super().__init__() 7 | self.name = "{}_{}".format("normalizer", str(variables)) 8 | self.variables = variables if type(variables) == list else [variables] 9 | 10 | def forward(self, variable_dict): 11 | with torch.no_grad(): 12 | self.normalize(variable_dict) 13 | 14 | def normalize(self, variable_dict): 15 | raise NotImplementedError 16 | 17 | 18 | 19 | class L2Normalizer(Normalizer): 20 | def __init__(self, variables): 21 | super().__init__(variables) 22 | 23 | def normalize(self, variable_dict): 24 | for v in self.variables: 25 | var = variable_dict[v + ".real"].data + 1j * variable_dict[v + ".imag"].data 26 | variable_dict[v + ".real"].data /= torch.linalg.norm( 27 | var, dim=1, keepdims=True 28 | ) 29 | variable_dict[v + ".imag"].data /= torch.linalg.norm( 30 | var, dim=1, keepdims=True 31 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /bispectral_networks/analysis/knn.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import numpy as np 3 | 4 | 5 | def analyze_dist(dist, dataset, top_n=100): 6 | breakdown = np.zeros((dist.shape[0], 2)) 7 | total = 0 8 | for i, row in enumerate(dist): 9 | i_label = dataset.labels[i] 10 | top_idxs = np.argsort(row)[:top_n] 11 | # Include ties 12 | maxval = row[top_idxs[-1]] 13 | top_idxs = np.where(row <= maxval)[0] 14 | for j in top_idxs: 15 | if i == j: 16 | continue 17 | else: 18 | j_label = dataset.labels[j] 19 | if j_label == i_label: 20 | # Same orbit 21 | breakdown[i, 0] += 1 22 | else: 23 | # Other 24 | breakdown[i, 1] += 1 25 | 26 | breakdown_mean = breakdown.sum(axis=0) / (breakdown.shape[0]) 27 | breakdown_percent = breakdown_mean / breakdown_mean.sum() 28 | return breakdown, breakdown_percent 29 | 30 | 31 | 32 | def knn_analysis(model, dataset, n): 33 | model.eval() 34 | output, _ = model.forward(dataset.data.float()) 35 | output = output.detach().cpu() 36 | output_dist = scipy.spatial.distance_matrix(output, output) 37 | nn_dist, nn_dist_mean = analyze_dist(output_dist, dataset, n) 38 | return output, output_dist, nn_dist_mean 39 | -------------------------------------------------------------------------------- /bispectral_networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class LpDistance: 6 | def __init__(self, p=2, pairwise=False): 7 | self.p = p 8 | self.pairwise = pairwise 9 | 10 | def __call__(self, x1, x2): 11 | if x1.dtype == torch.complex64 or x1.dtype == torch.complex128: 12 | dtype = x1.real.dtype 13 | else: 14 | dtype = x1.dtype 15 | if self.pairwise: 16 | return torch.nn.functional.pairwise_distance(x1, x2, p=self.p) 17 | else: 18 | rows, cols = np.meshgrid(range(len(x1)), range(len(x2))) 19 | rows = rows.flatten() 20 | cols = cols.flatten() 21 | dmat = torch.zeros((len(x1), len(x2)), dtype=dtype, device=x1.device) 22 | distances = torch.nn.functional.pairwise_distance(x1[rows], x2[cols], p=self.p) 23 | dmat[rows, cols] = distances 24 | return dmat 25 | 26 | 27 | class OrbitCollapse(torch.nn.Module): 28 | 29 | def __init__(self, 30 | distance=None): 31 | super().__init__() 32 | self.distance = distance 33 | 34 | def forward(self, embeddings, labels): 35 | L = 0 36 | count = 0 37 | for i in labels.unique(): 38 | idx = torch.where(labels==i)[0] 39 | dmat = self.distance(embeddings[idx], embeddings[idx]) 40 | ut_idx = np.triu_indices(dmat.shape[0], k=1) 41 | distances = dmat[ut_idx] 42 | L += distances.sum() 43 | count += len(distances) 44 | L /= count 45 | return L -------------------------------------------------------------------------------- /bispectral_networks/data/utils.py: -------------------------------------------------------------------------------- 1 | from bispectral_networks.data.datasets import TransformDataset 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def gen_dataset(config): 7 | """ 8 | Generate a TransformDataset from a config dictionary with the following 9 | structure: 10 | config = { 11 | "pattern": {"type": obj, "params": {}}, 12 | "transforms": { 13 | "0": {"type": obj, "params": {}}, 14 | "1": {"type": obj, "params": {}} 15 | } 16 | } 17 | The "type" parameter in each dictionary specifies an uninstantiated dataset 18 | or transform class. The "params" parameter specifies a dictionary containing 19 | the keyword arguments needed to instantiate the class. 20 | """ 21 | if "seed" in config: 22 | torch.manual_seed(config['seed']) 23 | np.random.seed(config['seed']) 24 | # Catch for datasets and transforms that have no parameters 25 | if "params" not in config["pattern"]: 26 | config["pattern"]["params"] = {} 27 | for t in config["transforms"]: 28 | if "params" not in config["transforms"][t]: 29 | config["transforms"][t]["params"] = {} 30 | 31 | # Instantiate pattern object 32 | pattern = config["pattern"]["type"](**config["pattern"]["params"]) 33 | 34 | # Instantiate transform objects 35 | transforms = [ 36 | config["transforms"][k]["type"](**config["transforms"][k]["params"]) 37 | for k in sorted(config["transforms"]) 38 | ] 39 | 40 | # Generate dataset 41 | dataset = TransformDataset(pattern, transforms) 42 | return dataset 43 | -------------------------------------------------------------------------------- /bispectral_networks/nn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | import cplxmodule 4 | from bispectral_networks.nn.layers import ( 5 | Bispectral, 6 | RowNorm, 7 | CplxToComplex 8 | 9 | ) 10 | 11 | 12 | class BispectralEmbedding(torch.nn.Module): 13 | def __init__( 14 | self, 15 | size_in, 16 | hdim, 17 | field="complex", 18 | constrained=False, 19 | bias=False, 20 | device="cpu", 21 | projection=True, 22 | linear_out=False, 23 | weight_init=cplxmodule.nn.init.cplx_trabelsi_independent_, 24 | name="bispectral-embedding", 25 | ): 26 | 27 | super().__init__() 28 | self.size_in = size_in 29 | self.name = name 30 | self.hdim = hdim 31 | self.field = field 32 | self.constrained = constrained 33 | self.bias = bias 34 | self.device = device 35 | self.weight_init = weight_init 36 | self.projection = projection 37 | self.linear_out = linear_out 38 | self.build_layers() 39 | 40 | def build_layers(self): 41 | layers = [ 42 | Bispectral( 43 | self.size_in, 44 | self.hdim, 45 | weight_init=self.weight_init, 46 | device=self.device, 47 | ), 48 | RowNorm(), 49 | CplxToComplex() 50 | ] 51 | 52 | self.layers = torch.nn.ModuleList(layers) 53 | 54 | def forward(self, x, term=0): 55 | x, recon = self.layers[0].forward(x, return_inv=True) 56 | for layer in self.layers[1:]: 57 | x = layer.forward(x) 58 | return x, recon -------------------------------------------------------------------------------- /bispectral_networks/analysis/plotting.py: -------------------------------------------------------------------------------- 1 | from mpl_toolkits.axes_grid1 import ImageGrid 2 | import matplotlib.pyplot as plt 3 | from matplotlib import animation 4 | import numpy as np 5 | from IPython.display import HTML 6 | 7 | 8 | 9 | def image_grid(data, shape=(10,10), figsize=(10,10), cmap='Greys_r', share_range=True, interpolation=None, save_name=None): 10 | 11 | fig = plt.figure(figsize=figsize) 12 | grid = ImageGrid(fig, 111, # similar to subplot(111) 13 | nrows_ncols=shape, # creates 10x10 grid of axes 14 | axes_pad=0.1, # pad between axes in inch. 15 | ) 16 | 17 | if share_range: 18 | vmin = data.min() 19 | vmax = data.max() 20 | 21 | for ax, im in zip(grid, data): 22 | # Iterating over the grid returns the Axes. 23 | if share_range: 24 | ax.imshow(im, vmin=vmin, vmax=vmax, cmap=cmap, interpolation=interpolation) 25 | else: 26 | ax.imshow(im, cmap=cmap, interpolation=interpolation) 27 | ax.set_axis_off() 28 | 29 | if save_name is not None: 30 | plt.savefig(save_name) 31 | 32 | 33 | def animated_video(vid, interval=25, figsize=(5,5), cmap='Greys_r'): 34 | def init(): 35 | return (im,) 36 | 37 | def animate(frame): 38 | im.set_data(frame) 39 | return (im,) 40 | 41 | fig, ax = plt.subplots(figsize=figsize) 42 | im = ax.imshow(np.zeros(vid[0].shape), vmin=np.min(vid), vmax=np.max(vid), cmap=cmap); 43 | plt.axis('off') 44 | anim = animation.FuncAnimation(fig, animate, init_func=init, 45 | frames=vid, interval=interval, blit=True) 46 | plt.close() 47 | return HTML(anim.to_jshtml()) 48 | 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bispectral Neural Networks 2 | 3 | This repository is the official implementation of Bispectral Neural Networks. 4 | 5 | ## Installation 6 | 7 | To install the requirements and package, run: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | python setup.py install 12 | ``` 13 | 14 | ## Datasets 15 | 16 | To download the datasets, run: 17 | 18 | ``` 19 | pip install gdown 20 | gdown 10w3fKdO0eWEe2KxZxpf8YFndXdCNNR8b 21 | unzip datasets.zip 22 | rm datasets.zip 23 | ``` 24 | 25 | 26 | If your machine doesn't have wget, follow these steps: 27 | 1. Download the zip file [here](https://drive.google.com/file/d/10w3fKdO0eWEe2KxZxpf8YFndXdCNNR8b/view?usp=sharing). 28 | 2. Place the file in the top node of this directory, i.e. in `bispectral-networks/`. 29 | 3. Run: 30 | ``` 31 | unzip datasets.zip 32 | rm -r datasets.zip 33 | ``` 34 | 35 | ## Training 36 | 37 | To train the models in the paper, run the following commands. 38 | 39 | ``` 40 | python train.py --config rotation_experiment 41 | python train.py --config translation_experiment 42 | ``` 43 | 44 | To run on GPU, add the following argument, with the integer specifying the device number, i.e.: 45 | 46 | 47 | ``` 48 | --device 0 49 | ``` 50 | 51 | The full set of hyperparameters and training configurations are specified in the config files in the ```configs/``` folder. 52 | 53 | To view learning curves in Tensorboard, run: 54 | ``` 55 | tensorboard --logdir logs/ 56 | ``` 57 | 58 | ## Pre-trained Models 59 | 60 | The pre-trained models are included in the repo, in the following locations: 61 | 62 | ``` 63 | logs/rotation_model/ 64 | logs/translation_model/ 65 | ``` 66 | 67 | 68 | ## Results and Figures 69 | 70 | All results and figures from the paper are generated in the Jupyter notebooks located at: 71 | 72 | ``` 73 | notebooks/rotation_experiment_analysis.ipynb 74 | notebooks/translation_experiment_analysis.ipynb 75 | ``` 76 | 77 | ## License 78 | 79 | This repository is licensed under the MIT License. 80 | 81 | -------------------------------------------------------------------------------- /bispectral_networks/analysis/adversary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class BasicGradientDescent(torch.nn.Module): 6 | def __init__( 7 | self, 8 | model, 9 | target_image, 10 | initial_image=None, 11 | distance_fn=torch.nn.functional.pairwise_distance, 12 | margin=1.0, 13 | pnorm=2, 14 | optimizer=torch.optim.Adam, 15 | scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, 16 | lr=0.001, 17 | save_interval=100, 18 | print_interval=1000, 19 | device="cpu", 20 | ): 21 | 22 | super().__init__() 23 | 24 | self.device = device 25 | self.distance_fn = distance_fn 26 | self.pnorm = pnorm 27 | self.margin = margin 28 | 29 | self.save_interval = save_interval 30 | self.print_interval = print_interval 31 | 32 | self.model = model 33 | for p in self.model.parameters(): 34 | p.requires_grad = False 35 | 36 | self.target_image = target_image.to(device) 37 | 38 | self.target_embedding, _ = model(target_image.to(device)) 39 | if len(self.target_embedding.shape) < 2: 40 | self.target_embedding = self.target_embedding.unsqueeze(0) 41 | self.target_embedding = self.target_embedding.detach() 42 | 43 | 44 | if initial_image is None: 45 | self.x_ = torch.nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=1.0, size=target_image.shape)).to(device)) 46 | else: 47 | self.x_ = torch.nn.Parameter(torch.tensor(initial_image)).to(device) 48 | 49 | self.optimizer = optimizer(self.parameters(), lr=lr) 50 | self.scheduler = scheduler(self.optimizer) 51 | 52 | self.history = [] 53 | self.d_history = [] 54 | 55 | def train(self, max_iter=1000): 56 | i = 0 57 | within_margin = False 58 | 59 | while not (within_margin or i == max_iter): 60 | embedding, _ = self.model(self.x_) 61 | if len(embedding.shape) < 2: 62 | embedding = embedding.unsqueeze(0) 63 | embedding = embedding.type(self.target_embedding.dtype) 64 | d = (self.distance_fn(self.target_embedding, embedding)).mean() 65 | 66 | if d < self.margin: 67 | within_margin = True 68 | 69 | else: 70 | self.optimizer.zero_grad() 71 | d.backward() 72 | self.optimizer.step() 73 | i += 1 74 | 75 | if i % self.save_interval == 0: 76 | self.history.append(self.x_.cpu().detach().clone().numpy()) 77 | self.d_history.append(d.cpu().detach().numpy()) 78 | 79 | if i % self.print_interval == 0: 80 | print("Iter: {} | Distance: {}".format(i, d.detach())) 81 | self.scheduler.step(d) 82 | 83 | print("Final Distance: {}".format(d)) 84 | 85 | if not within_margin: 86 | print("Did not reach margin.") 87 | 88 | return self.x_.detach(), self.target_embedding.detach(), embedding.detach() 89 | -------------------------------------------------------------------------------- /bispectral_networks/nn/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cplxmodule 4 | from cplxmodule import Cplx 5 | from cplxmodule.nn import CplxParameter 6 | from .functional import linear, linear_conjtx 7 | 8 | 9 | class RowNorm(torch.nn.Module): 10 | def forward(self, x): 11 | x = x - torch.mean(x.real + 1j * x.imag, axis=-1, keepdim=True) 12 | x = x / torch.linalg.norm(x.real + 1j * x.imag, axis=-1, keepdim=True) 13 | return Cplx(x.real, x.imag) 14 | 15 | 16 | class CplxToComplex(torch.nn.Module): 17 | def forward(self, x): 18 | return x.real + 1j * x.imag 19 | 20 | 21 | class Bispectral(torch.nn.Module): 22 | def __init__( 23 | self, 24 | size_in, 25 | size_out, 26 | weight_init=cplxmodule.nn.init.cplx_trabelsi_independent_, 27 | device="cpu", 28 | ): 29 | super().__init__() 30 | 31 | self.size_in, self.size_out = size_in, size_out 32 | self.device = device 33 | self.weight_init = weight_init 34 | 35 | self.reset_parameters() 36 | 37 | def forward(self, x, return_inv=False): 38 | return self.forward_(x, return_inv=return_inv) 39 | 40 | def reset_parameters(self): 41 | self.reset_parameters_() 42 | 43 | def reset_parameters_(self): 44 | size_out = self.size_out 45 | size_in = self.size_in 46 | 47 | self.W = Cplx.empty(size_out, size_in).to(self.device) 48 | self.weight_init(self.W) 49 | self.W = CplxParameter(self.W) 50 | 51 | def forward_(self, x, return_inv=False): 52 | if type(x) != Cplx: 53 | x = x.type(self.W.data.dtype) 54 | x = Cplx(x) 55 | if return_inv: 56 | l, l_inv = self.forward_linear(x, return_inv=return_inv) 57 | else: 58 | l = self.forward_linear(x) 59 | 60 | l_ = l.real + 1j * l.imag 61 | l_ = l_.unsqueeze(-1) 62 | l_cross = torch.matmul(l_, torch.swapaxes(l_, 1, -1)) 63 | l_cross = l_cross.reshape(l.shape[0], -1) 64 | l_cross = Cplx(l_cross.real, l_cross.imag) 65 | 66 | W_ = self.W.real + 1j * self.W.imag 67 | all_crosses = (W_[:, None, :] * W_[None, :, :]).conj() 68 | all_crosses = all_crosses.reshape((-1, self.size_out)).to(x.device) 69 | all_crosses = Cplx(all_crosses.real, all_crosses.imag) 70 | 71 | conj_term = linear(x, all_crosses) 72 | out = l_cross * conj_term 73 | 74 | # Take only upper triangular 75 | out = out.reshape(-1, self.size_out, self.size_out) 76 | idxs = np.triu_indices(self.size_out, k=0, m=None) 77 | out = out[:, idxs[0], idxs[1]] 78 | 79 | if return_inv: 80 | return out, l_inv 81 | else: 82 | return out 83 | 84 | def forward_linear(self, x, return_inv=False): 85 | if type(x) != Cplx: 86 | x = x.type(self.W.data.dtype) 87 | x = Cplx(x) 88 | 89 | l = linear(x, self.W) 90 | 91 | l_inv = linear_conjtx(self.W, l).real.T 92 | 93 | if return_inv: 94 | return l, l_inv 95 | else: 96 | return l -------------------------------------------------------------------------------- /bispectral_networks/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import datetime 4 | import copy 5 | from torch.utils.tensorboard import SummaryWriter 6 | from bispectral_networks.config import Config 7 | 8 | 9 | class TBLogger: 10 | def __init__( 11 | self, 12 | config, 13 | log_interval=1, 14 | checkpoint_interval=10, 15 | logdir=None 16 | ): 17 | self.config = config 18 | self.log_interval = log_interval 19 | self.checkpoint_interval = checkpoint_interval 20 | self.logdir = logdir 21 | 22 | def begin(self, model, data_loader): 23 | try: 24 | self.create_logdir() 25 | torch.save(self.config, os.path.join(self.logdir, "config.pt")) 26 | writer = SummaryWriter(self.logdir) 27 | return writer 28 | except: 29 | raise Exception("Problem creating logging and/or checkpoint directory.") 30 | 31 | def end(self, trainer, variable_dict, epoch): 32 | self.save_checkpoint(trainer, epoch) 33 | 34 | def create_logdir(self): 35 | if self.logdir is None: 36 | self.logdir = os.path.join( 37 | "logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 38 | ) 39 | 40 | os.makedirs(self.logdir, exist_ok=True) 41 | os.mkdir(os.path.join(self.logdir, "checkpoints")) 42 | 43 | def log_step(self, writer, trainer, log_dict, variable_dict, epoch, val_log_dict=None): 44 | if epoch % self.log_interval == 0: 45 | writer.add_scalars("train", log_dict, global_step=epoch) 46 | if val_log_dict is not None: 47 | writer.add_scalars("val", val_log_dict, global_step=epoch) 48 | 49 | if epoch % self.checkpoint_interval == 0: 50 | self.save_checkpoint(trainer, epoch) 51 | 52 | def save_checkpoint(self, trainer, epoch): 53 | checkpoint = { 54 | "trainer": trainer, 55 | "model_state_dict": trainer.model.state_dict(), 56 | 'optimizer_state_dict': trainer.optimizer.state_dict(), 57 | } 58 | torch.save( 59 | checkpoint, 60 | os.path.join(self.logdir, "checkpoints", "checkpoint_{}.pt".format(epoch)), 61 | ) 62 | 63 | 64 | def load_checkpoint(logdir, device="cpu"): 65 | all_checkpoints = os.listdir(os.path.join(logdir, "checkpoints")) 66 | all_epochs = sorted([int(x.split("_")[1].split(".")[0]) for x in all_checkpoints]) 67 | last_epoch = all_epochs[-1] 68 | checkpoint = torch.load(os.path.join(logdir, “checkpoints”, “checkpoint_{}.pt”.format(last_epoch)), map_location=torch.device(device)) 69 | config = torch.load(os.path.join(logdir, "config.pt")) 70 | if not hasattr(checkpoint, "model"): 71 | trainer = checkpoint["trainer"] 72 | model_config = Config(trainer.logger.config["model"]) 73 | optimizer_config = Config(copy.deepcopy(trainer.logger.config["optimizer"])) 74 | trainer.model = model_config.build() 75 | trainer.model.load_state_dict(checkpoint["model_state_dict"]) 76 | optimizer_config["params"]["params"] = trainer.model.parameters() 77 | trainer.optimizer = optimizer_config.build() 78 | trainer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 79 | checkpoint = trainer 80 | W = (checkpoint.model.layers[0].W.real + 1j * checkpoint.model.layers[0].W.imag).detach().numpy() 81 | patch_size = config["dataset"]["pattern"]["params"]["patch_size"] 82 | W = W.reshape(W.shape[0], patch_size, patch_size) 83 | return checkpoint, config, W 84 | -------------------------------------------------------------------------------- /configs/translation_experiment.py: -------------------------------------------------------------------------------- 1 | from bispectral_networks.config import Config 2 | 3 | """ 4 | DATA_LOADER 5 | """ 6 | from bispectral_networks.data.data_loader import MPerClassLoader 7 | 8 | data_loader_config = Config( 9 | { 10 | "type": MPerClassLoader, 11 | "params": { 12 | "batch_size": 100, 13 | "m": 10, 14 | "fraction_val": 0.2, 15 | "num_workers": 1, 16 | }, 17 | } 18 | ) 19 | 20 | """ 21 | DATASET 22 | """ 23 | from bispectral_networks.data.transforms import CyclicTranslation2D, Ravel, CenterMean, UnitStd 24 | from bispectral_networks.data.datasets import VanHateren 25 | 26 | 27 | pattern_config = Config( 28 | { 29 | "type": VanHateren, 30 | "params": {"path": "datasets/van-hateren/", 31 | "min_contrast": 0.1, 32 | "patches_per_image": 3, 33 | "patch_size": 16}, 34 | } 35 | ) 36 | 37 | transforms_config = { 38 | "0": Config( 39 | { 40 | "type": CenterMean, 41 | "params": {} 42 | } 43 | ), 44 | "1": Config( 45 | { 46 | "type": UnitStd, 47 | "params": {} 48 | } 49 | ), 50 | "2": Config( 51 | { 52 | "type": CyclicTranslation2D, 53 | "params": { 54 | "fraction_transforms": 1.0, 55 | "sample_method": "random" 56 | }, 57 | } 58 | ), 59 | "3": Config( 60 | { 61 | "type": Ravel, 62 | "params": {}, 63 | } 64 | ) 65 | } 66 | 67 | 68 | dataset_config = {"pattern": pattern_config, 69 | "transforms": transforms_config, 70 | "seed": 5} 71 | 72 | 73 | """ 74 | MODEL 75 | """ 76 | from bispectral_networks.nn.model import BispectralEmbedding 77 | from bispectral_networks.nn.layers import Bispectral 78 | 79 | model_config = Config( 80 | { 81 | "type": BispectralEmbedding, 82 | "params": {"size_in": 256, 83 | "hdim": 256}, 84 | } 85 | ) 86 | 87 | 88 | """ 89 | NORMALIZER 90 | """ 91 | from bispectral_networks.normalizer import L2Normalizer 92 | 93 | normalizer_config = Config({ 94 | "type": L2Normalizer, 95 | "params": { 96 | "variables": ["layers.0.W"] 97 | } 98 | 99 | }) 100 | 101 | 102 | """ 103 | LOSS 104 | """ 105 | from pytorch_metric_learning import losses, reducers 106 | from bispectral_networks.loss import OrbitCollapse, LpDistance 107 | loss_config = Config( 108 | { 109 | "type": OrbitCollapse, 110 | "params": { 111 | "distance": LpDistance(), 112 | }, 113 | } 114 | ) 115 | 116 | 117 | """ 118 | OPTIMIZER 119 | """ 120 | from torch.optim import Adam 121 | 122 | optimizer_config = Config({"type": Adam, "params": {"lr": 0.002}}) 123 | 124 | 125 | """ 126 | SCHEDULER 127 | """ 128 | from torch.optim.lr_scheduler import ReduceLROnPlateau 129 | 130 | scheduler_config = Config({"type": ReduceLROnPlateau, "params": {"factor": 0.5, "patience": 2, "min_lr": 1e-6}}) 131 | 132 | 133 | """ 134 | LOGGER 135 | """ 136 | 137 | from bispectral_networks.logger import TBLogger 138 | 139 | logger_config = Config( 140 | { 141 | "type": TBLogger, 142 | "params": { 143 | "log_interval": 1, 144 | "checkpoint_interval": 10, 145 | }, 146 | } 147 | ) 148 | 149 | 150 | """ 151 | MASTER CONFIG 152 | """ 153 | 154 | master_config = { 155 | "data_loader": data_loader_config, 156 | "dataset": dataset_config, 157 | "model": model_config, 158 | "optimizer": optimizer_config, 159 | "normalizer": normalizer_config, 160 | "scheduler": scheduler_config, 161 | "loss": loss_config, 162 | "seed": 200 163 | } -------------------------------------------------------------------------------- /configs/rotation_experiment.py: -------------------------------------------------------------------------------- 1 | from bispectral_networks.config import Config 2 | 3 | """ 4 | DATA_LOADER 5 | """ 6 | from bispectral_networks.data.data_loader import MPerClassLoader 7 | 8 | data_loader_config = Config( 9 | { 10 | "type": MPerClassLoader, 11 | "params": { 12 | "batch_size": 100, 13 | "m": 10, 14 | "fraction_val": 0.2, 15 | "num_workers": 0, 16 | }, 17 | } 18 | ) 19 | 20 | """ 21 | DATASET 22 | """ 23 | from bispectral_networks.data.transforms import SO2, Ravel, CenterMean, UnitStd, CircleCrop 24 | from bispectral_networks.data.datasets import VanHateren 25 | 26 | 27 | pattern_config = Config( 28 | { 29 | "type": VanHateren, 30 | "params": {"path": "datasets/van-hateren/", 31 | "min_contrast": 0.1, 32 | "patches_per_image": 3, 33 | "patch_size": 16}, 34 | } 35 | ) 36 | 37 | 38 | transforms_config = { 39 | "0": Config( 40 | { 41 | "type": CenterMean, 42 | "params": {} 43 | } 44 | ), 45 | "1": Config( 46 | { 47 | "type": UnitStd, 48 | "params": {} 49 | } 50 | ), 51 | "2": Config( 52 | { 53 | "type": SO2, 54 | "params": { 55 | "fraction_transforms": 0.3, 56 | "sample_method": "random" 57 | }, 58 | } 59 | ), 60 | "3": 61 | Config( 62 | { 63 | "type": CircleCrop, 64 | "params": {} 65 | } 66 | ), 67 | "4": Config( 68 | { 69 | "type": Ravel, 70 | "params": {}, 71 | } 72 | ) 73 | } 74 | 75 | 76 | dataset_config = {"pattern": pattern_config, 77 | "transforms": transforms_config, 78 | "seed": 5} 79 | 80 | 81 | """ 82 | MODEL 83 | """ 84 | from bispectral_networks.nn.model import BispectralEmbedding 85 | from bispectral_networks.nn.layers import Bispectral 86 | 87 | model_config = Config( 88 | { 89 | "type": BispectralEmbedding, 90 | "params": {"size_in": 256, 91 | "hdim": 256}, 92 | } 93 | ) 94 | 95 | 96 | """ 97 | NORMALIZER 98 | """ 99 | from bispectral_networks.normalizer import L2Normalizer 100 | 101 | normalizer_config = Config({ 102 | "type": L2Normalizer, 103 | "params": { 104 | "variables": ["layers.0.W"] 105 | } 106 | 107 | }) 108 | 109 | 110 | """ 111 | LOSS 112 | """ 113 | from pytorch_metric_learning import losses, reducers 114 | from bispectral_networks.loss import OrbitCollapse, LpDistance 115 | loss_config = Config( 116 | { 117 | "type": OrbitCollapse, 118 | "params": { 119 | "distance": LpDistance(), 120 | }, 121 | } 122 | ) 123 | 124 | 125 | """ 126 | OPTIMIZER 127 | """ 128 | from torch.optim import Adam 129 | 130 | optimizer_config = Config({"type": Adam, "params": {"lr": 0.002}}) 131 | 132 | 133 | """ 134 | SCHEDULER 135 | """ 136 | from torch.optim.lr_scheduler import ReduceLROnPlateau 137 | 138 | scheduler_config = Config({"type": ReduceLROnPlateau, "params": {"factor": 0.5, "patience": 2, "min_lr": 1e-6}}) 139 | 140 | 141 | """ 142 | LOGGER 143 | """ 144 | 145 | from bispectral_networks.logger import TBLogger 146 | 147 | logger_config = Config( 148 | { 149 | "type": TBLogger, 150 | "params": { 151 | "log_interval": 1, 152 | "checkpoint_interval": 10, 153 | }, 154 | } 155 | ) 156 | 157 | 158 | """ 159 | MASTER CONFIG 160 | """ 161 | 162 | master_config = { 163 | "data_loader": data_loader_config, 164 | "dataset": dataset_config, 165 | "model": model_config, 166 | "optimizer": optimizer_config, 167 | "normalizer": normalizer_config, 168 | "scheduler": scheduler_config, 169 | "loss": loss_config, 170 | "seed": 162 171 | } 172 | -------------------------------------------------------------------------------- /bispectral_networks/data/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from pytorch_metric_learning.samplers import MPerClassSampler 6 | 7 | 8 | class TrainValLoader: 9 | def __init__(self, 10 | batch_size, 11 | fraction_val=0.2, 12 | num_workers=0, 13 | seed=0): 14 | assert ( 15 | fraction_val <= 1.0 and fraction_val >= 0.0 16 | ), "fraction_val must be a fraction between 0 and 1" 17 | 18 | np.random.seed(seed) 19 | 20 | self.batch_size = batch_size 21 | self.fraction_val = fraction_val 22 | self.seed = seed 23 | self.num_workers = num_workers 24 | 25 | def split_data(self, dataset): 26 | 27 | if self.fraction_val > 0.0: 28 | dataset_size = len(dataset) 29 | indices = list(range(dataset_size)) 30 | split = int(np.floor(self.fraction_val * len(dataset))) 31 | 32 | np.random.shuffle(indices) 33 | 34 | train_indices, val_indices = indices[split:], indices[:split] 35 | val_dataset = copy.deepcopy(dataset) 36 | val_dataset.data = val_dataset.data[val_indices] 37 | val_dataset.labels = val_dataset.labels[val_indices] 38 | 39 | train_dataset = copy.deepcopy(dataset) 40 | train_dataset.data = train_dataset.data[train_indices] 41 | train_dataset.labels = train_dataset.labels[train_indices] 42 | 43 | else: 44 | val_dataset = None 45 | 46 | return train_dataset, val_dataset 47 | 48 | def construct_data_loaders(self, train_dataset, val_dataset): 49 | if val_dataset is not None: 50 | val = torch.utils.data.DataLoader( 51 | val_dataset, 52 | batch_size=self.batch_size, 53 | shuffle=True, 54 | num_workers=self.num_workers, 55 | pin_memory=False 56 | ) 57 | 58 | else: 59 | val = None 60 | 61 | train = torch.utils.data.DataLoader( 62 | train_dataset, 63 | batch_size=self.batch_size, 64 | shuffle=True, 65 | num_workers=self.num_workers, 66 | pin_memory=False 67 | ) 68 | 69 | return train, valg 70 | 71 | def load(self, dataset): 72 | train_dataset, val_dataset = self.split_data(dataset) 73 | self.train, self.val = self.construct_data_loaders(train_dataset, val_dataset) 74 | 75 | 76 | class MPerClassLoader(TrainValLoader): 77 | def __init__(self, 78 | batch_size=100, 79 | m=10, 80 | fraction_val=0.2, 81 | num_workers=0, 82 | seed=0): 83 | 84 | super().__init__(batch_size=batch_size, 85 | fraction_val=fraction_val, 86 | num_workers=num_workers, 87 | seed=seed) 88 | self.m = m 89 | 90 | def construct_data_loaders(self, train_dataset, val_dataset): 91 | if val_dataset is not None: 92 | val = torch.utils.data.DataLoader( 93 | val_dataset, 94 | batch_size=self.batch_size, 95 | shuffle=True, 96 | num_workers=self.num_workers, 97 | pin_memory=False 98 | ) 99 | 100 | else: 101 | val = None 102 | 103 | train_sampler = MPerClassSampler(labels=train_dataset.labels, 104 | m=self.m, 105 | batch_size=self.batch_size) 106 | 107 | train = torch.utils.data.DataLoader( 108 | train_dataset, 109 | batch_size=self.batch_size, 110 | sampler=train_sampler, 111 | num_workers=self.num_workers, 112 | pin_memory=False 113 | ) 114 | 115 | return train, val 116 | -------------------------------------------------------------------------------- /bispectral_networks/data/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import os 6 | import pandas as pd 7 | 8 | 9 | class TransformDataset: 10 | def __init__(self, dataset, transforms): 11 | """ 12 | Arguments 13 | --------- 14 | dataset (obj): 15 | Object from patterns.natural or patterns.synthetic 16 | transforms (list of obj): 17 | List of objects from transformations. The order of the objects 18 | determines the order in which they are applied. 19 | """ 20 | if type(transforms) != list: 21 | transforms = [transforms] 22 | self.transforms = transforms 23 | self.gen_transformations(dataset) 24 | if len(self.data.shape) == 3: 25 | self.img_size = tuple(self.data.shape[1:]) 26 | else: 27 | self.dim = self.data.shape[-1] 28 | 29 | def gen_transformations(self, dataset): 30 | transform_dict = OrderedDict() 31 | transformed_data = dataset.data.clone() 32 | new_labels = dataset.labels.clone() 33 | for transform in self.transforms: 34 | transformed_data, new_labels, transform_dict, t = transform( 35 | transformed_data, new_labels, transform_dict 36 | ) 37 | transform_dict[transform.name] = t 38 | self.data = transformed_data 39 | self.labels = new_labels 40 | self.transform_labels = transform_dict 41 | 42 | def __getitem__(self, idx): 43 | x = self.data[idx] 44 | y = self.labels[idx] 45 | return x, y 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | 51 | class VanHateren(Dataset): 52 | def __init__( 53 | self, 54 | path="datasets/van-hateren/", 55 | normalize=True, 56 | select_img_path="select_imgs.txt", 57 | patches_per_image=10, 58 | patch_size=16, 59 | min_contrast=1.0, 60 | ): 61 | 62 | 63 | super().__init__() 64 | 65 | self.name = "van-hateren" 66 | self.dim = patch_size ** 2 67 | self.path = path 68 | self.patches_per_image = patches_per_image 69 | self.select_img_path = select_img_path 70 | self.normalize = normalize 71 | self.patch_size = patch_size 72 | self.min_contrast = min_contrast 73 | self.img_shape = (1024, 1536) 74 | 75 | full_images = self.load_images() 76 | 77 | self.data, self.labels = self.get_patches(full_images) 78 | 79 | 80 | def get_patches(self, full_images): 81 | data = [] 82 | labels = [] 83 | 84 | i = 0 85 | 86 | for img in full_images: 87 | for p in range(self.patches_per_image): 88 | low_contrast = True 89 | j = 0 90 | while low_contrast and j < 100: 91 | start_x = np.random.randint(0, self.img_shape[1] - self.patch_size) 92 | start_y = np.random.randint(0, self.img_shape[0] - self.patch_size) 93 | patch = img[ 94 | start_y : start_y + self.patch_size, start_x : start_x + self.patch_size 95 | ] 96 | if patch.std() >= self.min_contrast: 97 | low_contrast = False 98 | data.append(patch) 99 | labels.append(i) 100 | j += 1 101 | 102 | if j == 100 and not low_contrast: 103 | print("Couldn't find patch to meet contrast requirement. Skipping.") 104 | continue 105 | 106 | i += 1 107 | data = torch.tensor(np.array(data)) 108 | labels = torch.tensor(np.array(labels)) 109 | return data, labels 110 | 111 | 112 | def load_images(self): 113 | if self.select_img_path is not None: 114 | with open(self.path + self.select_img_path, "r") as f: 115 | img_paths = f.read().splitlines() 116 | else: 117 | img_paths = os.listdir(path + "images/") 118 | 119 | all_imgs = [] 120 | 121 | for i, img_path in enumerate(img_paths): 122 | try: 123 | with open(self.path + "images/" + img_path, 'rb') as handle: 124 | s = handle.read() 125 | except: 126 | print("Can't load image at path {}".format(self.path + img_path)) 127 | continue 128 | img = np.fromstring(s, dtype='uint16').byteswap() 129 | if self.normalize: 130 | # Sets image values to lie between 0 and 1 131 | img = img.astype(float) 132 | img -= img.min() 133 | img /= img.max() 134 | img -= img.mean() 135 | img *= 2 136 | img = img.reshape(self.img_shape) 137 | all_imgs.append(img) 138 | 139 | all_imgs = np.array(all_imgs) 140 | return all_imgs 141 | 142 | def __getitem__(self, idx): 143 | x = self.data[idx] 144 | y = self.labels[idx] 145 | return x, y 146 | 147 | def __len__(self): 148 | return len(self.data) 149 | 150 | 151 | class MNISTExemplars(Dataset): 152 | """ 153 | Dataset object for the MNIST dataset. 154 | Takes the MNIST file path, then loads, standardizes, and saves it internally. 155 | """ 156 | 157 | def __init__(self, path, digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], n_exemplars=1): 158 | 159 | super().__init__() 160 | 161 | self.name = "mnist" 162 | self.dim = 28 ** 2 163 | self.img_size = (28, 28) 164 | self.digits = digits 165 | self.n_exemplars = n_exemplars 166 | 167 | mnist = np.array(pd.read_csv(path)) 168 | 169 | labels = mnist[:, 0] 170 | mnist = mnist[:, 1:] 171 | mnist = mnist / 255 172 | mnist = mnist - mnist.mean(axis=1, keepdims=True) 173 | mnist = mnist / mnist.std(axis=1, keepdims=True) 174 | mnist = mnist.reshape((len(mnist), 28, 28)) 175 | 176 | label_idxs = {i: [j for j, x in enumerate(labels) if x == i] for i in range(10)} 177 | 178 | exemplar_data = [] 179 | labels = [] 180 | for d in digits: 181 | idxs = label_idxs[d] 182 | random_idxs = np.random.choice(idxs, size=self.n_exemplars, replace=False) 183 | for i in random_idxs: 184 | exemplar_data.append(mnist[i]) 185 | labels.append(d) 186 | 187 | self.data = torch.tensor(exemplar_data) 188 | self.labels = torch.tensor(labels).long() 189 | 190 | def __getitem__(self, idx): 191 | x = self.data[idx] 192 | y = self.labels[idx] 193 | return x, y 194 | 195 | def __len__(self): 196 | return len(self.data) -------------------------------------------------------------------------------- /bispectral_networks/data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import itertools 4 | from collections import OrderedDict 5 | from skimage.transform import rotate 6 | 7 | 8 | class Transform: 9 | def __init__(self): 10 | self.name = None 11 | 12 | def define_containers(self, tlabels): 13 | transformed_data, transforms, new_labels = [], [], [] 14 | new_tlabels = OrderedDict({k: [] for k in tlabels.keys()}) 15 | return transformed_data, new_labels, new_tlabels, transforms 16 | 17 | def reformat(self, transformed_data, new_labels, new_tlabels, transforms): 18 | try: 19 | transformed_data = torch.stack(transformed_data) 20 | except: 21 | transformed_data = torch.tensor(transformed_data) 22 | transforms = torch.tensor(transforms) 23 | # new_labels = torch.tensor(new_labels) 24 | new_labels = torch.stack(new_labels) 25 | for k in new_tlabels.keys(): 26 | new_tlabels[k] = torch.stack(new_tlabels[k]) 27 | return transformed_data, new_labels, new_tlabels, transforms 28 | 29 | 30 | class CenterMean(Transform): 31 | def __init__(self): 32 | super().__init__() 33 | self.name = "center-mean" 34 | 35 | def __call__(self, data, labels, tlabels): 36 | if len(data.shape) == 2: 37 | axis = -1 38 | elif len(data.shape) == 3: 39 | axis = (-1, -2) 40 | else: 41 | raise ValueError( 42 | "Operation is not defined for data of dimension {}".format( 43 | len(data.shape) 44 | ) 45 | ) 46 | means = data.mean(axis=axis, keepdims=True) 47 | transformed_data = data - means 48 | return transformed_data, labels, tlabels, means 49 | 50 | 51 | class UnitStd(Transform): 52 | def __init__(self): 53 | super().__init__() 54 | self.name = "unit-std" 55 | 56 | def __call__(self, data, labels, tlabels): 57 | if len(data.shape) == 2: 58 | axis = -1 59 | elif len(data.shape) == 3: 60 | axis = (-1, -2) 61 | else: 62 | raise ValueError( 63 | "Operation is not defined for data of dimension {}".format( 64 | len(data.shape) 65 | ) 66 | ) 67 | stds = data.std(axis=axis, keepdims=True) 68 | transformed_data = data / stds 69 | return transformed_data, labels, tlabels, stds 70 | 71 | 72 | class Ravel(Transform): 73 | def __init__(self): 74 | super().__init__() 75 | self.name = "ravel" 76 | 77 | def __call__(self, data, labels, tlabels): 78 | transformed_data = data.reshape(data.shape[0], -1) 79 | transforms = torch.zeros(len(data)) 80 | return transformed_data, labels, tlabels, transforms 81 | 82 | 83 | class CircleCrop(Transform): 84 | def __init__(self): 85 | super().__init__() 86 | self.name = "circle-crop" 87 | 88 | def __call__(self, data, labels, tlabels): 89 | assert ( 90 | len(data.shape) == 3 91 | ), "Data must have shape (n_datapoints, img_size[0], img_size[1])" 92 | 93 | img_size = data.shape[1:] 94 | 95 | v, h = np.mgrid[: img_size[0], : img_size[1]] 96 | equation = (v - ((img_size[0] - 1) / 2)) ** 2 + ( 97 | h - ((img_size[1] - 1) / 2) 98 | ) ** 2 99 | circle = equation < (equation.max() / 2) 100 | 101 | transformed_data = data.clone() 102 | transformed_data[:, ~circle] = 0.0 103 | transforms = torch.zeros(len(data)) 104 | 105 | return transformed_data, labels, tlabels, transforms 106 | 107 | 108 | class CyclicTranslation2D(Transform): 109 | 110 | def __init__(self, fraction_transforms=0.1, sample_method="linspace"): 111 | super().__init__() 112 | assert sample_method in [ 113 | "linspace", 114 | "random", 115 | ], "sample_method must be one of ['linspace', 'random']" 116 | self.fraction_transforms = fraction_transforms 117 | self.sample_method = sample_method 118 | self.name = "cyclic-translation-2d" 119 | 120 | def get_samples(self, dim_v, dim_h): 121 | n_transforms = int(self.fraction_transforms * dim_h * dim_v) 122 | if self.sample_method == "linspace": 123 | unit_v = dim_v / n_transforms 124 | unit_h = dim_h / n_transforms 125 | return [ 126 | (int(v), int(h)) 127 | for v, h in zip( 128 | np.arange(0, dim_v, unit_v), 129 | np.arange(0, dim_h, unit_h), 130 | ) 131 | ] 132 | else: 133 | all_transforms = list( 134 | itertools.product( 135 | np.arange(dim_v), 136 | np.arange(dim_h), 137 | ) 138 | ) 139 | select_transforms_idx = np.random.choice( 140 | range(len(all_transforms)), size=n_transforms, replace=False 141 | ) 142 | select_transforms = [ 143 | all_transforms[x] for x in sorted(select_transforms_idx) 144 | ] 145 | return select_transforms 146 | 147 | def __call__(self, data, labels, tlabels): 148 | assert ( 149 | len(data.shape) == 3 150 | ), "Data must have shape (n_datapoints, img_size[0], img_size[1])" 151 | 152 | transformed_data, new_labels, new_tlabels, transforms = self.define_containers( 153 | tlabels 154 | ) 155 | 156 | dim_v, dim_h = data.shape[-2:] 157 | select_transforms = self.get_samples(dim_v, dim_h) 158 | for i, x in enumerate(data): 159 | if self.sample_method == "random" and self.fraction_transforms != 1.0: 160 | select_transforms = self.get_samples(dim_v, dim_h) 161 | for tv, th in select_transforms: 162 | xt = torch.roll(x, (tv, th), dims=(-2, -1)) 163 | transformed_data.append(xt) 164 | transforms.append((int(tv), int(th))) 165 | new_labels.append(labels[i]) 166 | for k in new_tlabels.keys(): 167 | new_tlabels[k].append(tlabels[k][i]) 168 | 169 | transformed_data, new_labels, new_tlabels, transforms = self.reformat( 170 | transformed_data, new_labels, new_tlabels, transforms 171 | ) 172 | return transformed_data, new_labels, new_tlabels, transforms 173 | 174 | 175 | class SO2(Transform): 176 | def __init__(self, fraction_transforms=0.1, sample_method="linspace"): 177 | super().__init__() 178 | assert sample_method in [ 179 | "linspace", 180 | "random", 181 | ], "sample_method must be one of ['linspace', 'random']" 182 | self.fraction_transforms = fraction_transforms 183 | self.sample_method = sample_method 184 | self.name = "so2" 185 | 186 | def get_samples(self): 187 | n_transforms = int(self.fraction_transforms * 360) 188 | if self.sample_method == "linspace": 189 | return np.linspace(0, 359, n_transforms) 190 | else: 191 | select_transforms = np.random.choice( 192 | np.arange(360), size=n_transforms, replace=False 193 | ) 194 | select_transforms = sorted(select_transforms) 195 | return select_transforms 196 | 197 | def __call__(self, data, labels, tlabels): 198 | assert ( 199 | len(data.shape) == 3 200 | ), "Data must have shape (n_datapoints, img_size[0], img_size[1])" 201 | 202 | transformed_data, new_labels, new_tlabels, transforms = self.define_containers( 203 | tlabels 204 | ) 205 | 206 | select_transforms = self.get_samples() 207 | for i, x in enumerate(data): 208 | if self.sample_method == "random": 209 | select_transforms = self.get_samples() 210 | for t in select_transforms: 211 | xt = rotate(x, t) 212 | transformed_data.append(xt) 213 | transforms.append(t) 214 | new_labels.append(labels[i]) 215 | for k in new_tlabels.keys(): 216 | new_tlabels[k].append(tlabels[k][i]) 217 | 218 | transformed_data, new_labels, new_tlabels, transforms = self.reformat( 219 | transformed_data, new_labels, new_tlabels, transforms 220 | ) 221 | return transformed_data, new_labels, new_tlabels, transforms -------------------------------------------------------------------------------- /bispectral_networks/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import OrderedDict 4 | import copy 5 | from bispectral_networks.config import Config 6 | from bispectral_networks.data.utils import gen_dataset 7 | 8 | 9 | class BispectralTrainer(torch.nn.Module): 10 | def __init__( 11 | self, 12 | model, 13 | loss, 14 | optimizer, 15 | recon_coeff=150, 16 | logger=None, 17 | scheduler=None, 18 | normalizer=None, 19 | ): 20 | super().__init__() 21 | self.recon_coeff = recon_coeff 22 | self.model = model 23 | self.loss = loss 24 | self.logger = logger 25 | self.normalizer = normalizer 26 | self.optimizer = optimizer 27 | self.scheduler = scheduler 28 | self.epoch = 0 29 | self.n_examples = 0 30 | 31 | def __getstate__(self): 32 | d = self.__dict__ 33 | self_dict = {k : d[k] for k in d if k != '_modules'} 34 | module_dict = OrderedDict({'loss': self.loss}) 35 | if self.normalizer is not None: 36 | module_dict["normalizer"] = self.normalizer 37 | if self.scheduler is not None: 38 | module_dict["scheduler"] = self.scheduler 39 | self_dict['_modules'] = module_dict 40 | return self_dict 41 | 42 | def __setstate__(self, state): 43 | self.__dict__ = state 44 | 45 | def step(self, data_loader, grad=True): 46 | log_dict = {"loss": 0, "rep_loss": 0, "recon_loss": 0, "total_loss": 0} 47 | for i, (x, labels) in enumerate(data_loader): 48 | loss = 0 49 | total_loss = 0 50 | 51 | x = x.to(self.model.device) 52 | labels = labels.to(self.model.device) 53 | 54 | if grad: 55 | self.optimizer.zero_grad() 56 | out, recon = self.model.forward(x) 57 | 58 | else: 59 | with torch.no_grad(): 60 | out, recon = self.model.forward(x) 61 | 62 | rep_loss = abs(self.loss(out, labels)) 63 | recon_loss = self.recon_coeff * torch.nn.functional.mse_loss(recon, x.float()) 64 | 65 | log_dict["rep_loss"] += rep_loss 66 | log_dict["recon_loss"] += recon_loss 67 | total_loss += rep_loss + recon_loss 68 | 69 | if grad: 70 | total_loss.backward() 71 | self.optimizer.step() 72 | 73 | if self.normalizer is not None: 74 | self.normalizer(dict(self.model.named_parameters())) 75 | 76 | log_dict["total_loss"] += total_loss 77 | 78 | n_samples = len(data_loader) 79 | for key in log_dict.keys(): 80 | log_dict[key] /= n_samples 81 | 82 | plot_variable_dict = {"model": self.model} 83 | 84 | return log_dict, plot_variable_dict 85 | 86 | def train( 87 | self, 88 | data_loader, 89 | epochs, 90 | start_epoch=0, 91 | print_status_updates=True, 92 | print_interval=1, 93 | ): 94 | if self.logger is not None: 95 | writer = self.logger.begin(self.model, data_loader) 96 | 97 | try: 98 | for i in range(start_epoch, start_epoch + epochs + 1): 99 | self.epoch = i 100 | log_dict, plot_variable_dict = self.step(data_loader.train, grad=True) 101 | 102 | 103 | if data_loader.val is not None: 104 | # By default, plots are only generated on train steps 105 | val_log_dict, _ = self.evaluate( 106 | data_loader.val 107 | ) 108 | else: 109 | val_log_dict = None 110 | 111 | if self.scheduler is not None: 112 | if val_log_dict is not None: 113 | self.scheduler.step(val_log_dict["total_loss"]) 114 | else: 115 | self.scheduler.step(train_log_dict["total_loss"]) 116 | 117 | if self.logger is not None: 118 | self.logger.log_step( 119 | writer=writer, 120 | trainer=self, 121 | log_dict=log_dict, 122 | val_log_dict=val_log_dict, 123 | variable_dict=plot_variable_dict, 124 | epoch=self.epoch, 125 | ) 126 | 127 | self.n_examples += len(data_loader.train.dataset) 128 | 129 | if i % print_interval == 0 and print_status_updates == True: 130 | if data_loader.val is not None: 131 | self.print_update(log_dict, val_log_dict) 132 | else: 133 | self.print_update(log_dict) 134 | 135 | 136 | except KeyboardInterrupt: 137 | print("Stopping and saving run at epoch {}".format(i)) 138 | end_dict = {"model": self.model, "data_loader": data_loader} 139 | if self.logger is not None: 140 | self.logger.end(self, end_dict, self.epoch) 141 | 142 | def resume(self, data_loader, epochs): 143 | self.train(data_loader, epochs, start_epoch=self.epoch+1) 144 | 145 | @torch.no_grad() 146 | def evaluate(self, data_loader): 147 | results = self.step(data_loader, grad=False) 148 | return results 149 | 150 | def print_update(self, result_dict_train, result_dict_val=None): 151 | 152 | update_string = "Epoch {} || N Examples {} || Train Total Loss {:0.5f}".format( 153 | self.epoch, self.n_examples, result_dict_train["total_loss"] 154 | ) 155 | if result_dict_val: 156 | update_string += " || Validation Total Loss {:0.5f}".format( 157 | result_dict_val["total_loss"] 158 | ) 159 | print(update_string) 160 | 161 | 162 | def construct_trainer(master_config, logger_config=None): 163 | """ 164 | master_config has the following format: 165 | master_config = { 166 | "dataset": dataset_config, 167 | "model": model_config, 168 | "optimizer": optimizer_config, 169 | "loss": loss_config, 170 | "data_loader": data_loader_config, 171 | } 172 | with optional regularizer, normalizer, and learning rate scheduler 173 | """ 174 | 175 | if "seed" in master_config: 176 | torch.manual_seed(master_config["seed"]) 177 | np.random.seed(master_config["seed"]) 178 | 179 | model = master_config["model"].build() 180 | loss = master_config["loss"].build() 181 | 182 | logger_config["params"]["config"] = master_config 183 | logger = logger_config.build() 184 | 185 | optimizer_config = copy.deepcopy(master_config["optimizer"]) 186 | optimizer_config["params"]["params"] = model.parameters() 187 | optimizer = optimizer_config.build() 188 | 189 | train_config = Config( 190 | { 191 | "type": BispectralTrainer, 192 | "params": { 193 | "model": model, 194 | "loss": loss, 195 | "logger": logger, 196 | "optimizer": optimizer, 197 | }, 198 | } 199 | ) 200 | 201 | if "regularizer" in master_config: 202 | regularizer = master_config["regularizer"].build() 203 | train_config["params"]["regularizer"] = regularizer 204 | 205 | if "normalizer" in master_config: 206 | normalizer = master_config["normalizer"].build() 207 | train_config["params"]["normalizer"] = normalizer 208 | 209 | if "scheduler" in master_config: 210 | scheduler_config = copy.deepcopy(master_config["scheduler"]) 211 | scheduler_config["params"]["optimizer"] = optimizer 212 | scheduler = scheduler_config.build() 213 | train_config["params"]["scheduler"] = scheduler 214 | 215 | trainer = train_config.build() 216 | 217 | return trainer 218 | 219 | 220 | def run_trainer(master_config, 221 | logger_config, 222 | device=0, 223 | n_examples=1e9, 224 | seed=None): 225 | 226 | if seed is not None: 227 | master_config["seed"] = seed 228 | 229 | dataset = gen_dataset(master_config["dataset"]) 230 | 231 | data_loader = master_config["data_loader"].build() 232 | data_loader.load(dataset) 233 | 234 | trainer = construct_trainer(master_config, logger_config=logger_config) 235 | 236 | epochs = int(n_examples // len(data_loader.train.dataset.data)) 237 | trainer.model.device = device 238 | trainer.model = trainer.model.to(device) 239 | trainer.train(data_loader, epochs=epochs) 240 | --------------------------------------------------------------------------------