├── README.md ├── main.py ├── neural_model.py ├── deep_nfa_env.yml ├── trainer.py ├── verify_deep_NFA.py └── dataset.py /README.md: -------------------------------------------------------------------------------- 1 | Code for verifying Deep Neural Feature Ansatz from https://arxiv.org/abs/2212.13881 2 | 3 | # Training Neural Networks 4 | Code for fully connected networks on all datasets considered in the paper is available in main.py. This code will try to save neural networks to a directory named saved_nns. 5 | 6 | # Verifying Ansatz 7 | Code for verifying the ansatz is available in verify_deep_NFA.py. It will require you to load a saved neural net and a corresponding dataset. 8 | 9 | # Software 10 | Software versions are available via the file deep_nfa_env.yml. This code primarily requires pytorch version 1.13 with functorch (installed via pip). 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import dataset 5 | import trainer 6 | 7 | SEED = 1717 8 | 9 | torch.manual_seed(SEED) 10 | random.seed(SEED) 11 | np.random.seed(SEED) 12 | torch.cuda.manual_seed(SEED) 13 | 14 | 15 | def get_name(dataset_name, configs): 16 | name_str = dataset_name + ':' 17 | for key in configs: 18 | name_str += key + ':' + str(configs[key]) + ':' 19 | name_str += 'nn' 20 | 21 | return name_str 22 | 23 | def main(): 24 | 25 | # Pick configs to save model 26 | configs = {} 27 | configs['num_epochs'] = 500 28 | configs['learning_rate'] = .1 29 | configs['weight_decay'] = 0 30 | configs['init'] = 'default' 31 | configs['optimizer'] = 'sgd' 32 | configs['freeze'] = False 33 | configs['width'] = 1024 34 | configs['depth'] = 3 35 | configs['act'] = 'relu' 36 | 37 | # Code to load and train net on selected dataset. 38 | # Datasets used in paper are in dataset.py 39 | # SVHN 40 | NUM_CLASSES = 10 41 | trainloader, valloader, testloader = dataset.get_svhn() 42 | trainer.train_network(trainloader, valloader, testloader, NUM_CLASSES, 43 | name=get_name('svhn', configs), configs=configs) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /neural_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Variable, Function 5 | import torch.optim as optim 6 | from torchvision import models 7 | from torch.nn.functional import upsample 8 | from copy import deepcopy 9 | import torch.nn.functional as F 10 | 11 | 12 | class Nonlinearity(torch.nn.Module): 13 | def __init__(self, name='relu'): 14 | super(Nonlinearity, self).__init__() 15 | self.name = name 16 | 17 | def forward(self, x): 18 | if self.name == 'relu': 19 | return F.relu(x) 20 | if self.name == 'sigmoid': 21 | return torch.sigmoid(x) 22 | if self.name == 'leaky_relu': 23 | return F.leaky_relu(x) 24 | if self.name == 'sine': 25 | return torch.sin(x) 26 | if self.name == 'tanh': 27 | return torch.tanh(x) 28 | 29 | 30 | class Net(nn.Module): 31 | 32 | def __init__(self, dim, depth=1, width=1024, num_classes=2, act_name='relu'): 33 | super(Net, self).__init__() 34 | bias = False 35 | self.dim = dim 36 | self.width = width 37 | self.depth = depth 38 | self.name = act_name 39 | 40 | if depth == 1: 41 | self.first = nn.Linear(dim, width, bias=bias) 42 | self.fc = nn.Sequential(Nonlinearity(name=self.name), 43 | nn.Linear(width, num_classes, bias=bias)) 44 | else: 45 | module = nn.Sequential(Nonlinearity(name=self.name), 46 | nn.Linear(width, width, bias=bias)) 47 | num_layers = depth - 1 48 | self.first = nn.Sequential(nn.Linear(dim, width, 49 | bias=bias)) 50 | self.middle = nn.ModuleList([deepcopy(module) \ 51 | for idx in range(num_layers)]) 52 | 53 | self.last = nn.Sequential(Nonlinearity(name=self.name), 54 | nn.Linear(width, num_classes, 55 | bias=bias)) 56 | 57 | 58 | def forward(self, x): 59 | if self.depth == 1: 60 | return self.fc(self.first(x)) 61 | else: 62 | o = self.first(x) 63 | for idx, m in enumerate(self.middle): 64 | o = m(o) 65 | o = self.last(o) 66 | return o 67 | -------------------------------------------------------------------------------- /deep_nfa_env.yml: -------------------------------------------------------------------------------- 1 | name: pytorch1.13 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli=1.0.9=h5eee18b_7 11 | - brotli-bin=1.0.9=h5eee18b_7 12 | - brotlipy=0.7.0=py37h27cfd23_1003 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2023.01.10=h06a4308_0 15 | - certifi=2022.12.7=py37h06a4308_0 16 | - cffi=1.15.1=py37h5eee18b_3 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - cryptography=39.0.1=py37h9ce1e76_0 19 | - cuda-cudart=11.7.99=0 20 | - cuda-cupti=11.7.101=0 21 | - cuda-libraries=11.7.1=0 22 | - cuda-nvrtc=11.7.99=0 23 | - cuda-nvtx=11.7.91=0 24 | - cuda-runtime=11.7.1=0 25 | - cycler=0.11.0=pyhd3eb1b0_0 26 | - dbus=1.13.18=hb2f20db_0 27 | - expat=2.4.9=h6a678d5_0 28 | - ffmpeg=4.3=hf484d3e_0 29 | - fftw=3.3.9=h27cfd23_1 30 | - fontconfig=2.14.1=h4c34cd2_2 31 | - fonttools=4.25.0=pyhd3eb1b0_0 32 | - freetype=2.12.1=h4a9f257_0 33 | - giflib=5.2.1=h5eee18b_3 34 | - glib=2.69.1=he621ea3_2 35 | - gmp=6.2.1=h295c915_3 36 | - gnutls=3.6.15=he1e5248_0 37 | - gst-plugins-base=1.14.1=h6a678d5_1 38 | - gstreamer=1.14.1=h5eee18b_1 39 | - icu=58.2=he6710b0_3 40 | - idna=3.4=py37h06a4308_0 41 | - intel-openmp=2021.4.0=h06a4308_3561 42 | - joblib=1.1.1=py37h06a4308_0 43 | - jpeg=9e=h5eee18b_1 44 | - kiwisolver=1.4.4=py37h6a678d5_0 45 | - krb5=1.19.4=h568e23c_0 46 | - lame=3.100=h7b6447c_0 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.38=h1181459_1 49 | - lerc=3.0=h295c915_0 50 | - libbrotlicommon=1.0.9=h5eee18b_7 51 | - libbrotlidec=1.0.9=h5eee18b_7 52 | - libbrotlienc=1.0.9=h5eee18b_7 53 | - libclang=14.0.6=default_hc6dbbc7_1 54 | - libclang13=14.0.6=default_he11475f_1 55 | - libcublas=11.10.3.66=0 56 | - libcufft=10.7.2.124=h4fbf590_0 57 | - libcufile=1.6.0.25=0 58 | - libcurand=10.3.2.56=0 59 | - libcusolver=11.4.0.1=0 60 | - libcusparse=11.7.4.91=0 61 | - libdeflate=1.17=h5eee18b_0 62 | - libedit=3.1.20221030=h5eee18b_0 63 | - libevent=2.1.12=h8f2d780_0 64 | - libffi=3.4.2=h6a678d5_6 65 | - libgcc-ng=11.2.0=h1234567_1 66 | - libgfortran-ng=11.2.0=h00389a5_1 67 | - libgfortran5=11.2.0=h1234567_1 68 | - libgomp=11.2.0=h1234567_1 69 | - libiconv=1.16=h7f8727e_2 70 | - libidn2=2.3.2=h7f8727e_0 71 | - libllvm14=14.0.6=hdb19cb5_2 72 | - libnpp=11.7.4.75=0 73 | - libnvjpeg=11.8.0.2=0 74 | - libpng=1.6.39=h5eee18b_0 75 | - libpq=12.9=h16c4e8d_3 76 | - libstdcxx-ng=11.2.0=h1234567_1 77 | - libtasn1=4.19.0=h5eee18b_0 78 | - libtiff=4.5.0=h6a678d5_2 79 | - libunistring=0.9.10=h27cfd23_0 80 | - libuuid=1.41.5=h5eee18b_0 81 | - libwebp=1.2.4=h11a3e52_1 82 | - libwebp-base=1.2.4=h5eee18b_1 83 | - libxcb=1.15=h7f8727e_0 84 | - libxkbcommon=1.0.1=h5eee18b_1 85 | - libxml2=2.10.3=hcbfbd50_0 86 | - libxslt=1.1.37=h2085143_0 87 | - lz4-c=1.9.4=h6a678d5_0 88 | - matplotlib=3.5.3=py37h06a4308_0 89 | - matplotlib-base=3.5.3=py37hf590b9c_0 90 | - mkl=2021.4.0=h06a4308_640 91 | - mkl-service=2.4.0=py37h7f8727e_0 92 | - mkl_fft=1.3.1=py37hd3c417c_0 93 | - mkl_random=1.2.2=py37h51133e4_0 94 | - munkres=1.1.4=py_0 95 | - ncurses=6.4=h6a678d5_0 96 | - nettle=3.7.3=hbbd107a_1 97 | - nspr=4.33=h295c915_0 98 | - nss=3.74=h0370c37_0 99 | - numpy=1.21.5=py37h6c91a56_3 100 | - numpy-base=1.21.5=py37ha15fc14_3 101 | - openh264=2.1.1=h4ff587b_0 102 | - openssl=1.1.1t=h7f8727e_0 103 | - packaging=22.0=py37h06a4308_0 104 | - pcre=8.45=h295c915_0 105 | - pillow=9.4.0=py37h6a678d5_0 106 | - pip=22.3.1=py37h06a4308_0 107 | - ply=3.11=py37_0 108 | - pycparser=2.21=pyhd3eb1b0_0 109 | - pyopenssl=23.0.0=py37h06a4308_0 110 | - pyparsing=3.0.9=py37h06a4308_0 111 | - pyqt=5.15.7=py37h6a678d5_1 112 | - pyqt5-sip=12.11.0=py37h6a678d5_1 113 | - pysocks=1.7.1=py37_1 114 | - python=3.7.16=h7a1cb2a_0 115 | - python-dateutil=2.8.2=pyhd3eb1b0_0 116 | - pytorch=1.13.1=py3.7_cuda11.7_cudnn8.5.0_0 117 | - pytorch-cuda=11.7=h778d358_3 118 | - pytorch-mutex=1.0=cuda 119 | - qt-main=5.15.2=h8373d8f_8 120 | - qt-webengine=5.15.9=hbbf29b9_6 121 | - qtwebkit=5.212=h3fafdc1_5 122 | - readline=8.2=h5eee18b_0 123 | - requests=2.28.1=py37h06a4308_0 124 | - scikit-learn=1.0.2=py37h51133e4_1 125 | - scipy=1.7.3=py37h6c91a56_2 126 | - setuptools=65.6.3=py37h06a4308_0 127 | - sip=6.6.2=py37h6a678d5_0 128 | - six=1.16.0=pyhd3eb1b0_1 129 | - sqlite=3.41.2=h5eee18b_0 130 | - threadpoolctl=2.2.0=pyh0d69192_0 131 | - tk=8.6.12=h1ccaba5_0 132 | - toml=0.10.2=pyhd3eb1b0_0 133 | - torchaudio=0.13.1=py37_cu117 134 | - torchvision=0.14.1=py37_cu117 135 | - tornado=6.2=py37h5eee18b_0 136 | - typing_extensions=4.3.0=py37h06a4308_0 137 | - urllib3=1.26.14=py37h06a4308_0 138 | - wheel=0.38.4=py37h06a4308_0 139 | - xz=5.2.10=h5eee18b_1 140 | - zlib=1.2.13=h5eee18b_0 141 | - zstd=1.5.5=hc292b87_0 142 | - pip: 143 | - functorch==1.13.1 144 | - h5py==3.8.0 145 | - hickle==5.0.2 146 | - jsonpatch==1.32 147 | - jsonpointer==2.3 148 | - kaggle==1.5.13 149 | - networkx==2.6.3 150 | - patchify==0.2.3 151 | - python-slugify==8.0.1 152 | - text-unidecode==1.3 153 | - tqdm==4.65.0 154 | - visdom==0.2.4 155 | - websocket-client==1.5.1 156 | prefix: /home/aradha/anaconda3/envs/pytorch1.13 157 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.optim as optim 4 | import time 5 | import neural_model 6 | import numpy as np 7 | from sklearn.metrics import r2_score 8 | 9 | 10 | def select_optimizer(name, lr, net, weight_decay): 11 | if name == 'sgd': 12 | return torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay) 13 | elif name == 'adam': 14 | return torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) 15 | 16 | 17 | def train_network(train_loader, val_loader, test_loader, num_classes, 18 | name='default_nn', configs=None, regression=False): 19 | 20 | for idx, batch in enumerate(train_loader): 21 | inputs, labels = batch 22 | _, dim = inputs.shape 23 | break 24 | 25 | if configs is not None: 26 | num_epochs = configs['num_epochs'] + 1 27 | net = neural_model.Net(dim, width=configs['width'], 28 | depth=configs['depth'], 29 | num_classes=num_classes, 30 | act_name=configs['act']) 31 | 32 | if configs['init'] != 'default': 33 | for idx, param in enumerate(net.parameters()): 34 | if idx == 0: 35 | init = torch.Tensor(param.size()).normal_().float() * configs['init'] 36 | param.data = init 37 | 38 | if configs['freeze']: 39 | for idx, param in enumerate(net.parameters()): 40 | if idx > 0: 41 | param.requires_grad = False 42 | 43 | 44 | optimizer = select_optimizer(configs['optimizer'], 45 | configs['learning_rate'], 46 | net, 47 | configs['weight_decay']) 48 | else: 49 | num_epochs = 501 50 | net = neural_model.Net(dim, num_classes=num_classes) 51 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 52 | 53 | d = {} 54 | d['state_dict'] = net.state_dict() 55 | if name is not None: 56 | torch.save(d, 'saved_nns/init_' + name + '.pth') 57 | 58 | 59 | net.cuda() 60 | best_val_acc = 0 61 | best_test_acc = 0 62 | best_val_loss = np.float("inf") 63 | best_test_loss = 0 64 | 65 | for i in range(num_epochs): 66 | 67 | train_loss = train_step(net, optimizer, train_loader) 68 | val_loss = val_step(net, val_loader) 69 | test_loss = val_step(net, test_loader) 70 | if regression: 71 | train_acc = get_r2(net, train_loader) 72 | val_acc = get_r2(net, val_loader) 73 | test_acc = get_r2(net, test_loader) 74 | else: 75 | train_acc = get_acc(net, train_loader) 76 | val_acc = get_acc(net, val_loader) 77 | test_acc = get_acc(net, test_loader) 78 | 79 | if val_acc >= best_val_acc: 80 | best_val_acc = val_acc 81 | best_test_acc = test_acc 82 | net.cpu() 83 | d = {} 84 | d['state_dict'] = net.state_dict() 85 | if name is not None: 86 | torch.save(d, 'saved_nns/' + name + '.pth') 87 | net.cuda() 88 | 89 | if val_loss <= best_val_loss: 90 | best_val_loss = val_loss 91 | best_test_loss = test_loss 92 | 93 | print("Epoch: ", i, 94 | "Train Loss: ", train_loss, "Test Loss: ", test_loss, 95 | "Train Acc: ", train_acc, "Test Acc: ", test_acc, 96 | "Best Val Acc: ", best_val_acc, "Best Val Loss: ", best_val_loss, 97 | "Best Test Acc: ", best_test_acc, "Best Test Loss: ", best_test_loss) 98 | 99 | net.cpu() 100 | 101 | d = {} 102 | d['state_dict'] = net.state_dict() 103 | torch.save(d, 'saved_nns/' + name + '_final.pth') 104 | return train_acc, best_val_acc, best_test_acc 105 | 106 | def train_step(net, optimizer, train_loader): 107 | net.train() 108 | start = time.time() 109 | train_loss = 0. 110 | num_batches = len(train_loader) 111 | 112 | for batch_idx, batch in enumerate(train_loader): 113 | optimizer.zero_grad() 114 | inputs, labels = batch 115 | targets = labels 116 | output = net(Variable(inputs).cuda()) 117 | target = Variable(targets).cuda() 118 | loss = torch.mean(torch.pow(output - target, 2)) 119 | loss.backward() 120 | optimizer.step() 121 | train_loss += loss.cpu().data.numpy() * len(inputs) 122 | end = time.time() 123 | print("Time: ", end - start) 124 | train_loss = train_loss / len(train_loader.dataset) 125 | return train_loss 126 | 127 | 128 | def val_step(net, val_loader): 129 | net.eval() 130 | val_loss = 0. 131 | for batch_idx, batch in enumerate(val_loader): 132 | inputs, labels = batch 133 | targets = labels 134 | with torch.no_grad(): 135 | output = net(Variable(inputs).cuda()) 136 | target = Variable(targets).cuda() 137 | loss = torch.mean(torch.pow(output - target, 2)) 138 | val_loss += loss.cpu().data.numpy() * len(inputs) 139 | val_loss = val_loss / len(val_loader.dataset) 140 | return val_loss 141 | 142 | 143 | def get_acc(net, loader): 144 | net.eval() 145 | count = 0 146 | for batch_idx, batch in enumerate(loader): 147 | inputs, targets = batch 148 | with torch.no_grad(): 149 | output = net(Variable(inputs).cuda()) 150 | target = Variable(targets).cuda() 151 | 152 | preds = torch.argmax(output, dim=-1) 153 | labels = torch.argmax(target, dim=-1) 154 | 155 | count += torch.sum(labels == preds).cpu().data.numpy() 156 | return count / len(loader.dataset) * 100 157 | 158 | 159 | def get_r2(net, loader): 160 | net.eval() 161 | count = 0 162 | preds = [] 163 | labels = [] 164 | for batch_idx, batch in enumerate(loader): 165 | inputs, targets = batch 166 | with torch.no_grad(): 167 | output = net(Variable(inputs).cuda()).flatten().cpu().numpy() 168 | target = Variable(targets).cuda().flatten().cpu().numpy() 169 | preds.append(output) 170 | labels.append(target) 171 | preds = np.concatenate(preds) 172 | labels = np.concatenate(labels) 173 | return r2_score(labels, preds) 174 | -------------------------------------------------------------------------------- /verify_deep_NFA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import dataset 5 | import neural_model 6 | from torch.linalg import norm 7 | from functorch import jacrev, vmap 8 | 9 | SEED = 1717 10 | 11 | torch.manual_seed(SEED) 12 | random.seed(SEED) 13 | np.random.seed(SEED) 14 | torch.cuda.manual_seed(SEED) 15 | 16 | def get_name(dataset_name, configs): 17 | name_str = dataset_name 18 | for key in configs: 19 | name_str += key + ':' + str(configs[key] + ':') 20 | name_str += 'nn' 21 | return name_str 22 | 23 | 24 | def load_nn(path, width, depth, dim, num_classes, layer_idx=0, 25 | remove_init=False, act_name='relu'): 26 | 27 | if remove_init: 28 | suffix = path.split('/')[-1] 29 | prefix = './saved_nns/' 30 | 31 | init_net = neural_model.Net(dim, width=width, depth=depth, 32 | num_classes=num_classes, 33 | act_name=act_name) 34 | d = torch.load(prefix + 'init_' + suffix) 35 | init_net.load_state_dict(d['state_dict']) 36 | init_params = [p for idx, p in enumerate(init_net.parameters())] 37 | 38 | net = neural_model.Net(dim, width=width, depth=depth, 39 | num_classes=num_classes, 40 | act_name=act_name) 41 | 42 | d = torch.load(path) 43 | net.load_state_dict(d['state_dict']) 44 | 45 | for idx, p in enumerate(net.parameters()): 46 | if idx == layer_idx: 47 | M = p.data.numpy() 48 | print(M.shape) 49 | if remove_init: 50 | M0 = init_params[idx].data.numpy() 51 | M -= M0 52 | break 53 | 54 | M = M.T @ M * 1/len(M) 55 | 56 | return net, M 57 | 58 | 59 | def load_init_nn(path, width, depth, dim, num_classes, layer_idx=0, act_name='relu'): 60 | suffix = path.split('/')[-1] 61 | prefix = './saved_nns/' 62 | 63 | net = neural_model.Net(dim, width=width, depth=depth, 64 | num_classes=num_classes, act_name=act_name) 65 | d = torch.load(prefix + 'init_' + suffix) 66 | net.load_state_dict(d['state_dict']) 67 | 68 | for idx, p in enumerate(net.parameters()): 69 | if idx == layer_idx: 70 | M = p.data.numpy() 71 | print(M.shape) 72 | break 73 | 74 | M = M.T @ M * 1/len(M) 75 | return net, M 76 | 77 | 78 | 79 | def get_layer_output(net, trainloader, layer_idx=0): 80 | net.eval() 81 | out = [] 82 | for idx, batch in enumerate(trainloader): 83 | data, labels = batch 84 | if layer_idx == 0: 85 | out.append(data.cpu()) 86 | elif layer_idx == 1: 87 | o = neural_model.Nonlinearity()(net.first(data)) 88 | out.append(o.cpu()) 89 | elif layer_idx > 1: 90 | o = net.first(data) 91 | for l_idx, m in enumerate(net.middle): 92 | o = m(o) 93 | if l_idx + 1 == layer_idx: 94 | o = neural_model.Nonlinearity()(o) 95 | out.append(o.cpu()) 96 | break 97 | out = torch.cat(out, dim=0) 98 | net.cpu() 99 | return out 100 | 101 | 102 | def build_subnetwork(net, dim, width, depth, num_classes, 103 | layer_idx=0, random_net=False, act_name='relu'): 104 | 105 | net_ = neural_model.Net(dim, depth=depth - layer_idx, 106 | width=width, num_classes=num_classes, 107 | act_name=act_name) 108 | 109 | params = [p for idx, p in enumerate(net.parameters())] 110 | if not random_net: 111 | for idx, p_ in enumerate(net_.parameters()): 112 | p_.data = params[idx + layer_idx].data 113 | 114 | return net_ 115 | 116 | 117 | def get_jacobian(net, data): 118 | with torch.no_grad(): 119 | return vmap(jacrev(net))(data).transpose(0, 2).transpose(0, 1) 120 | 121 | 122 | def egop(net, dataset, centering=False): 123 | device = torch.device('cuda') 124 | bs = 1000 125 | batches = torch.split(dataset, bs) 126 | net = net.cuda() 127 | G = 0 128 | 129 | Js = [] 130 | for batch_idx, data in enumerate(batches): 131 | data = data.to(device) 132 | print("Computing Jacobian for batch: ", batch_idx, len(batches)) 133 | J = get_jacobian(net, data) 134 | Js.append(J.cpu()) 135 | 136 | # Optional for stopping EGOP computation early 137 | #if batch_idx > 30: 138 | # break 139 | Js = torch.cat(Js, dim=-1) 140 | if centering: 141 | J_mean = torch.mean(Js, dim=-1).unsqueeze(-1) 142 | Js = Js - J_mean 143 | 144 | Js = torch.transpose(Js, 2, 0) 145 | Js = torch.transpose(Js, 1, 2) 146 | print(Js.shape) 147 | batches = torch.split(Js, bs) 148 | for batch_idx, J in enumerate(batches): 149 | print(batch_idx, len(batches)) 150 | m, c, d = J.shape 151 | J = J.cuda() 152 | G += torch.einsum('mcd,mcD->dD', J, J).cpu() 153 | del J 154 | G = G * 1/len(Js) 155 | 156 | return G 157 | 158 | 159 | def correlate(M, G): 160 | M = M.double() 161 | G = G.double() 162 | normM = norm(M.flatten()) 163 | normG = norm(G.flatten()) 164 | 165 | corr = torch.dot(M.flatten(), G.flatten()) / (normM * normG) 166 | return corr 167 | 168 | 169 | def read_configs(path): 170 | tokens = path.strip().split(':') 171 | print(tokens) 172 | act_name = 'relu' 173 | for idx, t in enumerate(tokens): 174 | if t == 'width': 175 | width = eval(tokens[idx+1]) 176 | if t == 'depth': 177 | depth = eval(tokens[idx+1]) 178 | if t == 'act': 179 | act_name = tokens[idx+1] 180 | 181 | return width, depth, act_name 182 | 183 | 184 | def verify_NFA(path, dataset_name, feature_idx=None, layer_idx=0): 185 | remove_init = True 186 | random_net = False 187 | 188 | if dataset_name == 'celeba': 189 | NUM_CLASSES = 2 190 | FEATURE_IDX = feature_idx 191 | SIZE = 96 192 | c = 3 193 | dim = c * SIZE * SIZE 194 | elif dataset_name == 'svhn' or dataset_name == 'cifar': 195 | NUM_CLASSES = 10 196 | SIZE = 32 197 | c = 3 198 | dim = c * SIZE * SIZE 199 | elif dataset_name == 'cifar_mnist': 200 | NUM_CLASSES = 10 201 | c = 3 202 | SIZE = 32 203 | dim = c * SIZE * SIZE * 2 204 | elif dataset_name == 'stl_star': 205 | NUM_CLASSES = 2 206 | c = 3 207 | SIZE = 96 208 | dim = c * SIZE * SIZE 209 | 210 | width, depth, act_name = read_configs(path) 211 | 212 | net, M = load_nn(path, width, depth, dim, NUM_CLASSES, layer_idx=layer_idx, 213 | remove_init=remove_init, act_name=act_name) 214 | net0, M0 = load_init_nn(path, width, depth, dim, NUM_CLASSES, layer_idx=layer_idx, 215 | act_name=act_name) 216 | subnet = build_subnetwork(net, M.shape[0], width, depth, NUM_CLASSES, layer_idx=layer_idx, 217 | random_net=random_net, act_name=act_name) 218 | 219 | init_correlation = correlate(torch.from_numpy(M), 220 | torch.from_numpy(M0)) 221 | 222 | print("Init Net Feature Matrix Correlation: " , init_correlation) 223 | 224 | if dataset_name == 'celeba': 225 | trainloader, valloader, testloader = dataset.get_celeba(FEATURE_IDX, 226 | num_train=20000, 227 | num_test=1) 228 | elif dataset_name == 'svhn': 229 | trainloader, valloader, testloader = dataset.get_svhn(num_train=1000, 230 | num_test=1) 231 | elif dataset_name == 'cifar': 232 | trainloader, valloader, testloader = dataset.get_cifar(num_train=1000, 233 | num_test=1) 234 | 235 | elif dataset_name == 'cifar_mnist': 236 | trainloader, valloader, testloader = dataset.get_cifar_mnist(num_train_per_class=1000, 237 | num_test_per_class=1) 238 | elif dataset_name == 'stl_star': 239 | trainloader, valloader, testloader = dataset.get_stl_star(num_train=1000, 240 | num_test=1) 241 | out = get_layer_output(net, trainloader, layer_idx=layer_idx) 242 | G = egop(subnet, out, centering=True) 243 | G2 = egop(subnet, out, centering=False) 244 | 245 | centered_correlation = correlate(torch.from_numpy(M), G) 246 | uncentered_correlation = correlate(torch.from_numpy(M), G2) 247 | print("Full Matrix Correlation Centered: " , centered_correlation) 248 | print("Full Matrix Correlation Uncentered: " , uncentered_correlation) 249 | 250 | return init_correlation, centered_correlation, uncentered_correlation 251 | 252 | def main(): 253 | 254 | path = '' # Path to saved neural net model 255 | idxs = [0, 1, 2] # Layers for which to compute EGOP 256 | init, centered, uncentered = [], [], [] 257 | for idx in idxs: 258 | results = verify_NFA(path, 'svhn', layer_idx=idx) 259 | i, c, u = results 260 | init.append(i.numpy().item()) 261 | centered.append(c.numpy().item()) 262 | uncentered.append(u.numpy().item()) 263 | for idx in idxs: 264 | print("Layer " + str(idx), init[idx], centered[idx], uncentered[idx]) 265 | 266 | 267 | if __name__ == "__main__": 268 | main() 269 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import Dataset, DataLoader 5 | from sklearn.model_selection import train_test_split 6 | import numpy as np 7 | from tqdm import tqdm 8 | from numpy.linalg import norm 9 | 10 | 11 | def one_hot_data(dataset, num_classes, num_samples=-1, shift_label=False): 12 | labelset = {} 13 | for i in range(num_classes): 14 | one_hot = torch.zeros(num_classes) 15 | one_hot[i] = 1 16 | labelset[i] = one_hot 17 | 18 | offset = 0 19 | if shift_label: 20 | offset = -1 21 | 22 | subset = [(ex.flatten(), labelset[label + offset]) \ 23 | for idx, (ex, label) in enumerate(dataset) if idx < num_samples] 24 | 25 | return subset 26 | 27 | 28 | def split(trainset, p=.8): 29 | train, val = train_test_split(trainset, train_size=p) 30 | return train, val 31 | 32 | 33 | 34 | def get_svhn(split_percentage=.8, num_train=np.float('inf'), num_test=np.float('inf')): 35 | 36 | NUM_CLASSES = 10 37 | transform = transforms.Compose([transforms.ToTensor()]) 38 | svhn_path = '~/datasets/' 39 | 40 | trainset = torchvision.datasets.SVHN(root=svhn_path, 41 | split='train', 42 | transform=transform, 43 | download=False) 44 | trainset = one_hot_data(trainset, NUM_CLASSES, num_samples=num_train) 45 | trainset, valset = split(trainset, p=split_percentage) 46 | 47 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 48 | shuffle=True, num_workers=1) 49 | 50 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, 51 | shuffle=False, num_workers=1) 52 | 53 | testset = torchvision.datasets.SVHN(root=svhn_path, 54 | split='test', 55 | transform=transform, 56 | download=False) 57 | 58 | testset = one_hot_data(testset, NUM_CLASSES, num_samples=num_test) 59 | 60 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, 61 | shuffle=False, num_workers=1) 62 | 63 | print("Num Train: ", len(trainset), "Num Val: ", len(valset), 64 | "Num Test: ", len(testset)) 65 | return trainloader, valloader, testloader 66 | 67 | 68 | def get_cifar(split_percentage=.8, num_train=np.float('inf'), num_test=np.float('inf')): 69 | 70 | NUM_CLASSES = 10 71 | transform = transforms.Compose([transforms.ToTensor()]) 72 | path = '~/datasets/' 73 | 74 | trainset = torchvision.datasets.CIFAR10(root=path, 75 | train=True, 76 | transform=transform, 77 | download=False) 78 | 79 | trainset = one_hot_data(trainset, NUM_CLASSES, num_samples=num_train) 80 | trainset, valset = split(trainset, p=split_percentage) 81 | 82 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 83 | shuffle=True, num_workers=1) 84 | 85 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, 86 | shuffle=False, num_workers=1) 87 | 88 | testset = torchvision.datasets.CIFAR10(root=path, 89 | train=False, 90 | transform=transform, 91 | download=False) 92 | 93 | testset = one_hot_data(testset, NUM_CLASSES, num_samples=num_test) 94 | 95 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, 96 | shuffle=False, num_workers=1) 97 | 98 | print("Num Train: ", len(trainset), "Num Val: ", len(valset), 99 | "Num Test: ", len(testset)) 100 | return trainloader, valloader, testloader 101 | 102 | 103 | def sample_data(num, d): 104 | X = np.random.normal(size=(num, d)) 105 | y = X[:, 0] * X[:, 1] 106 | y = y.reshape(-1, 1) 107 | return torch.from_numpy(X).float(), torch.from_numpy(y).float() 108 | 109 | 110 | def get_two_coordinates(split_percentage=.8, num_train=2000, num_test=1000, d=100): 111 | X, y = sample_data(num_train, d) 112 | trainset = list(zip(X, y)) 113 | trainset, valset = split(trainset, p=split_percentage) 114 | X_test, y_test = sample_data(num_test, d) 115 | testset = list(zip(X_test, y_test)) 116 | 117 | train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1) 118 | val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=1) 119 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=1) 120 | print("Num Train: ", len(trainset), "Num Val: ", len(valset), 121 | "Num Test: ", len(testset)) 122 | 123 | return train_loader, val_loader, test_loader 124 | 125 | 126 | def celeba_subset(dataset, feature_idx, num_samples=-1): 127 | 128 | NUM_CLASSES = 2 129 | labelset = {} 130 | for i in range(NUM_CLASSES): 131 | one_hot = torch.zeros(NUM_CLASSES) 132 | one_hot[i] = 1 133 | labelset[i] = one_hot 134 | 135 | by_class = {} 136 | features = [] 137 | for idx in tqdm(range(len(dataset))): 138 | ex, label = dataset[idx] 139 | features.append(label[feature_idx]) 140 | g = label[feature_idx].numpy().item() 141 | #ex = torch.mean(ex, dim=0) 142 | ex = ex.flatten() 143 | ex = ex / norm(ex) 144 | if g in by_class: 145 | by_class[g].append((ex, labelset[g])) 146 | else: 147 | by_class[g] = [(ex, labelset[g])] 148 | if idx > num_samples: 149 | break 150 | data = [] 151 | if 1 in by_class: 152 | max_len = min(25000, len(by_class[1])) 153 | data.extend(by_class[1][:max_len]) 154 | data.extend(by_class[0][:max_len]) 155 | else: 156 | max_len = 1 157 | data.extend(by_class[0][:max_len]) 158 | return data 159 | 160 | 161 | 162 | def get_celeba(feature_idx, split_percentage=.8, 163 | num_train=np.float('inf'), num_test=np.float('inf')): 164 | celeba_path = '~/datasets/' 165 | SIZE = 96 166 | transform = transforms.Compose( 167 | [transforms.Resize([SIZE,SIZE]), 168 | transforms.ToTensor() 169 | ]) 170 | 171 | trainset = torchvision.datasets.CelebA(root=celeba_path, 172 | split='train', 173 | transform=transform, 174 | download=False) 175 | trainset = celeba_subset(trainset, feature_idx, num_samples=num_train) 176 | trainset, valset = split(trainset, p=split_percentage) 177 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 178 | shuffle=True, num_workers=1) 179 | 180 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, 181 | shuffle=False, num_workers=1) 182 | 183 | testset = torchvision.datasets.CelebA(root=celeba_path, 184 | split='test', 185 | transform=transform, 186 | download=False) 187 | testset = celeba_subset(testset, feature_idx, num_samples=num_test) 188 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, 189 | shuffle=False, num_workers=1) 190 | 191 | print("Train Size: ", len(trainset), "Val Size: ", len(valset), "Test Size: ", len(testset)) 192 | return trainloader, valloader, testloader 193 | 194 | 195 | def group_by_class(dataset): 196 | labelset = {} 197 | for i in range(10): 198 | labelset[i] = [] 199 | for i, batch in enumerate(dataset): 200 | img, label = batch 201 | labelset[label].append(img.view(1, 3, 32, 32)) 202 | return labelset 203 | 204 | 205 | def merge_data(cifar, mnist, n): 206 | cifar_by_label = group_by_class(cifar) 207 | mnist_by_label = group_by_class(mnist) 208 | 209 | merged_data = [] 210 | merged_labels = [] 211 | 212 | labelset = {} 213 | 214 | for i in range(10): 215 | one_hot = torch.zeros(1, 10) 216 | one_hot[0, i] = 1 217 | labelset[i] = one_hot 218 | 219 | for l in cifar_by_label: 220 | 221 | cifar_data = torch.cat(cifar_by_label[l]) 222 | mnist_data = torch.cat(mnist_by_label[l]) 223 | min_len = min(len(mnist_data), len(cifar_data)) 224 | m = min(n, min_len) 225 | cifar_data = cifar_data[:m] 226 | mnist_data = mnist_data[:m] 227 | 228 | merged = torch.cat([cifar_data, mnist_data], axis=-1) 229 | #for i in range(3): 230 | # vis.image(merged[i]) 231 | merged_data.append(merged.reshape(m, -1)) 232 | #print(merged.shape) 233 | merged_labels.append(np.repeat(labelset[l], m, axis=0)) 234 | merged_data = torch.cat(merged_data, axis=0) 235 | 236 | merged_labels = np.concatenate(merged_labels, axis=0) 237 | merged_labels = torch.from_numpy(merged_labels) 238 | 239 | return list(zip(merged_data, merged_labels)) 240 | 241 | 242 | def get_cifar_mnist(split_percentage=.8, num_train_per_class=np.float('inf'), 243 | num_test_per_class=np.float('inf')): 244 | 245 | path = '~/datasets/' 246 | transform = transforms.Compose( 247 | [#transforms.Resize([32,32]), 248 | transforms.ToTensor() 249 | ]) 250 | 251 | mnist_transform = transforms.Compose( 252 | [transforms.Resize([32,32]), 253 | transforms.ToTensor(), 254 | transforms.Lambda(lambda x: x.repeat(3, 1, 1)) 255 | ]) 256 | 257 | 258 | cifar_trainset = torchvision.datasets.CIFAR10(root=path, 259 | train=True, 260 | transform=transform, 261 | download=False) 262 | 263 | mnist_trainset = torchvision.datasets.MNIST(root=path, 264 | train=True, 265 | transform=mnist_transform, 266 | download=False) 267 | trainset = merge_data(cifar_trainset, mnist_trainset, num_train_per_class) 268 | trainset, valset = split(trainset, p=split_percentage) 269 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 270 | shuffle=True, num_workers=2) 271 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, 272 | shuffle=False, num_workers=1) 273 | 274 | cifar_testset = torchvision.datasets.CIFAR10(root=path, 275 | train=False, 276 | transform=transform, 277 | download=False) 278 | 279 | mnist_testset = torchvision.datasets.MNIST(root=path, 280 | train=False, 281 | transform=mnist_transform, 282 | download=False) 283 | 284 | testset = merge_data(cifar_testset, mnist_testset, num_test_per_class) 285 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, 286 | shuffle=False, num_workers=2) 287 | 288 | print("Num Train: ", len(trainset), "Num Val: ", len(valset), 289 | "Num Test: ", len(testset)) 290 | return trainloader, valloader, testloader 291 | 292 | 293 | def draw_star(ex, v, c=3): 294 | ex[:c, 5:6, 7:14] = v 295 | ex[:c, 4, 9:12] = v 296 | ex[:c, 3, 10] = v 297 | ex[:c, 6, 8:13] = v 298 | ex[:c, 7, 9:12] = v 299 | ex[:c, 8, 8:13] = v 300 | ex[:c, 9, 8:10] = v 301 | ex[:c, 9, 11:13] = v 302 | return ex 303 | 304 | 305 | def one_hot_stl_toy(dataset, num_samples=-1): 306 | labelset = {} 307 | for i in range(2): 308 | one_hot = torch.zeros(2) 309 | one_hot[i] = 1 310 | labelset[i] = one_hot 311 | 312 | subset = [(ex, label) for idx, (ex, label) in enumerate(dataset) \ 313 | if idx < num_samples and (label == 0 or label == 9)] 314 | 315 | adjusted = [] 316 | for idx, (ex, label) in enumerate(subset): 317 | if label == 9: 318 | ex = draw_star(ex,1, c=2) 319 | y = 1 320 | else: 321 | ex = draw_star(ex, 0) 322 | y = 0 323 | ex = ex.flatten() 324 | adjusted.append((ex, labelset[y])) 325 | return adjusted 326 | 327 | 328 | 329 | def get_stl_star(split_percentage=.8, num_train=np.float('inf'), 330 | num_test=np.float('inf')): 331 | SIZE = 96 332 | transform = transforms.Compose( 333 | [transforms.Resize([SIZE, SIZE]), 334 | transforms.ToTensor() 335 | ]) 336 | 337 | path = '~/datasets/' 338 | trainset = torchvision.datasets.STL10(root=path, 339 | split='train', 340 | #train=True, 341 | transform=transform, 342 | download=False) 343 | trainset = one_hot_stl_toy(trainset, num_samples=num_train) 344 | trainset, valset = split(trainset, p=split_percentage) 345 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 346 | shuffle=True, num_workers=2) 347 | 348 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, 349 | shuffle=False, num_workers=1) 350 | testset = torchvision.datasets.STL10(root=path, 351 | split='test', 352 | transform=transform, 353 | download=False) 354 | testset = one_hot_stl_toy(testset, num_samples=num_test) 355 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, 356 | shuffle=False, num_workers=2) 357 | print("Num Train: ", len(trainset), "Num Val: ", len(valset), 358 | "Num Test: ", len(testset)) 359 | return trainloader, valloader, testloader 360 | --------------------------------------------------------------------------------