├── .gitignore ├── Dataset.py ├── Figures ├── .DS_Store ├── Figure1.pdf ├── Figure1.svg ├── Figure2.pdf ├── Figure2.svg ├── Figure3.pdf ├── Figure3.svg ├── Figure4.pdf ├── Figure4.svg ├── Figure5.pdf ├── Figure5.svg └── networks │ └── .DS_Store ├── ModelState.py ├── Network.py ├── README.md ├── cifar.py ├── fig2_network_performance.py ├── fig3_unit_taxonomy.py ├── fig4_lesion_study.py ├── fig5_cifar10_exp.py ├── functions.py ├── helper.py ├── mnist.py ├── models ├── .DS_Store └── patterns_rev │ └── seeded_mnist │ └── .gitignore ├── paper_results.ipynb ├── plot.py ├── requirements.txt ├── results.sh ├── supplement.py ├── test.py ├── train.py └── train_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | models/patterns_rev/seeded_mnist/* 2 | -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from types import SimpleNamespace 4 | 5 | # Named tuple to store dataset properties 6 | class Dataset(SimpleNamespace): 7 | x: torch.FloatTensor 8 | y: torch.FloatTensor 9 | indices: List[torch.FloatTensor] -------------------------------------------------------------------------------- /Figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/.DS_Store -------------------------------------------------------------------------------- /Figures/Figure1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/Figure1.pdf -------------------------------------------------------------------------------- /Figures/Figure2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/Figure2.pdf -------------------------------------------------------------------------------- /Figures/Figure3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/Figure3.pdf -------------------------------------------------------------------------------- /Figures/Figure4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/Figure4.pdf -------------------------------------------------------------------------------- /Figures/Figure5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/Figure5.pdf -------------------------------------------------------------------------------- /Figures/networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/Figures/networks/.DS_Store -------------------------------------------------------------------------------- /ModelState.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import re 4 | 5 | class ModelState: 6 | """A class to encapsulate a neural network model with a number of attributes associated with it. 7 | 8 | Serves as a place to store associated attributes to a model, such as the optimizer or training metadata. 9 | """ 10 | def __init__(self, 11 | model, 12 | optimizer, 13 | lr:float, 14 | title:str, 15 | results, 16 | device:str): 17 | 18 | self.model = model 19 | self.optimizer = optimizer(self.model.parameters(), lr=lr) 20 | self.title = title 21 | self.epochs = 0 22 | self.results = results 23 | self.device = device 24 | 25 | def save(self): 26 | filepath = "./models/" + self.title +".pth" 27 | 28 | torch.save({ 29 | "epochs": self.epochs, 30 | "model_state_dict": self.model.state_dict(), 31 | "optimizer_state_dict": self.optimizer.state_dict(), 32 | "results": self.results 33 | }, filepath) 34 | 35 | def load(self, idx=None): 36 | if (idx is None): 37 | filepath = "./models/" + self.title +".pth" 38 | else: 39 | filepath = "./models/" + self.title +"_" + str(idx) + ".pth" 40 | 41 | state = torch.load(filepath, map_location=torch.device(self.device)) 42 | self.epochs = state['epochs'] 43 | self.model.load_state_dict(state['model_state_dict']) 44 | self.optimizer.load_state_dict(state['optimizer_state_dict']) 45 | self.results = state['results'] 46 | -------------------------------------------------------------------------------- /Network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from ModelState import ModelState 5 | import functions 6 | from torch import nn 7 | 8 | class Network(torch.nn.Module): 9 | """ 10 | Recurrent Neural Network class containing parameters of the network 11 | and computes the forward pass. 12 | Returns hidden state of the network and preactivations of the units. 13 | """ 14 | def __init__(self, input_size: int, hidden_size: int, activation_func, weights_init=functions.init_params, prevbatch=False, conv=False, device=None): 15 | super(Network, self).__init__() 16 | 17 | self.input_size = input_size 18 | if conv: 19 | self.conv = nn.Conv2d(in_channels=3, out_channels=32,kernel_size=3) 20 | self.is_conv= conv 21 | self.hidden_size = hidden_size 22 | self.activation_func = activation_func 23 | self.W = torch.nn.Parameter(weights_init(hidden_size, hidden_size)) 24 | self.prevbatch = prevbatch 25 | self.device = device 26 | 27 | def forward(self, x, state=None, synap_trans=False, mask=None): 28 | 29 | if state is None: 30 | state = self.init_state(x.shape[0]) 31 | h = state 32 | h = h.to(self.device) 33 | x = x.to(self.device) 34 | 35 | # pad input so it matches the hidden state dimensions 36 | if not self.is_conv: 37 | x_pad = F.pad(x, (0, self.hidden_size-self.input_size), "constant", 0) 38 | if mask is not None: 39 | a = h @ (self.W * mask) + x_pad 40 | else: 41 | a = h @ self.W + x_pad 42 | 43 | h = self.activation_func(a) 44 | # return state vector and list of losses 45 | return h, [a, h, self.W] 46 | 47 | def init_state(self, batch_size): 48 | return torch.zeros((batch_size, self.hidden_size)) 49 | 50 | 51 | class State(ModelState): 52 | def __init__(self, 53 | activation_func, 54 | optimizer, 55 | lr:float, 56 | title:str, 57 | input_size:int, 58 | hidden_size:int, 59 | device:str, 60 | deterministic=True, 61 | weights_init=functions.init_params, 62 | prevbatch=False, 63 | conv=False, 64 | seed=None): 65 | 66 | if seed != None: 67 | torch.manual_seed(seed) 68 | np.random.seed(seed) 69 | self.seed = seed 70 | 71 | ModelState.__init__(self, 72 | 73 | Network(input_size, hidden_size, activation_func, weights_init=weights_init, prevbatch=prevbatch, conv=conv, device=device).to(device), 74 | optimizer, 75 | lr, 76 | 77 | title, 78 | { 79 | "train loss": np.zeros(0), 80 | "test loss": np.zeros(0), 81 | "h": np.zeros(0), 82 | "Wl1": np.zeros(0), 83 | "Wl2": np.zeros(0) 84 | }, 85 | device) 86 | 87 | def run(self, batch, loss_fn, state=None): 88 | """ 89 | Runs a batch of sequences through the model 90 | 91 | Returns: 92 | loss, 93 | training metadata 94 | """ 95 | sequence_length = batch.shape[0] 96 | batch_size = batch.shape[1] 97 | if state == None: 98 | h = self.model.init_state(batch_size) 99 | else: 100 | if self.model.prevbatch: 101 | 102 | h = state 103 | else: 104 | h = self.model.init_state(batch_size) 105 | 106 | loss = torch.zeros(1, dtype=torch.float, requires_grad=True) 107 | loss = loss.to(self.device) 108 | for i in range(sequence_length): 109 | h, l_a = self.model(batch[i], state=h) # l_a is now a list of potential loss terms 110 | 111 | loss = loss + self.loss(l_a, loss_fn) 112 | state = h 113 | return loss, loss.detach(), state 114 | 115 | def get_next_state(self, state, x): 116 | """ 117 | Return next state of model given current state and input 118 | 119 | """ 120 | next_state, _ = self.model(x, state) 121 | return next_state 122 | 123 | 124 | def loss(self, loss_terms, loss): 125 | loss_t1, loss_t2, beta = loss, None, 1 126 | # split for weighting 127 | if 'beta' in loss: 128 | beta, loss = loss.split('beta') 129 | beta = float(beta) 130 | if 'and' in loss: 131 | loss_t1, loss_t2 = loss.split('and') 132 | 133 | # parse loss terms 134 | loss_fn_t1, loss_arg_t1 = functions.parse_loss(loss_t1, loss_terms) 135 | loss_fn_t2, loss_arg_t2 = functions.parse_loss(loss_t2, loss_terms) 136 | 137 | return loss_fn_t1(loss_arg_t1) + beta*loss_fn_t2(loss_arg_t2) 138 | 139 | def predict(self, state, latent=False): 140 | """ 141 | Returns the networks 'prediction' for the input. 142 | """ 143 | state = state.to(self.device) 144 | pred = state @ self.model.W 145 | if not latent: 146 | return pred[:,:self.model.input_size] 147 | return pred[:,:] 148 | 149 | def predict_predonly(self, state, pred_mask, latent=False): 150 | """ 151 | Returns the networks 'prediction' for the input 152 | from the prediction units only 153 | """ 154 | W_pred = self.model.W.clone().detach() 155 | # set all non prediction units to zero 156 | W_pred[pred_mask==1, :] = 0 157 | pred = state @ W_pred 158 | if not latent: 159 | return pred[:,:self.model.input_size] 160 | return pred[:,:] 161 | 162 | def step(self, loss): 163 | loss.backward() 164 | #nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0) 165 | self.optimizer.step() 166 | 167 | def zero_grad(self): 168 | self.optimizer.zero_grad() 169 | 170 | def on_results(self, epoch:int, train_res, test_res, m_state): 171 | """Save training metadata 172 | """ 173 | h, Wl1,Wl2 = m_state 174 | functions.append_dict(self.results, {"train loss": train_res.cpu().numpy(), "test loss": test_res.cpu().numpy(), "h": h, "Wl1": Wl1, "Wl2":Wl2}) 175 | self.epochs += 1 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Energy efficient predictive coding networks 2 | 3 | ### Description 4 | Supplementary material for 'Predictive coding is a consequence of energy efficiency in recurrent neural networks' 5 | 6 | ### Dependencies 7 | 8 | ```pip install -r requirements.txt``` 9 | 10 | Or look in requirements.txt - be sure to use Python >=3.7 11 | 12 | ### Usage Notes: 13 | - To replicate the results in the paper, please run results.sh (which in turn calls all of the figure generators) - takes 1h15m on a Mac M1 14 | - trained models can be found [here](https://osf.io/c57d4/), but if you want to train your own models you can run train_models.py to produce the model instances used in determining the model preactivations. 15 | - Training has changed slightly since the paper was published, with early stopping implemented (set patience when calling train() - patience is the number of epochs to put up with no val_loss improvement) 16 | - Training takes 2h on a Mac M1 (vs 2 days for the original paper), with patience 5, and 20h with patience 200. I find that unless patience is really high, results aren't so good - necessary to break out of some local minima? 17 | 18 | - Also see [paper_results.ipynb](https://github.com/KietzmannLab/EmergentPredictiveCoding/blob/master/paper_results.ipynb) - but not that this is no longer current 19 | 20 | ```python train_models.py``` 21 | 22 | The script will automatically run on a gpu if a gpu is available and cuda is set up. Otherwise the script will revert back to cpu. If multiple gpu nodes are available, you can select which node you want the script to run on by prepending CUDA_VISIBLE_DEVICES, i.e: 23 | 24 | ```CUDA_VISIBLE_DEVICES=GPU_ID python train_models.py``` 25 | 26 | ### Data Sets 27 | We use the MNIST database of handwritten digits and CIFAR10, a labelled subset of the tiny image database . We created wrappers (mnist.py, cifar.py) that loads and transforms the images into sequences in ascending class order (with wraparound from class 9 to class 0). The sequenced data set is used as data for the networks. The appropriate training and test data can be created by simply calling: 28 | 29 | ```import mnist``` 30 | 31 | ```training_set, validation_set, test_set = mnist.load(val_ratio=0.0)``` 32 | for MNIST and: 33 | 34 | ```import cifar``` 35 | 36 | ```training_set, validation_set, test_set = mnist.load(val_ratio=0.0, color=True)``` 37 | for CIFAR10. 38 | The (batches of) sequences can then be generated by: 39 | 40 | ``` batches, labels = dataset.create_batches(batch_size=batch_size, sequence_length=sequence_length, shuffle=True)``` 41 | 42 | Where 'dataset denotes training_set, validation_set or test_set'. 43 | 44 | ### Usage Conditions 45 | If you use this code in your work, we ask you to please cite: 46 | Ali, A., Ahmad N., de Groot E., van Gerven M.A.J., Kietzmann T.C. (2021). **Predictive coding is a consequence of energy efficiency in recurrent neural networks.** doi: https://doi.org/10.1101/2021.02.16.430904 47 | 48 | ### TODO 49 | 50 | - Fix runtime warnings on fig*.py 51 | 52 | ### M1 acceleration 53 | 54 | The operator 'aten::sgn.out' was recently (as of 10.26.2023) added to PyTorch, so you might need a nightly 55 | build to install support for it 56 | 57 | ```pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu``` 58 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | from Dataset import Dataset 5 | import random 6 | 7 | class CIFAR10Dataset(Dataset): 8 | """Container class for the CIFAR10 database containing Tensors with the images and labels, as well as a list of indices for each category 9 | """ 10 | def __init__(self, x, y, indices, repeat=1): 11 | super(Dataset, self).__init__(x=x, y=y, indices=indices) 12 | self.repeat = repeat 13 | 14 | 15 | def create_batches(self, batch_size, sequence_length, shuffle=True, distractor=False, fixed_starting_point=None): 16 | data, labels = create_sequences(self, sequence_length, batch_size, shuffle, distractor, fixed_starting_point) 17 | data = data.repeat_interleave(self.repeat, dim=1) 18 | labels = labels.repeat_interleave(self.repeat, dim=1) 19 | return data, labels 20 | 21 | 22 | def create_sequences(dataset, sequence_length, batch_size, shuffle=True, distractor=False, fixed_starting_point=None): 23 | # number of datapoints 24 | data_size, ninputs = dataset.x.shape 25 | 26 | # maximum theoretical amount of sequences 27 | max_sequences = int(data_size / sequence_length) 28 | 29 | # for test and validation it is not actually necessary to shuffle, 30 | # so for consistent testing/validation we can use the same sequences every time 31 | if shuffle: 32 | # shuffle all the data points per digit class 33 | indices = [dataset.indices[i][torch.randperm(d.shape[0])] for i,d in enumerate(dataset.indices)] 34 | # choose random sequence starting points 35 | seq_starting_points = torch.randperm(max_sequences) 36 | else: 37 | indices = dataset.indices 38 | seq_starting_points = torch.arange(max_sequences) 39 | # if we want the same starting digit for all the sequences 40 | if fixed_starting_point is not None: 41 | assert(isinstance(fixed_starting_point, int) and fixed_starting_point in list(range(10))) 42 | seq_starting_points = torch.ones(max_sequences) * fixed_starting_point 43 | # from the starting points, create sequences of the required length 44 | # first we repeat each starting point 'sequence_length' times 45 | sequences = seq_starting_points.repeat_interleave(sequence_length).view(max_sequences, sequence_length) 46 | # we then add to each digit the index of its position within the sequence, 47 | # so we get increasing numbers in the sequence 48 | for i in range(1, sequence_length): 49 | sequences[:,i] += i 50 | # take the remainder of all numbers in sequence to get actual digits from 0-9 51 | sequences %= 10 52 | 53 | # switch out digit at position 8 for a distractor if flag is true 54 | if distractor: 55 | for i in range(max_sequences): 56 | digit = sequences[i,8] 57 | candidates = list(range(0,10)) 58 | candidates.remove(digit) 59 | sequences[i, 8] = random.choice(candidates) 60 | 61 | # flatten again 62 | sequences = sequences.flatten() 63 | # create an array to store the indices for the digits in 'data' 64 | epoch_indices = torch.zeros(data_size, dtype=torch.long) 65 | # because not every digit is equally represented, 66 | # we have to keep track of where in the sequence we have run out of 67 | # digits. This 'cutoff' is the minimum between all digits 68 | cutoff = data_size 69 | 70 | for i in range(10): 71 | # mask to filter out the positions of this digit 72 | mask = sequences==i 73 | # calculating the cumulative sum of the mask gives us a nice increasing 74 | # index exactly at the points of where the digit is in the list of sequences. 75 | # we can use this as an index for 'indices' 76 | indices_idx = torch.cumsum(mask, 0) 77 | # we cut 'idx' off where the index exceeds the number of digits we actually have 78 | # for this case 79 | indices_idx = indices_idx[indices_idx < indices[i].shape[0]] 80 | # keep track of the earliest cutoff point for later 81 | cutoff = min(cutoff, indices_idx.shape[0]) 82 | # also cutoff the mask so it has the right shape 83 | mask = mask[:indices_idx.shape[0]] 84 | # we select the data indices from 'indices' with 'indices_idx', mask that 85 | # so we are left with the data indices on the positions where the digits occur 86 | # in the sequences 87 | epoch_indices[:indices_idx.shape[0]][mask] = indices[i][indices_idx][mask] 88 | 89 | # if batch_size is invalid, create one big batch 90 | if batch_size < 1 or batch_size > int(cutoff / sequence_length): 91 | batch_size = int(cutoff / sequence_length) 92 | 93 | # we cut off the cutoff point so we can create an integer amount of batches and sequences 94 | cutoff = cutoff - cutoff % (batch_size * sequence_length) 95 | 96 | epoch_indices = epoch_indices[:cutoff] 97 | sequences = sequences[:cutoff] 98 | # select the data points and group per sequence and batch 99 | #x = dataset.x[epoch_indices].view(-1, batch_size, sequence_length, 32*32).transpose(1,2) 100 | x = dataset.x[epoch_indices].view(-1, batch_size, sequence_length, ninputs).transpose(1,2) 101 | y = sequences.view(-1, batch_size, sequence_length).transpose(1,2) 102 | return x, y 103 | 104 | 105 | def load(val_ratio = 0.1, color=False): 106 | """Load CIFAR10 data, transform to tensors and grayscale (if color=True) and calculate indices for each category 107 | """ 108 | if color: 109 | transform = transforms.ToTensor() 110 | 111 | nchannels = 3 112 | 113 | else: # gray scale 114 | transform = transforms.Compose( 115 | [transforms.Grayscale(), transforms.ToTensor() 116 | ]) 117 | nchannels = 1 118 | 119 | train_data = torchvision.datasets.CIFAR10("./datasets/", train=True, transform=transform, download=True) 120 | test_data = torchvision.datasets.CIFAR10("./datasets/", train=False, transform=transform, download=True) 121 | 122 | validation_size = int(val_ratio * len(train_data)) 123 | train_size = len(train_data) - validation_size 124 | ninputs = 32,32 125 | train_dimens = (train_size, nchannels, ninputs[0]*ninputs[1]) 126 | val_dimens = (validation_size, nchannels, ninputs[0]*ninputs[1]) 127 | 128 | train_x = torch.zeros(train_dimens) 129 | train_y = torch.zeros(train_size, dtype=torch.int) 130 | 131 | val_x = torch.zeros(val_dimens) 132 | val_y = torch.zeros(validation_size, dtype=torch.int) 133 | for i, d in enumerate(train_data): 134 | if i < train_size: 135 | train_x[i] = d[0].view(train_dimens[1], train_dimens[-1]) 136 | train_y[i] = d[1] 137 | else: 138 | 139 | val_x[i-train_size] = d[0].view(val_dimens[0], val_dimens[-1]) 140 | val_y[i-train_size] = d[1] 141 | # safe image indices for each category 142 | train_indices = [torch.nonzero(train_y==i).flatten() for i in range(10)] 143 | val_indices = [torch.nonzero(val_y==i).flatten() for i in range(10)] 144 | training_set = CIFAR10Dataset(x=train_x, y=train_y, indices=train_indices) 145 | validation_set = CIFAR10Dataset(x=val_x, y=val_y, indices=val_indices) 146 | 147 | test_size = len(test_data) 148 | test_dimens = (test_size, nchannels, ninputs[0]*ninputs[1]) 149 | 150 | test_x = torch.zeros(test_dimens) 151 | test_y = torch.zeros(test_size, dtype=torch.int) 152 | for i, d in enumerate(test_data): 153 | test_x[i] = d[0].view(test_dimens[1], test_dimens[-1]) 154 | test_y[i] = d[1] 155 | test_indices = [torch.nonzero(test_y==i).flatten() for i in range(10)] 156 | test_set = CIFAR10Dataset(x=test_x, y=test_y, indices=test_indices) 157 | 158 | # reshape to do a horizontal stack of color channels 159 | training_set.x = training_set.x.view(train_dimens[0], train_dimens[1]*train_dimens[2]) 160 | validation_set.x = validation_set.x.view(val_dimens[0], val_dimens[1]*val_dimens[2]) 161 | test_set.x = test_set.x.view(test_dimens[0], test_dimens[1]*test_dimens[2]) 162 | return training_set, validation_set, test_set 163 | 164 | def means(dataset:CIFAR10Dataset): 165 | means = torch.Tensor(10,32*32) 166 | for i in range(10): 167 | means[i] = torch.mean(dataset.x[dataset.indices[i]],dim=0) 168 | return means 169 | 170 | def medians(dataset:CIFAR10Dataset): 171 | ndata, ninputs = dataset.x.shape 172 | medians = torch.Tensor(10,ninputs) 173 | for i in range(10): 174 | medians[i] = torch.median(dataset.x[dataset.indices[i]],dim=0).values 175 | return medians 176 | 177 | if __name__ == '__main__': 178 | train, val, test = load() 179 | -------------------------------------------------------------------------------- /fig2_network_performance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Apr 29 09:24:57 2022 5 | 6 | @author: tempali 7 | 8 | In this analysis we compare how well L1 pre does vs. L1 post. 9 | """ 10 | 11 | # imports 12 | 13 | import torch 14 | import numpy as np 15 | import argparse 16 | import matplotlib.pyplot as plt 17 | import os 18 | import pandas as pd 19 | from functions import get_device 20 | 21 | parser = argparse.ArgumentParser(description='device') 22 | parser.add_argument('--i', type=str, help='Device index') 23 | args = parser.parse_args() 24 | 25 | DEVICE = get_device() 26 | 27 | R_PATH = 'Results/Fig2/Data/' 28 | F_PATH = 'Results/Fig2/' 29 | M_PATH = 'patterns_rev/seeded_mnist/' 30 | 31 | hdf_path = R_PATH+'network_stats.h5' 32 | 33 | LOAD = False 34 | SEED = None 35 | if not os.path.isdir(os.path.dirname(R_PATH)): 36 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 37 | if not os.path.isdir(os.path.dirname(F_PATH)): 38 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 39 | 40 | if SEED != None: 41 | torch.manual_seed(SEED) 42 | np.random.seed(SEED) 43 | 44 | # set up hdf5 file to store the results 45 | if not os.path.exists(hdf_path): 46 | store = pd.HDFStore(hdf_path) 47 | store.close() 48 | INPUT_SIZE = 28*28 49 | Z_CRIT = 2.576 #99% 50 | SEQ_LENGTH = 10 51 | # dataset loaders 52 | import mnist 53 | 54 | # framework files 55 | import Network 56 | import helper 57 | import plot 58 | from matplotlib.ticker import MaxNLocator 59 | 60 | # load data 61 | train_set, validation_set, test_set = mnist.load(val_ratio=0.0) 62 | 63 | # load pre, post MNIST networks 64 | nets = [[], [], []] 65 | 66 | NUM_INSTANCES=10 67 | # load networks for bootstrap 68 | losses = ['l1_pre','l1_post', [str(beta)+'beta'+'l1_postandl2_weights' for beta in [3708.0] ][0]] 69 | # set up dictionaries to fill in the data 70 | ec_results, ap_results, st_results, pre_results = dict(), dict(), dict(), dict() 71 | result_list = [('ec', ec_results),('ap', ap_results), ('st', st_results), ('pre', pre_results)] 72 | for loss_ind, loss in enumerate(losses): 73 | for i in range(0, NUM_INSTANCES): 74 | net = Network.State(activation_func=torch.nn.ReLU(), 75 | optimizer=torch.optim.Adam, 76 | lr=1e-4, 77 | input_size=INPUT_SIZE, 78 | hidden_size=INPUT_SIZE, 79 | title=M_PATH+"mnist_net_"+loss, 80 | device=DEVICE) 81 | net.load(i) 82 | nets[loss_ind].append(net) 83 | 84 | 85 | 86 | # # open file to read/writedata 87 | store = pd.HDFStore(hdf_path) 88 | # fig 2A: RNN_pre performs better than RNN_post 89 | #------------------------------------------------------------------------------ 90 | if not os.path.exists(hdf_path) or LOAD == False: 91 | # calculate energy consumption for the losses 92 | for loss_ind, loss in enumerate(losses): 93 | energies = dict() # dict of dicts 94 | for (ename, e_results) in result_list: 95 | bs_sample_dict = helper.bootstrap_model_activity(nets[loss_ind], train_set, test_set, energy = ename, seed=None, lesioned=False) 96 | en_samples = np.zeros((len(nets[loss_ind]), SEQ_LENGTH)) 97 | for i, net in enumerate(nets[loss_ind]): 98 | mean_en, _ =\ 99 | helper.model_activity_lesioned(net, train_set, test_set, lesion_type='pred', seq_length=10, energy=ename, save=False,\ 100 | latent=False, data_type='mnist',Z_crit=Z_CRIT) 101 | 102 | 103 | # fill sample matrices 104 | en_samples[i, :] = mean_en 105 | 106 | #compute bootstrap bounds and store results in dataframe 107 | [en_bounds] = helper.compute_bootstrap([en_samples]) 108 | en_samples, en_bounds = bs_sample_dict['norm'], bs_sample_dict['bs_norm'] 109 | energies[ename] = [en_samples[0], en_bounds[0]] 110 | df_loss = pd.DataFrame(data=energies) 111 | store[loss] = df_loss 112 | df_pre, df_post, df_pw = store['l1_pre'], store['l1_post'], store['3708.0betal1_postandl2_weights'] 113 | 114 | else: 115 | df_pre, df_post, df_pw = store['l1_pre'], store['l1_post'], store['3708.0betal1_postandl2_weights'] 116 | # retrieve energies and 117 | enames = list(zip(*result_list))[0] 118 | for ename in enames: 119 | x = np.arange(1,SEQ_LENGTH+1) 120 | start_index = 0 121 | 122 | 123 | # get samples for pre post and weighted post 124 | pre_samples, pre_bootstraps = df_pre[ename] 125 | act_samples, act_bootstraps = df_post[ename] 126 | pw_samples, pw_bootstraps = df_pw[ename] 127 | if ename == 'ap': 128 | fig, ax = plt.subplots(1,1) 129 | # add l1(preactivation) models 130 | mu_pre = np.mean(pre_samples, axis=0)[start_index:] # empirical mean of reservoir activity 131 | l1_pre = ax.plot(x, mu_pre, label="RNN_pre", color= '#EE6666') 132 | lower_pre, upper_pre = helper.extract_lower_upper(pre_bootstraps) 133 | ax.fill_between(x, lower_pre[start_index:], upper_pre[start_index:], color='#EE6666', alpha=0.3) 134 | 135 | mu_act = np.mean(act_samples, axis=0)[start_index:] # empirical mean of reservoir activity 136 | l1_post = ax.plot(x, mu_act, label="RNN_post", color= 'cornflowerblue') 137 | lower_act, upper_act = helper.extract_lower_upper(act_bootstraps) 138 | ax.fill_between(x, lower_act[start_index:], upper_act[start_index:], color='cornflowerblue', alpha=0.3) 139 | #ax.legend() 140 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)); 141 | elif ename == 'st': 142 | fig, (ax_top, ax_bott) = plt.subplots(2, 1, sharex=True) 143 | 144 | 145 | 146 | # add l1(preactivation) models 147 | mu_pre = np.mean(pre_samples, axis=0)[start_index:] # empirical mean of reservoir activity 148 | l1_pre = ax_bott.plot(x, mu_pre, label="RNN_pre", color= '#EE6666') 149 | lower_pre, upper_pre = helper.extract_lower_upper(pre_bootstraps) 150 | ax_bott.fill_between(x, lower_pre[start_index:], upper_pre[start_index:], color='#EE6666', alpha=0.3) 151 | 152 | mu_act = np.mean(act_samples, axis=0)[start_index:] # empirical mean of reservoir activity 153 | l1_post = ax_top.plot(x, mu_act, label="RNN_post", color= 'cornflowerblue') 154 | 155 | lower_act, upper_act = helper.extract_lower_upper(act_bootstraps) 156 | ax_top.fill_between(x, lower_act[start_index:], upper_act[start_index:], color='cornflowerblue', alpha=0.3) 157 | # set limits of axes using the bootstrap bounds 158 | ax_top.set_ylim(min(lower_act)-0.015, max(upper_act)+0.015) 159 | ax_bott.set_ylim(min(lower_pre)-0.015, max(upper_pre)+0.015) 160 | 161 | ax_top.spines.bottom.set_visible(False) 162 | ax_bott.spines.top.set_visible(False) 163 | ax_top.spines.top.set_visible(False) 164 | #ax_top.xaxis.tick_top() 165 | ax_top.tick_params(labeltop=False) # don't put tick labels at the top 166 | ax_top.tick_params(bottom=False) 167 | 168 | d = .4 # proportion of vertical to horizontal extent of the slanted line 169 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, 170 | linestyle="none", color='k', mec='k', mew=1, clip_on=False) 171 | ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) 172 | ax_bott.plot([0, 1], [1, 1], transform=ax_bott.transAxes, **kwargs) 173 | 174 | 175 | 176 | 177 | h1, l1 = ax_top.get_legend_handles_labels() 178 | h2, l2 = ax_bott.get_legend_handles_labels() 179 | #ax_top.legend(h1+h2, l1+l2, loc=1, prop={'size': 8}) 180 | 181 | 182 | 183 | ax_bott.xaxis.set_major_locator(MaxNLocator(integer=True)); 184 | 185 | ax_top.grid(True) 186 | ax_bott.grid(True) 187 | ax_top.tick_params(labeltop=False) # don't put tick labels at the top 188 | else: # ename == total energy consumption 189 | fig, (ax_top, ax_bott) = plt.subplots(2, 1, sharex=True) 190 | 191 | 192 | # add l1(preactivation) models 193 | mu_pre = np.mean(pre_samples, axis=0)[start_index:] # empirical mean of reservoir activity 194 | l1_pre = ax_bott.plot(x, mu_pre, label="RNN_pre", color= '#EE6666') 195 | lower_pre, upper_pre = helper.extract_lower_upper(pre_bootstraps) 196 | ax_bott.fill_between(x, lower_pre[start_index:], upper_pre[start_index:], color='#EE6666', alpha=0.3) 197 | 198 | mu_act = np.mean(act_samples, axis=0)[start_index:] # empirical mean of reservoir activity 199 | l1_post = ax_top.plot(x, mu_act, label="RNN_post", color= 'cornflowerblue') 200 | 201 | lower_act, upper_act = helper.extract_lower_upper(act_bootstraps) 202 | ax_top.fill_between(x, lower_act[start_index:], upper_act[start_index:], color='cornflowerblue', alpha=0.3) 203 | 204 | mu_pw = np.mean(pw_samples, axis=0)[start_index:] # empirical mean of reservoir activity 205 | pw = ax_bott.plot(x, mu_pw, label="RNN_post+l2(W)", color= 'black') 206 | 207 | lower_pw, upper_pw = helper.extract_lower_upper(pw_bootstraps) 208 | ax_bott.fill_between(x, lower_pw[start_index:], upper_pw[start_index:], color='black', alpha=0.3) 209 | # set limits of axes using the bootstrap bounds 210 | ax_top.set_ylim(min(lower_act)-0.015, max(upper_act)+0.015) 211 | ax_bott.set_ylim(min(lower_pw)-0.015, max(upper_pw)+0.015) 212 | 213 | 214 | 215 | ax_top.spines.bottom.set_visible(False) 216 | ax_bott.spines.top.set_visible(False) 217 | ax_top.spines.top.set_visible(False) 218 | #ax_top.xaxis.tick_top() 219 | ax_top.tick_params(labeltop=False) # don't put tick labels at the top 220 | ax_top.tick_params(bottom=False) 221 | 222 | d = .4 # proportion of vertical to horizontal extent of the slanted line 223 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, 224 | linestyle="none", color='k', mec='k', mew=1, clip_on=False) 225 | ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) 226 | ax_bott.plot([0, 1], [1, 1], transform=ax_bott.transAxes, **kwargs) 227 | 228 | 229 | 230 | 231 | h1, l1 = ax_top.get_legend_handles_labels() 232 | h2, l2 = ax_bott.get_legend_handles_labels() 233 | #ax_top.legend(h1+h2, l1+l2, loc=1, prop={'size': 8}) 234 | 235 | 236 | 237 | ax_bott.xaxis.set_major_locator(MaxNLocator(integer=True)); 238 | 239 | ax_top.grid(True) 240 | ax_bott.grid(True) 241 | ax_top.tick_params(labeltop=False) # don't put tick labels at the top 242 | plot.save_fig(fig, F_PATH + ename+'curve_MNIST') 243 | 244 | #------------------------------------------------------------------------------ 245 | # fig 2C: Plot digit predictions 246 | trained_net = nets[0][0] # example trained network for visualisation 247 | act_net = nets[1][0] 248 | # example untrained network for visualisation 249 | untrained_net = Network.State(activation_func=torch.nn.ReLU(), 250 | optimizer=torch.optim.Adam, 251 | lr=1e-4, 252 | input_size=INPUT_SIZE, 253 | hidden_size=INPUT_SIZE, 254 | title="", 255 | device=DEVICE) 256 | 257 | # get visualisations for trained & untrained network 258 | X,P, _, T = plot.example_sequence_state(trained_net, test_set) 259 | _,Pu, _, _ = plot.example_sequence_state(untrained_net, test_set) 260 | _, Pcat, _, _ = plot.example_sequence_state(act_net, test_set) 261 | # get visualisations 262 | fig, axes = plot.display(X, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight') 263 | plot.save_fig(fig, F_PATH+"input_drive", bbox_inches='tight') 264 | 265 | fig, axes = plot.display(P, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight') 266 | plot.save_fig(fig, F_PATH+"internal_drive_trained", bbox_inches='tight') 267 | 268 | fig, axes = plot.display(Pu, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight') 269 | plot.save_fig(fig, F_PATH+"internal_drive_untrained", bbox_inches='tight') 270 | 271 | fig, axes = plot.display(T, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight') 272 | plot.save_fig(fig, F_PATH+"total_drive", bbox_inches='tight') 273 | # get median total drive 274 | M = mnist.medians(train_set) 275 | fig, axes = plot.display(P + list(M), lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight') 276 | plot.save_fig(fig, F_PATH+"total_drive_median_digit", bbox_inches='tight') 277 | 278 | fig, axes = plot.display(Pcat, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight', colorbar=False) 279 | plot.save_fig(fig, F_PATH+"internal_drive_l1postwithoutcolorb", bbox_inches='tight') 280 | fig, axes = plot.display(Pcat, lims=None, shape=(10,1), figsize=(3,3), axes_visible=False, layout='tight', colorbar=True) 281 | plot.save_fig(fig, F_PATH+"internal_drive_l1postw", bbox_inches='tight') 282 | -------------------------------------------------------------------------------- /fig3_unit_taxonomy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 2 11:08:50 2022 5 | 6 | @author: tempali 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import argparse 12 | import matplotlib.pyplot as plt 13 | import pandas as pd 14 | import helper 15 | import plot 16 | import seaborn as sns 17 | import os 18 | from functions import get_device 19 | 20 | parser = argparse.ArgumentParser(description='device') 21 | parser.add_argument('--i', type=str, help='Device index') 22 | args = parser.parse_args() 23 | plt.style.use('ggplot') 24 | 25 | DEVICE = get_device() 26 | 27 | R_PATH = 'Results/Fig3/Data/' 28 | F_PATH = 'Results/Fig3/' 29 | M_PATH = 'patterns_rev/seeded_mnist/' 30 | hdf_path = R_PATH+'network_stats.h5' 31 | 32 | LOAD = False 33 | SEED = 2553 34 | if not os.path.isdir(os.path.dirname(R_PATH)): 35 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 36 | if not os.path.isdir(os.path.dirname(F_PATH)): 37 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 38 | 39 | if SEED != None: 40 | torch.manual_seed(SEED) 41 | np.random.seed(SEED) 42 | 43 | INPUT_SIZE = 28*28 44 | 45 | 46 | # dataset loaders 47 | import mnist 48 | 49 | import Network 50 | 51 | 52 | 53 | train_set, validation_set, test_set = mnist.load(val_ratio=0.0) 54 | # mnist dimensions 55 | nc, nx, ny = 1, 28, 28 56 | nunits = nx*ny 57 | n_instances = 10 58 | seq_length = 10 59 | nclasses = 10 60 | LOSS_FN = 'l1_pre' 61 | nets = [] 62 | 63 | for i in range(n_instances): 64 | net= Network.State(activation_func=torch.nn.ReLU(), 65 | optimizer=torch.optim.Adam, 66 | lr=1e-4, 67 | input_size=INPUT_SIZE, 68 | hidden_size=INPUT_SIZE, 69 | title=M_PATH+"mnist_net_"+LOSS_FN, 70 | device=DEVICE) 71 | 72 | net.load(i) 73 | nets.append(net) 74 | 75 | batch_size=1 76 | 77 | 78 | #------------------------------------------------------------------------------ 79 | ## fig 3A: plot topographic distribution of unit types and pixel variance 80 | # use the first network for visualisation purposes 81 | net = nets[0] 82 | if not os.path.exists(hdf_path) or LOAD == False: 83 | type_mask, type_stats = helper.compute_unit_types(net, test_set, train_set) 84 | type_dict = {'Mask': type_mask, 'Stats': type_stats} 85 | typedf = pd.DataFrame(data=type_dict) 86 | # save dataframe 87 | store = pd.HDFStore(hdf_path) 88 | store['type_stats'+str(net)] = typedf 89 | store.close() 90 | else: 91 | store = pd.HDFStore(hdf_path) 92 | typedf = store['type_stats'+str(net)] 93 | store.close() 94 | type_mask = typedf['Mask'] 95 | type_stats = typedf['Stats'] 96 | type_mask = type_mask.reshape(nc*nx,ny) 97 | # plot topographic distribution and save figure 98 | fig = plot.topographic_distribution(type_mask) 99 | plot.save_fig(fig, F_PATH + 'topographic_distribution_mnist') 100 | 101 | #------------------------------------------------------------------------------ 102 | ## Fig 3B: Input variance of prediction and error units 103 | u_types = ['prediction', 'error', 'hybrid', 'unspecified'] 104 | ## specify dictionary for all network instances 105 | pop_dict = {'Unit type':[], 'N': [], 'Median input variance':[], 'Network': []} 106 | for n, net in enumerate(nets): 107 | net_path = R_PATH + 'net'+str(n) 108 | if not os.path.exists(hdf_path) or LOAD == False: 109 | type_mask, type_stats = helper.compute_unit_types(net, test_set, train_set) 110 | type_dict = {'Mask': type_mask, 'Stats': type_stats} 111 | typedf = pd.DataFrame(data=type_dict) 112 | # save dataframe 113 | store = pd.HDFStore(hdf_path) 114 | store['type_stats_net'+str(n)] = typedf 115 | store.close() 116 | else: 117 | store = pd.HDFStore(hdf_path) 118 | typedf = store['type_stats_net'+str(n)] 119 | store.close() 120 | type_mask = typedf['Mask'] 121 | type_stats = typedf['Stats'] 122 | # reshape type mask for proper indexing 123 | type_mask = type_mask.reshape(nunits) 124 | # # retrieve indices of unit types (prediction, error & hybrid) 125 | err_inds = [i for i, e in enumerate(type_mask) if e in [0,1]] 126 | pred_inds = [i for i, p in enumerate(type_mask) if p in [2,3]] 127 | hybrid_inds = [i for i, h in enumerate(type_mask) if h in [4,5]] 128 | un_inds = [i for i, u in enumerate(type_mask) if u == 6] 129 | 130 | if not os.path.exists(hdf_path) or LOAD == False: 131 | # # get prediction and error unit indices 132 | 133 | # record input pixel variance per category 134 | var = torch.zeros(nclasses, INPUT_SIZE) 135 | # pred_inds, err_inds = [] , [] 136 | for cat in range(nclasses): 137 | var[cat] = torch.var(test_set.x[test_set.indices[cat]],dim=0) 138 | 139 | 140 | # set up dictionary for single network 141 | var_dict = {'Unit type': [], 'Input variance': [], 'Nr classes':[], 'Categories': []} 142 | 143 | # pure prediction units 144 | for p in pred_inds: 145 | cpred, _, _ , _ = type_stats[p] 146 | var_pred = torch.zeros(len(cpred)) 147 | for i, cat in enumerate(cpred): 148 | targ_pred = (cat - 1) % seq_length 149 | var_pred[i] = var[targ_pred, p] 150 | 151 | 152 | var_dict['Unit type'].append('prediction') 153 | var_dict['Input variance'].append(var_pred.mean().item()) 154 | var_dict['Nr classes'].append(len(cpred)) 155 | var_dict['Categories'].append(cpred) 156 | 157 | # pure error units 158 | for e in err_inds: 159 | _, cerr, _ , _ = type_stats[e] 160 | var_err = torch.zeros(len(cerr)) 161 | for i, cat in enumerate(cerr): 162 | targ_err = cat 163 | var_err[i] = var[targ_err, e] 164 | 165 | var_dict['Unit type'].append('error') 166 | var_dict['Input variance'].append(var_err.mean().item()) 167 | var_dict['Nr classes'].append(len(cerr)) 168 | var_dict['Categories'].append(cerr) 169 | 170 | # hybrid units 171 | for h in hybrid_inds: 172 | cpred, cerr, _ , _ = type_stats[h] 173 | var_pred, var_err = torch.zeros(len(cpred)), torch.zeros(len(cerr)) 174 | for i, cat in enumerate(cpred): 175 | targ_pred = (cat - 1) % seq_length 176 | var_pred[i] = var[targ_pred, h] 177 | 178 | for i, cat in enumerate(cerr): 179 | targ_err = cat 180 | var_err[i] = var[targ_err, h] 181 | 182 | var_dict['Unit type'].append('hybrid') 183 | var_dict['Input variance'].append((var_pred.mean().item(), var_err.mean().item())) 184 | var_dict['Nr classes'].append((len(cpred), len(cerr))) 185 | var_dict['Categories'].append((cpred, cerr)) 186 | 187 | # unspecified 188 | for u in un_inds: 189 | var_u = torch.zeros(nclasses) 190 | for cat in range(nclasses): 191 | var_u[cat] = var[cat, u] 192 | var_dict['Unit type'].append('unspecified') 193 | var_dict['Input variance'].append(var_u.mean().item()) 194 | var_dict['Nr classes'].append(0) 195 | var_dict['Categories'].append([]) 196 | 197 | # create a dataframe to store the variances per unit type for single network 198 | netdf = pd.DataFrame(data=var_dict) 199 | # save dataframe 200 | store = pd.HDFStore(hdf_path) 201 | store['mnist_net'+str(net)] = netdf 202 | store.close() 203 | else: # load input variance data 204 | store = pd.HDFStore(hdf_path) 205 | netdf = store['mnist_net'+str(net)] 206 | store.close() 207 | for u_type in u_types: 208 | pop_dict['Unit type'].append(u_type) 209 | if u_type == 'hybrid': 210 | u_type_var = list(netdf.loc[netdf['Unit type'] == u_type]['Input variance']) 211 | pred_var, err_var = torch.tensor([p for p, e in u_type_var]), torch.tensor([e for p, e in u_type_var]) 212 | # compute medians separately and add them to the df 213 | pop_dict['Median input variance'].append((torch.median(pred_var).item(), torch.median(err_var).item())) 214 | else: 215 | u_type_var = netdf.loc[netdf['Unit type'] == u_type]['Input variance'].median() 216 | pop_dict['Median input variance'].append(u_type_var) 217 | pop_dict['N'].append(len(netdf.loc[netdf['Unit type'] == u_type])) 218 | pop_dict['Network'].append('Network ' + str(n+1)) 219 | 220 | popdf = pd.DataFrame(data=pop_dict) 221 | # save dataframe 222 | store = pd.HDFStore(hdf_path) 223 | store['popinfo'] = popdf 224 | store.close() 225 | # plot input variance for each prediction and error unit 226 | fig, ax = plt.subplots(figsize=(7,7)) 227 | 228 | 229 | df_prederr = popdf.loc[popdf['Unit type'].isin(['prediction', 'error'])] 230 | 231 | 232 | ax = sns.barplot(x='Unit type', y='Median input variance', data=df_prederr, capsize=.2, color='#868484ff') 233 | 234 | plot.save_fig(fig, F_PATH + 'Input_variance_unit_types_mnist') 235 | 236 | #------------------------------------------------------------------------------ 237 | # ## fig 3: compute average number of prediction and error units 238 | summary_stats = {'Unit type':[], 'Mean number of units':[], 'Std':[]} 239 | for u_type in u_types: 240 | mean = popdf.loc[popdf['Unit type'] == u_type]['N'].mean() 241 | std = popdf.loc[popdf['Unit type'] == u_type]['N'].std() 242 | summary_stats['Unit type'].append(u_type) 243 | summary_stats['Mean number of units'].append(mean) 244 | summary_stats['Std'].append(std) 245 | 246 | # Put stats in dataframe and save them to disk 247 | summary_stats = pd.DataFrame(data=summary_stats) 248 | store = pd.HDFStore(hdf_path) 249 | store['summary_stats'] = summary_stats 250 | store.close() 251 | print(summary_stats) 252 | -------------------------------------------------------------------------------- /fig4_lesion_study.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 2 11:08:50 2022 5 | 6 | @author: tempali 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import argparse 12 | import matplotlib.pyplot as plt 13 | import pandas as pd 14 | 15 | import helper 16 | import plot 17 | import os 18 | from functions import get_device 19 | 20 | parser = argparse.ArgumentParser(description='device') 21 | parser.add_argument('--i', type=str, help='Device index') 22 | args = parser.parse_args() 23 | plt.style.use('ggplot') 24 | 25 | DEVICE = get_device() 26 | 27 | R_PATH = 'Results/Fig4/Data/' 28 | F_PATH = 'Results/Fig4/' 29 | M_PATH = 'patterns_rev/seeded_mnist/' 30 | hdf_path = R_PATH+'network_stats.h5' 31 | 32 | LOAD = False 33 | SEED = 2553 34 | if not os.path.isdir(os.path.dirname(R_PATH)): 35 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 36 | if not os.path.isdir(os.path.dirname(F_PATH)): 37 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 38 | 39 | if SEED != None: 40 | torch.manual_seed(SEED) 41 | np.random.seed(SEED) 42 | 43 | # set up hdf5 file to store the results 44 | if not os.path.exists(hdf_path): 45 | store = pd.HDFStore(hdf_path) 46 | store.close() 47 | INPUT_SIZE = 28*28 48 | Z_CRIT = 2.576 #99% 49 | 50 | # dataset loaders 51 | import mnist 52 | 53 | import Network 54 | 55 | 56 | 57 | train_set, validation_set, test_set = mnist.load(val_ratio=0.0) 58 | # mnist dimensions 59 | nc, nx, ny = 1, 28, 28 60 | nunits = nx*ny 61 | n_instances = 10 62 | seq_length = 10 63 | nclasses = 10 64 | LOSS_FN = 'l1_pre' 65 | nets = [] 66 | 67 | for i in range(n_instances): 68 | net= Network.State(activation_func=torch.nn.ReLU(), 69 | optimizer=torch.optim.Adam, 70 | lr=1e-4, 71 | input_size=INPUT_SIZE, 72 | hidden_size=INPUT_SIZE, 73 | title=M_PATH+"mnist_net_"+LOSS_FN, 74 | device=DEVICE) 75 | 76 | net.load(i) 77 | nets.append(net) 78 | 79 | batch_size=1 80 | # Use the first network for visualisation purposes 81 | net = nets[0] 82 | store = pd.HDFStore(hdf_path) 83 | #------------------------------------------------------------------------------ 84 | ## fig 4A: Lesion study MNIST 85 | if not os.path.exists(hdf_path) or LOAD == False: 86 | bs_sample_dict = helper.bootstrap_model_activity(nets, train_set, test_set, seed=None, lesioned=True) 87 | 88 | 89 | les_df = pd.DataFrame(data=bs_sample_dict) 90 | 91 | store['lesionstudy'] = les_df 92 | 93 | else: 94 | les_df = store['lesionstudy'] 95 | norm_samples, lesion_samples, cont_samples = store['norm'][0], \ 96 | store['lesion'][0], store['cont'][0] 97 | [bs_norm, bs_lesion, bs_cont] = les_df['bs_bounds'] 98 | 99 | # get samples 100 | norm_samples, lesion_samples, cont_samples= les_df['norm'][0], les_df['lesion'][0], les_df['cont'][0] 101 | # get bs_bounds 102 | bs_norm, bs_lesion, bs_cont = les_df['bs_norm'][0], les_df['bs_lesion'][0], les_df['bs_cont'][0] 103 | # plot results 104 | fig, ax = plt.subplots(1, 1) 105 | 106 | 107 | x = np.arange(1,seq_length+1) 108 | ax.set_xticks(x) 109 | 110 | mu_norm = np.mean(norm_samples, axis=0) # empirical mean of original RNN 111 | ax.plot(x, mu_norm, label="original RNN", color= '#EE6666') 112 | lower_norm, upper_norm = helper.extract_lower_upper(bs_norm) 113 | 114 | ax.fill_between(x, lower_norm, upper_norm, color='#EE6666', alpha=0.3) 115 | 116 | 117 | mu_les = np.mean(lesion_samples, axis=0) # empirical mean of sample set 118 | ax.plot(x, mu_les, label="prediction units lesioned", color= '#EECC55') 119 | 120 | lower_les, upper_les = helper.extract_lower_upper(bs_lesion) 121 | ax.fill_between(x, lower_les, upper_les, color='#EECC55', alpha=0.3) 122 | 123 | mu_cont = np.mean(cont_samples, axis=0) # empirical mean of sample set 124 | ax.plot(x, mu_cont, label="control lesioning", color= '#5efc03') 125 | lower_cont, upper_cont = helper.extract_lower_upper(bs_cont) 126 | 127 | ax.fill_between(x, lower_cont, upper_cont, color='#5efc03', alpha=0.3) 128 | 129 | ax.legend() 130 | 131 | plot.save_fig(fig, F_PATH + 'lesion_study_MNIST') 132 | 133 | #------------------------------------------------------------------------------ 134 | #------------------------------------------------------------------------------ 135 | ## fig 4B: Visualise internal network drive for lesioned and non-lesioned network 136 | ## only visualise [0] 137 | pred_mask = helper._pred_mask(net, test_set, train_set) 138 | # show internal network drive normal network 139 | fig_norm, _ = plot.pred_after_timestep(net, test_set) 140 | plot.save_fig(fig_norm, F_PATH + 'Internal drive without lesions') 141 | # show lesioned internal network drive network 142 | fig_les, _ = plot.pred_after_timestep(net, test_set, mask=pred_mask) 143 | plot.save_fig(fig_les, F_PATH + 'Internal drive with lesions') 144 | #------------------------------------------------------------------------------ 145 | 146 | #------------------------------------------------------------------------------ 147 | ## fig 4C: Compute and plot postsynaptic drive dynamics 148 | 149 | if not os.path.exists(hdf_path) or LOAD == False: 150 | pred_stats, err_stats = plot.bootstrap_post_dynamics(nets, test_set) 151 | dyn_dict = {'pred': pred_stats, 'err':err_stats} 152 | postdyn_df = pd.DataFrame(dyn_dict) 153 | 154 | store['postdyn'] = postdyn_df 155 | 156 | else: 157 | postdyn_df = store['postdyn'] 158 | pred_stats, err_stats = postdyn_df['pred'], postdyn_df['err'] 159 | 160 | 161 | fig, (ax1, ax2) = plt.subplots(2, 1) 162 | 163 | ax2.set_ylim(-11, 2.5) # pred stats 164 | ax1.set_ylim(-0.06, -0.01) # err stats 165 | #fig.subplots_adjust(hspace=0.1) # adjust space between axes 166 | x = np.arange(0,9,1) 167 | # plot the lines and confidence bounds 168 | ax1.plot(x, err_stats['samples'], color='b', label='error units ') 169 | ax1.fill_between(x, err_stats['l_bound'], err_stats['h_bound'], color='b', alpha=0.3) 170 | 171 | 172 | ax2.plot(x, pred_stats['samples'], color='r', label='prediction units') 173 | ax2.fill_between(x, pred_stats['l_bound'], pred_stats['h_bound'], color='r', alpha=0.3) 174 | 175 | ax1.legend() 176 | 177 | ax2.legend() 178 | 179 | plt.gca().set_aspect('auto') 180 | plt.grid(True) 181 | 182 | fig.tight_layout() 183 | plot.save_fig(fig, F_PATH + 'postsynaptic_drive_dynamics') 184 | store.close() 185 | #------------------------------------------------------------------------------ -------------------------------------------------------------------------------- /fig5_cifar10_exp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 2 11:08:50 2022 5 | 6 | @author: tempali 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import argparse 12 | import matplotlib.pyplot as plt 13 | import pandas as pd 14 | import helper 15 | import plot 16 | import seaborn as sns 17 | import os 18 | import random 19 | from functions import get_device 20 | from matplotlib.ticker import MaxNLocator 21 | 22 | def seed_everything(seed=42): 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | parser = argparse.ArgumentParser(description='device') 31 | parser.add_argument('--i', type=str, help='Device index') 32 | args = parser.parse_args() 33 | plt.style.use('ggplot') 34 | 35 | DEVICE = get_device() 36 | 37 | print('Using {}'.format(DEVICE)) 38 | 39 | R_PATH = 'Results/Fig5/Data/' 40 | F_PATH = 'Results/Fig5/' 41 | M_PATH = 'final_networks/seeded_cifar_nets/' 42 | hdf_path = R_PATH+'network_stats.h5' 43 | 44 | LOAD = False 45 | SEED = 2553 46 | if not os.path.isdir(os.path.dirname(R_PATH)): 47 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 48 | if not os.path.isdir(os.path.dirname(F_PATH)): 49 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 50 | 51 | if SEED != None: 52 | torch.manual_seed(SEED) 53 | np.random.seed(SEED) 54 | # set up hdf5 file to store the results 55 | if not os.path.exists(hdf_path) or LOAD==False: 56 | store = pd.HDFStore(hdf_path) 57 | store.close() 58 | INPUT_SIZE = 32*32*3 59 | 60 | 61 | # dataset loaders 62 | import cifar 63 | 64 | import Network 65 | 66 | 67 | 68 | train_set, validation_set, test_set = cifar.load(val_ratio=0.0, color=True) 69 | 70 | 71 | Z_CRIT= 1.96 #95% CI 72 | # cifar dimensions 73 | nc, nx, ny = 3, 32, 32 74 | nunits = INPUT_SIZE 75 | n_instances = 10 76 | nclasses = 10 77 | seq_length = 10 78 | LOSS_FN = 'l1_pre' 79 | nets = [[], []] 80 | c_types = ['cifar_net_'] # add cifar_latent_ if you want to test latent models 81 | hid_size = [0, 32] 82 | LOADS = [False, False] 83 | for c, c_type in enumerate(c_types): 84 | for i in range(n_instances): 85 | net= Network.State(activation_func=torch.nn.ReLU(), 86 | optimizer=torch.optim.Adam, 87 | lr=1e-4, 88 | input_size=INPUT_SIZE, 89 | hidden_size=INPUT_SIZE+hid_size[c], 90 | title=M_PATH+c_type+str(i), 91 | device=DEVICE) 92 | 93 | net.load() 94 | nets[c].append(net) 95 | 96 | batch_size=1 97 | for c, c_type in enumerate(c_types): 98 | #------------------------------------------------------------------------------ 99 | ## Fig 5B: Topographic distribution and Input variance of prediction and error units 100 | u_types = ['prediction', 'error', 'hybrid', 'all other units'] #['prediction', 'all other units'] 101 | ## specify dictionary for all network instances 102 | pop_dict = {'Unit type':[], 'N': [], 'Median input variance':[], 'Network': []} 103 | for n, net in enumerate(nets[c]): 104 | net_path = R_PATH + 'net'+str(n) 105 | if not os.path.exists(hdf_path) or LOADS[0] == False: 106 | type_mask, type_stats = helper.compute_unit_types(net, test_set, train_set, seed=SEED) 107 | type_dict = {'Mask': type_mask, 'Stats': type_stats} 108 | typedf = pd.DataFrame(data=type_dict) 109 | # save dataframe 110 | store = pd.HDFStore(hdf_path) 111 | store['type_stats_'+c_type+str(n)] = typedf 112 | store.close() 113 | else: 114 | store = pd.HDFStore(hdf_path) 115 | typedf = store['type_stats_'+c_type+str(n)] 116 | store.close() 117 | type_mask = typedf['Mask'] 118 | type_stats = typedf['Stats'] 119 | # reshape type mask for proper indexing 120 | type_mask = torch.tensor(list(type_mask)) 121 | type_mask = type_mask.reshape(nunits) 122 | # # retrieve indices of unit types (prediction, error & hybrid) 123 | #err_inds = [i for i, e in enumerate(type_mask) if e in [0,1]] 124 | pred_inds = [i for i, p in enumerate(type_mask) if p in [2,3]] 125 | #hybrid_inds = [i for i, h in enumerate(type_mask) if h in [4,5]] 126 | un_inds = [i for i, u in enumerate(type_mask) if u not in [2,3]] 127 | 128 | if not os.path.exists(hdf_path) or LOADS[1] == False: 129 | # # get prediction and error unit indices 130 | 131 | # record input pixel variance per category 132 | var = torch.zeros(nclasses, INPUT_SIZE) 133 | # pred_inds, err_inds = [] , [] 134 | for cat in range(nclasses): 135 | var[cat] = torch.var(test_set.x[test_set.indices[cat]],dim=0) 136 | 137 | 138 | # set up dictionary for single network 139 | var_dict = {'Unit type': [], 'Input variance': [], 'Nr classes':[], 'Categories': []} 140 | 141 | # pure prediction units 142 | for p in pred_inds: 143 | cpred, _, _ , _ = type_stats[p] 144 | var_pred = torch.zeros(len(cpred)) 145 | for i, cat in enumerate(cpred): 146 | targ_pred = (cat - 1) % seq_length 147 | var_pred[i] = var[targ_pred, p] 148 | 149 | 150 | var_dict['Unit type'].append('prediction') 151 | var_dict['Input variance'].append(var_pred.mean().item()) 152 | var_dict['Nr classes'].append(len(cpred)) 153 | var_dict['Categories'].append(cpred) 154 | 155 | 156 | # all other units 157 | for u in un_inds: 158 | var_u = torch.zeros(nclasses) 159 | for cat in range(nclasses): 160 | var_u[cat] = var[cat, u] 161 | var_dict['Unit type'].append('all other units') 162 | var_dict['Input variance'].append(var_u.mean().item()) 163 | var_dict['Nr classes'].append(0) 164 | var_dict['Categories'].append([]) 165 | 166 | # create a dataframe to store the variances per unit type for single network 167 | netdf = pd.DataFrame(data=var_dict) 168 | # save dataframe 169 | store = pd.HDFStore(hdf_path) 170 | store[c_type+str(n)] = netdf 171 | store.close() 172 | 173 | else: # load input variance data 174 | store = pd.HDFStore(hdf_path) 175 | netdf = store[c_type+str(n)] 176 | store.close() 177 | 178 | 179 | for u_type in u_types: 180 | pop_dict['Unit type'].append(u_type) 181 | if list(netdf.loc[netdf['Unit type'] == u_type]) == []: # unit type not in this network 182 | pop_dict['Median input variance'].append(0) 183 | elif u_type == 'hybrid': # take the prediction variance 184 | u_type_var = list(netdf.loc[netdf['Unit type'] == u_type]['Input variance']) 185 | pred_var = torch.tensor([p for p, e in u_type_var]) 186 | pop_dict['Median input variance'].append(torch.median(pred_var).item()) 187 | else: 188 | u_type_var = netdf.loc[netdf['Unit type'] == u_type]['Input variance'].median() 189 | pop_dict['Median input variance'].append(u_type_var) 190 | pop_dict['N'].append(len(netdf.loc[netdf['Unit type'] == u_type])) 191 | pop_dict['Network'].append('Network ' + str(n+1)) 192 | 193 | popdf = pd.DataFrame(data=pop_dict) 194 | # save dataframe 195 | store = pd.HDFStore(hdf_path) 196 | store['popinfo'] = popdf 197 | 198 | fig = plot.topographic_distribution(type_mask.reshape(3, 32, 32)) 199 | plot.save_fig(fig, F_PATH + 'topographic_distribution_'+c_type) 200 | # plot input variance for each prediction and error unit 201 | fig, ax = plt.subplots(figsize=(7,7)) 202 | 203 | df_prederr = popdf.loc[popdf['Unit type'].isin(['prediction', 'unspecified'])] 204 | 205 | ax = sns.barplot(x='Unit type', y='Median input variance', data=df_prederr, capsize=.2, color='#868484ff') 206 | plot.save_fig(fig, F_PATH + 'Input_variance_unit_types_'+c_type) 207 | 208 | #------------------------------------------------------------------------------ 209 | # ## fig 3A: compute average number of prediction and error units 210 | summary_stats = {'Unit type':[], 'Mean number of units':[], 'Std':[]} 211 | for u_type in u_types: 212 | mean = popdf.loc[popdf['Unit type'] == u_type]['N'].mean() 213 | std = popdf.loc[popdf['Unit type'] == u_type]['N'].std() 214 | summary_stats['Unit type'].append(u_type) 215 | summary_stats['Mean number of units'].append(mean) 216 | summary_stats['Std'].append(std) 217 | 218 | # Put stats in dataframe and save them to disk 219 | summary_stats = pd.DataFrame(data=summary_stats) 220 | store = pd.HDFStore(hdf_path) 221 | store['summary_stats'+str(c_type)] = summary_stats 222 | 223 | print(summary_stats) 224 | 225 | 226 | #------------------------------------------------------------------------------ 227 | ## fig 5C: lesioning study CIFAR10 228 | # checkif samples are already computed 229 | if not os.path.exists(hdf_path) or LOADS[1] == False: 230 | bs_sample_dict = helper.bootstrap_model_activity(nets[0], train_set, test_set, seed=None, lesioned=True) 231 | les_df = pd.DataFrame(data=bs_sample_dict) 232 | store['lesionstudy'] = les_df 233 | 234 | else: 235 | les_df = store['lesionstudy'] 236 | norm_samples, lesion_samples, cont_samples = store['norm'][0], \ 237 | store['lesion'][0], store['cont'][0] 238 | [bs_norm, bs_lesion, bs_cont] = les_df['bs_bounds'] 239 | 240 | 241 | # get samples 242 | norm_samples, lesion_samples, cont_samples= les_df['norm'][0], les_df['lesion'][0], les_df['cont'][0] 243 | # get bs_bounds 244 | bs_norm, bs_lesion, bs_cont = les_df['bs_norm'][0], les_df['bs_lesion'][0], les_df['bs_cont'][0] 245 | # plot results 246 | # create figure plot mean values and 95% CI 247 | 248 | fig, (ax_top, ax_bott) = plt.subplots(2, 1, sharex=True) 249 | 250 | 251 | x = np.arange(1,seq_length+1) 252 | 253 | 254 | mu_norm = np.mean(norm_samples, axis=0) # empirical mean of original RNN 255 | ax_bott.plot(x, mu_norm, label="original RNN", color= '#EE6666') 256 | lower_norm, upper_norm = helper.extract_lower_upper(bs_norm) 257 | 258 | ax_bott.fill_between(x, lower_norm, upper_norm, color='#EE6666', alpha=0.3) 259 | 260 | 261 | mu_les = np.mean(lesion_samples, axis=0) # empirical mean of sample set 262 | ax_top.plot(x, mu_les, label="prediction units lesioned", color= '#EECC55') 263 | 264 | lower_les, upper_les = helper.extract_lower_upper(bs_lesion) 265 | ax_top.fill_between(x, lower_les, upper_les, color='#EECC55', alpha=0.3) 266 | 267 | mu_cont = np.mean(cont_samples, axis=0) # empirical mean of sample set 268 | ax_bott.plot(x, mu_cont, label="control lesioning", color= '#5efc03') 269 | lower_cont, upper_cont = helper.extract_lower_upper(bs_cont) 270 | 271 | ax_bott.fill_between(x, lower_cont, upper_cont, color='#5efc03', alpha=0.3) 272 | 273 | # set limits of axes using the bootstrap bounds 274 | ax_top.set_ylim(min(lower_les)-0.01, max(upper_les)+0.01) 275 | ax_bott.set_ylim(min(lower_norm)-0.01, max(upper_norm)+0.01) 276 | 277 | ax_bott.xaxis.set_major_locator(MaxNLocator(integer=True)); 278 | 279 | ax_top.spines.bottom.set_visible(False) 280 | ax_bott.spines.top.set_visible(False) 281 | ax_top.spines.top.set_visible(False) 282 | 283 | ax_top.tick_params(labeltop=False) # don't put tick labels at the top 284 | ax_top.tick_params(bottom=False) 285 | 286 | h1, l1 = ax_top.get_legend_handles_labels() 287 | h2, l2 = ax_bott.get_legend_handles_labels() 288 | ax_bott.legend(h1+h2, l1+l2, loc=1, prop={'size': 8}) 289 | d = .4 # proportion of vertical to horizontal extent of the slanted line 290 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=10, 291 | linestyle="none", color='k', mec='k', mew=1, clip_on=False) 292 | ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) 293 | ax_bott.plot([0, 1], [1, 1], transform=ax_bott.transAxes, **kwargs) 294 | 295 | ax_top.grid(True); ax_bott.grid(True) 296 | ax_bott.tick_params(labeltop=False) # don't put tick labels at the top 297 | 298 | plot.save_fig(fig, F_PATH + 'lesion_study_CIFAR_'+c_type) 299 | store.close() 300 | #------------------------------------------------------------------------------ 301 | # Uncomment if you want to look at latent models 302 | # R_PATH_latent = 'Results/Fig5/Data/LatentModel/' 303 | # latent_hdf_path = R_PATH_latent+'latent_network_stats.h5' 304 | # if not os.path.isdir(os.path.dirname(R_PATH_latent)): 305 | # os.makedirs(os.path.dirname(R_PATH_latent), exist_ok=True) 306 | # latent_store = pd.HDFStore(latent_hdf_path) 307 | # N_latent = 32 308 | # latent_nets = [] 309 | # for i in range(0, 10): 310 | # net32 = Network.State(activation_func=torch.nn.ReLU(), 311 | # optimizer=torch.optim.Adam, 312 | # lr=1e-4, 313 | # input_size=INPUT_SIZE, 314 | # hidden_size=INPUT_SIZE+N_latent, 315 | # title=M_PATH+c_types[1]+str(i), 316 | # device=DEVICE) 317 | # net32.load() 318 | # latent_nets.append(net32) 319 | 320 | # latent_preds, non_latent_preds = [], [] 321 | # for n, net in enumerate(latent_nets): 322 | # type_mask, type_stats = helper.compute_unit_types(net, test_set, train_set) 323 | # type_dict = {'Mask': type_mask, 'Stats': type_stats} 324 | # typedf = pd.DataFrame(data=type_dict) 325 | # # save dataframe 326 | # latent_store = pd.HDFStore(latent_hdf_path) 327 | # latent_store['type_stats_'+c_types[1]+str(n)] = typedf 328 | # latent_store.close() 329 | # # reshape type mask for proper indexing 330 | # type_mask = type_mask.reshape(nunits+N_latent) 331 | # # # retrieve indices of unit types (prediction, error & hybrid) 332 | # err_inds = [i for i, e in enumerate(type_mask) if e in [0,1]] 333 | # pred_inds = [i for i, p in enumerate(type_mask) if p in [2,3]] 334 | # hybrid_inds = [i for i, h in enumerate(type_mask) if h in [4,5]] 335 | # un_inds = [i for i, u in enumerate(type_mask) if u == 6] 336 | # for ind in pred_inds: 337 | # if ind > INPUT_SIZE: 338 | # latent_preds.append(ind) 339 | # else: 340 | # non_latent_preds.append(ind) 341 | 342 | # print(latent_preds) 343 | #------------------------------------------------------------------------------ -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import time 5 | from typing import Dict 6 | 7 | 8 | def get_device(): 9 | if torch.cuda.is_available(): 10 | DEVICE = 'cuda' 11 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 12 | elif torch.backends.mps.is_available(): 13 | DEVICE = torch.device("mps") 14 | else: 15 | DEVICE = 'cpu' 16 | print('Using {}'.format(DEVICE)) 17 | return DEVICE 18 | 19 | def get_time() -> int: 20 | '''Returns current time in ms''' 21 | return int(round(time.time() * 1000)) 22 | 23 | class Timer: 24 | def __init__(self): 25 | self.reset() 26 | 27 | def lap(self): 28 | self.t_lap = get_time() 29 | 30 | def get(self): 31 | return get_time() - self.t_lap 32 | 33 | def reset(self): 34 | self.t_total = get_time() 35 | self.t_lap = get_time() 36 | 37 | def __str__(self): 38 | t = self.get() 39 | ms = t % 1000 40 | t = int(t / 1000) 41 | s = t % 60 42 | t = int(t / 60) 43 | m = t % 60 44 | if t == 0: 45 | return "{}.{:03}".format(s,ms) 46 | else: 47 | t = int(t / 60) 48 | h = t 49 | if t == 0: 50 | return "{}:{:02}.{:03}".format(m,s,ms) 51 | else: 52 | return "{}:{:02}:{:02}.{:03}".format(h,m,s,ms) 53 | 54 | def append_dict(dict_a:Dict[str,np.ndarray], dict_b:Dict[str,np.ndarray]): 55 | for k, v in dict_b.items(): 56 | dict_a[k] = np.concatenate((dict_a[k], v)) 57 | 58 | def L1Loss(x:torch.FloatTensor): 59 | 60 | return torch.mean(torch.abs(x)) 61 | 62 | def L2Loss(x:torch.FloatTensor): 63 | 64 | return torch.mean(torch.pow(x, 2)) 65 | 66 | def Linear(x:torch.FloatTensor): 67 | return x 68 | 69 | def parse_loss(args, terms): 70 | if args == None: 71 | return L1Loss, torch.tensor(0.0) 72 | 73 | pre, post, weights = terms 74 | arg1, arg2 = args.split('_') 75 | if arg1 == 'l1': 76 | loss_fn = L1Loss 77 | else: 78 | loss_fn = L2Loss 79 | 80 | if arg2 == 'pre': 81 | loss_arg = pre 82 | elif arg2 == 'post': 83 | loss_arg = post 84 | else: 85 | loss_arg = weights 86 | 87 | return loss_fn, loss_arg 88 | 89 | def init_params(size_x, size_y): 90 | return ((torch.rand(size_x, size_y) * 2.) - 1. ) * np.sqrt(1. / size_x) 91 | 92 | def normalize(x, p=2.0, dim=1): 93 | return F.normalize(x, p, dim) 94 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | 7 | import bootstrapped.bootstrap as bs 8 | import bootstrapped.stats_functions as bs_stats 9 | import mnist 10 | 11 | import cifar 12 | 13 | from Dataset import Dataset 14 | from ModelState import ModelState 15 | 16 | import scipy.stats as st 17 | 18 | # nested dict for color mapping unit types (fig 3A/fig5A) 19 | CMAPPING = {'0': {'0': {'0':0, '1': 1}, '1': {'0':2,'1':3}}, 20 | '1':{'0':4, '1':5} 21 | } 22 | 23 | # 24 | # 25 | # 26 | # ---------- Script with helper functions for plot.py ----------- 27 | # 28 | # 29 | 30 | 31 | 32 | def _calc_energy(net, preactivation, energy, mask=None, med=False): 33 | if energy == 'ap': 34 | return _calc_ap(net, preactivation) 35 | elif energy =='st': 36 | return _calc_st(net, preactivation, mask, med) 37 | # return summary of ap st (Sengupta et al 2010) 38 | return (1/3)*_calc_ap(net, preactivation) + (2/3)*_calc_st(net, preactivation, mask, med) 39 | 40 | 41 | def _calc_ap(net, preactivation): 42 | # calculate outputs 43 | act = F.relu(preactivation) 44 | return torch.abs(act) 45 | 46 | def _calc_st(net, preactivation, mask, med=False): 47 | # calculate outputs 48 | act = F.relu(preactivation) 49 | if med: 50 | return act 51 | abs_W = net.model.W.detach() 52 | if mask is not None and len(mask.squeeze().shape) > 1: # weights need to be masked 53 | abs_W = abs_W * mask 54 | abs_act, abs_W = torch.abs(act), torch.abs(abs_W) 55 | synaptrans = torch.sum(abs_act.unsqueeze(-1) * abs_W, axis=1) 56 | 57 | return synaptrans 58 | # 59 | # --- Helper functions for bootstrap plotting fig 2A, 4A & 5C --- 60 | # 61 | 62 | 63 | def compute_pixel_variance(images): 64 | """ 65 | computes variance of pixels for each channel seperately 66 | """ 67 | nsamples, nc, npix = images.shape 68 | pixel_var = torch.zeros(nc, npix) 69 | for c in range(nc): 70 | var_c = images[:, c, :].var(axis=0) 71 | pixel_var[c, :] = var_c 72 | return pixel_var 73 | 74 | def find_pred_units(net, dataset, seq_length=10, Z_crit=2.576): 75 | nclasses=10 76 | preact_stats = compute_preact_stats(net, dataset) 77 | nunits = net.model.hidden_size 78 | pred_rule = torch.zeros(nunits, nclasses) 79 | 80 | for cls_plt in range(10): 81 | med, mad = preact_stats[:, cls_plt, 0], preact_stats[:, cls_plt, 1] 82 | 83 | # scale MAD to obtain a pseudo standard deviation 84 | # https://stats.stackexchange.com/questions/355943/scale-factor-for-mad-for-non-normal-distribution) 85 | 86 | for i in range(nunits): # Z_crit CI 87 | if (torch.abs(med[i]) - torch.abs(Z_crit*mad[i])) > 0: 88 | pred_rule[i, cls_plt] = 1 89 | return pred_rule 90 | 91 | def find_error_units(net, test_set, seq_length=10, target=None, Z_crit=2.576): 92 | batch_size = 1 93 | 94 | class_error_units = torch.zeros(seq_length, net.model.hidden_size) 95 | t_ind = 8 # look at penultimate timepoint 96 | error_units = torch.zeros(net.model.hidden_size) 97 | for target in range(seq_length): 98 | starting_point = (target - t_ind) % 10 99 | # create normal sequences 100 | norm_seq = test_set.create_batches(batch_size, seq_length, fixed_starting_point=starting_point) 101 | # create distractor sequences 102 | dis_seq = test_set.create_batches(batch_size, seq_length, distractor=True,fixed_starting_point=starting_point) 103 | # collect responses of networks on test set 104 | responses = extract_responses(net, norm_seq) 105 | # collect responses of networks on distractor set 106 | d_responses = extract_responses(net, dis_seq) 107 | anomalies = detect_anomalies(net, responses, d_responses, Z_crit) 108 | for i in range(len(anomalies)): 109 | if anomalies[i] == 1: 110 | class_error_units[target, i] = 1 # 111 | error_units[i] = 1 112 | return error_units, class_error_units 113 | 114 | def detect_anomalies(net, responses, d_responses, Z_crit=2.576): 115 | t_ind = 8 # look at final time point 116 | n_units, n_samples = responses.shape[0], torch.tensor(responses.shape[1]) 117 | mean_responses, std_responses = responses[:,:, t_ind].mean(axis=1), responses[:,:, t_ind].std(axis=1) 118 | mean_distractor, std_d_responses = d_responses[:,:, t_ind].mean(axis=1), d_responses[:,:, t_ind].std(axis=1) 119 | 120 | #Z_crit = 2.576 #2.576 # 99% 121 | #Z_crit = 1.96 122 | anomalies = torch.zeros(n_units) 123 | 124 | for i in range(n_units): 125 | mu_i, mu_id = mean_responses[i], mean_distractor[i] 126 | # calculate standard errors and compute Z scores 127 | s_i, s_id = std_responses[i]/torch.sqrt(n_samples), std_d_responses[i]/torch.sqrt(n_samples) 128 | Z = torch.abs((mu_i - mu_id) / torch.sqrt(s_i**2 + s_id**2)) 129 | if Z >= Z_crit: 130 | anomalies[i] = 1 131 | return anomalies 132 | 133 | def extract_responses(net, test_set, seq_length=10): 134 | """ 135 | Collect responses from h_t from the network on the test 136 | data 137 | """ 138 | test_data, test_labels = test_set 139 | 140 | n_units, nbatch, batch_size = net.model.W.shape[0], test_data.shape[0], 1 141 | responses = torch.zeros(n_units, nbatch, seq_length) 142 | 143 | for i,batch in enumerate(test_data): 144 | state = net.model.init_state(batch_size) 145 | 146 | for t in range(seq_length): 147 | #state = net.get_next_state(state,batch[t]) 148 | state, l_terms= net.model.forward(batch[t], state) 149 | a, h, W = l_terms 150 | # collect unit responses h_t 151 | responses[:, i, t] = h.squeeze() #state.squeeze() 152 | return responses.detach() 153 | 154 | 155 | def compute_allclass_rd(net, test_set, seq_length=10): 156 | responses, drives = list(zip(*[tuple(compute_responses_drive(net,test_set, target, seq_length)) for target in range(0, 9)])) 157 | return torch.cat(responses, axis=1), torch.cat(drives, axis=1) 158 | 159 | def compute_unit_types(net:ModelState, dataset:Dataset, training_set:Dataset=None, Z_crit=2.576, seed=2553): 160 | """ 161 | Helper function that determines the types of units in the network 162 | 163 | The types are: 164 | 0: pure error unit (e*) 165 | 1: pure prediction unit (p*) 166 | 2: hybrid (h*) 167 | 3: unspecified (u) 168 | 169 | The pure units are devided based on how many classes they predict/error 170 | signal for: 171 | 172 | subtypes pure error units: 173 | 0.1: 1 class (e1) 174 | 0.2: multiclasses (e2) 175 | 176 | subtypes pure prediction units: 177 | 1.1: 1 class (p1) 178 | 1.2: multiclasses (p2) 179 | 180 | subtypes hybrid units: 181 | 2.1: hybrid unit within (prediction and error unit for the same class) (h1) 182 | 2.2: hybrid unit across (prediction and error unit for different classes) (h2) 183 | 184 | Resulting in 6 different typings 185 | 186 | These are assigned as: 187 | 0-1: pure error (0: e1, 1: e2) 188 | 2-3: pure prediction unit (2: p1, 3: p2) 189 | 4-5: hybrid (4: h1, 5: h2) 190 | 6: unspecified (u) 191 | """ 192 | if seed != None: 193 | torch.manual_seed(seed) 194 | np.random.seed(seed) 195 | #preact_stats = compute_preact_stats(net, dataset) 196 | nunits, nclasses = net.model.hidden_size, 10 197 | pred_rule = find_pred_units(net, dataset, Z_crit=Z_crit) 198 | _, error_rule = find_error_units(net, dataset, Z_crit=Z_crit) 199 | units_stats = [] 200 | 201 | for i in range(nunits): 202 | # count the number of classes the unit is predictive and error signaling for 203 | #n_pred, n_err = 0,0 204 | # track the classes that the unit is predictive and error signaling for 205 | cpred, cerr = [], [] 206 | for j in range(nclasses): 207 | if pred_rule[i,j] == 1: 208 | cpred.append(j) 209 | #n_pred += 1 210 | if error_rule[j,i] == 1: 211 | cerr.append(j) 212 | #n_err += 1 213 | # record if unit is predictive and error signaling 214 | within = list(set(cpred).intersection(cerr)) 215 | unique_pred = set(cpred).difference(set(cerr)) 216 | unique_err = set(cerr).difference(set(cpred)) 217 | if len(unique_pred) == 0 or len(unique_err) == 0: # cannot be across if one the lists are empty 218 | across = [] 219 | else: # just take the union of the two sets 220 | across = list(unique_pred.union(unique_err)) 221 | # record stats for unit i 222 | # cpred: the classes the unit is predictive for, cerr: the classes the unit is error signaling for 223 | # within: the classes the unit is both predictive and error signaling for 224 | # across: classes that the unit is either predictive or error signaling for 225 | units_stats.append((cpred, cerr, within, across)) 226 | 227 | # parse type (traverse decision tree) 228 | # assign a type to the units 229 | units_types = torch.zeros(nunits) 230 | for i, stats in enumerate(units_stats): 231 | cpred, cerr, within, across = stats 232 | # decide if unspecified or not 233 | if len(cpred) == 0 and len(cerr) ==0: 234 | units_types[i] = 6 235 | 236 | else: 237 | # decide if hybrid or not 238 | ply1 = int(len(within) > 0 or len(across) > 1) 239 | 240 | if ply1: #hybrid branche 241 | # decide if within/across 242 | ply2 = int((len(across) > 1)) 243 | units_types[i] = CMAPPING[str(ply1)][str(ply2)] 244 | 245 | else: # PE branch 246 | # decide if prediction unit 247 | ply2 = int((len(cpred) > 0)) 248 | # decide if multiclass 249 | ply3 = int(len(cpred) > 1 or len(cerr) >1) 250 | units_types[i] = CMAPPING[str(ply1)][str(ply2)][str(ply3)] 251 | 252 | 253 | return units_types, units_stats 254 | 255 | def compute_responses_drive(net, test_set, target=0, seq_length=10): 256 | """ 257 | collect network responses & drive h_k & network drive p_k+1 258 | look at h_k, p_k+1 (you want to correlate unit output i. vs. unit drive j.) 259 | 260 | Output: response matrix, drive matrix (NxOxK) where N=nr units, O= nr 261 | observations, K= sequence length = largest temporal history 262 | """ 263 | batch_size = 1 264 | 265 | # record network predictions (activities of units) 266 | state = net.model.init_state(batch_size) 267 | 268 | # collect the response matrices in here 269 | response_list, drive_list, synaptrans_list = [], [], [] 270 | # upper bound on temporal history since you need to be able to predict 271 | # one time step in the future and need to deal with 0-indexing 272 | K = seq_length-1 273 | 274 | for k in range(K): 275 | # determine where the sequence starts given temporal history k 276 | seq_start = (target - k) % 10 277 | 278 | # create sequences 279 | batch_data, batch_labels = test_set.create_batches(batch_size, \ 280 | seq_length, fixed_starting_point=seq_start) 281 | nbatch = batch_data.shape[0] 282 | 283 | 284 | response_k, drive_k, synaptrans_k = torch.zeros((net.model.W.shape[0], nbatch)).to(net.device),\ 285 | torch.zeros((net.model.W.shape[0], nbatch)).to(net.device), torch.zeros((net.model.W.shape[0], nbatch)).to(net.device) 286 | 287 | 288 | abs_W = net.model.W.detach()#torch.abs(net.model.W.detach()) 289 | 290 | # get observed responses and predictions associated with target 291 | for i, batch in enumerate(batch_data): 292 | # move state forward to k 293 | for m in range(0, k+1): # m in [0,..,k] 294 | state = net.get_next_state(state,batch[m]) 295 | 296 | # collect drives and responses for target (h_k, p_k+1) 297 | response_k[:, i] = state.to(net.device) 298 | drive_k[:, i] = net.predict(state).squeeze() 299 | synaptrans_k[:, i] = torch.sum(response_k[:, i].unsqueeze(-1) * abs_W, axis=1) 300 | # reset state 301 | state = net.model.init_state(batch_size) 302 | 303 | # add the responses and drives to the list 304 | response_list.append(response_k) 305 | drive_list.append(drive_k) 306 | synaptrans_list.append(synaptrans_k) 307 | 308 | # construct full matrices, normalize and return them 309 | responses, drives, synaptrans = torch.stack(response_list, dim=-1), torch.stack(drive_list, dim=-1), torch.stack(synaptrans_list, dim=-1) 310 | 311 | return responses.detach(), drives.detach(), synaptrans.detach() 312 | 313 | 314 | def compute_targ_pred_corrmat(responses, drive): 315 | """ 316 | records the correlation between h_k and p_k for temporal window 317 | k 318 | 319 | Output: a correlation matrix with dimensions N**2xK, where entry i,j contains the correlation 320 | between h^i_k and p^j_k+1, where K is the temporal window and N the 321 | number of units 322 | 323 | 324 | 325 | """ 326 | n_units, n_obs, T = responses.shape 327 | K = T-1 # upper bound on temporal history 328 | 329 | corr_mat = torch.zeros((n_units, n_units, K)) 330 | 331 | for k in range(0, K): 332 | # compute correlation matrix for u_tk: (h_t-k, p_t) 333 | h_k, p_k1 = responses[:, :, k], drive[:, :, k+1] 334 | # compute correlation coefficient 335 | c_k = torch.tensor(np.ma.corrcoef(np.ma.masked_invalid(h_k), \ 336 | np.ma.masked_invalid(p_k1)))[:n_units, n_units:] # only look at second quadrant 337 | for u_r in range(n_units): 338 | for u_d in range(n_units): 339 | corr_mat[u_r, u_d,k] = c_k[u_r, u_d] 340 | return corr_mat 341 | 342 | def compute_post_drive_bootstrap(pred,error ,seq_length=9): 343 | """ compute bootstrap bounds for each time point""" 344 | bs_pred, bs_error = [], [] 345 | for t in range(seq_length): 346 | bs_pred.append(bs.bootstrap(pred[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 347 | bs_error.append(bs.bootstrap(error[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 348 | return bs_pred, bs_error 349 | 350 | def compute_bootstrap(samples_list, seq_length=10): 351 | """ compute bootstrap bounds for each timepoint and set of samples""" 352 | bs_list = [[] for samples in samples_list] 353 | for t in range(seq_length): 354 | for i, bsamples in enumerate(bs_list): 355 | samples = samples_list[i] 356 | bsamples.append(bs.bootstrap(samples[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 357 | return bs_list 358 | 359 | def compute_bootstrap_dep(notn, meds, gmed, net, net_les=None, net_les_rev=None ,seq_length=10): 360 | """ compute bootstrap bounds for each time point""" 361 | bs_notn, bs_meds, bs_gmed, bs_net, bs_netles, bs_netles_rev = [],[],[],[], [], [] 362 | 363 | for t in range(seq_length): 364 | bs_notn.append(bs.bootstrap(notn[:,t], stat_func=bs_stats.mean, iteration_batch_size=None)) 365 | bs_meds.append(bs.bootstrap(meds[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 366 | bs_gmed.append(bs.bootstrap(gmed[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 367 | bs_net.append(bs.bootstrap(net[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 368 | 369 | 370 | if net_les is not None: 371 | bs_netles.append(bs.bootstrap(net_les[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 372 | bs_netles_rev.append(bs.bootstrap(net_les_rev[:, t], stat_func=bs_stats.mean, iteration_batch_size=None)) 373 | return bs_notn, bs_meds, bs_gmed, bs_net, bs_netles, bs_netles_rev 374 | 375 | def extract_lower_upper(bs_list): 376 | """ 377 | wrapper function that extracts upper and lower bounds of the confidence 378 | interval 379 | """ 380 | lower, upper = [bs.lower_bound for bs in bs_list], [bs.upper_bound for bs in bs_list] 381 | return lower,upper 382 | 383 | # 384 | # --- Helper function for Appendix A Figures A1 & A2 --- 385 | # 386 | def _run_seq_from_digit(digit, steps, net:ModelState, dataset:Dataset, mask=None): 387 | """Create sequences with the same starting digit through a model and return the hidden state 388 | 389 | Parameters: 390 | - digit: the last digit in the sequence 391 | - steps: sequence length, or steps before the sequence gets to the 'digit' 392 | - net: model 393 | - dataset: dataset to use 394 | - mask: mask can be used to turn off (i.e. lesion) certain units 395 | """ 396 | fixed_starting_point = (digit - steps) % 10 397 | b, _ = dataset.create_batches(batch_size=-1, sequence_length=steps, shuffle=True, fixed_starting_point=fixed_starting_point) 398 | batch = b.squeeze() # removed 0 because of weird 399 | 400 | h = net.model.init_state(1) 401 | h = h.to(net.device) 402 | for i in range(steps): 403 | # check if mask needs to be applied 404 | if mask is not None: 405 | mask = mask.to(net.device) 406 | # check if mask is for error or for prediction 407 | if len(mask.shape) > 1: # error mask 408 | h, l_a = net.model(batch[i], state=h, mask=mask) 409 | else: 410 | h, l_a = net.model(batch[i], state=h) 411 | h = h * mask 412 | else: 413 | h, l_a = net.model(batch[i], state=h) 414 | return h.detach() 415 | 416 | # 417 | # --- Helper functions for lesion plots (Figures 4, 5C) 418 | # 419 | 420 | def pred_class_mask(net:ModelState, test_set:Dataset, target=0, Z_crit=2.576): 421 | """ 422 | 423 | 424 | returns prediction unit mask for class: target 425 | 426 | """ 427 | target = (target - 1) % 10 # activation will affect prediction one time step later 428 | n_units = net.model.W.shape[0] 429 | # shape: nunits x nclasses x 2 430 | preact_stats = compute_preact_stats(net, test_set) 431 | 432 | med, mad = preact_stats[:, :, 0], preact_stats[:,:, 1] 433 | 434 | pred_mask = torch.ones(n_units) 435 | for i in range(n_units): 436 | if (torch.abs(med[i][target]) - torch.abs(Z_crit*mad[i][target])) > 0: 437 | pred_mask[i] = 0 # unit i is predictive for class target 438 | 439 | return pred_mask 440 | 441 | 442 | def _pred_mask(net:ModelState, test_set:Dataset, training_set:Dataset, latent=False, reverse=False, Z_crit=2.576): 443 | """ 444 | Wrapper function for calling the routine that computes the mask for the networks 445 | """ 446 | pred_mask = _pred_mask_mad(net, test_set,training_set, latent=latent, reverse=reverse, Z_crit=Z_crit) 447 | return pred_mask 448 | 449 | def _error_mask(net:ModelState, test_set, training_set, latent=False, reverse=False): 450 | """ 451 | Knock out lateral connections between error units that are not 452 | prediction units 453 | 454 | """ 455 | error_units, _ = find_error_units(net, test_set) 456 | error_indices = (error_units).nonzero().squeeze() 457 | pred_units = _pred_mask(net, test_set, training_set= training_set, latent=latent, reverse=reverse) 458 | pred_indices = (pred_units == 0).nonzero().squeeze() 459 | unique_error = error_indices[~error_indices.unsqueeze(1).eq(pred_indices).any(1)] 460 | unique_pred = pred_indices[~pred_indices.unsqueeze(1).eq(error_indices).any(1)] 461 | mask = torch.ones(net.model.W.shape) 462 | for i in range(mask.shape[0]): 463 | for j in range(mask.shape[1]): 464 | if i in error_indices and j in pred_indices: 465 | mask[i][j] = 0 466 | elif i in pred_indices and i in error_indices: # yellow unit 467 | mask[i][j] = 0 # prevent yellow units from inhibiting at t=1 468 | return mask 469 | 470 | def _pred_mask_mad(net:ModelState, test_set:Dataset, training_set:Dataset, latent=False, reverse=False, Z_crit=2.576): 471 | """ 472 | Returns a mask for the network units, where each entry is 1 if the 473 | associated unit has a bias in its final time point median 474 | preactivation and standard error in at least one class. 475 | The rationale behind this approach is that a unit with nonzero preactivation 476 | has to have a functional role in supressing activity induced by the incoming 477 | digit since it would have been supressed by the objective function otherwise. 478 | 479 | """ 480 | if type(training_set) is mnist.MNISTDataset: 481 | 482 | class_meds = mnist.medians(training_set) 483 | else: # cifar 484 | class_meds = cifar.medians(training_set) 485 | 486 | preact_stats = compute_preact_stats(net, test_set) 487 | 488 | med, mad = preact_stats[:, :, 0], preact_stats[:, :, 1] 489 | 490 | n_units, n_classes = net.model.W.shape[0], len(class_meds) 491 | A_mask = torch.zeros(n_units) 492 | 493 | 494 | for i in range(n_units): 495 | for j in range(n_classes): 496 | if (torch.abs(med[i][j]) - torch.abs(Z_crit*mad[i][j])) > 0: 497 | A_mask[i] = 1 # unit i is predictive for class j 498 | 499 | pred_mask = torch.ones(net.model.W.shape[0]) 500 | 501 | if reverse: 502 | N_pred = sum(A_mask == 1).item() 503 | 504 | idx = (A_mask == 0).nonzero().flatten() 505 | perm = torch.randperm(len(idx)) 506 | idx = idx[perm[:N_pred]] 507 | pred_mask[idx] = 0 508 | else: 509 | pred_mask[A_mask == 1] = 0 510 | return pred_mask 511 | 512 | 513 | # 514 | # --- Helper function (general) --- 515 | # 516 | def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): 517 | """ 518 | Adapted from https://stackoverflow.com/a/18926541 519 | """ 520 | if isinstance(cmap, str): 521 | cmap = plt.get_cmap(cmap) 522 | new_cmap = mpl.colors.LinearSegmentedColormap.from_list( 523 | 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), 524 | cmap(np.linspace(minval, maxval, n))) 525 | return new_cmap 526 | # 527 | # --- Helper function for identifying prediction units 528 | # Figures 3A,B 4B, 5B,C, Appendix A3,A4 --- 529 | # 530 | def compute_preact_stats(net:ModelState, dataset:Dataset, nclasses=10, ntime=10): 531 | """ 532 | Computer for each unit the average final time point median preactivation and MAD 533 | for each class 534 | 535 | Output: preact_stats matrix n_units x nclasses x 2 536 | """ 537 | 538 | preact_stats = torch.zeros((net.model.hidden_size, nclasses, 2)) 539 | 540 | # generate sequences that end in the same class 541 | for t in [ntime - 1]: # only look at final time point (0-indexed) 542 | for category in range(nclasses): 543 | starting_point = int(category - t + ntime) 544 | if starting_point > (ntime - 1): # cycle back 545 | starting_point -= ntime 546 | 547 | data, labels = dataset.create_batches(-1,ntime, shuffle=False,fixed_starting_point=starting_point) 548 | 549 | nb, ntime,batch_size,ninputs = data.shape 550 | 551 | data = data.squeeze(0) 552 | labels = labels.squeeze(0) 553 | batch_size = data.shape[1] 554 | h_net = net.model.init_state(batch_size) 555 | 556 | for i in range(data.shape[0]): # calculate response variance of category up until t 557 | x = data[i] 558 | h_net, l_net = net.model(x, state=h_net) 559 | #energy = _calc_energy(net, l_net[0], l_net[1]) 560 | if i == t: 561 | med, mad= l_net[0].to(net.device).median(axis=0).values, torch.tensor(st.median_abs_deviation(l_net[0].cpu().detach().numpy(), axis=0, scale='normal')) 562 | 563 | preact_stats[:, category, 0] = med 564 | preact_stats[:, category, 1] = mad 565 | 566 | return preact_stats.detach() 567 | 568 | 569 | # --- Helper functions that compute preactivation figures lesioned & non-lesioned 570 | # Figure 2A, 4B & 5C --- 571 | # 572 | def model_activity(net:ModelState, 573 | training_set:Dataset, 574 | test_set:Dataset, 575 | seq_length=10, 576 | data_type='mnist', 577 | color=False, 578 | save=True): 579 | """ 580 | calculates model preactivation and preactivation bounds 581 | for unlesioned models 582 | """ 583 | nclass = 10 # change this if you want to change the number of classes 584 | # category medians and median for all images 585 | if data_type == 'mnist': 586 | meds = mnist.medians(training_set) 587 | global_median = training_set.x.median(dim=0).values 588 | N = 784 589 | elif data_type == 'cifar': 590 | meds = cifar.medians(training_set) 591 | global_median = training_set.x.median(dim=0).values 592 | N = 3072 # only compute results over non-latent units 593 | # calc energy demands for theoretical benchmarks 594 | #meds = _calc_energy(net, meds, torch.nn.ReLU(meds)) # first preact 595 | #global_median = _calc_energy(net, global_median, torch.nn.ReLU(global_median)) 596 | with torch.no_grad(): 597 | data, labels = test_set.create_batches(-1, seq_length, shuffle=True) 598 | nb, ntime,batch_size,ninputs = data.shape 599 | 600 | 601 | data = data.squeeze(0) 602 | labels = labels.squeeze(0) 603 | batch_size = data.shape[1] 604 | 605 | # result lists 606 | mu_notn = [] 607 | mu_meds = [] 608 | mu_gmed = [] 609 | mu_net = [] 610 | mu_input = [] 611 | mu_latent = [] 612 | 613 | 614 | h_net = torch.zeros(batch_size, N) 615 | h_net = net.model.init_state(batch_size) 616 | 617 | for t in range(data.shape[0]): 618 | 619 | x = data[t] 620 | y = labels[t] 621 | # calculate energy demands of x 622 | #x = _calc_energy(net, x, torch.nn.ReLU(x)) first test with preact 623 | # repeat global median for each input image 624 | gmedian = torch.zeros_like(x) 625 | gmedian[:,:] = global_median 626 | # find the corresponding median for each input image 627 | median = torch.zeros_like(x) 628 | for i in range(nclass): 629 | median[y==i,:] = meds[i] 630 | 631 | # calculate hidden state 632 | h_meds = (x - median) 633 | h_gmed = (x - gmedian) 634 | 635 | 636 | # calculate L1 loss for each unit, assuming equal amounts of units in each model 637 | m_notn = x.abs().sum(dim=1)/net.model.input_size 638 | m_meds = h_meds.abs().sum(dim=1)/net.model.input_size 639 | m_gmed = h_gmed.abs().sum(dim=1)/net.model.input_size 640 | 641 | 642 | 643 | h_net, l_net = net.model(x, state=h_net) 644 | # calculate energy demands for the network 645 | m_net = torch.cat([a[:,:ninputs] for a in l_net[0]], dim=1).abs().mean(dim=1).mean() 646 | m_input = torch.cat([a[:,:ninputs] for a in l_net[0]], dim=1).abs().mean(dim=1).mean() 647 | m_latent = torch.cat([a[:,ninputs:] for a in l_net[0]], dim=1).abs().mean(dim=1).mean() 648 | 649 | # commented out for later analyses 650 | #m_net = _calc_energy(net, m_net, torch.nn.ReLU(m_net)) 651 | #m_input = _calc_energy(net, m_input, torch.nn.ReLU(m_input)) 652 | #m_latent = _calc_energy(net, m_latent, torch.nn.ReLU(m_latent)) 653 | 654 | # Calculate the mean 655 | mu_notn.append(m_notn.mean().cpu().item()) 656 | mu_meds.append(m_meds.mean().cpu().item()) 657 | mu_gmed.append(m_gmed.mean().cpu().item()) 658 | mu_net.append(m_net.mean().cpu().item()) 659 | mu_input.append(m_input.mean().cpu().item()) 660 | mu_latent.append(m_latent.mean().cpu().item()) 661 | 662 | 663 | 664 | return data, np.array(mu_notn), np.array(mu_meds), np.array(mu_gmed), np.array(mu_net), np.array(mu_input), np.array(mu_latent) 665 | 666 | 667 | def bootstrap_model_activity(nets:[ModelState], 668 | train_set:Dataset, 669 | test_set:Dataset, 670 | seq_length=10, 671 | energy='ec', 672 | lesioned=True, 673 | lesion_type='pred', 674 | latent=False, 675 | seed=None, 676 | Z_crit=2.576, 677 | data_type='mnist'): 678 | """ 679 | 680 | Calculates energy consumption of models and 681 | all CI 99%/95% bootstrapped with replacement 682 | 683 | """ 684 | if seed != None: 685 | torch.manual_seed(seed) 686 | np.random.seed(seed) 687 | 688 | 689 | # initialize sample matrices 690 | norm_samples = np.zeros((len(nets), seq_length)) 691 | lesion_samples = np.zeros((len(nets), seq_length)) 692 | cont_samples = np.zeros((len(nets), seq_length)) 693 | 694 | for i, net in enumerate(nets): 695 | mu_norm, mu_les =\ 696 | model_activity_lesioned(net, train_set, test_set, lesion_type='pred', seq_length=10, energy=energy, save=False,\ 697 | latent=False, data_type='mnist',Z_crit=Z_crit) 698 | 699 | # calculate energy curves with control lesion 700 | _, mu_cont=\ 701 | model_activity_lesioned(net, train_set, test_set, lesion_type='pred', seq_length=10, energy=energy, save=False,\ 702 | latent=False, data_type='mnist', reverse=True, Z_crit=Z_crit) 703 | # fill sample matrices 704 | norm_samples[i, :] = mu_norm 705 | lesion_samples[i, :] = mu_les 706 | cont_samples[i, :] = mu_cont 707 | 708 | # compute bootstrap bounds 709 | [bs_norm, bs_lesion, bs_cont] = compute_bootstrap([norm_samples, lesion_samples, cont_samples]) 710 | # store samples and bs in dictionary 711 | bs_sample_dict = {'norm': [norm_samples], 'lesion': [lesion_samples], 'cont': \ 712 | [cont_samples], 'bs_norm': [bs_norm],'bs_lesion': [bs_lesion], 'bs_cont':[bs_cont]} 713 | return bs_sample_dict 714 | 715 | def model_activity_lesioned(net:ModelState, training_set:Dataset, test_set:Dataset, lesion_type='pred', 716 | seq_length=10, energy='ec', save=True, 717 | latent=False, data_type='mnist', reverse=False, Z_crit=2.576): 718 | """ 719 | calculates model preactivation and preactivation bounds 720 | for lesioned models 721 | """ 722 | if data_type == 'mnist': 723 | batch_size = -1 # full dataset 724 | else: 725 | batch_size = 32 726 | if lesion_type == 'error': 727 | mask = _error_mask(net, test_set, training_set,latent=latent, reverse=False) 728 | else: 729 | mask = _pred_mask(net, test_set, training_set= training_set, latent=latent, reverse=reverse, Z_crit=Z_crit) 730 | 731 | with torch.no_grad(): 732 | data, labels = test_set.create_batches(batch_size, seq_length, shuffle=True) 733 | nbatch, ntime, batch_size, ninputs = data.shape 734 | #data = data.squeeze(0) 735 | #labels = labels.squeeze(0) 736 | #batch_size = data.shape[1] 737 | 738 | # result lists 739 | mu_net, mu_netles = torch.zeros(ntime), torch.zeros(ntime) 740 | 741 | #mu_input = [] 742 | #mu_latent = [] 743 | 744 | h_net = net.model.init_state(batch_size) 745 | h_netles = net.model.init_state(batch_size) 746 | # create seperate states to prevent leaking across batches 747 | state = h_net.unsqueeze(0).repeat_interleave(nbatch, dim=0) 748 | lesioned_state = h_netles.unsqueeze(0).repeat_interleave(nbatch, dim=0) 749 | for t in range(ntime): 750 | m_net, m_netles = [], [] 751 | for b in range(nbatch): 752 | h_net, h_netles = state[b], lesioned_state[b] 753 | x = data[b,t] 754 | h_net, l_net = net.model(x, state=h_net) 755 | if lesion_type == 'pred': 756 | h_netles = h_netles * mask # perform lesion 757 | h_netles, l_netles = net.model(x, state=h_netles) 758 | else: # lesion error units 759 | h_netles, l_netles = net.model(x, state=h_net, mask=mask) 760 | 761 | # calculate energy of the hidden states 762 | if energy != 'pre': 763 | l_net[0] = _calc_energy(net, l_net[0], energy) 764 | l_netles[0] = _calc_energy(net, l_netles[0], energy, mask) 765 | 766 | #m_net[b,:] = torch.cat([a[:,:ninputs]for a in [l_net[0]]], dim=1).abs().mean(dim=1) 767 | 768 | #m_netles[b,:] = torch.cat([a[:,:ninputs] for a in [l_netles[0]]], dim=1).abs().mean(dim=1) 769 | #m_input = torch.cat([a[:,:ninputs] for a in [l_netles[0]]], dim=1).abs().mean(dim=1).mean() 770 | #m_latent = torch.cat([a[:,ninputs:] for a in [l_netles[0]]], dim=1).abs().mean(dim=1).mean() 771 | m_net += torch.cat([a[:,:ninputs] for a in [l_net[0]]], dim=1).abs().mean(dim=1).tolist() 772 | m_netles += torch.cat([a[:,:ninputs] for a in [l_netles[0]]], dim=1).abs().mean(dim=1).tolist() 773 | # update state and lesioned state for batch b 774 | state[b], lesioned_state[b] = h_net, h_netles 775 | # m_net = m_net.mean(axis=0) 776 | #m_netles = m_netles.mean(axis=0) 777 | # Calculate the mean 778 | mu_net[t], mu_netles[t] = torch.tensor(m_net).mean(), torch.tensor(m_netles).mean() 779 | #mu_net.append(m_net.flatten().mean()) 780 | #mu_netles.append(m_netles.flatten().mean()) 781 | #mu_input.append(m_input.mean().cpu().item()) 782 | #mu_latent.append(m_latent.mean().cpu().item()) 783 | return np.array(mu_net), np.array(mu_netles) 784 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from Dataset import Dataset 4 | import random 5 | class MNISTDataset(Dataset): 6 | """Container class for the MNIST database containing Tensors with the images and labels, as well as a list of indices for each category 7 | """ 8 | def __init__(self, x, y, indices, repeat=1): 9 | super(Dataset, self).__init__(x=x, y=y, indices=indices) 10 | self.repeat = repeat 11 | 12 | def create_batches(self, batch_size, sequence_length, shuffle=True, distractor=False,fixed_starting_point=None): 13 | data, labels = create_sequences(self, sequence_length, batch_size, shuffle, distractor, fixed_starting_point) 14 | data = data.repeat_interleave(self.repeat, dim=1) 15 | labels = labels.repeat_interleave(self.repeat, dim=1) 16 | return data, labels 17 | 18 | def create_sequences(dataset, sequence_length, batch_size, shuffle=True, distractor=False, fixed_starting_point=None): 19 | # number of datapoints 20 | data_size = dataset.x.shape[0] 21 | 22 | # maximum theoretical amount of sequences 23 | max_sequences = int(data_size / sequence_length) 24 | 25 | # for test and validation it is not actually necessary to shuffle, 26 | # so for consistent testing/validation we can use the same sequences every time 27 | if shuffle: 28 | # shuffle all the data points per digit class 29 | indices = [dataset.indices[i][torch.randperm(d.shape[0])] for i,d in enumerate(dataset.indices)] 30 | # choose random sequence starting points 31 | seq_starting_points = torch.randperm(max_sequences) 32 | else: 33 | indices = dataset.indices 34 | seq_starting_points = torch.arange(max_sequences) 35 | # if we want the same starting digit for all the sequences 36 | if fixed_starting_point is not None: 37 | assert(isinstance(fixed_starting_point, int) and fixed_starting_point in list(range(10))) 38 | seq_starting_points = torch.ones(max_sequences) * fixed_starting_point 39 | # from the starting points, create sequences of the required length 40 | # first we repeat each starting point 'sequence_length' times 41 | sequences = seq_starting_points.repeat_interleave(sequence_length).view(max_sequences, sequence_length) 42 | # we then add to each digit the index of its position within the sequence, 43 | # so we get increasing numbers in the sequence 44 | for i in range(1, sequence_length): 45 | sequences[:,i] += i 46 | # take the remainder of all numbers in sequence to get actual digits from 0-9 47 | sequences %= 10 48 | # switch out digit at position 8 for a distractor if flag is true 49 | if distractor: 50 | for i in range(max_sequences): 51 | digit = sequences[i,8] 52 | candidates = list(range(0,10)) 53 | candidates.remove(digit) 54 | sequences[i, 8] = random.choice(candidates) 55 | # flatten again 56 | sequences = sequences.flatten() 57 | # create an array to store the indices for the digits in 'data' 58 | epoch_indices = torch.zeros(data_size, dtype=torch.long) 59 | # because not every digit is equally represented, 60 | # we have to keep track of where in the sequence we have run out of 61 | # digits. This 'cutoff' is the minimum between all digits 62 | cutoff = data_size 63 | 64 | for i in range(10): 65 | # mask to filter out the positions of this digit 66 | mask = sequences==i 67 | # calculating the cumulative sum of the mask gives us a nice increasing 68 | # index exactly at the points of where the digit is in the list of sequences. 69 | # we can use this as an index for 'indices' 70 | indices_idx = torch.cumsum(mask, 0) 71 | # we cut 'idx' off where the index exceeds the number of digits we actually have 72 | # for this case 73 | indices_idx = indices_idx[indices_idx < indices[i].shape[0]] 74 | # keep track of the earliest cutoff point for later 75 | cutoff = min(cutoff, indices_idx.shape[0]) 76 | # also cutoff the mask so it has the right shape 77 | mask = mask[:indices_idx.shape[0]] 78 | # we select the data indices from 'indices' with 'indices_idx', mask that 79 | # so we are left with the data indices on the positions where the digits occur 80 | # in the sequences 81 | epoch_indices[:indices_idx.shape[0]][mask] = indices[i][indices_idx][mask] 82 | 83 | # if batch_size is invalid, create one big batch 84 | if batch_size < 1 or batch_size > int(cutoff / sequence_length): 85 | batch_size = int(cutoff / sequence_length) 86 | 87 | # we cut off the cutoff point so we can create an integer amount of batches and sequences 88 | cutoff = cutoff - cutoff % (batch_size * sequence_length) 89 | 90 | epoch_indices = epoch_indices[:cutoff] 91 | sequences = sequences[:cutoff] 92 | # select the data points and group per sequence and batch 93 | x = dataset.x[epoch_indices].view(-1, batch_size, sequence_length, 28*28).transpose(1,2) 94 | y = sequences.view(-1, batch_size, sequence_length).transpose(1,2) 95 | return x, y 96 | 97 | 98 | 99 | def load(val_ratio = 0.1): 100 | """Load MNIST data, transform to tensors and calculate indices for each category 101 | """ 102 | train_data = torchvision.datasets.MNIST("./datasets/", train=True, transform=torchvision.transforms.ToTensor(), download=True) 103 | test_data = torchvision.datasets.MNIST("./datasets/", train=False, transform=torchvision.transforms.ToTensor(), download=True) 104 | 105 | validation_size = int(val_ratio * len(train_data)) 106 | train_size = len(train_data) - validation_size 107 | 108 | # reformat the dataset(s) in a sensible format 109 | train_x = torch.zeros((train_size, 28*28)) 110 | train_y = torch.zeros(train_size, dtype=torch.int) 111 | val_x = torch.zeros((validation_size, 28*28)) 112 | val_y = torch.zeros(validation_size, dtype=torch.int) 113 | for i, d in enumerate(train_data): 114 | if i < train_size: 115 | train_x[i] = d[0].view(28*28) 116 | train_y[i] = d[1] 117 | else: 118 | val_x[i-train_size] = d[0].view(28*28) 119 | val_y[i-train_size] = d[1] 120 | # safe image indices for each category 121 | train_indices = [torch.nonzero(train_y==i).flatten() for i in range(10)] 122 | val_indices = [torch.nonzero(val_y==i).flatten() for i in range(10)] 123 | training_set = MNISTDataset(x=train_x, y=train_y, indices=train_indices) 124 | validation_set = MNISTDataset(x=val_x, y=val_y, indices=val_indices) 125 | 126 | test_x = torch.zeros((len(test_data), 28*28)) 127 | test_y = torch.zeros(len(test_data), dtype=torch.int) 128 | for i, d in enumerate(test_data): 129 | test_x[i] = d[0].view(28*28) 130 | test_y[i] = d[1] 131 | test_indices = [torch.nonzero(test_y==i).flatten() for i in range(10)] 132 | test_set = MNISTDataset(x=test_x, y=test_y, indices=test_indices) 133 | 134 | return training_set, validation_set, test_set 135 | 136 | def means(dataset:MNISTDataset): 137 | means = torch.Tensor(10,28*28) 138 | for i in range(10): 139 | means[i] = torch.mean(dataset.x[dataset.indices[i]],dim=0) 140 | return means 141 | 142 | def medians(dataset:MNISTDataset): 143 | medians = torch.Tensor(10,28*28) 144 | for i in range(10): 145 | medians[i] = torch.median(dataset.x[dataset.indices[i]],dim=0).values 146 | return medians 147 | 148 | if __name__ == '__main__': 149 | load() -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/models/.DS_Store -------------------------------------------------------------------------------- /models/patterns_rev/seeded_mnist/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KietzmannLab/EmergentPredictiveCoding/4b926abd080c9c67f698d71a037af261fba40c30/models/patterns_rev/seeded_mnist/.gitignore -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | from matplotlib.ticker import MaxNLocator 7 | from matplotlib import cycler 8 | 9 | from Dataset import Dataset 10 | from ModelState import ModelState 11 | 12 | # import helper functions for plotting 13 | import helper 14 | # Global matplotlib settings 15 | 16 | colors = cycler('color', 17 | ['#EE6666', '#3388BB', '#9988DD', 18 | '#EECC55', '#88BB44', '#FFBBBB']) 19 | 20 | 21 | plt.rc('axes', axisbelow=True, prop_cycle=colors) 22 | plt.rc('grid', linestyle='--') 23 | plt.rc('xtick', direction='out', color='black') 24 | plt.rc('ytick', direction='out', color='black') 25 | plt.rc('lines', linewidth=2) 26 | 27 | # for a bit nicer font in plots 28 | mpl.rcParams['font.family'] = ['sans-serif'] 29 | mpl.rcParams['font.size'] = 18 30 | 31 | plt.style.use('ggplot') 32 | 33 | # --------------- Helper/non core functions --------------- 34 | # 35 | 36 | def save_fig(fig, filepath, bbox_inches=None): 37 | """Convenience wrapper for saving figures in a default "../Results/" directory and auto appends file extensions ".svg" 38 | and ".png" 39 | """ 40 | fig.savefig(filepath + ".svg", bbox_inches=bbox_inches) 41 | fig.savefig(filepath + ".png", bbox_inches=bbox_inches) 42 | 43 | def axes_iterator(axes): 44 | """Iterate over axes. Whether it is a single axis object, a list of axes, or a list of a list of axes 45 | """ 46 | if isinstance(axes, np.ndarray): 47 | for ax in axes: 48 | yield from axes_iterator(ax) 49 | else: 50 | yield axes 51 | 52 | def init_axes(len_x, figsize, shape=None, colorbar=False): 53 | """Convenience function for creating subplots with configuratons 54 | 55 | Parameters: 56 | - len_x: amount of subfigures 57 | - figsize: size per subplot. Actual figure size depends on the subfigure configuration and if colorbars are visible. 58 | - shape: subfigure configuration in rows and columns. If 'None', a configuration is chosen to minimise width and height. Default: None 59 | - colorbar: whether colorbars are going to be used. Used for figsize calculation 60 | """ 61 | if shape is not None: 62 | assert isinstance(shape, tuple) 63 | ncols = shape[0] 64 | nrows = shape[1] 65 | else: 66 | nrows = int(np.sqrt(len_x)) 67 | ncols = int(len_x / nrows) 68 | while not nrows*ncols == len_x: 69 | nrows -= 1 70 | ncols = int(len_x / nrows) 71 | 72 | #figsize = (figsize[1] * ncols + colorbar*0.5*figsize[0], figsize[0] * nrows) 73 | figsize = (figsize[1] * ncols + 0.5*colorbar*figsize[0], figsize[0] * nrows) 74 | return plt.subplots(nrows, ncols, figsize=figsize) 75 | 76 | def set_size(w,h, ax=None): 77 | """ w, h: width, height in inches """ 78 | if not ax: ax=plt.gca() 79 | l = ax.figure.subplotpars.left 80 | r = ax.figure.subplotpars.right 81 | t = ax.figure.subplotpars.top 82 | b = ax.figure.subplotpars.bottom 83 | figw = float(w)/(r-l) 84 | figh = float(h)/(t-b) 85 | ax.figure.set_size_inches(figw, figh) 86 | 87 | def display(imgs, 88 | lims=(-1.0, 1.0), 89 | cmap='seismic', 90 | size=None, 91 | figsize=(4,4), 92 | shape=None, 93 | colorbar=True, 94 | axes_visible=True, 95 | layout='regular', 96 | figax=None): 97 | """Convenience function for plotting multiple tensors as images. 98 | 99 | Function to quickly display multiple tensors as images in a grid. 100 | Image dimensions are expected to be square and are taken to be the square root of the tensor size. 101 | Tensor dimensions may be arbitrary. 102 | The images are automatically layed out in a compact grid, but this can be overridden. 103 | 104 | Parameters: 105 | - imgs: (list of) input tensor(s) (torch.Tensor or numpy.Array) 106 | - lims: pixel value interval. If 'None', it is set to the highest absolute value in both directions, positive and negative. Default: (-1,1) 107 | - cmap: color map. Default: 'seismic' 108 | - size: image width and height. If 'None', it is set to the first round square of the tensor size. Default: None 109 | - figsize: size per image. Actual figure size depends on the subfigure configuration and if colorbars are visible. Default: (4,4) 110 | - shape: subfigure configuration, in rows and columns of images. If 'None', a configuration is chosen to minimise width and height. Default: None 111 | - colorbar: show colorbar for only last row of axes. Default: False 112 | - axes_visible: show/hide axes. Default: True 113 | - layout: matplotlib layout. Default: 'regular' 114 | - figax: if not 'None', use existing figure and axes object. Default: None 115 | -cmaps: pass list of colormaps 116 | """ 117 | if not isinstance(imgs, list): 118 | imgs = [imgs] 119 | shape = (1,1) 120 | 121 | if size is not None: 122 | if not isinstance(size, tuple): 123 | size = (size, size) 124 | 125 | # convert to numpy if not already so 126 | imgs = [im.detach().cpu().numpy() if isinstance(im, torch.Tensor) else im for im in imgs] 127 | 128 | if lims is None: 129 | mx = max([max(im.max(),abs(im.min())) for im in imgs]) 130 | lims = (-mx, mx) 131 | 132 | if figax is None: 133 | fig, axes = init_axes(len(imgs), figsize, shape=shape, colorbar=colorbar) 134 | else: 135 | fig, axes = figax 136 | 137 | for i, ax in enumerate(axes_iterator(axes)): 138 | 139 | img = imgs[i] 140 | ax.grid() 141 | if size is None: 142 | _size = int(np.sqrt(img.size)) 143 | img = img[:_size*_size].reshape(_size,_size) 144 | else: 145 | img = img[:size[0]*size[1]].reshape(size[0],size[1]) 146 | 147 | plot_im = ax.imshow(img,cmap=cmap) 148 | 149 | ax.label_outer() 150 | 151 | if lims is not None: 152 | plot_im.set_clim(lims[0], lims[1]) 153 | 154 | if axes_visible == False: 155 | ax.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False) 156 | 157 | 158 | if colorbar: 159 | 160 | # if isinstance(axes, np.ndarray): 161 | # for rax in axes: 162 | 163 | # if isinstance(rax, np.ndarray): 164 | # fig.colorbar(plot_im, ax=rax, shrink=0.80, location='right'); 165 | # else: 166 | # fig.colorbar(plot_im, ax=rax, shrink=0.80); 167 | #else: 168 | #fig.colorbar(plot_im, ax=axes, shrink=0.80); 169 | if isinstance(axes, np.ndarray): 170 | fig.colorbar(plot_im, ax = axes[-1], location='right') 171 | fig.colorbar(plot_im, ax=axes, shrink=0.80); 172 | set_size(figsize[0]+0.5, figsize[1], axes[-1]) 173 | else: 174 | fig.colorbar(plot_im, ax=axes) 175 | set_size(figsize[0]+0.5, figsize, axes[-1]) 176 | 177 | if layout == 'tight': 178 | fig.tight_layout() 179 | 180 | return fig, axes 181 | 182 | 183 | 184 | def scatter(x, y, discrete=False, figsize=(8,6), color = 'r', xlabel="", ylabel="", legend=None, figax=None): 185 | """Convenience function to create scatter plots 186 | 187 | Parameters: 188 | - x: x data points. Array or list of arrays. 189 | - y: y data points 190 | - discrete: whether xaxis ticks should be integer values. Default: False 191 | - figsize: matplotlib figsize. Default: (8,6) 192 | - xlabel: Default: "" 193 | - ylabel: Default: "" 194 | - legend: display legend. Default: None 195 | - figax: if not 'None', use existing figure and axes objects. Default: None 196 | """ 197 | if figax is None: 198 | fig, axes = plt.subplots(1, figsize=figsize) 199 | else: 200 | fig, axes = figax 201 | 202 | if isinstance(x, list): 203 | ax = axes 204 | for i, _x in enumerate(x): 205 | ax.scatter(_x, y[i]) 206 | else: 207 | axes.scatter(x, y, c=color) 208 | 209 | if discrete: 210 | axes.xaxis.set_major_locator(MaxNLocator(integer=True)); 211 | axes.set_xlabel(xlabel) 212 | axes.set_ylabel(ylabel) 213 | axes.grid(); 214 | 215 | if legend is not None: 216 | axes.legend(legend) 217 | 218 | return fig, axes 219 | 220 | def training_progress(net:ModelState, save=True): 221 | """ 222 | 223 | wrapper function that shows model training 224 | 225 | """ 226 | fig, axes = init_axes(1, figsize=(6,8)) 227 | 228 | 229 | axes.plot(np.arange(1, len(net.results["train loss"])+1), net.results["train loss"], label="Training set") 230 | axes.plot(np.arange(1, len(net.results["test loss"])+1), net.results["test loss"], label="Test set") 231 | 232 | 233 | 234 | 235 | axes.xaxis.set_major_locator(MaxNLocator(integer=True)); 236 | axes.set_xlabel("Training time",fontsize=16) 237 | axes.set_ylabel("Loss",fontsize=16) 238 | axes.legend() 239 | axes.set_title('Loss network', fontsize=18) 240 | axes.grid(True) 241 | axes.spines['right'].set_visible(False) 242 | axes.spines['top'].set_visible(False) 243 | fig.tight_layout() 244 | 245 | if save is True: 246 | save_fig(fig, "training-progress", bbox_inches='tight') 247 | 248 | 249 | 250 | # 251 | # --------------- Plotting code for figures paper --------------- 252 | # 253 | 254 | 255 | # 256 | # Figure 2A and Figure 4B, 5C 257 | # 258 | 259 | def bootstrap_post_dynamics(net_list:[ModelState], 260 | test_set:Dataset, 261 | seq_length=10): 262 | err_samples, pred_samples = np.zeros((len(net_list), seq_length-1)), np.zeros((len(net_list), seq_length-1)) 263 | pred_stats, err_stats = dict(), dict() 264 | for i, net in enumerate(net_list): 265 | responses, drives, synaptrans = helper.compute_responses_drive(net, \ 266 | test_set, target=0) 267 | 268 | 269 | pred_units = helper.pred_class_mask(net, test_set, target=0) 270 | error_units, error_units_c = helper.find_error_units(net, test_set) 271 | pred_indices = (pred_units == 0).nonzero().squeeze() 272 | error_indices = (error_units_c[0,:]).nonzero().squeeze() 273 | # remove hybrid units 274 | error_indices = error_indices[~error_indices.unsqueeze(1).eq(pred_indices).any(1)] 275 | # select error & prediction units 276 | preds = synaptrans[pred_indices] 277 | error = synaptrans[error_indices] 278 | # get the average drives of the units 279 | pred_curve = preds.mean(axis=1).mean(axis=0) 280 | error_curve = error.mean(axis=1).mean(axis=0)#.sort(dim=0).values[:len(pred_indices)] 281 | #error_curve = top_error.mean(axis=0) 282 | 283 | 284 | # record curves 285 | pred_samples[i ,:] = pred_curve.cpu().numpy() 286 | err_samples[i ,:] = error_curve.cpu().numpy() 287 | pred_stats['samples'], err_stats['samples'] = pred_samples.mean(axis=0), err_samples.mean(axis=0) 288 | bs_pred, bs_err = helper.compute_post_drive_bootstrap(pred_samples, err_samples) 289 | (h_pred, l_pred), (h_err, l_err) = helper.extract_lower_upper(bs_pred), helper.extract_lower_upper(bs_err) 290 | pred_stats['h_bound'], pred_stats['l_bound'] = h_pred, l_pred 291 | err_stats['h_bound'], err_stats['l_bound'] = h_err, l_err 292 | return pred_stats, err_stats 293 | 294 | 295 | 296 | 297 | 298 | 299 | def display_activity_lossfn(model_results, 300 | lesioned = False, 301 | save=True, 302 | reverse=False, 303 | energy_type='ec', 304 | data_type='mnist'): 305 | """ 306 | visualises energy consumption of networks trained with different 307 | loss functions 308 | """ 309 | data, bootstraps, samples = model_results['l1_pre'] 310 | # get results for activation 311 | _, _, _, bs_net_act, bs_netles_act, bs_netles_rev_act = model_results['l1_post'][1] 312 | _, _, _, net_act_samples, netles_act_samples, netles_act_rev_samples = model_results['l1_post'][-1] 313 | # # get results for activation + weights 314 | # _, _, _, bs_net_weight, bs_netles_weight, bs_netles_rev_weight = model_results['l1_postandl2_weights'][1] 315 | # _, _, _, net_weight_samples, netles_weight_samples, netles_weight_rev_samples = model_results['l1_postandl2_weights'][-1] 316 | 317 | bs_notn, bs_meds, bs_gmed, bs_net, bs_netles, bs_netles_rev = bootstraps 318 | notn_samples, meds_samples, gmed_samples, net_samples, net_les_samples, net_les_samples_rev = samples 319 | # create figure plot mean values and 95% CI 320 | if energy_type == 'ap': 321 | fig, (ax1) = plt.subplots(1, 1, sharex=True) 322 | ax2 = ax1 323 | else: 324 | fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) 325 | fig.subplots_adjust(hspace=0.125) # adjust space between axes 326 | start_index = 0 327 | 328 | if energy_type == 'ec': 329 | start_index = 1 330 | ax2.set_ylim(0.165, 0.20) # RNN_pre 331 | ax1.set_ylim(0.7, 0.79) # RNN_post/RNN_post+weights 332 | if energy_type == 'st': 333 | start_index = 0 334 | ax2.set_ylim(0.2, 0.26) # RNN_pre 335 | ax1.set_ylim(0.5, 6) # RNN_post/RNN_post+weights 336 | 337 | x = np.arange(start_index+1,data.shape[0]+1) 338 | # add l1(preactivation) models 339 | mu_net = np.mean(net_samples, axis=0)[start_index:] # empirical mean of reservoir activity 340 | l1_pre = ax2.plot(x, mu_net, label="RNN_pre", color= '#EE6666') 341 | lower_net, upper_net = helper.extract_lower_upper(bs_net) 342 | ax2.fill_between(x, lower_net[start_index:], upper_net[start_index:], color='#EE6666', alpha=0.3) 343 | #ax2.tick_params(axis='y', labelcolor='#EE6666') 344 | # add l1(act) models 345 | 346 | mu_net_act = np.mean(net_act_samples, axis=0)[start_index:] # empirical mean of reservoir activity 347 | l1_post = ax1.plot(x, mu_net_act, label="RNN_post", color= 'cornflowerblue') 348 | #ax1.tick_params(axis='y', labelcolor='m') 349 | lower_net_act, upper_net_act = helper.extract_lower_upper(bs_net_act) 350 | ax1.fill_between(x, lower_net_act[start_index:], upper_net_act[start_index:], color='cornflowerblue', alpha=0.3) 351 | 352 | # add l1(post) + l2(weights) models 353 | # mu_net_weight = np.mean(net_weight_samples, axis=0)[start_index:] # empirical mean of reservoir activity 354 | # l1l2_postW = ax1.plot(x, mu_net_weight, linestyle='--', label="RNN_post+weights", color= 'cyan') 355 | # lower_net_weight, upper_net_weight = helper.extract_lower_upper(bs_net_weight) 356 | # ax2.fill_between(x, lower_net_weight[start_index:], upper_net_weight[start_index:], color='cyan', alpha=0.3) 357 | 358 | 359 | 360 | 361 | if energy_type == 'ec' or energy_type=='st': 362 | ax1.spines.bottom.set_visible(False) 363 | ax2.spines.top.set_visible(False) 364 | ax1.spines.top.set_visible(False) 365 | #ax1.xaxis.tick_top() 366 | ax1.tick_params(labeltop=False) # don't put tick labels at the top 367 | ax1.tick_params(bottom=False) 368 | 369 | d = .4 # proportion of vertical to horizontal extent of the slanted line 370 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, 371 | linestyle="none", color='k', mec='k', mew=1, clip_on=False) 372 | ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs) 373 | ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs) 374 | 375 | 376 | 377 | if energy_type == 'ec' or energy_type=='st': 378 | h1, l1 = ax1.get_legend_handles_labels() 379 | h2, l2 = ax2.get_legend_handles_labels() 380 | ax1.legend(h1+h2, l1+l2, loc=0) 381 | else: 382 | ax1.legend() 383 | 384 | 385 | ax2.xaxis.set_major_locator(MaxNLocator(integer=True)); 386 | 387 | ax1.grid(True) 388 | ax2.grid(True) 389 | #ax1.spines['right'].set_visible(False) 390 | #ax1.spines['top'].set_visible(False) 391 | #ax1.xaxis.tick_top() 392 | ax1.tick_params(labeltop=False) # don't put tick labels at the top 393 | #ax2.xaxis.tick_bottom() 394 | #ax1.xaxis.set_tick_params(which='major', size=10, width=2, labelsize=8) 395 | #ax1.xaxis.set_tick_params(which='major', size=10, width=2, labelsize=8) 396 | #ax1.yaxis.set_tick_params(which='major', size=10, width=2, labelsize=8) 397 | #ax2.yaxis.set_tick_params(which='major', size=10, width=2, labelsize=8) 398 | 399 | 400 | if save is True: 401 | if lesioned: 402 | save_fig(fig, "energy_curves" + "_"+ energy_type + "_"+data_type+"/lesioned-model-activity", bbox_inches='tight') 403 | else: 404 | save_fig(fig, "energy_curves" + "_"+ energy_type + "_"+data_type+"/model-activity", bbox_inches='tight') 405 | return fig, (ax1, ax2) 406 | 407 | def display_model_activity(model_results, 408 | lesioned = False, 409 | save=True, 410 | reverse=False, 411 | data_type='mnist'): 412 | """ 413 | visualises energy consumption of networks trained with different 414 | loss functions 415 | """ 416 | data, bootstraps, samples = model_results['l1_pre'] 417 | # get results for activation 418 | _, _, _, bs_net_act, bs_netles_act, bs_netles_rev_act = model_results['l1_post'][1] 419 | _, _, _, net_act_samples, netles_act_samples, netles_act_rev_samples = model_results['l1_post'][-1] 420 | # get results for activation + weights 421 | _, _, _, bs_net_weight, bs_netles_weight, bs_netles_rev_weight = model_results['l1_postandl2_weights'][1] 422 | _, _, _, net_weight_samples, netles_weight_samples, netles_weight_rev_samples = model_results['l1_postandl2_weights'][-1] 423 | 424 | bs_notn, bs_meds, bs_gmed, bs_net, bs_netles, bs_netles_rev = bootstraps 425 | notn_samples, meds_samples, gmed_samples, net_samples, net_les_samples, net_les_samples_rev = samples 426 | # create figure plot mean values and 95% CI 427 | fig, axes = plt.subplots(1, figsize=(14,10)) 428 | 429 | x = np.arange(2,data.shape[0]+1) 430 | #mu_gmed = np.mean(gmed_samples, axis=0) # empirical mean of global median 431 | # axes.plot(x, mu_gmed, label="dataset median inhibition", color= '0.2') 432 | # lower_gmed, upper_gmed = helper.extract_lower_upper(bs_gmed) 433 | # axes.fill_between(x, lower_gmed, upper_gmed, color='0.2', alpha=0.3) 434 | # mu_meds = np.mean(meds_samples, axis=0) # empirical mean of category median 435 | # axes.plot(x, mu_meds, label="category median inhibition", color='0.7') 436 | # lower_med, upper_med = helper.extract_lower_upper(bs_meds) 437 | # axes.fill_between(x, lower_med, upper_med, color='0.7', alpha=0.3) 438 | 439 | # add a space between bounds and data that is input-dependent 440 | #axes.plot([],[], linestyle='', label=' ') 441 | 442 | #if not lesioned: # add input drive to the figure 443 | #mu_notn = np.mean(notn_samples, axis=0) # empirical mean of input drive sequences 444 | #axes.plot(x, mu_notn, label="input", color= '#3388BB') 445 | #lower_notn, upper_notn = helper.extract_lower_upper(bs_notn) 446 | #axes.fill_between(x, lower_notn, upper_notn, color='#3388BB', alpha=0.3) 447 | # add l1(preactivation) models 448 | mu_net = np.mean(net_samples, axis=0)[1:] # empirical mean of reservoir activity 449 | l1_pre = axes.plot(x, mu_net, label="RNN_pre", color= '#EE6666') 450 | lower_net, upper_net = helper.extract_lower_upper(bs_net) 451 | axes.fill_between(x, lower_net[1:], upper_net[1:], color='#EE6666', alpha=0.3) 452 | axes.tick_params(axis='y', labelcolor='#EE6666') 453 | # add l1(act) models 454 | axes2= axes.twinx() 455 | mu_net_act = np.mean(net_act_samples, axis=0)[1:] # empirical mean of reservoir activity 456 | l1_post = axes2.plot(x, mu_net_act, label="RNN_post", color= 'm') 457 | axes2.tick_params(axis='y', labelcolor='m') 458 | 459 | lower_net_act, upper_net_act = helper.extract_lower_upper(bs_net_act) 460 | axes2.fill_between(x, lower_net_act[1:], upper_net_act[1:], color='m', alpha=0.3) 461 | 462 | # add l1(post) + l2(weights) models 463 | mu_net_weight = np.mean(net_weight_samples, axis=0)[1:] # empirical mean of reservoir activity 464 | l1l2_postW = axes2.plot(x, mu_net_weight, linestyle='--', label="RNN_post+weights", color= 'm') 465 | lower_net_weight, upper_net_weight = helper.extract_lower_upper(bs_net_weight) 466 | #axes2.fill_between(x, lower_net_weight[1:], upper_net_weight[1:], color='m', alpha=0.1) 467 | 468 | 469 | 470 | if lesioned: # add lesioned reservoir to the figure 471 | mu_netles = np.mean(net_les_samples, axis=0) # empirical mean of sample set 472 | axes.plot(x, mu_netles, label="prediction units lesioned", color= '#EECC55') 473 | lower_netles, upper_netles = helper.extract_lower_upper(bs_netles) 474 | axes.fill_between(x, lower_netles, upper_netles, color='#EECC55', alpha=0.3) 475 | if reverse: 476 | mu_netles_rev = np.mean(net_les_samples_rev, axis=0) # empirical mean of sample set 477 | axes.plot(x, mu_netles_rev, linestyle='--', label="error units lesioned", color= '#5efc03') 478 | lower_netles_rev, upper_netles_rev = helper.extract_lower_upper(bs_netles_rev) 479 | 480 | axes.fill_between(x, lower_netles_rev, upper_netles_rev, color='#5efc03', alpha=0.3) 481 | 482 | axes.xaxis.set_major_locator(MaxNLocator(integer=True)); 483 | 484 | axes.legend(fontsize=18,labelspacing=0.1, facecolor='0.95') 485 | 486 | # lns = [l1_pre, l1_post, l1l2_postW] 487 | # labs = [l.get_label() for l in lns] 488 | # axes.legend(lns, labs, loc=0) 489 | h1, l1 = axes.get_legend_handles_labels() 490 | h2, l2 = axes2.get_legend_handles_labels() 491 | axes.legend(h1+h2, l1+l2, loc=0) 492 | axes.grid(True) 493 | axes.spines['right'].set_visible(False) 494 | axes.spines['top'].set_visible(False) 495 | axes.xaxis.set_tick_params(which='major', size=10, width=2, labelsize=16) 496 | axes.yaxis.set_tick_params(which='major', size=10, width=2, labelsize=16) 497 | axes2.xaxis.set_tick_params(which='major', size=10, width=2, labelsize=16) 498 | axes2.yaxis.set_tick_params(which='major', size=10, width=2, labelsize=16) 499 | 500 | 501 | if save is True: 502 | if lesioned: 503 | save_fig(fig, "preactivation_curves" + data_type+"/lesioned-model-activity", bbox_inches='tight') 504 | else: 505 | save_fig(fig, "preactivation_curves" + data_type+"/model-activity", bbox_inches='tight') 506 | return fig, axes,axes2 507 | 508 | # 509 | # Figure 2C 510 | # 511 | def example_sequence_state(net:ModelState, dataset:Dataset, latent=False, seed=2553, save=False): 512 | """ 513 | visualises input and internal drive for a sample sequence 514 | """ 515 | if seed != None: 516 | torch.manual_seed(seed) 517 | np.random.seed(seed) 518 | 519 | batches, _ = dataset.create_batches(batch_size=-1, sequence_length=10, shuffle=False) 520 | 521 | ex_seq = batches[0,:,:,:] 522 | input_size = ex_seq.shape[-1] # make sure we only visualize input units and no latent resources 523 | X = []; P = []; H=[]; T=[]; L=[] 524 | 525 | h = net.model.init_state(ex_seq.shape[1]) 526 | 527 | for x in ex_seq: 528 | 529 | p = net.predict(h, latent) 530 | h, l_a = net.model(x, state=h) 531 | #x_mu, p_mu = x[:,:input_size].mean(dim=0), p[:,:input_size].mean(dim=0) 532 | #x_std, p_std = x[:, :input_size].std(dim=0), p[:, :input_size].std(dim=0) 533 | X.append(x[0,:input_size].detach().cpu()) 534 | P.append(p[0,:input_size].detach().cpu()) 535 | H.append(h[0,:input_size].detach().cpu()) 536 | T.append(l_a[0][0,:input_size].detach().cpu()) 537 | # standardize input and internal drive so that they're on the same scale 538 | # x_scaled = (x[0,:input_size]- x_mu) / x_std 539 | # p_scaled = (p[0,:input_size]- p_mu) / p_std 540 | # x_scaled[torch.isnan(x_scaled)]=0 541 | # p_scaled[torch.isnan(p_scaled)]=0 542 | #t = x_scaled.detach().cpu() + p_scaled.detach().cpu() 543 | 544 | if latent: # look at latent unit drive 545 | L.append(p[:,input_size:].mean(dim=0).detach().cpu()) 546 | 547 | # fig = plt.figure(figsize=(3,3)) 548 | # if latent: 549 | # fig, axes = display(X+P+L, shape=(10,3), figsize=(3,3), axes_visible=False, layout='tight') 550 | # else: 551 | # fig, axes = display(X+P+H+T, shape=(10,3), figsize=(3,3), axes_visible=False, layout='tight') 552 | 553 | 554 | # if save is True: 555 | # save_fig(fig, "example_sequence_state", bbox_inches='tight') 556 | return X, P, H, T 557 | 558 | # 559 | # Figure 3B/5B 560 | # 561 | def pixel_variance(pix_var): 562 | """Plot variance of each pixel and channel""" 563 | vmi, vma = 0, pix_var.max() 564 | fig, ax = plt.subplots(1, 1) 565 | im = ax.imshow(pix_var, vmin=vmi, vmax=vma, cmap='gray') 566 | ax.grid(False) 567 | fig.colorbar(im) 568 | return fig 569 | 570 | def topographic_distribution(type_mask): 571 | """ 572 | 573 | plots topographic distribution of prediction and error units in 574 | data space. 575 | 576 | """ 577 | 578 | import seaborn as sns 579 | from matplotlib.colors import ListedColormap 580 | blues = ["#3399ff", "#0000ff"] # pure error 581 | reds = ["#ff9999","#ff0000"] # pure prediction 582 | browns = ["#ffff00ff","#ffaa00ff"] # hybrid 583 | grey= ["#cccccc"] # unspecified 584 | combined = blues + reds + browns + grey 585 | cmap = ListedColormap(sns.color_palette(combined).as_hex()) 586 | 587 | if len(type_mask.shape) > 2: # channel dim exists 588 | fig,ax = plt.subplots(1, 3) 589 | nc, nx, ny = type_mask.shape 590 | for c in range(nc): 591 | ax[c].imshow(type_mask[c], cmap=cmap) 592 | ax[c].grid(False) 593 | else: 594 | fig,ax = plt.subplots(1, 1) 595 | ax.imshow(type_mask, cmap=cmap) 596 | ax.grid(False) 597 | 598 | return fig 599 | 600 | 601 | 602 | 603 | # 604 | # Appendix A Figures A1 & A2 & A3 605 | # 606 | 607 | def pred_after_timestep(net, dataset, mask=None, digits=[0], seed=2553): 608 | """ 609 | visualises internal drive after 0-9 preceding frames. 610 | """ 611 | if seed != None: 612 | torch.manual_seed(seed) 613 | np.random.seed(seed) 614 | 615 | 616 | 617 | imgs= [] 618 | ntime=10 619 | nunits = net.model.input_size 620 | 621 | 622 | for digit in digits: 623 | imgs = imgs + [net.predict(torch.zeros(1,nunits))] +\ 624 | [net.predict(helper._run_seq_from_digit(digit, i, net, dataset, mask=mask)).mean(dim=0) for i in range(1,ntime)] 625 | 626 | fig, axes = display(imgs, shape=(ntime, len(digits)), axes_visible=False) 627 | 628 | 629 | fig.tight_layout() 630 | return fig, axes 631 | 632 | def pred_after_timestep_predonly(net, dataset, mask, pred_mask, digits=[0], seed=2553): 633 | """ 634 | visualises internal drive after 0-9 preceding frames. Where the internal 635 | drive only comes from prediction units. 636 | """ 637 | if seed != None: 638 | torch.manual_seed(seed) 639 | np.random.seed(seed) 640 | 641 | 642 | 643 | imgs= [] 644 | ntime=10 645 | nunits = net.model.input_size 646 | 647 | 648 | for digit in digits: 649 | imgs = imgs + [net.predict(torch.zeros(1,nunits))] +\ 650 | [net.predict_predonly(helper._run_seq_from_digit(digit, i, net, dataset, mask=mask), pred_mask=pred_mask).mean(dim=0) for i in range(1,ntime)] 651 | 652 | fig, axes = display(imgs, shape=(ntime, len(digits)), axes_visible=False) 653 | 654 | 655 | fig.tight_layout() 656 | return fig, axes 657 | 658 | 659 | def color_code_pred_units(mnist, cifar, save=False): 660 | mnist_net, test_set_m = mnist 661 | cifar_net, test_set_c = cifar 662 | pred_mnist, pred_cifar = torch.zeros(mnist_net.model.W.shape[0]), torch.zeros(cifar_net.model.W.shape[0]) 663 | for target in range(0, 10): 664 | pred_mask_m = helper.pred_class_mask(mnist_net, test_set_m, target=target) 665 | pred_mask_c = helper.pred_class_mask(cifar_net, test_set_c, target=target) 666 | pred_mnist += pred_mask_m 667 | pred_cifar += pred_mask_c 668 | 669 | fig, axes = plt.subplots(1, 2) 670 | cmap_base = 'seismic' 671 | vmin, vmax = 0, 1 672 | cmap = helper.truncate_colormap(cmap_base, vmin, vmax) 673 | im1, im2 = axes[0].imshow(pred_mnist.view(28,28), cmap=cmap), axes[1].imshow(pred_cifar.view(32*32,3)[...,0], cmap=cmap) 674 | axes[0].grid(False); axes[1].grid(False) 675 | fig.subplots_adjust(right=0.8) 676 | cbar_ax = fig.add_axes([0.85, 0.25, 0.05, 0.5]) 677 | fig.colorbar(im2, cax=cbar_ax) 678 | 679 | 680 | 681 | if save is True: 682 | save_fig(fig, "pixel_var_mnist_cifar", bbox_inches='tight') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bootstrapped 2 | matplotlib 3 | numpy 4 | scipy 5 | seaborn 6 | tables 7 | torch 8 | torchvision -------------------------------------------------------------------------------- /results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | time python fig2_network_performance.py 4 | time python fig3_unit_taxonomy.py 5 | time python fig4_lesion_study.py 6 | time python fig5_cifar10_exp.py 7 | -------------------------------------------------------------------------------- /supplement.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Apr 29 09:24:57 2022 5 | 6 | @author: tempali 7 | 8 | Code used to generate the supplemental figures 9 | """ 10 | 11 | # imports 12 | 13 | import torch 14 | import numpy as np 15 | import argparse 16 | import os 17 | import pandas as pd 18 | 19 | parser = argparse.ArgumentParser(description='device') 20 | parser.add_argument('--i', type=str, help='Device index') 21 | args = parser.parse_args() 22 | 23 | if torch.cuda.is_available(): 24 | DEVICE = 'cuda' 25 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 26 | else: 27 | DEVICE = 'cpu' 28 | 29 | 30 | print('Using {}'.format(DEVICE)) 31 | 32 | R_PATH = 'Results/Supl/Data/' 33 | F_PATH = 'Results/Supl/' 34 | M_PATH = 'final_networks/mnist_nets/' 35 | hdf_path = R_PATH+'network_stats.h5' 36 | 37 | LOAD = False 38 | SEED = 2553 39 | if not os.path.isdir(os.path.dirname(R_PATH)): 40 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 41 | if not os.path.isdir(os.path.dirname(F_PATH)): 42 | os.makedirs(os.path.dirname(R_PATH), exist_ok=True) 43 | 44 | if SEED != None: 45 | torch.manual_seed(SEED) 46 | np.random.seed(SEED) 47 | 48 | # set up hdf5 file to store the results 49 | if not os.path.exists(hdf_path): 50 | store = pd.HDFStore(hdf_path) 51 | store.close() 52 | INPUT_SIZE = 28*28 53 | Z_CRIT = 2.576 #99% 54 | 55 | # dataset loaders 56 | import mnist 57 | 58 | # framework files 59 | import Network 60 | import helper 61 | import plot 62 | 63 | 64 | # load data 65 | train_set, validation_set, test_set = mnist.load(val_ratio=0.0) 66 | 67 | # load pre and post MNIST networks 68 | nets = [[], []] 69 | 70 | n_instances = 10 71 | # load networks for bootstrap 72 | losses = ['l1_pre', 'l1_post'] 73 | # set up dictionaries to fill in the data 74 | ec_results, ap_results, st_results = dict(), dict(), dict() 75 | 76 | for loss_ind, loss in enumerate(losses): 77 | for i in range(0, n_instances): 78 | net = Network.State(activation_func=torch.nn.ReLU(), 79 | optimizer=torch.optim.Adam, 80 | lr=1e-4, 81 | input_size=INPUT_SIZE, 82 | hidden_size=INPUT_SIZE, 83 | title=M_PATH+"mnist_net_"+loss, 84 | device=DEVICE) 85 | net.load(i) 86 | nets[loss_ind].append(net) 87 | net = nets[0][0] 88 | 89 | 90 | 91 | # fig A1 & A2: plot digit predictions and median MNIST digit 92 | digits = list(range(0, 10)) 93 | #------------------------------------------------------------------------------ 94 | fig, ax = plot.pred_after_timestep(net, test_set, mask=None, digits=digits, seed=2553) 95 | plot.save_fig(fig, F_PATH+"A1", bbox_inches='tight') 96 | #------------------------------------------------------------------------------ 97 | #fig A2: plot lesioned predictions + median MNIST digit 98 | pred_mask = helper._pred_mask(net, test_set, train_set) 99 | fig, ax = plot.pred_after_timestep(net, test_set, mask=pred_mask, digits=digits, seed=2553) 100 | plot.save_fig(fig, F_PATH+"A2", bbox_inches='tight') 101 | fig, ax = plot.display(train_set.x.median(dim=0).values, axes_visible=False) 102 | plot.save_fig(fig, F_PATH+"A2_med", bbox_inches='tight') 103 | #------------------------------------------------------------------------------ 104 | #fig A3: plot class specific lesioned predictions for each digit 105 | masks = [] 106 | target = 7 107 | c_mask = helper.pred_class_mask(net, test_set, target=target, Z_crit=Z_CRIT) 108 | type_mask, type_stats = helper.compute_unit_types(net, test_set, train_set) 109 | plot.topographic_distribution(type_mask.reshape(28,28)) 110 | fig, ax = plot.pred_after_timestep(net, test_set, mask=c_mask, digits=digits, seed=2553) 111 | plot.save_fig(fig, F_PATH+"A3", bbox_inches='tight') 112 | #------------------------------------------------------------------------------ 113 | # fig A4: topo dist untrained network 114 | untrained_net = Network.State(activation_func=torch.nn.ReLU(), 115 | optimizer=torch.optim.Adam, 116 | lr=1e-4, 117 | input_size=INPUT_SIZE, 118 | hidden_size=INPUT_SIZE, 119 | title='', 120 | device=DEVICE) 121 | type_mask, type_stats = helper.compute_unit_types(untrained_net, test_set, train_set) 122 | fig = plot.topographic_distribution(type_mask.reshape(28, 28)) 123 | plot.save_fig(fig, F_PATH+"A4", bbox_inches='tight') -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #### Imports 2 | import torch 3 | import functions 4 | import mnist 5 | import Network 6 | import plot 7 | 8 | import cifar 9 | 10 | 11 | #### Load datasets 12 | if torch.cuda.is_available(): 13 | DEVICE = 'cuda' 14 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 15 | else: 16 | DEVICE = 'cpu' 17 | 18 | print('Using {}'.format(DEVICE)) 19 | 20 | INPUT_SIZE_MNIST = 28*28 21 | INPUT_SIZE_CIFAR = 32*32*3 22 | BATCH_SIZE = 32 23 | SEQ_LENGTH = 10 24 | LOSS_FN = functions.L1Loss 25 | # load MNIST 26 | training_set_m, validation_set_m, test_set_m = mnist.load(val_ratio=0.0) 27 | # load CIFAR10 28 | training_set_c, validation_set_c, test_set_c = cifar.load(val_ratio=0.0, color=True) 29 | 30 | #### Load trained networks for MNIST & CIFAR10 31 | mnist_nets = [] 32 | cifar_nets = [] 33 | n_instances = 10 34 | # load networks for bootstrap 35 | for i in range(0, n_instances): 36 | mnist_net = Network.State(activation_func=torch.nn.ReLU(), 37 | optimizer=torch.optim.Adam, 38 | lr=1e-4, 39 | input_size=INPUT_SIZE_MNIST, 40 | hidden_size=INPUT_SIZE_MNIST, 41 | title="networks/mnist_networks/mnist_net", 42 | device=DEVICE) 43 | mnist_net.load(i) 44 | mnist_nets.append(mnist_net) 45 | 46 | cifar_net = Network.State(activation_func=torch.nn.ReLU(), 47 | optimizer=torch.optim.Adam, 48 | lr=1e-4, 49 | input_size=INPUT_SIZE_CIFAR, 50 | hidden_size=INPUT_SIZE_CIFAR, 51 | title="networks/cifar_networks/cifar_net", 52 | device=DEVICE) 53 | cifar_net.load(i) 54 | cifar_nets.append(cifar_net) 55 | 56 | #### Figure 4B: evolution of preactivation for model, lesioned model, control & three benchmarks¶ 57 | plot.bootstrap_model_activity(mnist_nets, training_set_m, test_set_m, seed=None, lesioned=True, save=True, data_type='mnist') 58 | 59 | #### Figure 5C: evolution of preactivation for model, lesioned model, control & three benchmarks for CIFAR10 60 | plot.bootstrap_model_activity(cifar_nets, training_set_c, test_set_c, seed=None, lesioned=True, save=True, data_type='cifar') 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from typing import Callable 4 | import functions 5 | from ModelState import ModelState 6 | from Dataset import Dataset 7 | 8 | def test_epoch(ms: ModelState, 9 | dataset: Dataset, 10 | loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], 11 | batch_size: int, 12 | sequence_length: int): 13 | batches, labels = dataset.create_batches(batch_size=batch_size, sequence_length=sequence_length, shuffle=True) 14 | num_batches = batches.shape[0] 15 | batch_size = batches.shape[2] 16 | tot_loss = 0 17 | tot_res = None 18 | state = None 19 | for i, batch in enumerate(batches): 20 | 21 | with torch.no_grad(): 22 | loss, res, state = test_batch(ms, batch, loss_fn, state) 23 | 24 | tot_loss += loss 25 | 26 | if tot_res is None: 27 | tot_res = res 28 | else: 29 | tot_res += res 30 | tot_loss /= num_batches 31 | tot_res /= num_batches 32 | print("Test loss: {:.8f}".format(tot_loss)) 33 | return tot_loss, tot_res 34 | 35 | def test_batch(ms: ModelState, 36 | batch: torch.FloatTensor, 37 | loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], 38 | state) -> float: 39 | loss, res, state = ms.run(batch, loss_fn, state) 40 | return loss.item(), res, state 41 | 42 | def train_batch(ms: ModelState, 43 | batch: torch.FloatTensor, 44 | loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], 45 | state) -> float: 46 | 47 | loss, res, state = ms.run(batch, loss_fn, state) 48 | 49 | ms.step(loss) 50 | ms.zero_grad() 51 | return loss.item(), res, state 52 | 53 | def train_epoch(ms: ModelState, 54 | dataset: Dataset, 55 | loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], 56 | batch_size: int, 57 | sequence_length: int, 58 | verbose = True) -> float: 59 | 60 | batches, labels = dataset.create_batches(batch_size=batch_size, sequence_length=sequence_length, shuffle=True) 61 | num_batches = batches.shape[0] 62 | batch_size = batches.shape[2] 63 | 64 | t = functions.Timer() 65 | tot_loss = 0. 66 | tot_res = None 67 | state = None 68 | for i, batch in enumerate(batches): 69 | 70 | loss, res, state = train_batch(ms, batch, loss_fn, state) 71 | tot_loss += loss 72 | 73 | if tot_res is None: 74 | tot_res = res 75 | else: 76 | tot_res += res 77 | 78 | if verbose and (i+1) % int(num_batches/10) == 0: 79 | dt = t.get(); t.lap() 80 | print("Batch {}/{}, ms/batch: {}, loss: {:.5f}".format(i, num_batches, dt / (num_batches/10), tot_loss/(i))) 81 | 82 | tot_loss /= num_batches 83 | tot_res /= num_batches 84 | 85 | 86 | print("Training loss: {:.8f}".format(tot_loss)) 87 | 88 | return tot_loss, tot_res, state.detach() 89 | 90 | def train(ms: ModelState, 91 | train_ds: Dataset, 92 | test_ds: Dataset, 93 | loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], 94 | num_epochs: int = 1, 95 | batch_size: int = 32, 96 | sequence_length: int = 3, 97 | patience: int = 200, 98 | verbose = False): 99 | ms_name = ms.title.split('/')[-1] 100 | best_epoch = 0; tries = 0 101 | best_loss = sys.float_info.max 102 | best_network = None 103 | 104 | for epoch in range(ms.epochs+1, ms.epochs+1 + num_epochs): 105 | print("Epoch {}, Lossfn {}".format(epoch, ms_name)) 106 | 107 | train_loss, train_res, h = train_epoch(ms, train_ds, loss_fn, batch_size, sequence_length, verbose=verbose) 108 | 109 | test_loss, test_res = test_epoch(ms, test_ds, loss_fn, batch_size, sequence_length) 110 | # if epoch == 1 or epoch == num_epochs - 10: 111 | # W = ms.model.W.detach() 112 | # torch.save(W, 'models/'+ms.title+'W_'+ str(epoch)+'.pt') 113 | h, W_l1, W_l2 = functions.L1Loss(h), functions.L1Loss(ms.model.W.detach()), functions.L2Loss(ms.model.W.detach()) 114 | m_state = [[h.cpu().numpy()], [W_l1.cpu().numpy()], [W_l2.cpu().numpy()]] 115 | ms.on_results(epoch, train_res, test_res, m_state) 116 | 117 | if (test_loss < best_loss): 118 | best_loss = test_loss 119 | best_epoch = epoch 120 | best_network = ms.model.W.detach() 121 | tries = 0 122 | else: 123 | print("Loss did not improve from", best_loss) 124 | tries = tries + 1 125 | if (tries >= patience): 126 | print("Stopping early") 127 | break; 128 | -------------------------------------------------------------------------------- /train_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cifar 3 | import mnist 4 | import Network 5 | import random 6 | import torch 7 | from functions import get_device 8 | from train import train 9 | 10 | parser = argparse.ArgumentParser(description='device') 11 | parser.add_argument('--i', type=str, help='Device index') 12 | args = parser.parse_args() 13 | 14 | DEVICE = get_device() 15 | 16 | INPUT_SIZE= 28*28 17 | BATCH_SIZE = 32 18 | SEQ_LENGTH = 10 19 | 20 | # dataset loaders 21 | 22 | train_set, validation_set, test_set = mnist.load(val_ratio=0.0) 23 | 24 | """ 25 | Create and train ten instances of energy efficient RNNs for MNIST 26 | """ 27 | n_instances = 10 # number of model instances 28 | #losses = [str(beta)+'beta'+'l1_postandl2_weights' for beta in [3708.0] ] 29 | losses = ['l1_pre', 'l1_post', [str(beta)+'beta'+'l1_postandl2_weights' for beta in [3708.0]][0]] 30 | seeds = [[random.randint(0,10000) for i in range(n_instances)] for j in range(len(losses))] 31 | #seeds = [[random.randint(0,10000) for i in range(n_instances)]] 32 | # train MNIST networks 33 | 34 | for loss_ind, loss in enumerate(losses): 35 | for i in range(0, n_instances): 36 | print("loss", loss_ind, "instance", i) 37 | net = Network.State(activation_func=torch.nn.ReLU(), 38 | optimizer=torch.optim.Adam, 39 | lr=1e-4, 40 | input_size=INPUT_SIZE, 41 | hidden_size=INPUT_SIZE, 42 | title="patterns_rev/seeded_mnist/mnist_net_"+loss+"_"+str(i), 43 | device=DEVICE, 44 | seed=seeds[loss_ind][i]) 45 | 46 | train(net, 47 | train_ds=train_set, 48 | test_ds=test_set, 49 | loss_fn=loss, 50 | num_epochs=200, 51 | batch_size=BATCH_SIZE, 52 | sequence_length=SEQ_LENGTH, 53 | verbose=False) 54 | 55 | # # save model 56 | net.save() 57 | 58 | """ 59 | Create and train ten instances of energy efficient RNNs for CIFAR10 60 | """ 61 | INPUT_SIZE = 3072 62 | HIDDEN_SIZE = 3072 # add 32 to this number if you want to have extra latent resources 63 | BATCH_SIZE = 32 64 | SEQ_LENGTH = 10 65 | LOSS_FN = 'l1_pre' 66 | 67 | training_set, validation_set, test_set = cifar.load(val_ratio=0.0, color=True) 68 | 69 | """ 70 | Create and train ten instances of energy efficient RNNs with cifar 10 71 | # """ 72 | N = 10 # number of model instances per seed 73 | 74 | seeds = [random.randint(0,10000) for i in range(N)] 75 | 76 | for i in range(N): 77 | 78 | cifar_net= Network.State(activation_func=torch.nn.ReLU(), 79 | optimizer=torch.optim.Adam, 80 | lr=1e-4, 81 | input_size=INPUT_SIZE, 82 | hidden_size=HIDDEN_SIZE, 83 | title="/final_networks/seeded_cifar_nets/cifar_net_"+str(i), 84 | device=DEVICE, 85 | seed=seeds[i]) 86 | 87 | 88 | cifar_net.save() 89 | 90 | train(cifar_net, 91 | train_ds=training_set, 92 | test_ds=test_set, 93 | loss_fn=LOSS_FN, 94 | num_epochs=1000, 95 | batch_size=BATCH_SIZE, 96 | sequence_length=SEQ_LENGTH, 97 | verbose=False 98 | ) 99 | ## save model 100 | cifar_net.save() 101 | --------------------------------------------------------------------------------