├── README.md ├── activation.py ├── dataset.py ├── demo.py ├── engine.py ├── environment_pytorch==1.7.1.yml ├── graphs.ipynb ├── hebbconv.py ├── hebblinear.py ├── images └── architecture.PNG ├── layer.py ├── log.py ├── model.py ├── multi_layer.py ├── nb_utils.py ├── post_hoc_loss.py ├── presets.json ├── ray_search.py ├── train.py └── utils.py /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Union 6 | 7 | def get_activation( 8 | activation_fn:str, 9 | param: float = 1., 10 | dim: int = 1 11 | ): 12 | """ 13 | Select the corresponding activation class 14 | Parameters 15 | ---------- 16 | activation_fn 17 | t_invert 18 | beta 19 | power 20 | dim 21 | 22 | Returns 23 | ------- 24 | 25 | """ 26 | if activation_fn == 'triangle': 27 | return Triangle(power=param) 28 | if activation_fn == 'relu': 29 | return nn.ReLU() 30 | if activation_fn == 'repu': 31 | return RePU(power=param) 32 | if activation_fn == 'sigmoid': 33 | return Sigmoid(beta=param) 34 | if activation_fn == 'tanh': 35 | return Tanh(beta=param) 36 | if activation_fn == 'exp': 37 | return Exp(t_invert=param) 38 | if activation_fn == 'softmax': 39 | return SoftMax(t_invert=param, dim=dim) 40 | if activation_fn == 'hard': 41 | return Hard() 42 | 43 | class RePU(nn.Module): 44 | r"""Applies the Repu function element-wise: 45 | """ 46 | 47 | def __init__(self, power: float, inplace: bool = False): 48 | super(RePU, self).__init__() 49 | self.power = power 50 | self.inplace = inplace 51 | 52 | def forward(self, input: Tensor) -> Tensor: 53 | return F.relu(input, inplace=self.inplace)**self.power 54 | 55 | def extra_repr(self) -> str: 56 | return 'power=%s'%self.power 57 | 58 | class Tanh(nn.Module): 59 | r"""Applies the Tanh element-wise function: 60 | """ 61 | def __init__(self, beta: float): 62 | super(Tanh, self).__init__() 63 | self.beta = beta 64 | 65 | def forward(self, input: Tensor) -> Tensor: 66 | return torch.tanh(input * self.beta) 67 | 68 | def extra_repr(self) -> str: 69 | return 'beta=%s'%self.beta 70 | 71 | class Sigmoid(nn.Module): 72 | r"""Applies the Sigmoid element-wise function: 73 | """ 74 | def __init__(self, beta: float): 75 | super(Sigmoid, self).__init__() 76 | self.beta = 10#beta 77 | 78 | def forward(self, input: Tensor) -> Tensor: 79 | return torch.sigmoid(input * self.beta) 80 | 81 | def extra_repr(self) -> str: 82 | return 'beta=%s'%self.beta 83 | 84 | class Triangle(nn.Module): 85 | r"""Applies the Sigmoid element-wise function: 86 | """ 87 | 88 | def __init__(self, power: float=1, inplace: bool = True): 89 | super(Triangle, self).__init__() 90 | self.inplace = inplace 91 | self.power = power 92 | 93 | def forward(self, input: Tensor) -> Tensor: 94 | input = input - torch.mean(input.data, axis=1, keepdims=True) 95 | return F.relu(input, inplace=self.inplace) ** self.power 96 | 97 | def extra_repr(self) -> str: 98 | return 'power=%s'%self.power 99 | 100 | 101 | class Exp(nn.Module): 102 | r"""Applies the exp element-wise function: 103 | """ 104 | def __init__(self, t_invert: float): 105 | super(Exp, self).__init__() 106 | self.t_invert = t_invert 107 | def forward(self, input: Tensor) -> Tensor: 108 | return torch.exp(input * self.t_invert) 109 | 110 | def extra_repr(self) -> str: 111 | return 't_invert=%s'%self.t_invert 112 | 113 | class Hard(nn.Module): 114 | r"""Applies the exp element-wise function: 115 | """ 116 | def __init__(self,): 117 | super(Hard, self).__init__() 118 | def forward(self, input: Tensor) -> Tensor: 119 | return nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).to( 120 | torch.float) 121 | 122 | 123 | class SoftMax(nn.Module): 124 | r"""Applies the softmax function element-wise: 125 | """ 126 | 127 | def __init__(self, t_invert: float, dim: Union[int, tuple] = 1): 128 | super(SoftMax, self).__init__() 129 | self.t_invert = t_invert 130 | self.dim = dim 131 | 132 | def forward(self, input: Tensor) -> Tensor: 133 | if isinstance(self.dim, int): 134 | return torch.softmax(self.t_invert * input, dim=self.dim) 135 | 136 | shape = list(input.shape) 137 | if self.dim[0] != 0: 138 | shape_dim_0 = shape[self.dim[0]] 139 | input = input.permute(0, self.dim[0]) 140 | shape[self.dim[0]] = shape[0] 141 | shape[0] = shape_dim_0 142 | if self.dim[1] != 1: 143 | shape_dim_1 = shape[self.dim[1]] 144 | input = input.permute(1, self.dim[1]) 145 | shape[self.dim[1]] = shape[1] 146 | shape[1] = shape_dim_1 147 | 148 | input = input.view([shape[0]*shape[1]]+shape[2:]) 149 | input = torch.softmax(self.t_invert * input, dim=0) 150 | input = input.view(shape) 151 | if self.dim[1] != 1: 152 | input = input.permute(1, self.dim[1]) 153 | if self.dim[0] != 0: 154 | input = input.permute(0, self.dim[0]) 155 | return input 156 | 157 | 158 | 159 | 160 | def extra_repr(self) -> str: 161 | return 't_invert=%s, dim=%s'%(self.t_invert, self.dim) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | try: 5 | from utils import seed_init_fn, DATASET 6 | except: 7 | from hebb.utils import seed_init_fn, DATASET 8 | import numpy as np 9 | import os 10 | import os.path as op 11 | import torch 12 | from torch.utils.data.sampler import Sampler, SubsetRandomSampler 13 | from torchvision import datasets, transforms 14 | import torchvision.transforms.functional as TF 15 | from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST, STL10, ImageNet, ImageFolder 16 | from typing import Optional, Any 17 | 18 | 19 | class AddGaussianNoise(object): 20 | def __init__(self, mean=0., std=1.): 21 | self.std = std 22 | self.mean = mean 23 | 24 | def __call__(self, tensor): 25 | return tensor + torch.randn(tensor.size(), device=tensor.device) * self.std + self.mean 26 | 27 | 28 | def imagenet_tf(width, height): 29 | return transforms.Compose([ 30 | transforms.RandomSizedCrop((width, height)), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]), 35 | ]) 36 | 37 | 38 | def imagenet_test(width, height): 39 | return transforms.Compose([ 40 | transforms.Resize((width, height)), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]), 44 | ]) 45 | 46 | 47 | def advanced_transform(width, height): 48 | return transforms.Compose([ 49 | transforms.RandomApply([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=20 / 360)], 50 | p=0.5), 51 | transforms.RandomApply([transforms.ColorJitter(saturation=1)], p=0.5), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.Pad(8), 54 | transforms.RandomApply( 55 | [transforms.Lambda(lambda x: TF.resize(x, (48 + random.randint(-6, 6), 48 + random.randint(-6, 6))))], 56 | p=0.3), 57 | transforms.RandomApply([transforms.RandomAffine(degrees=10, shear=10)], p=0.3), 58 | transforms.CenterCrop(40), 59 | transforms.RandomApply([transforms.RandomCrop((width, height))], p=0.5), 60 | transforms.CenterCrop((width, height)), 61 | ]) 62 | 63 | 64 | def crop_flip(width, height): 65 | return transforms.Compose( 66 | [ 67 | transforms.RandomCrop( 68 | (width, height), padding=4, padding_mode="reflect" 69 | ), 70 | transforms.RandomHorizontalFlip(p=0.5), 71 | 72 | ] 73 | ) 74 | 75 | 76 | def select_dataset(dataset_config, device, dataset_path): 77 | test_transform = None 78 | val_indices = None 79 | split = dataset_config["split"] if "split" in dataset_config else "train" 80 | if dataset_config['name'] == 'CIFAR10': 81 | dataset_class = FastCIFAR10 82 | indices = list(range(50000)) 83 | 84 | if dataset_config['augmentation']: 85 | dataset_train_class = AugFastCIFAR10 86 | dataset_config['num_workers'] = 4 87 | device = 'cpu' 88 | transform = crop_flip(dataset_config['width'], dataset_config['height']) 89 | else: 90 | dataset_train_class = FastCIFAR10 91 | transform = None 92 | 93 | elif dataset_config['name'] == 'CIFAR100': 94 | dataset_class = FastCIFAR100 95 | indices = list(range(50000)) 96 | 97 | if dataset_config['augmentation']: 98 | dataset_train_class = AugFastCIFAR100 99 | dataset_config['num_workers'] = 4 100 | device = 'cpu' 101 | transform = crop_flip(dataset_config['width'], dataset_config['height']) 102 | else: 103 | dataset_train_class = FastCIFAR100 104 | transform = None 105 | 106 | elif dataset_config['name'] == 'MNIST': 107 | dataset_class = FastMNIST 108 | indices = list(range(60000)) 109 | 110 | if dataset_config['augmentation']: 111 | dataset_train_class = AugFastMNIST 112 | dataset_config['num_workers'] = 4 113 | device = 'cpu' 114 | transform = crop_flip(dataset_config['width'], dataset_config['height']) 115 | transform = AddGaussianNoise(std=dataset_config['noise_std']) 116 | else: 117 | dataset_train_class = FastMNIST 118 | transform = None 119 | 120 | elif dataset_config['name'] == 'FashionMNIST': 121 | dataset_class = FastFashionMNIST 122 | indices = list(range(60000)) 123 | 124 | if dataset_config['augmentation']: 125 | dataset_train_class = AugFastFashionMNIST 126 | dataset_config['num_workers'] = 4 127 | device = 'cpu' 128 | transform = crop_flip(dataset_config['width'], dataset_config['height']) 129 | else: 130 | dataset_train_class = FastFashionMNIST 131 | transform = None 132 | 133 | elif dataset_config['name'].startswith('ImageNette'): 134 | device = 'cpu' 135 | dataset_class = ImageNette 136 | indices = list(range(9469)) 137 | 138 | if dataset_config['px'] == 'default': 139 | dataset_path = '/home/username/.fastai/data/imagenette2' 140 | elif dataset_config['px'] == 320: 141 | dataset_path = '/home/username/.fastai/data/imagenette2-320' 142 | else: 143 | dataset_path = os.path.join(DATASET, 'imagenette2-160') # '/home/username/.fastai/data/imagenette2-160' 144 | 145 | if dataset_config['augmentation']: 146 | dataset_train_class = ImageNette 147 | dataset_config['num_workers'] = 4 148 | device = 'cpu' 149 | transform = imagenet_tf(dataset_config['width'], dataset_config['height']) 150 | else: 151 | dataset_train_class = ImageNette 152 | transform = imagenet_test(dataset_config['width'], dataset_config['height']) 153 | test_transform = imagenet_test(dataset_config['width'], dataset_config['height']) 154 | 155 | elif dataset_config['name'].startswith('ImageNetV2'): 156 | device = 'cpu' 157 | dataset_class = ImageNetV2 158 | indices = list(range(10000)) 159 | 160 | if dataset_config['name'][10:] == 'MatchedFrequency': 161 | dataset_path = '/scratch/hrodriguez/workspace/data/imagenetv2-matched-frequency-format-val' 162 | elif dataset_config['name'][10:] == 'Threshold07': 163 | dataset_path = '/scratch/hrodriguez/workspace/data/imagenetv2-threshold0.7-format-val' 164 | elif dataset_config['name'][10:] == 'TopImages': 165 | dataset_path = '/scratch/hrodriguez/workspace/data/imagenetv2-top-images-format-val' 166 | else: 167 | raise ValueError 168 | 169 | if dataset_config['augmentation']: 170 | dataset_train_class = ImageNetV2 171 | dataset_config['num_workers'] = 4 172 | device = 'cpu' 173 | transform = imagenet_tf(dataset_config['width'], dataset_config['height']) 174 | else: 175 | dataset_train_class = ImageNetV2 176 | transform = imagenet_test(dataset_config['width'], dataset_config['height']) 177 | test_transform = imagenet_test(dataset_config['width'], dataset_config['height']) 178 | 179 | elif dataset_config['name'] == 'ImageNet': 180 | device = 'cpu' 181 | dataset_class = AugImageNet 182 | indices = list(range(1000000)) 183 | 184 | dataset_path = '/scratch/datasets/ilsvrc12/' 185 | 186 | if dataset_config['augmentation']: 187 | dataset_train_class = AugImageNet 188 | dataset_config['num_workers'] = 4 189 | device = 'cpu' 190 | transform = imagenet_tf(dataset_config['width'], dataset_config['height']) 191 | else: 192 | dataset_train_class = AugImageNet 193 | transform = imagenet_test(dataset_config['width'], dataset_config['height']) 194 | test_transform = imagenet_test(dataset_config['width'], dataset_config['height']) 195 | 196 | elif dataset_config['name'] == 'STL10': 197 | device = 'cpu' 198 | dataset_class = FastSTL10 199 | if split == 'train': 200 | indices = list(range(5000)) 201 | elif split == 'unlabeled': 202 | indices = list(range(100000)) 203 | else: 204 | indices = list(range(105000)) 205 | 206 | if dataset_config['augmentation']: 207 | dataset_train_class = AugFastSTL10 208 | dataset_config['num_workers'] = 4 209 | device = 'cpu' 210 | transform = crop_flip(dataset_config['width'], dataset_config['height']) 211 | else: 212 | dataset_train_class = FastSTL10 213 | transform = None 214 | else: 215 | raise ValueError 216 | 217 | if not isinstance(dataset_config['training_class'], str): 218 | # we have to select indices up to the training_sample (trainign set size) otherwise the future origin_dataset 219 | # won't have enough indeces (it only stores the datapoints of the chosen training_class(es) 220 | # Another headache is that although you give training_sample, the validation set is taken from that 221 | # In the end: if want to validate, do it with only the same class(es) 222 | not_all_classes_samples = dataset_config['training_sample'] 223 | if dataset_config['validation']: 224 | not_all_classes_samples += dataset_config['val_sample'] 225 | not_all_classes_indices = indices[:not_all_classes_samples] 226 | indices = copy.deepcopy(not_all_classes_indices) 227 | 228 | if dataset_config['shuffle']: 229 | np.random.shuffle(indices) 230 | train_indices = indices[:dataset_config['training_sample']] 231 | 232 | if dataset_config['validation']: 233 | val_indices = indices[:dataset_config['val_sample']] 234 | train_indices = indices[ 235 | dataset_config['val_sample']:(dataset_config['training_sample'] + dataset_config['val_sample'])] 236 | return dataset_train_class, dataset_class, test_transform, transform, device, train_indices, val_indices, split, dataset_path 237 | 238 | 239 | def make_data_loaders(dataset_config, batch_size, device, dataset_path=DATASET): 240 | """ 241 | Load Mnist Dataset and create a dataloader 242 | 243 | Parameters 244 | ---------- 245 | dataset_config : dict 246 | Configuration of the expected dataset 247 | batch_size: int 248 | dataset_path : str path 249 | Path to the dataset folder. 250 | 251 | Returns 252 | ------- 253 | train_loader : torch.utils.data.DataLoader 254 | Training dataloader. 255 | test_loader : torch.utils.data.DataLoader 256 | Testing dataloader. 257 | 258 | """ 259 | g = torch.Generator() 260 | if dataset_config['seed'] is not None: 261 | seed_init_fn(dataset_config['seed']) 262 | g.manual_seed(dataset_config['seed'] % 2 ** 32) 263 | 264 | dataset_train_class, dataset_class, test_transform, transform, device, train_indices, val_indices, split, dataset_path = select_dataset( 265 | dataset_config, device, dataset_path) 266 | 267 | train_sampler = SubsetRandomSampler(train_indices, generator=g) 268 | origin_dataset = dataset_train_class( 269 | dataset_path, 270 | split=split, 271 | train=True, 272 | download=not dataset_config['name'] in ['ImageNet'], # TODO: make this depend on whether dataset exists or not 273 | transform=transform, 274 | zca=dataset_config['zca_whitened'], 275 | device=device, 276 | train_class=dataset_config['training_class'] 277 | ) 278 | train_loader = torch.utils.data.DataLoader(dataset=origin_dataset, 279 | batch_size=batch_size, 280 | num_workers=dataset_config['num_workers'], 281 | sampler=train_sampler) 282 | 283 | if val_indices is not None: 284 | val_sampler = SubsetRandomSampler(val_indices) 285 | test_loader = torch.utils.data.DataLoader(dataset=origin_dataset, 286 | batch_size=batch_size, 287 | num_workers=dataset_config['num_workers'], 288 | sampler=val_sampler) 289 | else: 290 | test_loader = torch.utils.data.DataLoader( 291 | dataset_class( 292 | dataset_path, 293 | split="val" if dataset_config['name'] in ['ImageNet', 'ImageNette', 294 | 'ImageNetV2MatchedFrequency'] else "test", 295 | train=False, 296 | zca=dataset_config['zca_whitened'], 297 | transform=test_transform, 298 | device=device 299 | ), 300 | batch_size=batch_size if dataset_config['name'] in ['STL10', 'ImageNet', 'ImageNette', 301 | 'ImageNetV2MatchedFrequency', 'ImageNetV2TopImages', 302 | 'ImageNetV2Threshold07'] else 1000, 303 | num_workers=dataset_config['num_workers'], 304 | shuffle=dataset_config['shuffle'], 305 | ) 306 | return train_loader, test_loader 307 | 308 | 309 | def whitening_zca(x: torch.Tensor, transpose=True, dataset: str = "CIFAR10"): 310 | path = op.join(DATASET, dataset + "_zca.pt") 311 | zca = None 312 | try: 313 | zca = torch.load(path, map_location='cpu')['zca'] 314 | except: 315 | pass 316 | 317 | if zca is None: 318 | 319 | if transpose: 320 | x = x.copy().transpose(0, 3, 1, 2) 321 | 322 | x = x.copy().reshape(x.shape[0], -1) 323 | 324 | cov = np.cov(x, rowvar=False) 325 | 326 | u, s, v = np.linalg.svd(cov) 327 | 328 | SMOOTHING_CONST = 1e-1 329 | zca = np.dot(u, np.dot(np.diag(1.0 / np.sqrt(s + SMOOTHING_CONST)), u.T)) 330 | zca = torch.from_numpy(zca).float() 331 | 332 | os.makedirs(os.path.dirname(path), exist_ok=True) 333 | torch.save({'zca': zca}, path) 334 | 335 | return zca 336 | 337 | 338 | # *************************************************** Imagenet-10 *************************************************** 339 | 340 | class AugImageNet(ImageNet): 341 | def __init__(self, *args, **kwargs): 342 | device = kwargs.pop('device', "cpu") 343 | zca = kwargs.pop('zca', False) 344 | train_class = kwargs.pop('train_class', 'all') 345 | train = kwargs.pop('train', True) 346 | super().__init__(*args, **kwargs) 347 | 348 | 349 | class ImageNette(ImageFolder): 350 | 351 | def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None: 352 | root = self.root = os.path.expanduser(root) 353 | device = kwargs.pop('device', "cpu") 354 | zca = kwargs.pop('zca', False) 355 | train_class = kwargs.pop('train_class', 'all') 356 | train = kwargs.pop('train', True) 357 | assert split in ['val', 'train'] 358 | self.split = split 359 | 360 | super(ImageNette, self).__init__(self.split_folder, **kwargs) 361 | 362 | @property 363 | def split_folder(self) -> str: 364 | return os.path.join(self.root, self.split) 365 | 366 | def extra_repr(self) -> str: 367 | return "Split: {split}".format(**self.__dict__) 368 | 369 | 370 | class ImageNetV2(ImageFolder): 371 | 372 | def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None: 373 | root = self.root = os.path.expanduser(root) 374 | device = kwargs.pop('device', "cpu") 375 | zca = kwargs.pop('zca', False) 376 | train_class = kwargs.pop('train_class', 'all') 377 | train = kwargs.pop('train', True) 378 | assert split in ['test', 379 | 'val'] # although it's called val i think it' really a test, we don't use it for model dev 380 | self.split = split 381 | 382 | super(ImageNetV2, self).__init__(self.split_folder, **kwargs) 383 | 384 | @property 385 | def split_folder(self) -> str: 386 | return self.root 387 | 388 | def extra_repr(self) -> str: 389 | return "Split: {split}".format(**self.__dict__) 390 | 391 | 392 | # *************************************************** STL-10 *************************************************** 393 | 394 | class FastSTL10(STL10): 395 | """ 396 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 397 | 398 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 399 | """ 400 | 401 | def __init__(self, *args, **kwargs): 402 | device = kwargs.pop('device', "cpu") 403 | zca = kwargs.pop('zca', False) 404 | train_class = kwargs.pop('train_class', 'all') 405 | train = kwargs.pop('train', True) 406 | super().__init__(*args, **kwargs) 407 | 408 | mean = (0.4914, 0.48216, 0.44653) 409 | std = (0.247, 0.2434, 0.2616) 410 | 411 | # norm = transforms.Normalize(mean, std) 412 | 413 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255) 414 | 415 | if train: 416 | if not isinstance(train_class, str): 417 | index_class = np.isin(self.labels, train_class) 418 | self.data = self.data[index_class] 419 | self.labels = np.array(self.labels)[index_class] 420 | self.len = self.data.shape[0] 421 | 422 | if zca: 423 | self.data = (self.data - mean) / std 424 | self.zca = whitening_zca(self.data, transpose=False, dataset=STL10) 425 | zca_whitening = transforms.LinearTransformation(self.zca, torch.zeros(self.zca.size(1))) 426 | self.data = torch.tensor(self.data, dtype=torch.float) 427 | 428 | # self.data = torch.movedim(self.data, -1, 1) # -> set dim to: (batch, channels, height, width) 429 | # self.data = norm(self.data) 430 | if zca: 431 | self.data = zca_whitening(self.data) 432 | print("self.data.mean(), self.data.std()", self.data.mean(), self.data.std()) 433 | 434 | # self.data = self.data.to(device) # Rescale to [0, 1] 435 | 436 | # self.data = self.data.div_(CIFAR10_STD) #(NOT) Normalize to 0 centered with 1 std 437 | 438 | self.labels = torch.tensor(self.labels, device=device) 439 | 440 | def __getitem__(self, index: int): 441 | """ 442 | Parameters 443 | ---------- 444 | index : int 445 | Index of the element to be returned 446 | 447 | Returns 448 | ------- 449 | tuple: (image, target) where target is the index of the target class 450 | """ 451 | if self.labels is not None: 452 | img, target = self.data[index], int(self.labels[index]) 453 | else: 454 | img, target = self.data[index], None 455 | 456 | return img, target 457 | 458 | 459 | class AugFastSTL10(FastSTL10): 460 | """ 461 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 462 | 463 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 464 | """ 465 | 466 | def __getitem__(self, index: int): 467 | """ 468 | Parameters 469 | ---------- 470 | index : int 471 | Index of the element to be returned 472 | 473 | Returns 474 | ------- 475 | tuple: (image, target) where target is the index of the target class 476 | """ 477 | 478 | if self.labels is not None: 479 | img, target = self.data[index], int(self.labels[index]) 480 | else: 481 | img, target = self.data[index], None 482 | 483 | if self.transform is not None: 484 | img = self.transform(img) 485 | 486 | return img, target 487 | 488 | 489 | # *************************************************** CIFAR-10 *************************************************** 490 | 491 | class FastCIFAR10(CIFAR10): 492 | """ 493 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 494 | 495 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 496 | """ 497 | 498 | def __init__(self, *args, **kwargs): 499 | device = kwargs.pop('device', "cpu") 500 | zca = kwargs.pop('zca', False) 501 | train_class = kwargs.pop('train_class', 'all') 502 | split = kwargs.pop('split', 'train') 503 | super().__init__(*args, **kwargs) 504 | 505 | self.split = split 506 | 507 | mean = (0.4914, 0.48216, 0.44653) 508 | std = (0.247, 0.2434, 0.2616) 509 | 510 | norm = transforms.Normalize(mean, std) 511 | 512 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255) 513 | 514 | if self.train: 515 | if not isinstance(train_class, str): 516 | index_class = np.isin(self.targets, train_class) 517 | self.data = self.data[index_class] 518 | self.targets = np.array(self.targets)[index_class] 519 | self.len = self.data.shape[0] 520 | 521 | if zca: 522 | self.data = (self.data - mean) / std 523 | self.zca = whitening_zca(self.data) 524 | zca_whitening = transforms.LinearTransformation(self.zca, torch.zeros(self.zca.size(1))) 525 | self.data = torch.tensor(self.data, dtype=torch.float) 526 | 527 | self.data = torch.movedim(self.data, -1, 1) # -> set dim to: (batch, channels, height, width) 528 | # self.data = norm(self.data) 529 | if zca: 530 | self.data = zca_whitening(self.data) 531 | print("self.data.mean(), self.data.std()", self.data.mean(), self.data.std()) 532 | 533 | # self.data = self.data.to(device) # Rescale to [0, 1] 534 | 535 | # self.data = self.data.div_(CIFAR10_STD) #(NOT) Normalize to 0 centered with 1 std 536 | 537 | self.targets = torch.tensor(self.targets, device=device) 538 | 539 | def __getitem__(self, index: int): 540 | """ 541 | Parameters 542 | ---------- 543 | index : int 544 | Index of the element to be returned 545 | 546 | Returns 547 | ------- 548 | tuple: (image, target) where target is the index of the target class 549 | """ 550 | img = self.data[index] 551 | target = self.targets[index] 552 | 553 | return img, target 554 | 555 | 556 | class AugFastCIFAR10(FastCIFAR10): 557 | """ 558 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 559 | 560 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 561 | """ 562 | 563 | def __getitem__(self, index: int): 564 | """ 565 | Parameters 566 | ---------- 567 | index : int 568 | Index of the element to be returned 569 | 570 | Returns 571 | ------- 572 | tuple: (image, target) where target is the index of the target class 573 | """ 574 | img = self.transform(self.data[index]) 575 | target = self.targets[index] 576 | 577 | return img, target 578 | 579 | 580 | class FastCIFAR100(CIFAR100): 581 | """ 582 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 583 | 584 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 585 | """ 586 | 587 | def __init__(self, *args, **kwargs): 588 | device = kwargs.pop('device', "cpu") 589 | zca = kwargs.pop('zca', False) 590 | train_class = kwargs.pop('train_class', 'all') 591 | split = kwargs.pop('split', 'train') 592 | super().__init__(*args, **kwargs) 593 | 594 | self.split = split 595 | 596 | mean = (0.4914, 0.48216, 0.44653) 597 | std = (0.247, 0.2434, 0.2616) 598 | 599 | norm = transforms.Normalize(mean, std) 600 | 601 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255) 602 | 603 | if self.train: 604 | if not isinstance(train_class, str): 605 | index_class = np.isin(self.targets, train_class) 606 | self.data = self.data[index_class] 607 | self.targets = np.array(self.targets)[index_class] 608 | self.len = self.data.shape[0] 609 | print(self.len) 610 | 611 | if zca: 612 | self.data = (self.data - mean) / std 613 | self.zca = whitening_zca(self.data) 614 | zca_whitening = transforms.LinearTransformation(self.zca, torch.zeros(self.zca.size(1))) 615 | self.data = torch.tensor(self.data, dtype=torch.float) 616 | 617 | self.data = torch.movedim(self.data, -1, 1) # -> set dim to: (batch, channels, height, width) 618 | # self.data = norm(self.data) 619 | if zca: 620 | self.data = zca_whitening(self.data) 621 | print(self.data.mean(), self.data.std()) 622 | 623 | # self.data = self.data.to(device) # Rescale to [0, 1] 624 | 625 | # self.data = self.data.div_(CIFAR10_STD) #(NOT) Normalize to 0 centered with 1 std 626 | 627 | self.targets = torch.tensor(self.targets, device=device) 628 | 629 | def __getitem__(self, index: int): 630 | """ 631 | Parameters 632 | ---------- 633 | index : int 634 | Index of the element to be returned 635 | 636 | Returns 637 | ------- 638 | tuple: (image, target) where target is the index of the target class 639 | """ 640 | img = self.data[index] 641 | target = self.targets[index] 642 | 643 | return img, target 644 | 645 | 646 | class AugFastCIFAR100(FastCIFAR100): 647 | """ 648 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 649 | 650 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 651 | """ 652 | 653 | def __getitem__(self, index: int): 654 | """ 655 | Parameters 656 | ---------- 657 | index : int 658 | Index of the element to be returned 659 | 660 | Returns 661 | ------- 662 | tuple: (image, target) where target is the index of the target class 663 | """ 664 | img = self.transform(self.data[index]) 665 | target = self.targets[index] 666 | 667 | return img, target 668 | 669 | 670 | # *************************************************** MNIST *************************************************** 671 | 672 | class FastMNIST(MNIST): 673 | def __init__(self, *args, **kwargs): 674 | device = kwargs.pop('device', "cpu") 675 | zca = kwargs.pop('zca', False) 676 | train_class = kwargs.pop('train_class', 'all') 677 | split = kwargs.pop('split', 'train') 678 | super().__init__(*args, **kwargs) 679 | 680 | self.split = split 681 | 682 | if self.train: 683 | if not isinstance(train_class, str): 684 | print(train_class) 685 | self.targets = np.array(self.targets) 686 | index_class = np.isin(self.targets, train_class) 687 | self.data = self.data[index_class] 688 | self.targets = self.targets[index_class] 689 | self.len = self.data.shape[0] 690 | 691 | # Scale data to [0,1] 692 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255).unsqueeze(1) 693 | 694 | self.targets = torch.tensor(self.targets, device=device) 695 | 696 | # Normalize it with the usual MNIST mean and std 697 | # self.data = self.data.sub_(0.1307).div_(0.3081) 698 | 699 | # Put both data and targets on GPU in advance 700 | 701 | def __getitem__(self, index): 702 | """ 703 | Args: 704 | index (int): Index 705 | 706 | Returns: 707 | tuple: (image, target) where target is index of the target class. 708 | """ 709 | img, target = self.data[index], self.targets[index] 710 | 711 | return img, target 712 | 713 | 714 | class AugFastMNIST(FastMNIST): 715 | """ 716 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 717 | 718 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 719 | """ 720 | 721 | def __getitem__(self, index: int): 722 | """ 723 | Parameters 724 | ---------- 725 | index : int 726 | Index of the element to be returned 727 | 728 | Returns 729 | ------- 730 | tuple: (image, target) where target is the index of the target class 731 | """ 732 | img = self.transform(self.data[index]) 733 | target = self.targets[index] 734 | 735 | return img, target 736 | 737 | 738 | # *************************************************** FashionMNIST *************************************************** 739 | 740 | class FastFashionMNIST(FashionMNIST): 741 | def __init__(self, *args, **kwargs): 742 | device = kwargs.pop('device', "cpu") 743 | zca = kwargs.pop('zca', False) 744 | train_class = kwargs.pop('train_class', 'all') 745 | split = kwargs.pop('split', 'train') 746 | super().__init__(*args, **kwargs) 747 | self.split = split 748 | if self.train: 749 | if not isinstance(train_class, str): 750 | print(train_class) 751 | self.targets = np.array(self.targets) 752 | index_class = np.isin(self.targets, train_class) 753 | self.data = self.data[index_class] 754 | self.targets = self.targets[index_class] 755 | self.len = self.data.shape[0] 756 | 757 | # Scale data to [0,1] 758 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255).unsqueeze(1) 759 | 760 | self.targets = self.targets.to(device) 761 | 762 | # Normalize it with the usual MNIST mean and std 763 | # self.data = self.data.sub_(0.1307).div_(0.3081) 764 | 765 | # Put both data and targets on GPU in advance 766 | 767 | def __getitem__(self, index): 768 | """ 769 | Args: 770 | index (int): Index 771 | 772 | Returns: 773 | tuple: (image, target) where target is index of the target class. 774 | """ 775 | img, target = self.data[index], self.targets[index] 776 | 777 | return img, target 778 | 779 | 780 | class AugFastFashionMNIST(FastFashionMNIST): 781 | """ 782 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 783 | 784 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 785 | """ 786 | 787 | def __getitem__(self, index: int): 788 | """ 789 | Parameters 790 | ---------- 791 | index : int 792 | Index of the element to be returned 793 | 794 | Returns 795 | ------- 796 | tuple: (image, target) where target is the index of the target class 797 | """ 798 | img = self.transform(self.data[index]) 799 | target = self.targets[index] 800 | 801 | return img, target 802 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo single-file script to train a ConvNet on CIFAR10 using SoftHebb, an unsupervised, efficient and bio-plausible 3 | learning algorithm 4 | """ 5 | import math 6 | import warnings 7 | 8 | import torch 9 | from torch import nn, optim 10 | import torch.nn.functional as F 11 | from torch.nn.modules.utils import _pair 12 | from torch.optim.lr_scheduler import StepLR 13 | import torchvision 14 | 15 | 16 | class SoftHebbConv2d(nn.Module): 17 | def __init__( 18 | self, 19 | in_channels: int, 20 | out_channels: int, 21 | kernel_size: int, 22 | stride: int = 1, 23 | padding: int = 0, 24 | dilation: int = 1, 25 | groups: int = 1, 26 | t_invert: float = 12, 27 | ) -> None: 28 | """ 29 | Simplified implementation of Conv2d learnt with SoftHebb; an unsupervised, efficient and bio-plausible 30 | learning algorithm. 31 | This simplified implementation omits certain configurable aspects, like using a bias, groups>1, etc. which can 32 | be found in the full implementation in hebbconv.py 33 | """ 34 | super(SoftHebbConv2d, self).__init__() 35 | assert groups == 1, "Simple implementation does not support groups > 1." 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.kernel_size = _pair(kernel_size) 39 | self.stride = _pair(stride) 40 | self.dilation = _pair(dilation) 41 | self.groups = groups 42 | self.padding_mode = 'reflect' 43 | self.F_padding = (padding, padding, padding, padding) 44 | weight_range = 25 / math.sqrt((in_channels / groups) * kernel_size * kernel_size) 45 | self.weight = nn.Parameter(weight_range * torch.randn((out_channels, in_channels // groups, *self.kernel_size))) 46 | self.t_invert = torch.tensor(t_invert) 47 | 48 | def forward(self, x): 49 | x = F.pad(x, self.F_padding, self.padding_mode) # pad input 50 | # perform conv, obtain weighted input u \in [B, OC, OH, OW] 51 | weighted_input = F.conv2d(x, self.weight, None, self.stride, 0, self.dilation, self.groups) 52 | 53 | if self.training: 54 | # ===== find post-synaptic activations y = sign(u)*softmax(u, dim=C), s(u)=1 - 2*I[u==max(u,dim=C)] ===== 55 | # Post-synaptic activation, for plastic update, is weighted input passed through a softmax. 56 | # Non-winning neurons (those not with the highest activation) receive the negated post-synaptic activation. 57 | batch_size, out_channels, height_out, width_out = weighted_input.shape 58 | # Flatten non-competing dimensions (B, OC, OH, OW) -> (OC, B*OH*OW) 59 | flat_weighted_inputs = weighted_input.transpose(0, 1).reshape(out_channels, -1) 60 | # Compute the winner neuron for each batch element and pixel 61 | flat_softwta_activs = torch.softmax(self.t_invert * flat_weighted_inputs, dim=0) 62 | flat_softwta_activs = - flat_softwta_activs # Turn all postsynaptic activations into anti-Hebbian 63 | win_neurons = torch.argmax(flat_weighted_inputs, dim=0) # winning neuron for each pixel in each input 64 | competing_idx = torch.arange(flat_weighted_inputs.size(1)) # indeces of all pixel-input elements 65 | # Turn winner neurons' activations back to hebbian 66 | flat_softwta_activs[win_neurons, competing_idx] = - flat_softwta_activs[win_neurons, competing_idx] 67 | softwta_activs = flat_softwta_activs.view(out_channels, batch_size, height_out, width_out).transpose(0, 1) 68 | # ===== compute plastic update Δw = y*(x - u*w) = y*x - (y*u)*w ======================================= 69 | # Use Convolutions to apply the plastic update. Sweep over inputs with postynaptic activations. 70 | # Each weighting of an input pixel & an activation pixel updates the kernel element that connected them in 71 | # the forward pass. 72 | yx = F.conv2d( 73 | x.transpose(0, 1), # (B, IC, IH, IW) -> (IC, B, IH, IW) 74 | softwta_activs.transpose(0, 1), # (B, OC, OH, OW) -> (OC, B, OH, OW) 75 | padding=0, 76 | stride=self.dilation, 77 | dilation=self.stride, 78 | groups=1 79 | ).transpose(0, 1) # (IC, OC, KH, KW) -> (OC, IC, KH, KW) 80 | 81 | # sum over batch, output pixels: each kernel element will influence all batches and output pixels. 82 | yu = torch.sum(torch.mul(softwta_activs, weighted_input), dim=(0, 2, 3)) 83 | delta_weight = yx - yu.view(-1, 1, 1, 1) * self.weight 84 | delta_weight.div_(torch.abs(delta_weight).amax() + 1e-30) # Scale [min/max , 1] 85 | self.weight.grad = delta_weight # store in grad to be used with common optimizers 86 | 87 | return weighted_input 88 | 89 | 90 | class DeepSoftHebb(nn.Module): 91 | def __init__(self): 92 | super(DeepSoftHebb, self).__init__() 93 | # block 1 94 | self.bn1 = nn.BatchNorm2d(3, affine=False) 95 | self.conv1 = SoftHebbConv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2, t_invert=1,) 96 | self.activ1 = Triangle(power=0.7) 97 | self.pool1 = nn.MaxPool2d(kernel_size=4, stride=2, padding=1) 98 | # block 2 99 | self.bn2 = nn.BatchNorm2d(96, affine=False) 100 | self.conv2 = SoftHebbConv2d(in_channels=96, out_channels=384, kernel_size=3, padding=1, t_invert=0.65,) 101 | self.activ2 = Triangle(power=1.4) 102 | self.pool2 = nn.MaxPool2d(kernel_size=4, stride=2, padding=1) 103 | # block 3 104 | self.bn3 = nn.BatchNorm2d(384, affine=False) 105 | self.conv3 = SoftHebbConv2d(in_channels=384, out_channels=1536, kernel_size=3, padding=1, t_invert=0.25,) 106 | self.activ3 = Triangle(power=1.) 107 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 108 | # block 4 109 | self.flatten = nn.Flatten() 110 | self.classifier = nn.Linear(24576, 10) 111 | self.classifier.weight.data = 0.11048543456039805 * torch.rand(10, 24576) 112 | self.dropout = nn.Dropout(0.5) 113 | 114 | def forward(self, x): 115 | # block 1 116 | out = self.pool1(self.activ1(self.conv1(self.bn1(x)))) 117 | # block 2 118 | out = self.pool2(self.activ2(self.conv2(self.bn2(out)))) 119 | # block 3 120 | out = self.pool3(self.activ3(self.conv3(self.bn3(out)))) 121 | # block 4 122 | return self.classifier(self.dropout(self.flatten(out))) 123 | 124 | 125 | class Triangle(nn.Module): 126 | def __init__(self, power: float = 1, inplace: bool = True): 127 | super(Triangle, self).__init__() 128 | self.inplace = inplace 129 | self.power = power 130 | 131 | def forward(self, input: torch.Tensor) -> torch.Tensor: 132 | input = input - torch.mean(input.data, axis=1, keepdims=True) 133 | return F.relu(input, inplace=self.inplace) ** self.power 134 | 135 | 136 | class WeightNormDependentLR(optim.lr_scheduler._LRScheduler): 137 | """ 138 | Custom Learning Rate Scheduler for unsupervised training of SoftHebb Convolutional blocks. 139 | Difference between current neuron norm and theoretical converged norm (=1) scales the initial lr. 140 | """ 141 | 142 | def __init__(self, optimizer, power_lr, last_epoch=-1, verbose=False): 143 | self.optimizer = optimizer 144 | self.initial_lr_groups = [group['lr'] for group in self.optimizer.param_groups] # store initial lrs 145 | self.power_lr = power_lr 146 | super().__init__(optimizer, last_epoch, verbose) 147 | 148 | def get_lr(self): 149 | if not self._get_lr_called_within_step: 150 | warnings.warn("To get the last learning rate computed by the scheduler, " 151 | "please use `get_last_lr()`.", UserWarning) 152 | new_lr = [] 153 | for i, group in enumerate(self.optimizer.param_groups): 154 | for param in group['params']: 155 | # difference between current neuron norm and theoretical converged norm (=1) scales the initial lr 156 | # initial_lr * |neuron_norm - 1| ** 0.5 157 | norm_diff = torch.abs(torch.linalg.norm(param.view(param.shape[0], -1), dim=1, ord=2) - 1) + 1e-10 158 | new_lr.append(self.initial_lr_groups[i] * (norm_diff ** self.power_lr)[:, None, None, None]) 159 | return new_lr 160 | 161 | 162 | class TensorLRSGD(optim.SGD): 163 | @torch.no_grad() 164 | def step(self, closure=None): 165 | """Performs a single optimization step, using a non-scalar (tensor) learning rate. 166 | 167 | Arguments: 168 | closure (callable, optional): A closure that reevaluates the model 169 | and returns the loss. 170 | """ 171 | loss = None 172 | if closure is not None: 173 | with torch.enable_grad(): 174 | loss = closure() 175 | 176 | for group in self.param_groups: 177 | weight_decay = group['weight_decay'] 178 | momentum = group['momentum'] 179 | dampening = group['dampening'] 180 | nesterov = group['nesterov'] 181 | 182 | for p in group['params']: 183 | if p.grad is None: 184 | continue 185 | d_p = p.grad 186 | if weight_decay != 0: 187 | d_p = d_p.add(p, alpha=weight_decay) 188 | if momentum != 0: 189 | param_state = self.state[p] 190 | if 'momentum_buffer' not in param_state: 191 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 192 | else: 193 | buf = param_state['momentum_buffer'] 194 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 195 | if nesterov: 196 | d_p = d_p.add(buf, alpha=momentum) 197 | else: 198 | d_p = buf 199 | 200 | p.add_(-group['lr'] * d_p) 201 | return loss 202 | 203 | 204 | class CustomStepLR(StepLR): 205 | """ 206 | Custom Learning Rate schedule with step functions for supervised training of linear readout (classifier) 207 | """ 208 | 209 | def __init__(self, optimizer, nb_epochs): 210 | threshold_ratios = [0.2, 0.35, 0.5, 0.6, 0.7, 0.8, 0.9] 211 | self.step_thresold = [int(nb_epochs * r) for r in threshold_ratios] 212 | super().__init__(optimizer, -1, False) 213 | 214 | def get_lr(self): 215 | if self.last_epoch in self.step_thresold: 216 | return [group['lr'] * 0.5 217 | for group in self.optimizer.param_groups] 218 | return [group['lr'] for group in self.optimizer.param_groups] 219 | 220 | 221 | class FastCIFAR10(torchvision.datasets.CIFAR10): 222 | """ 223 | Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup). 224 | 225 | Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 226 | """ 227 | 228 | def __init__(self, *args, **kwargs): 229 | device = kwargs.pop('device', "cpu") 230 | super().__init__(*args, **kwargs) 231 | 232 | self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255) 233 | self.data = torch.movedim(self.data, -1, 1) # -> set dim to: (batch, channels, height, width) 234 | self.targets = torch.tensor(self.targets, device=device) 235 | 236 | def __getitem__(self, index: int): 237 | """ 238 | Parameters 239 | ---------- 240 | index : int 241 | Index of the element to be returned 242 | 243 | Returns 244 | ------- 245 | tuple: (image, target) where target is the index of the target class 246 | """ 247 | img = self.data[index] 248 | target = self.targets[index] 249 | 250 | return img, target 251 | 252 | 253 | # Main training loop CIFAR10 254 | if __name__ == "__main__": 255 | device = torch.device('cuda:0') 256 | model = DeepSoftHebb() 257 | model.to(device) 258 | 259 | unsup_optimizer = TensorLRSGD([ 260 | {"params": model.conv1.parameters(), "lr": -0.08, }, # SGD does descent, so set lr to negative 261 | {"params": model.conv2.parameters(), "lr": -0.005, }, 262 | {"params": model.conv3.parameters(), "lr": -0.01, }, 263 | ], lr=0) 264 | unsup_lr_scheduler = WeightNormDependentLR(unsup_optimizer, power_lr=0.5) 265 | 266 | sup_optimizer = optim.Adam(model.classifier.parameters(), lr=0.001) 267 | sup_lr_scheduler = CustomStepLR(sup_optimizer, nb_epochs=50) 268 | criterion = nn.CrossEntropyLoss() 269 | 270 | trainset = FastCIFAR10('./data', train=True, download=True) 271 | unsup_trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True, ) 272 | sup_trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, ) 273 | 274 | testset = FastCIFAR10('./data', train=False) 275 | testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False) 276 | 277 | # Unsupervised training with SoftHebb 278 | running_loss = 0.0 279 | for i, data in enumerate(unsup_trainloader, 0): 280 | inputs, _ = data 281 | inputs = inputs.to(device) 282 | 283 | # zero the parameter gradients 284 | unsup_optimizer.zero_grad() 285 | 286 | # forward + update computation 287 | with torch.no_grad(): 288 | outputs = model(inputs) 289 | 290 | # optimize 291 | unsup_optimizer.step() 292 | unsup_lr_scheduler.step() 293 | 294 | # Supervised training of classifier 295 | # set requires grad false and eval mode for all modules but classifier 296 | unsup_optimizer.zero_grad() 297 | model.conv1.requires_grad = False 298 | model.conv2.requires_grad = False 299 | model.conv3.requires_grad = False 300 | model.conv1.eval() 301 | model.conv2.eval() 302 | model.conv3.eval() 303 | model.bn1.eval() 304 | model.bn2.eval() 305 | model.bn3.eval() 306 | for epoch in range(50): 307 | model.classifier.train() 308 | model.dropout.train() 309 | running_loss = 0.0 310 | correct = 0 311 | total = 0 312 | for i, data in enumerate(sup_trainloader, 0): 313 | inputs, labels = data 314 | inputs = inputs.to(device) 315 | labels = labels.to(device) 316 | 317 | # zero the parameter gradients 318 | sup_optimizer.zero_grad() 319 | 320 | # forward + backward + optimize 321 | outputs = model(inputs) 322 | loss = criterion(outputs, labels) 323 | loss.backward() 324 | sup_optimizer.step() 325 | 326 | # compute training statistics 327 | running_loss += loss.item() 328 | if epoch % 10 == 0 or epoch == 49: 329 | total += labels.size(0) 330 | _, predicted = torch.max(outputs.data, 1) 331 | correct += (predicted == labels).sum().item() 332 | sup_lr_scheduler.step() 333 | # Evaluation on test set 334 | if epoch % 10 == 0 or epoch == 49: 335 | print(f'Accuracy of the network on the train images: {100 * correct // total} %') 336 | print(f'[{epoch + 1}] loss: {running_loss / total:.3f}') 337 | 338 | # on the test set 339 | model.eval() 340 | running_loss = 0. 341 | correct = 0 342 | total = 0 343 | # since we're not training, we don't need to calculate the gradients for our outputs 344 | with torch.no_grad(): 345 | for data in testloader: 346 | images, labels = data 347 | images = images.to(device) 348 | labels = labels.to(device) 349 | # calculate outputs by running images through the network 350 | outputs = model(images) 351 | # the class with the highest energy is what we choose as prediction 352 | _, predicted = torch.max(outputs.data, 1) 353 | total += labels.size(0) 354 | correct += (predicted == labels).sum().item() 355 | loss = criterion(outputs, labels) 356 | running_loss += loss.item() 357 | 358 | print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %') 359 | print(f'test loss: {running_loss / total:.3f}') 360 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | 5 | 6 | def train_BP(model, criterion, optimizer, loader, device, measures): 7 | """ 8 | Train only the traditional blocks with backprop 9 | """ 10 | # with torch.autograd.set_detect_anomaly(True): 11 | t = time.time() 12 | for inputs, target in loader: 13 | ## 1. forward propagation$ 14 | inputs = inputs.float().to(device, non_blocking=True) 15 | target = target.to(device, non_blocking=True) 16 | 17 | output = model(inputs) 18 | # print(r"%s" % (time.time() - t)) 19 | 20 | ## 2. loss calculation 21 | loss = criterion(output, target) 22 | 23 | ## 3. compute gradient and do SGD step 24 | optimizer.zero_grad() 25 | loss.backward() 26 | optimizer.step() 27 | 28 | # print(optimizer.param_groups) 29 | 30 | ## 4. Accuracy assessment 31 | predict = output.data.max(1)[1] 32 | 33 | acc = predict.eq(target.data).sum() 34 | # Save if measurement is wanted 35 | 36 | # print(model.blocks[1].layer.weight.mean(), model.blocks[1].layer.weight.std()) 37 | 38 | convergence, R1 = model.convergence() 39 | measures.step(target.shape[0], loss.clone().detach().cpu(), acc.cpu(), convergence, R1, model.get_lr()) 40 | 41 | return measures, optimizer.param_groups[0]['lr'] 42 | 43 | 44 | def train_hebb(model, loader, device, measures=None, criterion=None): 45 | """ 46 | Train only the hebbian blocks 47 | """ 48 | t = time.time() 49 | loss_acc = (not model.is_hebbian()) and (criterion is not None) 50 | with torch.no_grad(): 51 | for inputs, target in loader: 52 | # print(inputs.min(), inputs.max(), inputs.mean(), inputs.std()) 53 | ## 1. forward propagation 54 | inputs = inputs.float().to(device) # , non_blocking=True) 55 | output = model(inputs) 56 | 57 | # print(r"%s"%(time.time()-t)) 58 | 59 | if loss_acc: 60 | target = target.to(device, non_blocking=True) 61 | 62 | ## 2. loss calculation 63 | loss = criterion(output, target) 64 | 65 | ## 3. Accuracy assessment 66 | predict = output.data.max(1)[1] 67 | acc = predict.eq(target.data).sum() 68 | # Save if measurement is wanted 69 | conv, r1 = model.convergence() 70 | measures.step(target.shape[0], loss.clone().detach().cpu(), acc.cpu(), conv, r1, model.get_lr()) 71 | model.update() 72 | 73 | info = model.radius() 74 | convergence, R1 = model.convergence() 75 | return measures, model.get_lr(), info, convergence, R1 76 | 77 | 78 | def train_sup_hebb(model, loader, device, measures=None, criterion=None): 79 | """ 80 | Train only the hebbian blocks 81 | """ 82 | t = time.time() 83 | loss_acc = (not model.is_hebbian()) and (criterion is not None) 84 | with torch.no_grad(): 85 | for inputs, target in loader: 86 | # print(inputs.min(), inputs.max(), inputs.mean(), inputs.std()) 87 | ## 1. forward propagation 88 | inputs = inputs.float().to(device) 89 | output = model(inputs) 90 | model.blocks[-1].layer.plasticity(x=model.blocks[-1].layer.forward_store['x'], 91 | pre_x=model.blocks[-1].layer.forward_store['pre_x'], 92 | wta=torch.nn.functional.one_hot(target, num_classes= 93 | model.blocks[-1].layer.forward_store['pre_x'].shape[1]).type( 94 | model.blocks[-1].layer.forward_store['pre_x'].type())) 95 | 96 | if loss_acc: 97 | target = target.to(device, non_blocking=True) 98 | 99 | ## 2. loss calculation 100 | loss = criterion(output, target) 101 | 102 | ## 3. Accuracy assessment 103 | predict = output.data.max(1)[1] 104 | acc = predict.eq(target.data).sum() 105 | # Save if measurement is wanted 106 | conv, r1 = model.convergence() 107 | measures.step(target.shape[0], loss.clone().detach().cpu(), acc.cpu(), conv, r1, model.get_lr()) 108 | 109 | model.update() 110 | 111 | info = model.radius() 112 | convergence, R1 = model.convergence() 113 | return measures, model.get_lr(), info, convergence, R1 114 | 115 | 116 | def train_unsup(model, loader, device, 117 | blocks=[]): # fixed bug as optimizer is not used or pass in the only use it has in this repo currently 118 | """ 119 | Unsupervised learning only works with hebbian learning 120 | """ 121 | model.train(blocks=blocks) # set unsup blocks to train mode 122 | _, lr, info, convergence, R1 = train_hebb(model, loader, device) 123 | return lr, info, convergence, R1 124 | 125 | 126 | def train_sup(model, criterion, optimizer, loader, device, measures, learning_mode, blocks=[]): 127 | """ 128 | train hybrid model. 129 | learning_mode=HB --> train_hebb 130 | learning_mode=BP --> train_BP 131 | """ 132 | if len(blocks) == 1: 133 | model.train(blocks=blocks) 134 | if model.get_block(blocks[0]).is_hebbian(): 135 | measures, lr, info, convergence, R1 = train_sup_hebb(model, loader, device, measures, criterion) 136 | else: 137 | measures, lr = train_BP(model, criterion, optimizer, loader, device, measures) 138 | else: 139 | model.train(blocks=blocks) 140 | if learning_mode == 'HB': 141 | measures, lr, info, convergence, R1 = train_sup_hebb(model, loader, device, measures, criterion) 142 | else: 143 | measures, lr = train_BP(model, criterion, optimizer, loader, device, measures) 144 | return measures, lr 145 | 146 | 147 | def evaluate_unsup(model, train_loader, test_loader, device, blocks): 148 | """ 149 | Unsupervised evaluation, only support MLP architecture 150 | 151 | """ 152 | if model.get_block(blocks[-1]).arch == 'MLP': 153 | sub_model = model.sub_model(blocks) 154 | return evaluate_hebb(sub_model, train_loader, test_loader, device) 155 | else: 156 | return 0., 0. 157 | 158 | 159 | def evaluate_hebb(model, train_loader, test_loader, device): 160 | if train_loader.dataset.split == 'unlabeled': 161 | print('Unalbeled dataset, cant perform unsupervised evaluation') 162 | return 0, 0 163 | preactivations, winner_ids, neuron_labels, targets = infer_dataset(model, train_loader, device) 164 | acc_train = get_accuracy(model, winner_ids, targets, preactivations, neuron_labels, device) 165 | 166 | preactivations_test, winner_ids_test, _, targets_test = infer_dataset(model, test_loader, device) 167 | acc_test = get_accuracy(model, winner_ids_test, targets_test, preactivations_test, neuron_labels, device) 168 | return float(acc_train.cpu()), float(acc_test.cpu()) 169 | 170 | 171 | def infer_dataset(model, loader, device): 172 | model.eval() 173 | targets_lst = [] 174 | winner_ids = [] 175 | preactivations_lst = [] 176 | wta_lst = [] 177 | with torch.no_grad(): 178 | for inputs, targets in loader: 179 | ## 1. forward propagation 180 | inputs = inputs[targets != -1] 181 | targets = targets[targets != -1] 182 | if targets.nelement() != 0: 183 | inputs = inputs.float().to(device, non_blocking=True) 184 | preactivations, wta = model.foward_x_wta(inputs) 185 | preactivations_lst.append(preactivations) 186 | wta_lst.append(wta) 187 | targets_lst += targets.tolist() 188 | winner_ids_minibatch = wta.argmax(dim=1) 189 | winner_ids += winner_ids_minibatch.tolist() 190 | 191 | winner_ids = torch.FloatTensor(winner_ids).to(torch.int64).to(device) 192 | targets = torch.FloatTensor(targets_lst).to(torch.int64).to(device) 193 | preactivations = torch.cat(preactivations_lst).to(device) 194 | wta = torch.cat(wta_lst).to(device) 195 | neuron_labels = get_neuron_labels(model, winner_ids, targets, preactivations, wta) 196 | return preactivations, winner_ids, neuron_labels, targets 197 | 198 | 199 | def evaluate_sup(model, criterion, loader, device): 200 | """ 201 | Evaluate the model, returning loss and acc 202 | """ 203 | model.eval() 204 | loss_sum = 0 205 | acc_sum = 0 206 | n_inputs = 0 207 | 208 | with torch.no_grad(): 209 | for inputs, target in loader: 210 | ## 1. forward propagation 211 | inputs = inputs.float().to(device, non_blocking=True) 212 | target = target.to(device, non_blocking=True) 213 | output = model(inputs) 214 | 215 | ## 2. loss calculation 216 | loss = criterion(output, target) 217 | loss_sum += loss.clone().detach() 218 | 219 | ## 3. Accuracy assesment 220 | predict = output.data.max(1)[1] 221 | acc = predict.eq(target.data).sum() 222 | acc_sum += acc 223 | n_inputs += target.shape[0] 224 | 225 | return loss_sum.cpu() / n_inputs, 100 * acc_sum.cpu() / n_inputs 226 | 227 | 228 | def accuracy(output, target, topk=(1,)): 229 | """Computes the accuracy over the k top predictions for the specified values of k""" 230 | with torch.no_grad(): 231 | maxk = max(topk) 232 | batch_size = target.size(0) 233 | 234 | _, pred = output.topk(maxk, 1, True, True) 235 | pred = pred.t() 236 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 237 | 238 | res = [] 239 | for k in topk: 240 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 241 | res.append(correct_k.mul_(100.0 / batch_size)) 242 | return res 243 | 244 | 245 | """Code needs to be rewrite""" 246 | 247 | 248 | def get_neuron_labels(model, winner_ids, targets, preactivations, wta): 249 | targets_onehot = nn.functional.one_hot(targets, num_classes=preactivations.shape[1]).to(torch.float32) 250 | winner_ids_onehot = nn.functional.one_hot(winner_ids, num_classes=preactivations.shape[1]).to(torch.float32) 251 | responses_matrix = torch.matmul(winner_ids_onehot.t(), targets_onehot) 252 | 253 | neuron_outputs_for_label_total = torch.matmul(wta.t(), targets_onehot) 254 | 255 | responses_matrix[responses_matrix.sum(dim=1) == 0] = neuron_outputs_for_label_total[ 256 | responses_matrix.sum(dim=1) == 0] 257 | neuron_labels = responses_matrix.argmax(1) 258 | return neuron_labels 259 | 260 | 261 | def get_accuracy(model, winner_ids, targets, preactivations, neuron_labels, device): 262 | n_samples = preactivations.shape[0] 263 | # if not model.ensemble: 264 | predlabels = torch.FloatTensor([neuron_labels[i] for i in winner_ids]).to(device) 265 | ''' 266 | else: 267 | if model.test_uses_softmax: 268 | soft_acts = activation(preactivations, model.t_invert, model.activation_fn, dim=1, power=model.power, normalize=True) 269 | winner_ensembles = [ 270 | np.argmax([np.sum(np.where(neuron_labels == ensemble, soft_acts[sample], np.asarray(0))) for 271 | ensemble in range(10)]) for sample in range(n_samples)] 272 | else: 273 | winner_ensembles = [ 274 | np.argmax([np.sum(np.where(neuron_labels == ensemble, preactivations[sample], np.asarray(0))) for 275 | ensemble in range(10)]) for sample in range(n_samples)] 276 | predlabels = winner_ensembles 277 | ''' 278 | correct_pred = predlabels == targets 279 | n_correct = correct_pred.sum() 280 | accuracy = n_correct / len(targets) 281 | return 100 * accuracy.cpu() 282 | -------------------------------------------------------------------------------- /environment_pytorch==1.7.1.yml: -------------------------------------------------------------------------------- 1 | name: softhebb 2 | channels: 3 | - pytorch 4 | - fastchan 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - _tflow_select=2.3.0=eigen 11 | - absl-py=0.13.0=pyhd8ed1ab_0 12 | - adversarial-robustness-toolbox=1.7.1=pyhd8ed1ab_0 13 | - aiohttp=3.7.4.post0=py38h497a2fe_0 14 | - aiohttp-cors=0.7.0=py_0 15 | - aioredis=1.3.1=py_0 16 | - argon2-cffi=20.1.0=py38h497a2fe_2 17 | - astor=0.8.1=pyh9f0ad1d_0 18 | - astunparse=1.6.3=pyhd8ed1ab_0 19 | - async-timeout=3.0.1=py_1000 20 | - async_generator=1.10=py_0 21 | - attrs=21.2.0=pyhd8ed1ab_0 22 | - backcall=0.2.0=pyh9f0ad1d_0 23 | - backports=1.0=py_2 24 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 25 | - bayesian-optimization=1.1.0=py_0 26 | - blas=1.0=mkl 27 | - bleach=4.1.0=pyhd8ed1ab_0 28 | - blessings=1.7=py38h578d9bd_1004 29 | - blinker=1.4=py_1 30 | - boto3=1.20.24=pyhd3eb1b0_0 31 | - botocore=1.23.24=pyhd3eb1b0_0 32 | - brotli=1.0.9=he6710b0_2 33 | - brotlipy=0.7.0=py38h497a2fe_1001 34 | - bzip2=1.0.8=h7b6447c_0 35 | - c-ares=1.17.1=h7f98852_1 36 | - ca-certificates=2022.3.29=h06a4308_1 37 | - cachetools=4.2.2=pyhd8ed1ab_0 38 | - catalogue=2.0.6=py38h578d9bd_0 39 | - certifi=2021.10.8=py38h578d9bd_2 40 | - cffi=1.14.6=py38ha65f79e_0 41 | - chardet=4.0.0=py38h578d9bd_1 42 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 43 | - click=8.0.1=py38h578d9bd_0 44 | - cma=3.0.3=py_0 45 | - colorama=0.4.4=pyh9f0ad1d_0 46 | - colorful=0.5.4=pyhd8ed1ab_0 47 | - cryptography=3.4.7=py38ha5dfef3_0 48 | - cudatoolkit=10.1.243=h6bb024c_0 49 | - curl=7.78.0=h1ccaba5_0 50 | - cycler=0.10.0=py38_0 51 | - cymem=2.0.5=py38h2531618_0 52 | - cython-blis=0.7.4=py38h27cfd23_1 53 | - dataclasses=0.8=pyhc8e2a94_1 54 | - dbus=1.13.18=hb2f20db_0 55 | - debugpy=1.4.1=py38h709712a_0 56 | - decorator=5.1.0=pyhd8ed1ab_0 57 | - defusedxml=0.7.1=pyhd8ed1ab_0 58 | - eagerpy=0.29.0=pyh9f0ad1d_0 59 | - einops=0.3.2=pyhd8ed1ab_0 60 | - entrypoints=0.3=pyhd8ed1ab_1003 61 | - expat=2.4.1=h2531618_2 62 | - fastdownload=0.0.5=py_0 63 | - ffmpeg=4.3=hf484d3e_0 64 | - ffmpeg-python=0.2.0=py_0 65 | - filelock=3.0.12=pyh9f0ad1d_0 66 | - fontconfig=2.13.1=h6c09931_0 67 | - fonttools=4.25.0=pyhd3eb1b0_0 68 | - foolbox=3.3.1=pyh44b312d_1 69 | - freetype=2.10.4=h5ab3b9f_0 70 | - fsspec=2021.10.1=pyhd8ed1ab_0 71 | - future=0.18.2=py38h578d9bd_3 72 | - gast=0.3.3=py_0 73 | - gettext=0.19.8.1=h0b5b191_1005 74 | - git=2.32.0=pl5262hc120c5b_1 75 | - gitdb=4.0.7=pyhd8ed1ab_0 76 | - gitpython=3.1.24=pyhd8ed1ab_0 77 | - glib=2.69.0=h5202010_0 78 | - gmp=6.2.1=h2531618_2 79 | - gnutls=3.6.15=he1e5248_0 80 | - google-api-core=1.31.1=pyhd8ed1ab_0 81 | - google-auth=1.34.0=pyh6c4a22f_0 82 | - google-auth-oauthlib=0.4.1=py_2 83 | - google-cloud-core=1.7.1=pyhd3eb1b0_0 84 | - google-cloud-storage=1.31.0=py_0 85 | - google-crc32c=1.1.2=py38h27cfd23_0 86 | - google-pasta=0.2.0=pyh8c360ce_0 87 | - google-resumable-media=1.3.1=pyhd3eb1b0_1 88 | - googleapis-common-protos=1.53.0=py38h578d9bd_0 89 | - gpustat=0.6.0=pyhd8ed1ab_1 90 | - grpcio=1.38.1=py38hdd6454d_0 91 | - gst-plugins-base=1.14.0=h8213a91_2 92 | - gstreamer=1.14.0=h28cd5cc_2 93 | - h5py=2.10.0=nompi_py38h9915d05_106 94 | - hdf5=1.10.6=nompi_h7c3c948_1111 95 | - hiredis=1.1.0=py38h1e0a361_1 96 | - icu=58.2=he6710b0_3 97 | - idna=3.1=pyhd3deb0d_0 98 | - importlib-metadata=4.6.3=py38h578d9bd_0 99 | - intel-openmp=2021.3.0=h06a4308_3350 100 | - ipykernel=6.4.1=py38he5a9106_0 101 | - ipython=7.28.0=py38he5a9106_0 102 | - ipython_genutils=0.2.0=py_1 103 | - ipywidgets=7.6.5=pyhd8ed1ab_0 104 | - jedi=0.18.0=py38h578d9bd_2 105 | - jinja2=3.0.1=pyhd8ed1ab_0 106 | - jmespath=0.10.0=pyhd3eb1b0_0 107 | - joblib=1.0.1=pyhd8ed1ab_0 108 | - jpeg=9b=h024ee3a_2 109 | - jsonschema=3.2.0=pyhd8ed1ab_3 110 | - jupyter_client=7.0.3=pyhd8ed1ab_0 111 | - jupyter_core=4.8.1=py38h578d9bd_0 112 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 113 | - jupyterlab_widgets=1.0.2=pyhd8ed1ab_0 114 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 115 | - kiwisolver=1.3.1=py38h2531618_0 116 | - kornia=0.5.8=pyhd8ed1ab_0 117 | - krb5=1.19.2=hcc1bbae_0 118 | - lame=3.100=h7b6447c_0 119 | - langcodes=3.3.0=pyhd8ed1ab_0 120 | - lcms2=2.12=h3be6417_0 121 | - ld_impl_linux-64=2.35.1=h7274673_9 122 | - libblas=3.9.0=11_linux64_mkl 123 | - libcblas=3.9.0=11_linux64_mkl 124 | - libcrc32c=1.1.1=he6710b0_2 125 | - libcurl=7.78.0=h0b77cf5_0 126 | - libedit=3.1.20191231=he28a2e2_2 127 | - libev=4.33=h516909a_1 128 | - libffi=3.3=he6710b0_2 129 | - libgcc-ng=9.3.0=h5101ec6_17 130 | - libgfortran-ng=7.5.0=h14aa051_19 131 | - libgfortran4=7.5.0=h14aa051_19 132 | - libgomp=9.3.0=h5101ec6_17 133 | - libiconv=1.15=h63c8f33_5 134 | - libidn2=2.3.2=h7f8727e_0 135 | - liblapack=3.9.0=11_linux64_mkl 136 | - libllvm10=10.0.1=he513fc3_3 137 | - libnghttp2=1.43.0=h812cca2_0 138 | - libpng=1.6.37=hbc83047_0 139 | - libprotobuf=3.17.2=h780b84a_1 140 | - libsodium=1.0.18=h36c2ea0_1 141 | - libssh2=1.9.0=ha56f1ee_6 142 | - libstdcxx-ng=9.3.0=hd4cf53a_17 143 | - libtasn1=4.16.0=h27cfd23_0 144 | - libtiff=4.2.0=h85742a9_0 145 | - libunistring=0.9.10=h27cfd23_0 146 | - libuuid=1.0.3=h1bed415_2 147 | - libuv=1.40.0=h7b6447c_0 148 | - libwebp-base=1.2.0=h27cfd23_0 149 | - libxcb=1.14=h7b6447c_0 150 | - libxml2=2.9.12=h03d6c58_0 151 | - llvmlite=0.36.0=py38h4630a5e_0 152 | - lz4-c=1.9.3=h295c915_1 153 | - markdown=3.3.4=pyhd8ed1ab_0 154 | - markupsafe=2.0.1=py38h497a2fe_0 155 | - matplotlib=3.4.2=py38h06a4308_0 156 | - matplotlib-base=3.4.2=py38hab158f2_0 157 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 158 | - mistune=0.8.4=py38h497a2fe_1004 159 | - mkl=2021.3.0=h06a4308_520 160 | - mkl-service=2.4.0=py38h7f8727e_0 161 | - mkl_fft=1.3.0=py38h42c9631_2 162 | - mkl_random=1.2.2=py38h51133e4_0 163 | - msgpack-python=1.0.2=py38h1fd1430_1 164 | - multidict=5.1.0=py38h497a2fe_1 165 | - munkres=1.1.4=py_0 166 | - murmurhash=1.0.5=py38h2531618_0 167 | - mypy=0.910=py38h497a2fe_0 168 | - mypy_extensions=0.4.3=py38h578d9bd_3 169 | - nbclient=0.5.4=pyhd8ed1ab_0 170 | - nbconvert=6.2.0=py38h578d9bd_0 171 | - nbformat=5.1.3=pyhd8ed1ab_0 172 | - ncurses=6.2=he6710b0_1 173 | - nest-asyncio=1.5.1=pyhd8ed1ab_0 174 | - nettle=3.7.3=hbbd107a_1 175 | - ninja=1.10.2=hff7bd54_1 176 | - notebook=6.4.4=pyha770c72_0 177 | - numba=0.53.1=py38h8b71fd7_1 178 | - numpy=1.20.3=py38hf144106_0 179 | - numpy-base=1.20.3=py38h74d4b33_0 180 | - nvidia-ml=7.352.0=py_0 181 | - oauthlib=3.1.1=pyhd8ed1ab_0 182 | - olefile=0.46=py_0 183 | - opencensus=0.7.13=pyhd8ed1ab_0 184 | - opencensus-context=0.1.2=py38h578d9bd_4 185 | - openh264=2.1.0=hd408876_0 186 | - openjpeg=2.3.0=h05c96fa_1 187 | - openssl=1.1.1n=h7f8727e_0 188 | - opt_einsum=3.3.0=pyhd8ed1ab_1 189 | - packaging=21.0=pyhd8ed1ab_0 190 | - pandas=1.3.1=py38h1abd341_0 191 | - pandoc=2.14.2=h7f98852_0 192 | - pandocfilters=1.5.0=pyhd8ed1ab_0 193 | - parso=0.8.2=pyhd8ed1ab_0 194 | - pathy=0.6.0=pyhd3eb1b0_0 195 | - patsy=0.5.1=py_0 196 | - pcre=8.45=h295c915_0 197 | - pcre2=10.35=h14c3975_1 198 | - perl=5.32.1=0_h7f98852_perl5 199 | - pexpect=4.8.0=pyh9f0ad1d_2 200 | - pickleshare=0.7.5=py_1003 201 | - pillow=8.3.1=py38h2c7a002_0 202 | - pip=21.2.2=py38h06a4308_0 203 | - preshed=3.0.5=py38h2531618_4 204 | - prometheus_client=0.11.0=pyhd8ed1ab_0 205 | - prompt-toolkit=3.0.20=pyha770c72_0 206 | - protobuf=3.17.2=py38h709712a_0 207 | - psutil=5.8.0=py38h497a2fe_1 208 | - ptyprocess=0.7.0=pyhd3deb0d_0 209 | - pyasn1=0.4.8=py_0 210 | - pyasn1-modules=0.2.7=py_0 211 | - pycparser=2.20=pyh9f0ad1d_2 212 | - pydantic=1.8.2=py38h497a2fe_0 213 | - pydeprecate=0.3.1=pyhd8ed1ab_0 214 | - pydub=0.25.1=pyhd8ed1ab_0 215 | - pygments=2.10.0=pyhd8ed1ab_0 216 | - pyjwt=2.1.0=pyhd8ed1ab_0 217 | - pynndescent=0.5.6=pyh6c4a22f_0 218 | - pyopenssl=20.0.1=pyhd8ed1ab_0 219 | - pyparsing=2.4.7=pyhd3eb1b0_0 220 | - pyqt=5.9.2=py38h05f1152_4 221 | - pyrsistent=0.17.3=py38h497a2fe_2 222 | - pysocks=1.7.1=py38h578d9bd_3 223 | - python=3.8.11=h12debd9_0_cpython 224 | - python-dateutil=2.8.2=pyhd3eb1b0_0 225 | - python_abi=3.8=2_cp38 226 | - pytorch=1.7.1=py3.8_cuda10.1.243_cudnn7.6.3_0 227 | - pytorch-lightning=1.4.9=pyhd8ed1ab_0 228 | - pytorch-model-summary=0.1.1=py_0 229 | - pytz=2021.1=pyhd8ed1ab_0 230 | - pyu2f=0.1.5=pyhd8ed1ab_0 231 | - pyyaml=5.4.1=py38h497a2fe_0 232 | - pyzmq=19.0.2=py38ha71036d_2 233 | - qt=5.9.7=h5867ecd_1 234 | - ray-core=1.4.0=py38h9ba0119_1 235 | - ray-tune=1.4.0=py38h578d9bd_1 236 | - readline=8.1=h27cfd23_0 237 | - redis-py=3.5.3=pyh9f0ad1d_0 238 | - requests=2.26.0=pyhd8ed1ab_0 239 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 240 | - resampy=0.2.2=py_0 241 | - rsa=4.7.2=pyh44b312d_0 242 | - s3transfer=0.5.0=pyhd3eb1b0_0 243 | - scikit-learn=0.22.1=py38hcdab131_1 244 | - scipy=1.4.1=py38h18bccfc_3 245 | - send2trash=1.8.0=pyhd8ed1ab_0 246 | - setproctitle=1.1.10=py38h497a2fe_1004 247 | - setuptools=52.0.0=py38h06a4308_0 248 | - shellingham=1.3.1=pyhd3eb1b0_0 249 | - sip=4.19.13=py38he6710b0_0 250 | - six=1.16.0=pyhd3eb1b0_0 251 | - smart_open=5.1.0=pyhd3eb1b0_0 252 | - smmap=3.0.5=pyh44b312d_0 253 | - spacy=3.2.1=py38hae6d005_0 254 | - spacy-legacy=3.0.8=pyhd8ed1ab_0 255 | - spacy-loggers=1.0.1=pyhd8ed1ab_0 256 | - sqlite=3.36.0=hc218d9a_0 257 | - srsly=2.4.1=py38h2531618_0 258 | - statsmodels=0.12.2=py38h5c078b8_0 259 | - tabulate=0.8.9=pyhd8ed1ab_0 260 | - tbb=2020.2=h4bd325d_4 261 | - tensorboard=2.6.0=pyhd8ed1ab_0 262 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 263 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 264 | - tensorboardx=2.4=pyhd8ed1ab_0 265 | - tensorflow=2.3.0=eigen_py38h71ff20e_0 266 | - tensorflow-base=2.3.0=eigen_py38hb57a387_0 267 | - tensorflow-estimator=2.5.0=pyh81a9013_1 268 | - termcolor=1.1.0=py_2 269 | - terminado=0.12.1=py38h578d9bd_0 270 | - testpath=0.5.0=pyhd8ed1ab_0 271 | - thinc=8.0.13=py38hae6d005_0 272 | - tk=8.6.10=hbc83047_0 273 | - toml=0.10.2=pyhd8ed1ab_0 274 | - torchaudio=0.7.2=py38 275 | - torchmetrics=0.5.1=pyhd8ed1ab_0 276 | - torchvision=0.8.2=py38_cu101 277 | - tornado=6.1=py38h27cfd23_0 278 | - tqdm=4.62.0=pyhd8ed1ab_0 279 | - traitlets=5.1.0=pyhd8ed1ab_0 280 | - typer=0.4.0=pyhd8ed1ab_0 281 | - typing-extensions=3.10.0.0=hd3eb1b0_0 282 | - typing_extensions=3.10.0.0=pyh06a4308_0 283 | - umap-learn=0.5.3=py38h578d9bd_0 284 | - urllib3=1.26.6=pyhd8ed1ab_0 285 | - wasabi=0.8.2=pyhd3eb1b0_0 286 | - wcwidth=0.2.5=pyh9f0ad1d_2 287 | - webencodings=0.5.1=py_1 288 | - werkzeug=2.0.1=pyhd8ed1ab_0 289 | - wheel=0.36.2=pyhd3eb1b0_0 290 | - widgetsnbextension=3.5.1=py38h578d9bd_4 291 | - wrapt=1.12.1=py38h497a2fe_3 292 | - xz=5.2.5=h7b6447c_0 293 | - yaml=0.2.5=h516909a_0 294 | - yarl=1.6.3=py38h497a2fe_2 295 | - zeromq=4.3.4=h9c3ff4c_0 296 | - zipp=3.5.0=pyhd8ed1ab_0 297 | - zlib=1.2.11=h7b6447c_3 298 | - zstd=1.4.9=haebb681_0 299 | - pip: 300 | - --trusted-host pypi.org 301 | - --trusted-host pypi.python.org 302 | - --trusted-host files.pythonhosted.org 303 | - --trusted-host download.pytorch.org 304 | - efficientnet-pytorch==0.7.1 305 | - kmeans-pytorch==0.3 306 | - torchsummary==1.5.1 307 | prefix: /home/username/anaconda3/envs/softhebb # if this path cannot be found, it will be installed under anaconda dir 308 | -------------------------------------------------------------------------------- /hebblinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Generator, Union 4 | 5 | try: 6 | from utils import init_weight, normalize, activation, unsup_lr_scheduler 7 | except: 8 | from hebb.utils import init_weight, normalize, activation, unsup_lr_scheduler 9 | import einops 10 | from tabulate import tabulate 11 | 12 | 13 | class HebbHardLinear(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | n_neurons: int, 18 | lebesgue_p: int, 19 | weight_distribution: str, 20 | weight_range: float, 21 | weight_offset: float, 22 | lr_scheduler: str, 23 | bias: bool = False 24 | ) -> None: 25 | """ 26 | Hard Winner take all implementation 27 | """ 28 | 29 | super().__init__() 30 | 31 | self.stat = torch.zeros(3, n_neurons) 32 | 33 | self.learning_update = False 34 | self.was_update = True 35 | 36 | self.in_features = in_features 37 | self.n_neurons = n_neurons 38 | self.lebesgue_p = lebesgue_p 39 | 40 | self.register_buffer( 41 | 'weight', 42 | init_weight((n_neurons, in_features), weight_distribution, weight_range, weight_offset) 43 | ) 44 | 45 | self.register_buffer("delta_w", torch.zeros_like(self.weight), persistent=False) 46 | 47 | self.register_buffer( 48 | "rad", 49 | torch.ones(n_neurons), 50 | persistent=False 51 | ) 52 | self.get_radius() 53 | 54 | self.lr_scheduler_config = lr_scheduler.copy() 55 | self.lr_adaptive = self.lr_scheduler_config['adaptive'] 56 | 57 | self.reset() 58 | 59 | self.conv = 0 60 | 61 | self.register_buffer('bias', None) 62 | 63 | def stat_wta(self): 64 | count = self.stat.clone() 65 | count[2:] = (100 * self.stat[2:].t() / self.stat[2:].sum(1)).t() 66 | count = count[:, :20] 67 | x = list(range(count.shape[1])) 68 | y = [['{lr:.1e}'.format(lr=lr) for lr in count[0].tolist()]] + [['{x:.2f}'.format(x=x) for x in y.tolist()] for 69 | y in count[1:]] 70 | return tabulate(y, headers=x, tablefmt='orgtbl') 71 | 72 | def reset(self): 73 | if self.lr_adaptive: 74 | self.register_buffer("lr", torch.ones_like(self.weight), persistent=False) 75 | self.lr_scheduler = unsup_lr_scheduler(lr=self.lr_scheduler_config['lr'], 76 | nb_epochs=self.lr_scheduler_config['nb_epochs'], 77 | ratio=self.lr_scheduler_config['ratio'], 78 | speed=self.lr_scheduler_config['speed'], 79 | div=self.lr_scheduler_config['div'], 80 | decay=self.lr_scheduler_config['decay']) 81 | 82 | self.update_lr() 83 | else: 84 | self.lr_scheduler = unsup_lr_scheduler(lr=self.lr_scheduler_config['lr'], 85 | nb_epochs=self.lr_scheduler_config['nb_epochs'], 86 | ratio=self.lr_scheduler_config['ratio'], 87 | speed=self.lr_scheduler_config['speed'], 88 | div=self.lr_scheduler_config['div'], 89 | decay=self.lr_scheduler_config['decay']) 90 | self.lr = next(self.lr_scheduler) 91 | 92 | def get_pre_activations(self, x: torch.Tensor) -> torch.Tensor: 93 | """ 94 | Compute the preactivation or the current of the hebbian layer 95 | ---------- 96 | x : torch.Tensor 97 | Input 98 | pre_x : torch.Tensor 99 | Pre activation 100 | Returns 101 | ------- 102 | pre_x : torch.Tensor 103 | Pre_activation of the hebbian layer 104 | """ 105 | pre_x = torch.matmul(x, 106 | (torch.sign(self.weight) * torch.abs(self.weight) ** (self.lebesgue_p - 1)).t() 107 | ) 108 | 109 | if self.bias is not None: 110 | pre_x = torch.add(pre_x, self.bias) 111 | 112 | return pre_x 113 | 114 | def get_lr(self): 115 | if self.lr_adaptive: 116 | return self.lr.mean().cpu() 117 | return self.lr 118 | 119 | def get_wta(self, pre_x: torch.Tensor) -> torch.Tensor: 120 | """ 121 | Compute the hard winner take all 122 | ---------- 123 | pre_x : torch.Tensor 124 | Input 125 | Returns 126 | ------- 127 | wta : torch.Tensor 128 | preactivation or the current of the hebbian layer 129 | """ 130 | wta = nn.functional.one_hot(pre_x.argmax(dim=1), num_classes=pre_x.shape[1]).to( 131 | torch.float) 132 | self.stat[2] += wta.sum(0).cpu() 133 | return wta 134 | 135 | def forward(self, x: torch.Tensor, return_x_wta: bool = False) -> torch.Tensor: 136 | """ 137 | Compute output of the layer (forward pass). 138 | Parameters 139 | ---------- 140 | x : torch.Tensor 141 | Input. Expected to be of shape (batch_size, ...), where ... denotes an arbitrary 142 | sequence of dimensions, with product equal to in_features. 143 | """ 144 | if False: 145 | x = 10 * nn.functional.normalize(x) 146 | 147 | pre_x = self.get_pre_activations(x) 148 | 149 | # If propagation of preAcitvations only no need to do the rest 150 | if not self.learning_update and not return_x_wta: 151 | return pre_x 152 | 153 | wta = self.get_wta(pre_x) 154 | 155 | if return_x_wta: 156 | return pre_x, wta 157 | 158 | if self.learning_update: 159 | self.plasticity(x, pre_x, wta) 160 | 161 | return pre_x 162 | 163 | def train(self, mode: bool = True) -> None: 164 | """ 165 | Set the learning update to the mode expected. 166 | mode:True --> training 167 | 168 | mode:False --> predict 169 | """ 170 | self.learning_update = mode 171 | 172 | def delta_weight( 173 | self, 174 | x: torch.Tensor, 175 | pre_x: torch.Tensor, 176 | wta: torch.Tensor, ) -> torch.Tensor: 177 | """ 178 | Compute the change of weights 179 | Parameters 180 | ---------- 181 | x : torch.Tensor 182 | x. Input (batch_size, in_features). 183 | pre_x : torch.Tensor 184 | pre_x. Linear transformation of the input (batch_size, in_features). 185 | wta : torch.Tensor 186 | wta. Winner take all (batch_size, in_features). 187 | Returns 188 | ------- 189 | delta_weight : torch.Tensor 190 | """ 191 | # ---Compute change of weights---# 192 | 193 | yx = torch.matmul(wta.t(), x) # Overlap between winner take all and inputs 194 | yu = torch.multiply(wta, pre_x) 195 | yu = torch.sum(yu.t(), dim=1).unsqueeze(1) 196 | # Overlap between preactivation and winner take all 197 | # Results are summed over batches, resulting in a shape of (output_size,) 198 | delta_weight = yx - yu.view(-1, 1) * self.weight 199 | 200 | # ---Normalize---# 201 | nc = torch.abs(delta_weight).amax() # .amax(1, keepdim=True) 202 | delta_weight.div_(nc + 1e-30) 203 | 204 | return delta_weight 205 | 206 | def plasticity( 207 | self, 208 | x: torch.Tensor, 209 | pre_x: torch.Tensor = None, 210 | wta: torch.Tensor = None) -> None: 211 | """ 212 | Update weight and bias accordingly to the plasticity computation 213 | Parameters 214 | ---------- 215 | x : torch.Tensor 216 | x. Input (batch_size, in_features). 217 | pre_x : torch.Tensor 218 | pre_x. Conv2d transformation of the input (batch_size, in_features). 219 | wta : torch.Tensor 220 | wta. Winner take all (batch_size, in_features). 221 | 222 | """ 223 | if pre_x is None: 224 | pre_x = self._conv_forward(x, self.weight, self.bias) 225 | # for some algo (krotov) pre_x and Conv2d trans are different 226 | pre_x = self.get_pre_x(x, pre_x) 227 | wta = self.get_wta(pre_x) 228 | 229 | self.delta_w = self.delta_weight(x, pre_x, wta) 230 | # self.weight.add_(self.lr * delta_weight) 231 | 232 | # self.update() 233 | 234 | if self.bias is not None: 235 | self.delta_b = self.delta_bias(wta) 236 | 237 | def update(self): 238 | """ 239 | Update weight and bias accordingly to the plasticity computation 240 | Returns 241 | ------- 242 | 243 | """ 244 | self.weight.add_(self.lr * self.delta_w) 245 | self.was_update = True 246 | 247 | if self.bias is not None: 248 | self.bias.add_(self.lr * self.lrb * self.delta_b) 249 | # self.bias.clip_(-1, 0) 250 | self.update_lr() 251 | 252 | def update_lr(self) -> None: 253 | if self.lr_adaptive: 254 | norm = self.get_radius() 255 | 256 | nc = 1e-10 257 | 258 | # lr_amplitude = next(self.lr_scheduler) 259 | 260 | lr_amplitude = self.lr_scheduler_config['lr'] 261 | 262 | lr = lr_amplitude * torch.pow(torch.abs(norm - torch.ones_like(norm)) + nc, 263 | self.lr_scheduler_config['power_lr']) 264 | 265 | # lr = lr.clip(max=lr_amplitude) 266 | 267 | self.stat[0] = lr.clone() 268 | 269 | self.lr = einops.repeat(lr, 'o -> o i', i=self.in_features) 270 | else: 271 | self.lr = next(self.lr_scheduler) 272 | self.stat[0] = self.lr 273 | 274 | def get_radius(self): 275 | if self.was_update: 276 | weight = self.weight.view(self.weight.shape[0], -1) 277 | self.rad = torch.linalg.norm(weight, dim=1, ord=self.lebesgue_p) 278 | self.was_update = False 279 | return self.rad 280 | 281 | def radius(self) -> float: 282 | """ 283 | Returns 284 | ------- 285 | radius : float 286 | """ 287 | 288 | meanb = torch.mean(self.bias) if self.bias is not None else 0. 289 | stdb = torch.std(self.bias) if self.bias is not None else 0. 290 | weight = self.weight.view(self.weight.shape[0], -1) 291 | mean = torch.mean(weight, axis=1) 292 | mean_weight = torch.mean(mean) 293 | std_weigh = torch.std(weight) 294 | norm = torch.linalg.norm(weight, dim=1, ord=self.lebesgue_p) 295 | self.stat[1] = norm 296 | mean_radius = torch.mean(norm) 297 | std_radius = torch.std(norm) 298 | max_radius = torch.amax(torch.abs(norm - mean_radius * torch.ones_like(norm))) 299 | mean2_radius = torch.mean(torch.abs(norm - mean_radius * torch.ones_like(norm))) 300 | 301 | return 'MB:{mb:.3e}/SB:{sb:.3e}/MW:{m:.3e}/SW:{s:.3e}/MR:{mean:.3e}/SR:{std:.3e}/MeD:{mean2:.3e}/MaD:{max:.3e}'.format( 302 | mb=meanb, 303 | sb=stdb, 304 | m=mean_weight, 305 | s=std_weigh, 306 | mean=mean_radius, 307 | std=std_radius, 308 | mean2=mean2_radius, 309 | max=max_radius) + '\n' + self.stat_wta() 310 | 311 | def convergence(self) -> float: 312 | """ 313 | Returns 314 | ------- 315 | convergence : float 316 | Metric of convergence as the nb of filter closed to 1 317 | """ 318 | weight = self.weight.view(self.weight.shape[0], -1) 319 | 320 | norm = torch.linalg.norm(weight, dim=1, ord=self.lebesgue_p) 321 | # mean_radius = torch.mean(norm) 322 | conv = torch.mean(torch.abs(norm - torch.ones_like(norm))) 323 | 324 | R1 = torch.sum(torch.abs(norm - torch.ones_like(norm)) < 5e-3) 325 | return float(conv.cpu()), int(R1.cpu()) 326 | 327 | def extra_repr2(self) -> str: 328 | return 'in_features={}, out_features={}, lebesgue_p={}, bias={}'.format( 329 | self.in_features, self.n_neurons, self.lebesgue_p, self.bias is not None 330 | ) 331 | 332 | def extra_repr(self) -> str: 333 | return self.extra_repr2() 334 | 335 | def __label__(self): 336 | s = '{in_features}{n_neurons}{lebesgue_p}' 337 | return 'H' + s.format(**self.__dict__) 338 | 339 | 340 | class HebbHardKrotovLinear(HebbHardLinear): 341 | def __init__( 342 | self, 343 | in_features: int, 344 | n_neurons: int, 345 | lebesgue_p: int, 346 | weight_distribution: str, 347 | weight_range: float, 348 | weight_offset: float, 349 | lr_scheduler: Generator, 350 | bias: bool = False, 351 | delta: float = 0.05, 352 | ranking_param: int = 2 353 | ) -> None: 354 | """ 355 | Krotov implementation from the HardLinear class 356 | """ 357 | 358 | super(HebbHardKrotovLinear, self).__init__(in_features, n_neurons, lebesgue_p, weight_distribution, 359 | weight_range, weight_offset, lr_scheduler, bias) 360 | 361 | self.delta = delta 362 | self.ranking_param = ranking_param 363 | self.stat = torch.zeros(4, n_neurons) 364 | 365 | def extra_repr(self): 366 | s = ', ranking_param=%s, delta=%s' % (self.ranking_param, self.delta) 367 | return self.extra_repr2() + s 368 | 369 | def get_wta(self, pre_x: torch.Tensor) -> torch.Tensor: 370 | """ 371 | Compute the krotov winner take all 372 | ---------- 373 | pre_x : torch.Tensor 374 | pre_x 375 | Returns 376 | ------- 377 | wta : torch.Tensor 378 | preactivation or the current of the hebbian layer 379 | """ 380 | _, ranks = pre_x.sort(descending=True, dim=1) 381 | wta = nn.functional.one_hot(pre_x.argmax(dim=1), num_classes=pre_x.shape[1]).to( 382 | torch.float) 383 | 384 | self.stat[2] += wta.sum(0).cpu() 385 | # wta = wta - self.delta * nn.functional.one_hot(ranks[:, self.ranking_param-1], num_classes=pre_x.shape[1]) 386 | batch_indices = torch.arange(pre_x.size(0)) 387 | _, ranking_indices = pre_x.topk(self.ranking_param, dim=1) 388 | wta[batch_indices, ranking_indices[batch_indices, self.ranking_param - 1]] = -self.delta 389 | 390 | self.stat[3] += torch.histc(torch.tensor(ranking_indices[batch_indices, self.ranking_param - 1]), 391 | bins=self.out_channels, min=0, 392 | max=self.out_channels - 1).cpu() 393 | # print(wta[batch_indices, ranking_indices[batch_indices, 0]].mean()) 394 | return wta 395 | 396 | 397 | class HebbSoftLinear(HebbHardLinear): 398 | def __init__( 399 | self, 400 | in_features: int, 401 | n_neurons: int, 402 | lebesgue_p: int, 403 | weight_distribution: str, 404 | weight_range: float, 405 | weight_offset: float, 406 | lr_scheduler: Generator, 407 | lr_bias: float, 408 | bias: bool = False, 409 | activation_fn: str = 'exp', 410 | t_invert: float = 12 411 | ) -> None: 412 | """ 413 | Soft implementation from the HardLinear class 414 | """ 415 | super(HebbSoftLinear, self).__init__(in_features, n_neurons, lebesgue_p, weight_distribution, 416 | weight_range, weight_offset, lr_scheduler, bias) 417 | 418 | self.activation_fn = activation_fn 419 | self.t_invert = torch.tensor(t_invert) 420 | 421 | if bias: 422 | self.register_buffer('bias', torch.ones(n_neurons) \ 423 | * torch.log(torch.tensor(1 / n_neurons)) / self.t_invert 424 | ) # uniform initial priors, and acount for softmax's T_invert 425 | 426 | self.lrb = torch.tensor(1 / t_invert) 427 | 428 | def extra_repr(self): 429 | s = ', t_invert=%s, bias=%s, lr_bias=%s' % ( 430 | float(self.t_invert), not self.bias is None, round(float(self.lrb), 4)) 431 | return self.extra_repr2() + s 432 | 433 | def get_wta(self, pre_x: torch.Tensor) -> torch.Tensor: 434 | """ 435 | Compute the soft winner take all 436 | ---------- 437 | pre_x : torch.Tensor 438 | pre_x 439 | Returns 440 | ------- 441 | wta : torch.Tensor 442 | preactivation or the current of the hebbian layer 443 | """ 444 | wta = activation(pre_x, t_invert=self.t_invert, activation_fn=self.activation_fn, normalize=True) 445 | self.stat[2] += wta.sum(0).cpu() 446 | return wta 447 | 448 | def delta_bias(self, wta: torch.Tensor) -> None: 449 | """ 450 | Compute the change of Bias 451 | Parameters 452 | ---------- 453 | wta : torch.Tensor 454 | wta. Winner take all (batch_size, in_features). 455 | """ 456 | batch_size = wta.shape[0] 457 | if self.activation_fn == 'exp': 458 | ebb = torch.exp(self.t_invert * self.bias) # e^(bias*t_invert) 459 | # ---Compute change of bias---# 460 | delta_bias = (torch.sum(wta, dim=0) - ebb * batch_size) / ebb 461 | elif self.activation_fn == 'relu': 462 | delta_bias = (torch.sum(wta, dim=0) - wta.shape[0] * self.bias - batch_size) # eta * (y-w-1) 463 | 464 | nc = torch.abs(delta_bias).amax() # .amax(1, keepdim=True) 465 | delta_bias.div_(nc + 1e-30) 466 | 467 | return delta_bias 468 | 469 | 470 | class HebbSoftKrotovLinear(HebbSoftLinear): 471 | def __init__( 472 | self, 473 | in_features: int, 474 | n_neurons: int, 475 | lebesgue_p: int, 476 | weight_distribution: str, 477 | weight_range: float, 478 | weight_offset: float, 479 | lr_scheduler: Generator, 480 | lr_bias: float, 481 | bias: bool = False, 482 | delta: float = 0.05, 483 | ranking_param: int = 2, 484 | activation_fn: str = 'exp', 485 | t_invert: float = 12 486 | ) -> None: 487 | """ 488 | Krotov implementation from the SoftLinear class 489 | """ 490 | 491 | super(HebbSoftKrotovLinear, self).__init__(in_features, n_neurons, lebesgue_p, weight_distribution, 492 | weight_range, 493 | weight_offset, lr_scheduler, lr_bias, bias, activation_fn, t_invert) 494 | 495 | self.delta = delta 496 | self.ranking_param = ranking_param 497 | 498 | self.m_winner = [] 499 | self.m_anti_winner = [] 500 | self.mode = 0 501 | self.stat = torch.zeros(4, n_neurons) 502 | 503 | def extra_repr(self): 504 | s = ', t_invert=%s, bias=%s, lr_bias=%s' % ( 505 | float(self.t_invert), not self.bias is None, round(float(self.lrb), 4)) 506 | s += ', ranking_param=%s, delta=%s' % (self.ranking_param, self.delta) 507 | return self.extra_repr2() + s 508 | 509 | def get_wta(self, pre_x: torch.Tensor) -> torch.Tensor: 510 | """ 511 | Compute the soft krotov winner take all 512 | ---------- 513 | pre_x : torch.Tensor 514 | pre_x 515 | Returns 516 | ------- 517 | wta : torch.Tensor 518 | preactivation or the current of the hebbian layer 519 | """ 520 | batch_size, out_channels = pre_x.shape 521 | # pre_x = pre_x - torch.mean(pre_x, axis=1, keepdims=True) 522 | # pre_x[pre_x < 0] = -float("Inf") 523 | wta = activation(pre_x, t_invert=self.t_invert, activation_fn=self.activation_fn, normalize=True) 524 | self.stat[2] += wta.sum(0).cpu() 525 | # print(wta.sum(0).cpu()) 526 | batch_indices = torch.arange(pre_x.size(0)) 527 | if self.mode == 0: 528 | wta = -wta 529 | # _, ranking_indices = pre_x_flat.topk(1, dim=1) 530 | # ranking_indices = ranking_indices[batch_indices,0] 531 | ranking_indices = torch.argmax(pre_x, dim=1) 532 | wta[batch_indices, ranking_indices] *= -1 533 | self.m_winner.append(wta[batch_indices, ranking_indices].mean().cpu()) 534 | self.m_anti_winner.append(1 - self.m_winner[-1]) 535 | if self.mode == 1: 536 | _, ranking_indices = pre_x.topk(self.ranking_param, dim=1) 537 | 538 | self.m_anti_winner.append( 539 | wta[batch_indices, ranking_indices[batch_indices, self.ranking_param - 1]].mean().cpu()) 540 | self.m_winner.append(wta[batch_indices, ranking_indices[batch_indices, 0]].mean().cpu()) 541 | 542 | # print(wta[batch_indices, ranking_indices[batch_indices, self.ranking_param-1]].mean()) 543 | wta[batch_indices, ranking_indices[batch_indices, self.ranking_param - 1]] *= -self.delta 544 | # print(wta[batch_indices, ranking_indices[batch_indices, self.ranking_param-1]].mean()) 545 | # print(wta[batch_indices, ranking_indices[batch_indices, 0]].mean()) 546 | if self.mode == 2: 547 | _, ranking_indices = pre_x.topk(self.ranking_param, dim=1) 548 | 549 | self.m_anti_winner.append( 550 | wta[batch_indices, ranking_indices[batch_indices, self.ranking_param - 1]].mean().cpu()) 551 | self.m_winner.append(wta[batch_indices, ranking_indices[batch_indices, 0]].mean().cpu()) 552 | 553 | # print(wta[batch_indices, ranking_indices[batch_indices, self.ranking_param-1]].mean()) 554 | wta[batch_indices, ranking_indices[batch_indices, self.ranking_param - 1]] = -self.delta 555 | return wta 556 | 557 | def radius(self) -> float: 558 | """ 559 | Returns 560 | ------- 561 | radius : float 562 | """ 563 | meanb = torch.mean(self.bias) if self.bias is not None else 0. 564 | stdb = torch.std(self.bias) if self.bias is not None else 0. 565 | weight = self.weight.view(self.weight.shape[0], -1) 566 | mean = torch.mean(weight, axis=1) 567 | mean_weight = torch.mean(mean) 568 | std_weigh = torch.std(weight) 569 | norm = torch.linalg.norm(weight, dim=1, ord=self.lebesgue_p) 570 | self.stat[1] = norm 571 | mean_radius = torch.mean(norm) 572 | std_radius = torch.std(norm) 573 | max_radius = torch.amax(torch.abs(norm - mean_radius * torch.ones_like(norm))) 574 | mean2_radius = torch.mean(torch.abs(norm - mean_radius * torch.ones_like(norm))) 575 | 576 | m_winner = torch.mean(torch.tensor(self.m_winner)) 577 | m_anti_winner = torch.mean(torch.tensor(self.m_anti_winner)) 578 | 579 | self.m_winner = [] 580 | self.m_anti_winner = [] 581 | 582 | return 'MB:{mb:.3e}/SB:{sb:.3e}/MW:{m:.3e}/SW:{s:.3e}/MR:{mean:.3e}/SR:{std:.3e}/MeD:{mean2:.3e}/MaD:{max:.3e}/MW:{m_winner:.3f}/MAW:{m_anti_winner:.3f}'.format( 583 | mb=meanb, 584 | sb=stdb, 585 | m=mean_weight, 586 | s=std_weigh, 587 | mean=mean_radius, 588 | std=std_radius, 589 | mean2=mean2_radius, 590 | max=max_radius, 591 | m_winner=m_winner, 592 | m_anti_winner=m_anti_winner 593 | ) + '\n' + self.stat_wta() + '\n' 594 | 595 | 596 | class SupervisedSoftHebbLinear(HebbSoftKrotovLinear): 597 | def __init__(self, **kwargs): 598 | self.forward_store = {} 599 | self.async_updates = True # TODO make this a parameter from preset 600 | super().__init__(**kwargs) 601 | 602 | def get_wta(self, pre_x: torch.Tensor) -> torch.Tensor: 603 | """ should not be called""" 604 | raise NotImplementedError 605 | 606 | def forward( 607 | self, x: torch.Tensor, return_x_wta: bool = False, 608 | ) -> torch.Tensor: 609 | """ 610 | Compute output of the layer (forward pass). 611 | Parameters 612 | ---------- 613 | x : torch.Tensor 614 | Input. Expected to be of shape (batch_size, ...), where ... denotes an arbitrary 615 | sequence of dimensions, with product equal to in_features. 616 | """ 617 | 618 | pre_x = self.get_pre_activations(x) 619 | 620 | # If propagation of preAcitvations only no need to do the rest 621 | if not self.learning_update and not return_x_wta: 622 | return pre_x 623 | 624 | # if clamped_wta is None and not self.async_updates: 625 | # wta = self.get_wta(pre_x) # we don't need to do this 626 | # else: 627 | # wta = clamped_wta 628 | if not self.async_updates: 629 | wta = self.get_wta(pre_x) 630 | if return_x_wta: 631 | return pre_x, wta 632 | 633 | if self.learning_update: 634 | # this does happen, we should change it, or change the behaviour based on this 635 | # pdb.set_trace() # we shouldn't perform the learning update here, but later when we have the targets 636 | if self.async_updates: 637 | self.forward_store['x'] = x 638 | self.forward_store['pre_x'] = pre_x 639 | else: 640 | self.plasticity(x, pre_x, wta) 641 | return pre_x 642 | 643 | def plasticity( 644 | self, 645 | x: torch.Tensor, 646 | pre_x: torch.Tensor = None, 647 | wta: torch.Tensor = None) -> None: 648 | """ 649 | Update weight and bias accordingly to the plasticity computation 650 | Parameters 651 | ---------- 652 | x : torch.Tensor 653 | x. Input (batch_size, in_features). 654 | pre_x : torch.Tensor 655 | pre_x. Conv2d transformation of the input (batch_size, in_features). 656 | wta : torch.Tensor 657 | wta. Winner take all (batch_size, in_features). 658 | 659 | """ 660 | if pre_x is None: 661 | raise ValueError # although actually we could recompute, but throw error for now 662 | 663 | self.delta_w = self.delta_weight(x, pre_x, wta) 664 | 665 | if self.bias is not None: 666 | self.delta_b = self.delta_bias(wta) 667 | 668 | # My idea is to call this at some point, where I have the labels too ! 669 | 670 | 671 | def select_linear_layer( 672 | params) -> Union[HebbHardLinear, HebbHardKrotovLinear, HebbSoftLinear, HebbSoftKrotovLinear, 673 | SupervisedSoftHebbLinear]: 674 | """ 675 | Select the appropriate from a set of params 676 | ---------- 677 | params : torch.Tensor 678 | wta. Winner take all (batch_size, in_features). 679 | Returns 680 | ------- 681 | layer : bio 682 | preactivation or the current of the hebbian layer 683 | 684 | """ 685 | if params['softness'] == 'hard': 686 | layer = HebbHardLinear( 687 | in_features=params['in_channels'], 688 | n_neurons=params['out_channels'], 689 | lebesgue_p=params['lebesgue_p'], 690 | weight_distribution=params['weight_init'], 691 | weight_range=params['weight_init_range'], 692 | weight_offset=params['weight_init_offset'], 693 | lr_scheduler=params['lr_scheduler'], 694 | bias=params['add_bias']) 695 | elif params['softness'] == 'hardkrotov': 696 | layer = HebbHardKrotovLinear( 697 | in_features=params['in_channels'], 698 | n_neurons=params['out_channels'], 699 | lebesgue_p=params['lebesgue_p'], 700 | weight_distribution=params['weight_init'], 701 | weight_range=params['weight_init_range'], 702 | weight_offset=params['weight_init_offset'], 703 | lr_scheduler=params['lr_scheduler'], 704 | bias=params['add_bias'], 705 | delta=params['delta'], 706 | ranking_param=params['ranking_param']) 707 | elif params['softness'] == 'soft': 708 | layer = HebbSoftLinear( 709 | in_features=params['in_channels'], 710 | n_neurons=params['out_channels'], 711 | lebesgue_p=params['lebesgue_p'], 712 | weight_distribution=params['weight_init'], 713 | weight_range=params['weight_init_range'], 714 | weight_offset=params['weight_init_offset'], 715 | lr_scheduler=params['lr_scheduler'], 716 | lr_bias=params['lr_bias'], 717 | bias=params['add_bias'], 718 | activation_fn=params['soft_activation_fn'], 719 | t_invert=params['t_invert']) 720 | elif params['softness'] == 'softkrotov': 721 | layer = HebbSoftKrotovLinear( 722 | in_features=params['in_channels'], 723 | n_neurons=params['out_channels'], 724 | lebesgue_p=params['lebesgue_p'], 725 | weight_distribution=params['weight_init'], 726 | weight_range=params['weight_init_range'], 727 | weight_offset=params['weight_init_offset'], 728 | lr_scheduler=params['lr_scheduler'], 729 | lr_bias=params['lr_bias'], 730 | bias=params['add_bias'], 731 | delta=params['delta'], 732 | ranking_param=params['ranking_param'], 733 | activation_fn=params['soft_activation_fn'], 734 | t_invert=params['t_invert']) 735 | elif params['softness'] == 'supervisedsoftkrotov': 736 | layer = SupervisedSoftHebbLinear( 737 | in_features=params['in_channels'], 738 | n_neurons=params['out_channels'], 739 | lebesgue_p=params['lebesgue_p'], 740 | weight_distribution=params['weight_init'], 741 | weight_range=params['weight_init_range'], 742 | weight_offset=params['weight_init_offset'], 743 | lr_scheduler=params['lr_scheduler'], 744 | lr_bias=params['lr_bias'], 745 | bias=params['add_bias'], 746 | delta=params['delta'], 747 | ranking_param=params['ranking_param'], 748 | activation_fn=params['soft_activation_fn'], 749 | t_invert=params['t_invert']) 750 | else: 751 | raise ValueError 752 | return layer 753 | -------------------------------------------------------------------------------- /images/architecture.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuromorphicComputing/SoftHebb/f4252307633ba38a26ab52504f5fd40352f5a6fd/images/architecture.PNG -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | try: 6 | from utils import RESULT 7 | except: 8 | from hebb.utils import RESULT 9 | 10 | import torch.nn.functional as F 11 | from typing import Callable, List, Optional 12 | from hebblinear import select_linear_layer 13 | from hebbconv import select_Conv2d_layer 14 | from activation import get_activation 15 | import os.path as op 16 | import einops 17 | 18 | 19 | class AttDropout(nn.Dropout): 20 | def forward(self, input): 21 | if self.training: 22 | nb_channels = input.shape[1] 23 | std = input.std(1) 24 | pb = self.p * ((-std) / std.max() + 1) 25 | pb = einops.repeat(pb, 'b w h -> b k w h', k=nb_channels) 26 | dropout = torch.bernoulli(pb) 27 | input[dropout == 1] = 0 28 | return input 29 | else: 30 | return input 31 | 32 | 33 | class MaxNorm(nn.Module): 34 | def __init__(self, ) -> None: 35 | super(MaxNorm, self).__init__() 36 | 37 | def forward(self, input: Tensor) -> Tensor: 38 | shape = input.shape 39 | input_flat = input.reshape(shape[0], -1) 40 | 41 | return (input_flat / torch.unsqueeze(input_flat.amax(dim=1), 1)).view(shape) 42 | 43 | 44 | class NormNorm(nn.Module): 45 | def __init__(self) -> None: 46 | super(NormNorm, self).__init__() 47 | self.multiplier = 10 48 | 49 | def forward(self, input: Tensor) -> Tensor: 50 | return self.multiplier * nn.functional.normalize(input) 51 | 52 | def extra_repr(self) -> str: 53 | return 'multiplier={}'.format( 54 | self.multiplier 55 | ) 56 | 57 | 58 | ''' 59 | class batchstd1d(nn.Module): 60 | def __init__(self, num_features) -> None: 61 | super(NormNorm, self).__init__() 62 | self.num_features = num_features 63 | self.register_buffer('running_std', torch.ones(num_features)) 64 | self.register_buffer("num_batches_tracked", None) 65 | 66 | def forward(self, input: Tensor) -> Tensor: 67 | if self.training and self.track_running_stats: 68 | if self.num_batches_tracked is not None: # type: ignore[has-type] 69 | self.num_batches_tracked = self.num_batches_tracked + 1 70 | self.running_std = 71 | 72 | return self.multiplier*nn.functional.normalize(input) 73 | 74 | def extra_repr(self) -> str: 75 | return 'multiplier={}'.format( 76 | self.multiplier 77 | ) 78 | ''' 79 | 80 | 81 | class BasicBlock(nn.Module): 82 | def __init__( 83 | self, 84 | arch: str, 85 | preset: str, 86 | num: int, 87 | in_channels: int, 88 | hebbian: bool, 89 | layer: nn.Module, 90 | resume: str = None, 91 | activation: Callable = None, 92 | operations: List = None, 93 | pool: Optional[nn.Module] = None, 94 | batch_norm: Optional[nn.Module] = None, 95 | dropout: Optional[nn.Module] = None, 96 | att_dropout: Optional[nn.Module] = None, 97 | ) -> None: 98 | super(BasicBlock, self).__init__() 99 | self.arch = arch 100 | self.preset = preset 101 | self.num = num 102 | self.in_channels = in_channels 103 | self.operations = operations 104 | self.layer = layer 105 | self.pool = pool 106 | self.batch_norm = batch_norm 107 | self.activation = activation 108 | self.dropout = dropout 109 | self.att_dropout = att_dropout 110 | self.hebbian = hebbian 111 | self.resume = resume 112 | if resume is not None: 113 | self.resume_block() 114 | 115 | def get_name(self): 116 | s = '' 117 | for operation in self.operations: 118 | s += operation.__class__.__name__ 119 | if self.dropout is not None: 120 | s += self.dropout.__str__() 121 | if self.is_hebbian(): 122 | s += self.layer.__label__() 123 | else: 124 | s += self.layer.__str__()[:10] 125 | if self.batch_norm is not None: 126 | s += self.batch_norm.__class__.__name__ 127 | return s 128 | 129 | def is_hebbian(self): 130 | return self.hebbian 131 | 132 | def radius(self): 133 | return self.layer.radius() 134 | 135 | def get_lr(self): 136 | return self.layer.get_lr() 137 | 138 | def foward_x_wta(self, x: Tensor) -> Tensor: 139 | x = self.operations(x) 140 | return self.layer(x, return_x_wta=True) 141 | 142 | def update(self): 143 | if self.is_hebbian(): 144 | self.layer.update() 145 | 146 | def sequential(self): 147 | elements = [] 148 | if self.att_dropout is not None: 149 | elements.append(self.att_dropout) 150 | 151 | if self.operations: 152 | elements.append(self.operations) 153 | 154 | if self.dropout is not None: 155 | elements.append(self.dropout) 156 | 157 | elements.append(self.layer) 158 | 159 | if self.activation is not None: 160 | elements.append(self.activation) 161 | 162 | if self.batch_norm is not None: 163 | elements.append(self.batch_norm) 164 | 165 | if self.pool is not None: 166 | elements.append(self.pool) 167 | 168 | return nn.Sequential(*elements) 169 | 170 | def forward(self, x: Tensor) -> Tensor: 171 | # print('*'*(self.num+1), x.mean()) 172 | # torch.cuda.empty_cache() 173 | 174 | if self.att_dropout is not None: 175 | x = self.att_dropout(x) 176 | 177 | x = self.operations(x) 178 | 179 | if self.dropout is not None: 180 | x = self.dropout(x) 181 | 182 | # x = self.layer(x.detach()) if self.is_hebbian() else self.layer(x) 183 | x = self.layer(x) 184 | 185 | if self.activation is not None: 186 | x = self.activation(x) 187 | 188 | if self.batch_norm is not None: 189 | x = self.batch_norm(x) 190 | 191 | if self.pool is not None: 192 | x = self.pool(x) 193 | 194 | # torch.cuda.empty_cache() 195 | return x 196 | 197 | def __str__(self): 198 | print('\n', '----- Architecture Block %s, number %s -----' % (self.get_name(), self.num)) 199 | if self.att_dropout is not None: 200 | print('-', self.att_dropout.__str__()) 201 | for operation in self.operations: 202 | print('-', operation.__str__()) 203 | if self.dropout is not None: 204 | print('-', self.dropout.__str__()) 205 | print('-', self.layer.__str__()) 206 | 207 | if self.activation is not None: 208 | print('-', self.activation.__str__()) 209 | if self.batch_norm is not None: 210 | print('-', self.batch_norm.__str__()) 211 | if self.pool is not None: 212 | print('-', self.pool.__str__()) 213 | if self.resume is not None: 214 | print('***', self.resume) 215 | 216 | def resume_block(self, device: str = 'cpu'): 217 | model_path = op.join(RESULT, 'layer', 'block%s' % self.num, self.get_name(), 'checkpoint.pth.tar') 218 | if op.isfile(model_path): 219 | try: 220 | checkpoint = torch.load(model_path, map_location=device) 221 | self.load_state_dict(checkpoint['state_dict']) 222 | self.resume = 'Block %s loaded successfuly' % self.get_name() 223 | except Exception as e: 224 | self.resume = 'File %s exist but %s' % (self.get_name(), e) 225 | else: 226 | self.resume = 'Block %s not found' % self.get_name() 227 | 228 | 229 | def generate_block(params) -> BasicBlock: 230 | """ 231 | 232 | Parameters 233 | ---------- 234 | params 235 | 236 | Returns 237 | ------- 238 | 239 | """ 240 | config = params['layer'] 241 | 242 | pool = None 243 | batch_norm = None 244 | operations = [] 245 | activation = None 246 | dropout = None 247 | att_dropout = None 248 | 249 | if 'operation' in params: 250 | if 'batchnorm2d' in params['operation']: 251 | if config['arch'] == 'MLP': 252 | if 'flatten' in params['operation']: 253 | operations.append(nn.BatchNorm2d(config['old_channels'], affine=False)) 254 | else: 255 | operations.append(nn.BatchNorm1d(config['in_channels'], affine=False)) 256 | else: 257 | operations.append(nn.BatchNorm2d(config['in_channels'], affine=False)) 258 | 259 | if 'flatten' in params['operation']: 260 | operations.append(nn.Flatten()) 261 | 262 | if 'batchnorm1d' in params['operation']: 263 | if config['arch'] == 'MLP': 264 | operations.append(nn.BatchNorm1d(config['in_channels'], affine=False)) 265 | elif 'flatten' in params['operation']: 266 | operations.append(nn.BatchNorm1d(config['in_channels'], affine=False)) 267 | 268 | if 'max' in params['operation']: 269 | operations.append(MaxNorm()) 270 | if 'normnorm' in params['operation']: 271 | operations.append(NormNorm()) 272 | 273 | if config['arch'] == 'MLP': 274 | 275 | if config['hebbian']: 276 | layer = select_linear_layer(config) 277 | else: 278 | layer = nn.Linear(config['in_channels'], config['out_channels']) 279 | if 'batch_norm' in params and params['batch_norm']: 280 | batch_norm = nn.BatchNorm1d(config['out_channels'], affine=False) 281 | 282 | elif config['arch'] == 'CNN': 283 | if config['hebbian']: 284 | layer = select_Conv2d_layer(config) 285 | else: 286 | layer = nn.Conv2d( 287 | config['in_channels'], 288 | config['out_channels'], 289 | bias=config['add_bias'], 290 | kernel_size=config['kernel_size'], 291 | stride=config['stride'], 292 | padding=config['padding'], 293 | padding_mode=config['padding_mode'], 294 | dilation=config['dilation'], 295 | groups=config['groups'] 296 | ) 297 | if params['pool'] is not None: 298 | if params['pool']['type'] == 'max': 299 | pool = nn.MaxPool2d(kernel_size=params['pool']['kernel_size'], stride=params['pool']['stride'], 300 | padding=params['pool']['padding']) 301 | elif params['pool']['type'] == 'avg': 302 | pool = nn.AvgPool2d(kernel_size=params['pool']['kernel_size'], stride=params['pool']['stride'], 303 | padding=params['pool']['padding']) 304 | 305 | if 'batch_norm' in params and params['batch_norm']: 306 | batch_norm = nn.BatchNorm2d(config['out_channels'], affine=False) 307 | 308 | if params['activation'] is not None: 309 | activation = get_activation( 310 | activation_fn=params['activation']['function'], 311 | param=params['activation']['param'], 312 | dim=1) 313 | 314 | if 'dropout' in params and isinstance(params['dropout'], float): 315 | dropout = nn.Dropout(p=params['dropout']) 316 | 317 | if 'att_dropout' in params and isinstance(params['att_dropout'], float): 318 | att_dropout = AttDropout(p=params['att_dropout']) 319 | 320 | block = BasicBlock( 321 | arch=params['arch'], 322 | preset=params['preset'], 323 | num=params['num'], 324 | in_channels=config['in_channels'], 325 | hebbian=config['hebbian'], 326 | layer=layer, 327 | resume=None if "resume" not in params else params['resume'], 328 | activation=activation, 329 | operations=nn.Sequential(*operations), 330 | pool=pool, 331 | batch_norm=batch_norm, 332 | dropout=dropout, 333 | att_dropout=att_dropout 334 | ) 335 | return block 336 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | try: 5 | from utils import RESULT 6 | except: 7 | from hebb.utils import RESULT 8 | import torch 9 | import os 10 | import os.path as op 11 | 12 | 13 | class LogSupBatch: 14 | def __init__(self): 15 | self.data = [] 16 | self.metric_dict = {0: 'n', 1: 'train_loss', 2: 'train_acc', 3: 'convergence', 4: 'R1', 5: 'lr'} 17 | 18 | def step(self, n, train_loss, train_acc, convergence=0, R1=0, lr=1): 19 | self.data.append([n, train_loss, train_acc, convergence, R1, lr]) 20 | 21 | def get_summary(self): 22 | if self.data: 23 | data_np = self.get_numpy() 24 | nb_samle = data_np[:, 0].sum() 25 | mean_loss = data_np[:, 1].sum() / nb_samle 26 | mean_acc = data_np[:, 2].sum() / nb_samle 27 | return mean_loss, 100 * mean_acc 28 | else: 29 | return 0, 0 30 | 31 | def get_numpy(self): 32 | return np.array(self.data, dtype=np.float32) 33 | 34 | def reset(self): 35 | self.__init__() 36 | 37 | def to_dict(self): 38 | return {'data': self.data} 39 | 40 | def from_dict(self, dict_): 41 | self.data = dict_['data'] 42 | return self 43 | 44 | 45 | class LogSup: 46 | def __init__(self, config): 47 | self.metric_dict = {0: 'train_loss', 1: 'train_acc', 2: 'test_loss', 3: 'test_acc', 4: 'convergence', 5: 'R1'} 48 | self.metric = '' 49 | self.mode = '' 50 | self.batch = [] 51 | self.initial_start = time.time() 52 | self.start = self.initial_start 53 | self.convergence = None 54 | self.epoch_time = 0 55 | if config is not None: 56 | self.epcoh = 0 57 | self.lr = config['lr'] 58 | self.nb_epoch = config['nb_epoch'] 59 | self.print_freq = config['print_freq'] 60 | self.data = [] 61 | self.metric = 'test_acc' 62 | self.metric_id = {value: key for key, value in self.metric_dict.items()}[self.metric] 63 | self.mode = 'min' if self.metric.endswith('loss') else 'max' 64 | self.best_perf = 0 if self.mode == 'max' else 100 65 | self.perf = self.best_perf 66 | self.is_best = True 67 | 68 | def step(self, epoch, logbatch, test_loss, test_acc, lr, save=False): 69 | 70 | train_loss, train_acc = logbatch.get_summary() 71 | self.lr = float(lr) 72 | self.data.append([int(epoch), float(train_loss), float(train_acc), float(test_loss), float(test_acc)]) 73 | self.perf = self.data[-1][self.metric_id] 74 | self.is_best = self.perf > self.best_perf if self.mode == 'max' else self.perf < self.best_perf 75 | self.best_perf = max(self.perf, self.best_perf) if self.mode == 'max' else min(self.perf, self.best_perf) 76 | if save: 77 | self.batch.append(logbatch) 78 | self.epoch_time = time.time() - self.start 79 | self.start = time.time() 80 | return self.new_log_batch() 81 | 82 | def new_log_batch(self): 83 | return LogSupBatch() 84 | 85 | def verbose(self): 86 | epoch, train_loss, train_acc, test_loss, test_acc = self.data[-1] 87 | 88 | print('Epoch: [{0}/{1}]\t' 89 | 'lr: {lr:.2e}\t' 90 | 'time: {total_time}\t' 91 | 'Loss_train {train_loss:.5f}\t' 92 | 'Acc_train {train_acc:.2f}\t/\t' 93 | 'Loss_test {test_loss:.5f}\t' 94 | 'Acc_test {test_acc:.2f}' 95 | .format(epoch, self.nb_epoch, lr=self.lr, time=self.epoch_time, 96 | total_time=time.strftime("%H:%M:%S", time.gmtime(time.time() - self.initial_start)), 97 | train_acc=train_acc, train_loss=train_loss, 98 | test_loss=test_loss, test_acc=test_acc)) 99 | 100 | def get_numpy(self): 101 | return np.array(self.data, dtype=np.float32) 102 | 103 | def to_dict(self): 104 | return {'data': self.data, 105 | 'metric': self.metric, 106 | 'best_perf': self.best_perf, 107 | 'mode': self.mode, 108 | 'batch': [b.to_dict() for b in self.batch] 109 | } 110 | 111 | def from_dict(self, dict_): 112 | self.data = dict_['data'] 113 | self.batch = [LogSupBatch().from_dict(d) for d in dict_['batch']] 114 | self.batch.append(LogSupBatch()) 115 | self.best_perf = dict_['best_perf'] 116 | if dict_['metric'] != self.metric: 117 | if self.mode == 'max': 118 | self.best_perf = self.get_numpy()[:, self.metric_id].max() 119 | elif self.mode == 'min': 120 | self.best_perf = self.get_numpy()[:, self.metric_id].min() 121 | return self 122 | 123 | 124 | class LogUnsup(LogSup): 125 | def __init__(self, config): 126 | super().__init__(config) 127 | self.metric_dict = {0: 'train_acc', 1: 'test_acc', 2: 'convergence', 3: 'R1'} 128 | self.nb_epoch = config['nb_epoch'] if config is not None else 0 129 | self.metric = 'test_acc' 130 | self.metric_id = 1 131 | self.mode = 'max' 132 | self.best_perf = 100 133 | self.perf = self.best_perf 134 | self.is_best = True 135 | self.info = '' 136 | 137 | def step(self, epoch, train_acc, test_acc, info, convergence, R1, lr): 138 | self.lr = float(lr) 139 | self.data.append([int(epoch), float(train_acc), float(test_acc), float(convergence), int(R1)]) 140 | self.info = info 141 | self.perf = self.data[-1][self.metric_id] 142 | self.is_best = self.perf > self.best_perf if self.mode == 'max' else self.perf < self.best_perf 143 | self.best_perf = max(self.perf, self.best_perf) if self.mode == 'max' else min(self.perf, self.best_perf) 144 | self.epoch_time = time.time() - self.start 145 | self.start = time.time() 146 | 147 | def verbose(self): 148 | epoch, train_acc, test_acc, convergence, R1 = self.data[-1] 149 | print('Epoch: [{0}/{1}]\t' 150 | 'lr: {lr:.2e}\t' 151 | 'time: {total_time}\t' 152 | 'Acc_train {train_acc:.2f}\t' 153 | 'Acc_test {test_acc:.2f}\t' 154 | 'convergence: {convergence:.2e}\t' 155 | 'R1: {R1}\t' 156 | 'Info {info}' 157 | .format(epoch, self.nb_epoch, lr=self.lr, time=self.epoch_time, 158 | total_time=time.strftime("%H:%M:%S", time.gmtime(time.time() - self.initial_start)), 159 | train_acc=train_acc, convergence=convergence, R1=R1, info=self.info, test_acc=test_acc)) 160 | 161 | def get_numpy(self): 162 | return np.array(self.data, dtype=np.float32) 163 | 164 | 165 | class Log: 166 | def __init__(self, configs={}): 167 | self.sup = {} 168 | self.unsup = {} 169 | for id, config in configs.items(): 170 | if config['mode'] == 'unsupervised': 171 | self.unsup[id] = LogUnsup(config) 172 | else: 173 | self.sup[id] = LogSup(config) 174 | 175 | def to_dict(self): 176 | return {'sup': {id: sup.to_dict() for id, sup in self.sup.items()}, 177 | 'unsup': {id: unsup.to_dict() for id, unsup in self.unsup.items()}} 178 | 179 | def from_dict(self, dict_): 180 | self.sup = {} 181 | self.unsup = {} 182 | for id, config in dict_['sup'].items(): 183 | self.sup[id] = LogSup(None).from_dict(config) 184 | for id, config in dict_['unsup'].items(): 185 | self.unsup[id] = LogUnsup(None).from_dict(config) 186 | return self 187 | 188 | 189 | def save_logs(log, model_name, filename='final.pth.tar'): 190 | folder_path = op.join(RESULT, 'network', model_name, 'measures') 191 | if not op.isdir(folder_path): 192 | os.mkdir(folder_path) 193 | 194 | torch.save({ 195 | 'log': log.to_dict() 196 | }, op.join(folder_path, filename)) 197 | 198 | 199 | def load_logs(model_name, filename='final.pth.tar'): 200 | folder_path = op.join(RESULT, 'network', model_name, 'measures') 201 | if not op.isdir(folder_path): 202 | os.mkdir(folder_path) 203 | dict = torch.load(op.join(folder_path, filename))['log'] 204 | log = Log().from_dict(dict) 205 | return dict, log 206 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from utils import RESULT, activation 6 | except: 7 | from hebb.utils import RESULT, activation 8 | from layer import generate_block 9 | import os 10 | import os.path as op 11 | 12 | 13 | def load_layers(params, model_name, resume=None, verbose=True, model_path_override=None): 14 | """ 15 | Create Model and load state if resume 16 | """ 17 | 18 | if resume is not None: 19 | if model_path_override is None: 20 | model_path = op.join(RESULT, 'network', model_name, 'models', 'checkpoint.pth.tar') 21 | else: 22 | model_path = model_path_override 23 | 24 | if op.isfile(model_path): 25 | checkpoint = torch.load(model_path) # , map_location=device) 26 | state_dict = checkpoint['state_dict'] 27 | params2 = checkpoint['config'] 28 | if resume == 'without_classifier': 29 | classifier_key = list(params.keys())[-1] 30 | params2[classifier_key] = params[classifier_key] 31 | 32 | model = MultiLayer(params2) 33 | 34 | state_dict2 = model.state_dict() 35 | 36 | if resume == 'without_classifier': 37 | for key, value in state_dict.items(): 38 | if resume == 'without_classifier' and str(params[classifier_key]['num']) in key: 39 | continue 40 | if key in state_dict2: 41 | state_dict2[key] = value 42 | model.load_state_dict(state_dict2) 43 | else: 44 | model.load_state_dict(state_dict) 45 | # log.from_dict(checkpoint['measures']) 46 | starting_epoch = 0 # checkpoint['epoch'] 47 | print('\n', 'Model %s loaded successfuly with best perf' % (model_name)) 48 | # shutil.rmtree(op.join(RESULT, params.folder_name, 'figures')) 49 | # os.mkdir(op.join(RESULT, params.folder_name, 'figures')) 50 | else: 51 | print('\n', 'Model %s not found' % model_name) 52 | model = MultiLayer(params) 53 | print('\n') 54 | else: 55 | model = MultiLayer(params) 56 | 57 | if verbose: 58 | model.__str__() 59 | 60 | return model 61 | 62 | 63 | def save_layers(model, model_name, epoch, blocks, filename='checkpoint.pth.tar', storing_path=None): 64 | """ 65 | Save model and each of its training blocks 66 | """ 67 | if storing_path is None: 68 | if not op.isdir(RESULT): 69 | os.makedirs(RESULT) 70 | if not op.isdir(op.join(RESULT, 'network')): 71 | os.mkdir(op.join(RESULT, 'network')) 72 | os.mkdir(op.join(RESULT, 'layer')) 73 | 74 | folder_path = op.join(RESULT, 'network', model_name) 75 | if not op.isdir(folder_path): 76 | os.makedirs(op.join(folder_path, 'models')) 77 | storing_path = op.join(folder_path, 'models') 78 | 79 | 80 | torch.save({ 81 | 'state_dict': model.state_dict(), 82 | 'config': model.config, 83 | 'epoch': epoch 84 | }, op.join(storing_path, filename)) 85 | 86 | for block_id in blocks: 87 | block = model.get_block(block_id) 88 | block_path = op.join(RESULT, 'layer', 'block%s' % block.num) 89 | if not op.isdir(block_path): 90 | os.makedirs(block_path) 91 | folder_path = op.join(block_path, block.get_name()) 92 | if not op.isdir(folder_path): 93 | os.mkdir(folder_path) 94 | torch.save({ 95 | 'state_dict': block.state_dict(), 96 | 'epoch': epoch 97 | }, op.join(folder_path, filename)) 98 | 99 | 100 | class MultiLayer(nn.Module): 101 | """ 102 | MultiLayer Network created from list of preset blocks 103 | """ 104 | 105 | def __init__(self, blocks_params: dict, blocks: nn.Module = None) -> None: 106 | super().__init__() 107 | self.train_mode = None 108 | self.train_blocks = [] 109 | 110 | self.config = blocks_params 111 | if blocks_params is not None: 112 | blocks = [] 113 | for _, params in blocks_params.items(): 114 | blocks.append(generate_block(params)) 115 | self.blocks = nn.Sequential(*blocks) 116 | else: 117 | self.blocks = nn.Sequential(*blocks) 118 | 119 | def foward_x_wta(self, x: torch.Tensor) -> torch.Tensor: 120 | for id, block in self.generator_block(): 121 | if id != len(self.blocks) - 1: 122 | x = block(x) 123 | else: 124 | return block.foward_x_wta(x) 125 | 126 | def forward(self, x: torch.Tensor) -> torch.Tensor: 127 | x = self.blocks(x) 128 | return x 129 | 130 | def get_block(self, id): 131 | return self.blocks[id] 132 | 133 | def sub_model(self, block_ids): 134 | sub_blocks = [] 135 | max_id = max(block_ids) 136 | for id, block in self.generator_block(): 137 | sub_blocks.append(self.get_block(id)) 138 | if id == max_id: 139 | break 140 | 141 | return MultiLayer(None, sub_blocks) 142 | 143 | def is_hebbian(self) -> bool: 144 | """ 145 | Return if the last block of the model is hebbian 146 | """ 147 | return self.blocks[-1].is_hebbian() 148 | 149 | def get_lr(self) -> float: 150 | """ 151 | Return the lr of the last hebbian block 152 | """ 153 | if self.train_blocks: 154 | for i in reversed(self.train_blocks): 155 | if self.blocks[-i].is_hebbian(): 156 | return self.blocks[-i].get_lr() 157 | if self.blocks[0].is_hebbian(): 158 | return self.blocks[0].get_lr() 159 | return 0 160 | 161 | def radius(self, layer=None) -> str: 162 | """ 163 | Return the radius of the first hebbian block 164 | """ 165 | if layer is not None: 166 | return self.blocks[layer].radius() 167 | if self.train_blocks: 168 | r = [] 169 | for i in reversed(self.train_blocks): 170 | if self.blocks[i].is_hebbian(): 171 | r.append(self.blocks[i].radius()) 172 | return '\n ************************************************************** \n'.join(r) 173 | if self.blocks[0].is_hebbian(): 174 | return self.blocks[0].radius() 175 | return '' 176 | 177 | def convergence(self) -> str: 178 | """ 179 | Return the radius of the last hebbian block 180 | """ 181 | for i in range(1, len(self.blocks) + 1): 182 | if self.blocks[-i].is_hebbian(): 183 | return self.blocks[-i].layer.convergence() 184 | return 0, 0 185 | 186 | def reset(self): 187 | if self.blocks[0].is_hebbian(): 188 | self.blocks[0].layer.reset() 189 | 190 | def generator_block(self): 191 | for id, block in enumerate(self.blocks): 192 | yield id, block 193 | 194 | def update(self): 195 | for block in self.train_blocks: 196 | self.get_block(block).update() 197 | 198 | def __str__(self): 199 | for _, block in self.generator_block(): 200 | block.__str__() 201 | 202 | def train(self, mode=True, blocks=[]): 203 | """ 204 | Set the learning update to the expected mode. 205 | mode:True, BP:False, HB:True --> training Hebbian layer 206 | mode:True, BP:True, HB:False --> training fc 207 | mode:True, BP:True, HB:True --> training Hebbain + fc blocks 208 | mode:False --> predict 209 | """ 210 | self.training = mode 211 | self.train_blocks = blocks 212 | # print('train mode %s and layer %s'%(mode, blocks)) 213 | 214 | for param in self.parameters(): 215 | param.requires_grad = False 216 | for _, block in self.generator_block(): 217 | block.eval() 218 | 219 | for block in blocks: 220 | module = self.get_block(block) 221 | 222 | module.train(mode) 223 | for param in module.parameters(): 224 | param.requires_grad = True 225 | 226 | 227 | class HebbianOptimizer: 228 | def __init__(self, model): 229 | """Custom optimizer which particularly delegates weight updates of Unsupervised layers to these layers themselves. 230 | 231 | Args: 232 | model (torch.nn.Module): Pytorch model 233 | """ 234 | self.model = model 235 | self.param_groups = [] 236 | 237 | @torch.no_grad() 238 | def step(self, *args): 239 | """Performs a single optimization step. 240 | """ 241 | loss = None 242 | 243 | for block in self.model.blocks: 244 | if block.is_hebbian(): 245 | block.update(*args) 246 | 247 | def zero_grad(self): 248 | pass 249 | 250 | 251 | class AggregateOptim: 252 | def __init__(self, optimizers): 253 | """Custom optimizer aggregating several optimizers together to run simulaneously 254 | 255 | Args: 256 | optimizers (List[torch.autograd.optim.Optimizer]): List of optimizers which need to be called simultaneously 257 | """ 258 | self.optimizers = optimizers 259 | self.param_groups = [] 260 | for optim in self.optimizers: 261 | self.param_groups.extend(optim.param_groups) 262 | 263 | def __repr__(self): 264 | representations = [] 265 | for optim in self.optimizers: 266 | representations.append(repr(optim)) 267 | return '\n'.join(representations) 268 | 269 | def step(self): 270 | for optim in self.optimizers: 271 | optim.step() 272 | 273 | def zero_grad(self): 274 | for optim in self.optimizers: 275 | optim.zero_grad() 276 | -------------------------------------------------------------------------------- /multi_layer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils import load_presets, get_device, load_config_dataset, seed_init_fn, str2bool 4 | from model import load_layers 5 | from train import run_sup, run_unsup, check_dimension, training_config, run_hybrid 6 | from log import Log, save_logs 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | parser = argparse.ArgumentParser(description='Multi layer Hebbian Training') 12 | 13 | parser.add_argument('--preset', choices=load_presets(), default=None, 14 | type=str, help='Preset of hyper-parameters ' + 15 | ' | '.join(load_presets()) + 16 | ' (default: None)') 17 | 18 | parser.add_argument('--dataset-unsup', choices=load_config_dataset(), default='MNIST', 19 | type=str, help='Dataset possibilities ' + 20 | ' | '.join(load_config_dataset()) + 21 | ' (default: MNIST)') 22 | 23 | parser.add_argument('--dataset-sup', choices=load_config_dataset(), default='MNIST', 24 | type=str, help='Dataset possibilities ' + 25 | ' | '.join(load_config_dataset()) + 26 | ' (default: MNIST)') 27 | 28 | parser.add_argument('--training-mode', choices=['successive', 'consecutive', 'simultaneous'], default='successive', 29 | type=str, help='Training possibilities ' + 30 | ' | '.join(['successive', 'consecutive', 'simultaneous']) + 31 | ' (default: successive)') 32 | 33 | parser.add_argument('--resume', choices=[None, "all", "without_classifier"], default=None, 34 | type=str, help='Resume Model ' + 35 | ' | '.join(["best", "last"]) + 36 | ' (default: None)') 37 | 38 | parser.add_argument('--model-name', default=None, type=str, help='Model Name') 39 | 40 | parser.add_argument('--training-blocks', default=None, nargs='+', type=int, 41 | help='Selection of the blocks that will be trained') 42 | 43 | parser.add_argument('--seed', default=None, type=int, 44 | help='Selection of the blocks that will be trained') 45 | 46 | parser.add_argument('--gpu-id', default=0, type=int, metavar='N', 47 | help='Id of gpu selected for training (default: 0)') 48 | 49 | parser.add_argument('--save', default=True, type=str2bool, metavar='N', 50 | help='') 51 | 52 | parser.add_argument('--validation', default=False, type=str2bool, metavar='N', 53 | help='') 54 | 55 | 56 | def main(blocks, name_model, resume, save, dataset_sup_config, dataset_unsup_config, train_config, gpu_id): 57 | device = get_device(gpu_id) 58 | model = load_layers(blocks, name_model, resume) 59 | 60 | model = model.to(device) 61 | 62 | log = Log(train_config) 63 | 64 | for id, config in train_config.items(): 65 | if config['mode'] == 'unsupervised': 66 | run_unsup( 67 | config['nb_epoch'], 68 | config['print_freq'], 69 | config['batch_size'], 70 | name_model, 71 | dataset_unsup_config, 72 | model, 73 | device, 74 | log.unsup[id], 75 | blocks=config['blocks'], 76 | save=save 77 | ) 78 | elif config['mode'] == 'supervised': 79 | run_sup( 80 | config['nb_epoch'], 81 | config['print_freq'], 82 | config['batch_size'], 83 | config['lr'], 84 | name_model, 85 | dataset_sup_config, 86 | model, 87 | device, 88 | log.sup[id], 89 | blocks=config['blocks'], 90 | save=save 91 | ) 92 | else: 93 | run_hybrid( 94 | config['nb_epoch'], 95 | config['print_freq'], 96 | config['batch_size'], 97 | config['lr'], 98 | name_model, 99 | dataset_sup_config, 100 | model, 101 | device, 102 | log.sup[id], 103 | blocks=config['blocks'], 104 | save=save 105 | ) 106 | 107 | save_logs(log, name_model) 108 | 109 | 110 | if __name__ == '__main__': 111 | params = parser.parse_args() 112 | name_model = params.preset if params.model_name is None else params.model_name 113 | blocks = load_presets(params.preset) 114 | dataset_sup_config = load_config_dataset(params.dataset_sup, params.validation) 115 | dataset_unsup_config = load_config_dataset(params.dataset_unsup, params.validation) 116 | if params.seed is not None: 117 | dataset_sup_config['seed'] = params.seed 118 | dataset_unsup_config['seed'] = params.seed 119 | 120 | if dataset_sup_config['seed'] is not None: 121 | seed_init_fn(dataset_sup_config['seed']) 122 | 123 | blocks = check_dimension(blocks, dataset_sup_config) 124 | 125 | train_config = training_config(blocks, dataset_sup_config, dataset_unsup_config, params.training_mode, 126 | params.training_blocks) 127 | 128 | main(blocks, name_model, params.resume, params.save, dataset_sup_config, dataset_unsup_config, train_config, 129 | params.gpu_id) 130 | -------------------------------------------------------------------------------- /nb_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import os.path as op 5 | from typing import Union, Any 6 | 7 | import eagerpy as ep 8 | import foolbox 9 | from foolbox import PyTorchModel 10 | from foolbox.attacks.base import T, get_criterion, raise_if_kwargs 11 | from foolbox.attacks.gradient_descent_base import BaseGradientDescent, normalize_lp_norms, uniform_l2_n_balls 12 | from foolbox.criteria import Misclassification, TargetedMisclassification 13 | from foolbox.devutils import flatten 14 | from foolbox.distances import l2 15 | from foolbox.models.base import Model 16 | from foolbox.types import Bounds 17 | import matplotlib.pyplot as plt 18 | from matplotlib.colors import ListedColormap 19 | from matplotlib.patches import Rectangle 20 | import numpy as np 21 | import pandas as pd 22 | from sklearn.decomposition import PCA 23 | import torch 24 | import umap 25 | 26 | from model import load_layers 27 | from log import load_logs 28 | from utils import SEARCH 29 | 30 | 31 | def find_checkpoint(path): 32 | for exp in os.listdir(path): 33 | exp_path = os.path.join(path, exp) 34 | if os.path.isdir(exp_path) and os.path.isfile(os.path.join(exp_path, 'checkpoint.pth.tar')): 35 | return exp_path 36 | print('No checkpoint found') 37 | return None 38 | 39 | 40 | def resume_model(preset, model_path_override=None): 41 | model = load_layers([], preset, 'last', verbose=False, model_path_override=model_path_override) 42 | return model 43 | 44 | def get_results(path, metric='test_acc'): 45 | seeds_results = [] 46 | for exp in os.listdir(path): 47 | exp_path = os.path.join(path, exp) 48 | if os.path.isdir(exp_path): 49 | with open(os.path.join(exp_path, 'progress.csv')) as f: 50 | results = csv.DictReader(f) 51 | for row in results: 52 | result = float(row[metric]) 53 | seeds_results.append(result) 54 | assert len(seeds_results) > 0, 'Experiment results not found' 55 | return seeds_results 56 | 57 | def get_mean_std(path, metric='test_acc'): 58 | seeds_results = get_results(path, metric) 59 | return np.mean(seeds_results), np.std(seeds_results) 60 | 61 | 62 | class ChooseNeuronFlat(torch.nn.Module): 63 | def __init__(self, idx): 64 | super(ChooseNeuronFlat, self).__init__() 65 | self.idx = idx 66 | 67 | def forward(self, x): 68 | return x[:, self.idx].view(-1) 69 | 70 | 71 | class SingleMax(foolbox.criteria.Criterion): 72 | def __init__(self, max_val: float, eps: float): 73 | super().__init__() 74 | self.max_val: float = max_val 75 | self.eps: float = eps 76 | 77 | def __repr__(self) -> str: 78 | return f"{self.__class__.__name__}(Max={self.max_val}, eps={self.eps})" 79 | 80 | def __call__(self, perturbed: T, outputs: T) -> T: 81 | outputs_, restore_type = ep.astensor_(outputs) 82 | del perturbed, outputs 83 | if self.max_val is None: 84 | is_adv = outputs_ > 0 or outputs_ <= 0 85 | else: 86 | is_adv = (self.max_val - outputs_) < self.eps 87 | return restore_type(is_adv) 88 | 89 | 90 | def attack_model(model, data, labels, epsilons, iterations, random_start=True, nb_start=20, step_size=0.01 / 0.3, 91 | criterion=None, device=None): 92 | attack = L2ProjGradientDescent(steps=iterations, random_start=random_start, rel_stepsize=step_size) 93 | bounds = (0, 1) 94 | fmodel = PyTorchModel(model, device=device, bounds=bounds) 95 | raw_advs, clipped_advs, success = attack(model=fmodel, inputs=data, criterion=criterion, epsilons=epsilons) 96 | raw_advs = torch.stack(raw_advs) 97 | clipped_advs = torch.stack(clipped_advs) 98 | # For debugging 99 | # print('Raw : x norm', raw_advs.norm().item(), 'x max', raw_advs.max().item(), 'x min', raw_advs.min().item()) 100 | # print('Clipped : x norm', clipped_advs.norm().item(), 'x max', clipped_advs.max().item(), 'x min', clipped_advs.min().item()) 101 | for start in range(1, nb_start): 102 | new_raw_advs, new_clipped_advs, new_success = attack(model=fmodel, inputs=data, criterion=criterion, 103 | epsilons=epsilons) 104 | new_raw_advs = torch.stack(new_raw_advs) 105 | raw_advs[new_success == 1] = new_raw_advs[new_success == 1] 106 | 107 | new_clipped_advs = torch.stack(new_clipped_advs) 108 | clipped_advs[new_success == 1] = new_clipped_advs[new_success == 1] 109 | success[new_success == 1] = 1 110 | return raw_advs, clipped_advs, success 111 | 112 | 113 | class L2ProjGradientDescent(BaseGradientDescent): 114 | distance = l2 115 | 116 | def get_random_start(self, x0: ep.Tensor, epsilon: float) -> ep.Tensor: 117 | print('Getting random start') 118 | batch_size, n = flatten(x0).shape 119 | r = uniform_l2_n_balls(x0, batch_size, n).reshape(x0.shape) 120 | return x0 + 0.00001 * epsilon * r 121 | 122 | def normalize( 123 | self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds 124 | ) -> ep.Tensor: 125 | # This is to normalize gradients 126 | return gradients 127 | 128 | def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor: 129 | sphere = epsilon * normalize_lp_norms(x, p=2).abs() 130 | sphere = sphere - sphere.min() 131 | if not sphere.max() > 0 and not sphere.min() < 0: 132 | return sphere 133 | else: 134 | return sphere / sphere.max() 135 | 136 | def run( 137 | self, 138 | model: Model, 139 | inputs: T, 140 | criterion: Union[Misclassification, TargetedMisclassification, T], 141 | *, 142 | epsilon: float, 143 | **kwargs: Any, 144 | ) -> T: 145 | raise_if_kwargs(kwargs) 146 | x0, restore_type = ep.astensor_(inputs) 147 | criterion_ = get_criterion(criterion) 148 | del inputs, criterion, kwargs 149 | 150 | # perform a gradient ascent (targeted attack) or descent (untargeted attack) 151 | if isinstance(criterion_, Misclassification): 152 | gradient_step_sign = 1.0 153 | classes = criterion_.labels 154 | loss_fn = self.get_loss_fn(model, classes) 155 | elif hasattr(criterion_, "target_classes"): 156 | gradient_step_sign = -1.0 157 | classes = criterion_.target_classes # type: ignore 158 | loss_fn = self.get_loss_fn(model, classes) 159 | elif hasattr(criterion_, "max_val"): 160 | def loss_fn_max(inputs: ep.Tensor) -> ep.Tensor: 161 | out = model(inputs) 162 | return out 163 | 164 | loss_fn = loss_fn_max 165 | gradient_step_sign = 1.0 166 | else: 167 | raise ValueError("unsupported criterion") 168 | 169 | if self.abs_stepsize is None: 170 | stepsize = self.rel_stepsize * epsilon 171 | else: 172 | stepsize = self.abs_stepsize 173 | 174 | if self.random_start: 175 | x = self.get_random_start(x0, epsilon) 176 | x = ep.clip(x, *model.bounds) 177 | else: 178 | x = x0 179 | 180 | for i_s in range(self.steps): 181 | val, gradients = self.value_and_grad(loss_fn, x) 182 | gradients = self.normalize(gradients, x=x, bounds=model.bounds) 183 | x = x + gradient_step_sign * stepsize * gradients 184 | x = self.project(x, x0, epsilon) 185 | x = ep.clip(x, *model.bounds) 186 | return restore_type(x) 187 | 188 | 189 | def get_representations(model, data_loader, device, layer, n_inputs_max=np.inf, module=None, reload=True): 190 | print("Calculate representations...") 191 | inputs = [] 192 | reps = [] 193 | targets = [] 194 | n_inputs = 0 195 | for step, (img, target) in enumerate(data_loader): 196 | # print("batch number: ", step, " of ", len(data_loader)) 197 | residual_co = [] 198 | model_input = img.to(device) 199 | inputs.append(model_input.clone()) 200 | for l in range(layer): 201 | block = model.blocks[l] 202 | 203 | model_input = block(model_input) 204 | 205 | reps.append(model_input.detach()) 206 | 207 | targets.append(target) 208 | n_inputs = n_inputs + target.shape[0] 209 | if n_inputs >= n_inputs_max: 210 | break 211 | 212 | inputs = torch.cat(inputs).cpu() 213 | reps = torch.cat(reps).cpu() 214 | targets = torch.cat(targets).cpu() 215 | 216 | return inputs, reps, targets 217 | 218 | 219 | def rep_PCA(X, k): 220 | pca = PCA(n_components=k) 221 | pca.fit(X) 222 | print('PCA done and explained %s' % np.sum(pca.explained_variance_ratio_)) 223 | Y = pca.transform(X) 224 | return Y 225 | 226 | 227 | def rep_2d(inputs, reps, targets, n_points=None, apply_pca=True, avg_pixels=False, method='tsne'): 228 | n_samples = targets.shape[0] 229 | if n_points == None: 230 | n_points = n_samples 231 | 232 | d_inputs = inputs.reshape(n_samples, -1) 233 | 234 | if avg_pixels: 235 | reps = torch.mean(reps, (2, 3)) # spatial mean pooling 236 | 237 | d_reps = reps.reshape(n_samples, -1) 238 | 239 | if apply_pca: 240 | if d_reps.shape[1] > 100: 241 | d_reps = rep_PCA(d_reps, 100) 242 | 243 | print("Doing %s..." % method) 244 | if method == 't-sne': 245 | tsne_reps = manifold.TSNE(perplexity=50) 246 | 247 | # t_inputs = tsne_inputs.fit_transform(d_inputs[:n_points,:]) 248 | t_reps = tsne_reps.fit_transform(d_reps[:n_points, :]) 249 | else: 250 | reducer = umap.UMAP(random_state=42) 251 | reducer.fit(d_reps[:n_points, :]) 252 | t_reps = reducer.transform(d_reps[:n_points, :]) 253 | 254 | # tSNE_plot(opt, t_inputs, targets[:n_points], class_names, fig_name_ext = 'input') 255 | return t_reps, targets[:n_points] 256 | 257 | 258 | def plot_2d(t_data, targets, class_names, plot_labels=True, title=None, xlim=None, ylim=None, marker_size=50, 259 | no_border=False, STORE_GRAPH=None): 260 | plt.close('all') 261 | plt.rcParams.update({'font.size': 25, 262 | 'lines.linewidth': 2, 263 | 'lines.linestyle': '-', 264 | 'lines.markersize': 8}) 265 | 266 | colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', 267 | '#17becf'] 268 | # color_targets = [colors[t] for t in targets] 269 | cmap2 = ListedColormap(colors) 270 | fig = plt.figure() 271 | fig, ax = plt.subplots(figsize=(20, 15)) 272 | 273 | scatter = plt.scatter(t_data[:, 0], t_data[:, 1], s=marker_size, c=targets, cmap=cmap2) 274 | 275 | if xlim is not None: 276 | plt.xlim(xlim) 277 | if ylim is not None: 278 | plt.ylim(ylim) 279 | if plot_labels: 280 | plt.legend(handles=scatter.legend_elements()[0], 281 | labels=class_names.values(), 282 | prop={'size': 18}) 283 | if title is not None: 284 | if no_border: 285 | fig.patch.set_visible(False) 286 | ax.axis('off') 287 | else: 288 | plt.title(title) 289 | 290 | try: 291 | fig.savefig(op.join(STORE_GRAPH, title + '.png')) 292 | fig.savefig(op.join(STORE_GRAPH, title + '.png'), format='svg') 293 | except: 294 | print('---Save not available') 295 | 296 | 297 | def unravel_index(indices: torch.LongTensor, shape) -> torch.LongTensor: 298 | """Converts flat indices into unraveled coordinates in a target shape. 299 | This is a `torch` implementation of `numpy.unravel_index`. 300 | Args: 301 | indices: A tensor of (flat) indices, (*, N). 302 | shape: The targeted shape, (D,). 303 | Returns: 304 | The unraveled coordinates, (*, N, D). 305 | """ 306 | 307 | coord = [] 308 | 309 | for dim in reversed(shape): 310 | coord.append(indices % dim) 311 | indices = indices // dim 312 | 313 | coord = torch.stack(coord[::-1], dim=-1) 314 | 315 | return coord 316 | 317 | 318 | def max_activation(reps, n_neurons=4, n_max=10): 319 | s = reps.shape 320 | neuron_inds = np.random.choice([i for i in range(s[1])], size=(n_neurons), replace=False) 321 | responses = reps[:, neuron_inds, :, :].permute(1, 0, 2, 3).reshape(n_neurons, -1) # n_neurons, b*x*y 322 | 323 | patches_list = [] 324 | for n in range(n_neurons): 325 | inds_flattened = responses[n].argsort(descending=True) # b*x*y (sorted by value, largest first) 326 | inds = unravel_index(inds_flattened, (s[0], s[2], s[3])) # b*x*y, 3 (time, x, y) 327 | patches = inds[:n_max].clone() # n_max, 3 (max. activating patches) 328 | 329 | ctr = 0 330 | tns = [] 331 | tns.append(patches[0][0]) 332 | for npatch in range(1, n_max): 333 | tnp = patches[npatch][0] 334 | while tnp in tns: 335 | ctr += 1 336 | patches[npatch][:] = inds[n_max + ctr] 337 | tnp = patches[npatch][0] 338 | tns.append(tnp) 339 | 340 | patches_list.append(patches) 341 | return patches_list 342 | 343 | 344 | def imgs_patches(model, layer, indexes): 345 | sizes = [1] 346 | operations = [] 347 | size = 1 348 | for l in range(layer - 1, -1, -1): 349 | conv = model.blocks[l].layer 350 | 351 | d_conv = conv.dilation[0] 352 | s_conv = conv.stride[0] 353 | k_conv = conv.kernel_size[0] 354 | p_conv = conv.padding[0] 355 | 356 | pool = model.blocks[l].pool 357 | k_pool = pool.kernel_size 358 | s_pool = pool.stride 359 | 360 | size = (size - 1) * s_pool + k_pool 361 | 362 | size = (size - 1) * s_conv + (k_conv - 1) * d_conv + 1 363 | 364 | sizes.append(size) 365 | operations.append(lambda x1, x2, s=s_pool, k=k_pool: (x1 * s - s // 2, (x2 - 1) * s + s // 2)) 366 | operations.append(lambda x1, x2, p=p_conv, d=d_conv, s=s_conv, k=k_conv: ( 367 | x1 * d * s - k // 2, (x2 - 1) * d * s + (k - k // 2))) 368 | 369 | patches = [] 370 | for n in range(len(indexes)): 371 | patches_n = [] 372 | for b, w, h in indexes[n]: 373 | h1 = int(h) 374 | h2 = int(h1 + 1) 375 | w1 = int(w) 376 | w2 = int(w1 + 1) 377 | for op in operations: 378 | h1, h2 = op(h1, h2) 379 | w1, w2 = op(w1, w2) 380 | patches_n.append([b, (w1, w2), (h1, h2)]) 381 | patches.append(patches_n) 382 | 383 | return patches 384 | 385 | 386 | def plot_patches(imgs, patches, layer, time=0, STORE_GRAPH=None): 387 | n_examples = len(patches) 388 | n_plots = len(patches[0]) 389 | fig, axes = plt.subplots(nrows=n_examples, ncols=n_plots, figsize=(20, int(20 * n_examples / n_plots)), sharex=True, 390 | sharey=True) # , gridspec_kw={'wspace': 0.05}) 391 | 392 | for i in range(n_examples): 393 | for j in range(n_plots): 394 | # b, (w0, w1), (h0, h1) = patches[i][j] 395 | b, (h0, h1), (w0, w1) = patches[i][j] 396 | ax = axes[i][j] 397 | 398 | ax.imshow(imgs[b].permute(1, 2, 0)) 399 | 400 | rect = Rectangle((w0, h0), (w1 - w0), (h1 - h0), linewidth=1, edgecolor='r', facecolor='none') 401 | 402 | # Add the patch to the Axes 403 | ax.add_patch(rect) 404 | ax.set_xticks([]) 405 | ax.set_yticks([]) 406 | ax.set_aspect('equal') 407 | plt.show() 408 | try: 409 | fig.savefig(op.join(STORE_GRAPH, 'patch_activation_' + str(layer) + '_' + str(time) + '.png', format='png')) 410 | fig.savefig(op.join(STORE_GRAPH, 'patch_activation_' + str(layer) + '_' + str(time) + '.svg', format='svg')) 411 | except: 412 | print('---Save not available') 413 | 414 | 415 | def get_param(dict_): 416 | def p(d): 417 | parms = {} 418 | for key, value in d.items(): 419 | if isinstance(value, dict): 420 | for x, y in p(value).items(): 421 | parms.update({key + '/' + x: y}) 422 | else: 423 | parms.update({key: value}) 424 | return parms 425 | 426 | return p(dict_) 427 | 428 | 429 | def load_results(exp_path, folder): 430 | try: 431 | result = pd.read_csv(op.join(exp_path, folder, 'progress.csv')) 432 | 433 | with open(op.join(exp_path, folder, 'params.json')) as f: 434 | data = json.load(f) 435 | 436 | for param, value in get_param(data).items(): 437 | result[param] = value 438 | except: 439 | print(folder) 440 | result = [] 441 | return result 442 | 443 | 444 | # def get_data(exp): 445 | # exp_path = op.join(SEARCH, exp) 446 | # folders = [f for f in os.listdir(exp_path) if op.isdir(op.join(exp_path, f))] 447 | # data = [load_results(exp_path, f) for f in folders] 448 | # data = [x for x in data if isinstance(x, pd.DataFrame)] 449 | # return data 450 | 451 | def get_data(exp, return_folders=False): 452 | exp_path = op.join(SEARCH, exp) 453 | folders = [f for f in os.listdir(exp_path) if op.isdir(op.join(exp_path, f))] 454 | data = [load_results(exp_path, f) for f in folders] 455 | df_data = [] 456 | used_folders = [] 457 | for idx_x, x in enumerate(data): 458 | if isinstance(x, pd.DataFrame): 459 | df_data.append(x) 460 | used_folders.append(folders[idx_x]) 461 | assert len(df_data) == len(used_folders) 462 | if return_folders: 463 | return df_data, used_folders 464 | else: 465 | return df_data 466 | 467 | 468 | # def error(data, bars=16, wts='test_acc'): 469 | # if 'convergence' in data[0].columns: 470 | # for d in data: 471 | # d.rename(columns={'convergence': 'R1'}, inplace=True) 472 | # conv_acc = pd.DataFrame([d.iloc[-1][['R1', wts]] for d in data]) 473 | # nb_r1 = conv_acc['R1'].max() 474 | # off = nb_r1 / bars / 2 475 | # conv_acc['bars'] = conv_acc.apply(lambda x: min(nb_r1, max(0, int((x['R1'] + off) / nb_r1 * bars) * nb_r1 / bars)), 476 | # 1) 477 | # conv_acc = conv_acc.sort_values('bars') 478 | # error_acc = conv_acc[['bars', wts]].groupby('bars').agg(['mean', 'std']).fillna(method='ffill') 479 | # 480 | # return error_acc 481 | 482 | 483 | def error(data, bars=16, wts='test_acc'): 484 | if 'convergence' in data[0].columns: 485 | convergence_metric = 'convergence' 486 | else: 487 | convergence_metric = 'R1' 488 | conv_acc = pd.DataFrame([d.iloc[-1][[convergence_metric, wts]] for d in data]) 489 | nb_r1 = conv_acc[convergence_metric].max() 490 | off = nb_r1/bars/2 491 | conv_acc['bars'] = conv_acc.apply(lambda x: min(nb_r1, max(0,int((x[convergence_metric]+off)/nb_r1*bars) * nb_r1 / bars)), 1) 492 | conv_acc = conv_acc.sort_values('bars') 493 | error_acc = conv_acc[['bars', wts]].groupby('bars').agg(['mean','std']).fillna(method='ffill') 494 | return error_acc 495 | 496 | 497 | # def extract_data(data, features=None, wts='test_acc'): 498 | # if features is None: 499 | # features = ['b0/layer/t_invert', 'R1', wts] 500 | # conv_acc = pd.DataFrame([d.iloc[-1][features] for d in data]) 501 | # conv_acc = conv_acc.groupby(features[:-2]).agg({wts: ['mean', 'std']}).droplevel(0, 1).reset_index() 502 | # conv_acc['t'] = 1 / conv_acc['b0/layer/t_invert'] 503 | # conv_acc = conv_acc.sort_values('b0/layer/t_invert') 504 | # return conv_acc 505 | 506 | 507 | def extract_data(data, features=None, wts='test_acc'): 508 | if features is None: 509 | features = ['b0/layer/softness', 'b0/layer/t_invert', 'R1', wts, 'dataset_unsup/seed'] 510 | conv_acc = pd.DataFrame([d.iloc[-1][features] for d in data]) 511 | if 'b0/layer/t_invert' in features: 512 | conv_acc = conv_acc.groupby(features[:-3]).agg({i: ['mean', 'std'] for i in features[-3:-1]}) 513 | conv_acc.columns = conv_acc.columns.map('_'.join) 514 | conv_acc = conv_acc.reset_index() 515 | conv_acc['t'] = 1 / conv_acc['b0/layer/t_invert'] 516 | conv_acc = conv_acc.sort_values('t') 517 | return conv_acc 518 | 519 | # def extract_data(data, features=None, wts='test_acc'): 520 | # if features is None: 521 | # features = [wts] 522 | # conv_acc = pd.DataFrame([d.iloc[-1][features] for d in data]) 523 | # return conv_acc 524 | 525 | 526 | def load_data(exps, t='t1'): 527 | datas = [] 528 | for exp in exps: 529 | data=[] 530 | _, log = load_logs(exp) 531 | for b in log.sup[t].batch[:-1]: 532 | d = b.get_numpy() 533 | d[:, 2] *= 100/ d[:, 0] 534 | data.extend(d) 535 | datas.append(np.array(data)) 536 | return datas 537 | 538 | def moving_avg(x, rolling_window=None): 539 | if rolling_window is None: 540 | rolling_window = max(50, int(len(x)/30)) 541 | cumsum, moving_aves = [0], [] 542 | 543 | for i, y in enumerate(x, 1): 544 | cumsum.append(cumsum[i-1] + y) 545 | if i > rolling_window: 546 | r = min(len(cumsum), rolling_window) 547 | moving_ave = (cumsum[i] - cumsum[i-r])/r 548 | moving_aves.append(moving_ave) 549 | return np.array(rolling_window * [moving_aves[0]] + moving_aves) 550 | 551 | 552 | def plot_filter(filters, rows=10, cols=8, title=None): 553 | if filters.shape[0] < rows * cols: 554 | if filters.shape[0] <= cols: 555 | rows = 2 556 | cols = int(filters.shape[0] / 2) 557 | elif filters.shape[0] <= rows: 558 | cols = int(filters.shape[0] / rows) 559 | else: 560 | rows = int(filters.shape[0] / cols) 561 | 562 | if cols > 2: 563 | 564 | fig, axes = plt.subplots( 565 | rows, 566 | cols, 567 | figsize=(20, int(20 * rows / cols)) 568 | ) 569 | 570 | vmin = filters[:(rows + 1) * cols].min() 571 | vmax = filters[:(rows + 1) * cols].max() 572 | 573 | for row in range(rows): 574 | for col in range(cols): 575 | i = row * cols + col 576 | filter = filters[i] 577 | 578 | if len(filter.shape) == 3: 579 | filter -= np.min(filter) # Normalize 580 | filter /= np.max(filter) 581 | axes[row, col].imshow(filter, cmap='coolwarm') # , vmin=vmin, vmax=vmax) 582 | else: 583 | axes[row, col].imshow(filter, cmap='coolwarm', vmin=vmin, vmax=vmax) 584 | axes[row, col].axis("off") 585 | # axes[row, col].title.set_text('{i},r: {r:.2f}'.format(i=i, r=np.sum(np.abs(filter) ** 2))) 586 | else: 587 | fig, axes = plt.subplots(1, filters.shape[0], figsize=(20, 6)) 588 | for row in range(filters.shape[0]): 589 | filter = filters[row] 590 | 591 | axes[row].imshow(filter, cmap='coolwarm', vmin=0, vmax=1) 592 | axes[row].axis("off") 593 | 594 | if title is not None: 595 | plt.title(title) 596 | 597 | plt.show() 598 | 599 | try: 600 | fig.savefig(op.join(STORE_GRAPH, title + '.png')) 601 | fig.savefig(op.join(STORE_GRAPH, title + '.png'), format='svg') 602 | except: 603 | print('---Save not available') -------------------------------------------------------------------------------- /post_hoc_loss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils import load_presets, get_device, load_config_dataset, seed_init_fn 4 | from model import load_layers 5 | from train import run_sup, check_dimension 6 | from log import Log, save_logs 7 | import copy 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | parser = argparse.ArgumentParser(description='Post hoc loss') 12 | 13 | parser.add_argument('--preset', choices=load_presets(), default=None, 14 | type=str, help='Preset of hyper-parameters ' + 15 | ' | '.join(load_presets()) + 16 | ' (default: None)') 17 | 18 | parser.add_argument('--dataset-unsup', choices=load_config_dataset(), default='MNIST', 19 | type=str, help='Dataset possibilities ' + 20 | ' | '.join(load_config_dataset()) + 21 | ' (default: MNIST)') 22 | 23 | parser.add_argument('--dataset-sup', choices=load_config_dataset(), default='MNIST', 24 | type=str, help='Dataset possibilities ' + 25 | ' | '.join(load_config_dataset()) + 26 | ' (default: MNIST)') 27 | 28 | parser.add_argument('--model-name', default=None, type=str, help='Model Name') 29 | 30 | 31 | parser.add_argument('--seed', default=None, type=int, 32 | help='Selection of the blocks that will be trained') 33 | 34 | parser.add_argument('--gpu-id', default=0, type=int, metavar='N', 35 | help='Id of gpu selected for training (default: 0)') 36 | 37 | 38 | def training_config(blocks, dataset_sup_config, dataset_unsup_config): 39 | """ 40 | Define the training order of blocks: 41 | -successive: one block after the other 42 | -consecutive: Hebbian blocks then BP blocks 43 | -simultaneous: All at once with an hybrid learning 44 | 45 | Parameters 46 | ---------- 47 | blocks: dict 48 | configuration of every blocks in the model 49 | dataset_config: dict 50 | configuration of the dataset 51 | mode: str 52 | 53 | Returns 54 | ------- 55 | train_layer_order: dict 56 | configuration of the training blocks order 57 | 58 | 59 | """ 60 | for id in range(len(blocks)): 61 | blocks['b%s' % id]['layer']['lr_scheduler'] = {'decay':'cste', 'lr': 0.1} 62 | blocks_train = range(len(blocks)) 63 | 64 | train_layer_order = {} 65 | layer = {'sup':[], 'unsup':[]} 66 | lr = {'sup': [], 'unsup': []} 67 | 68 | for id in blocks_train: 69 | block = blocks['b%s' % id] 70 | config = block['layer'] 71 | if config['hebbian']: 72 | layer['unsup'].append(id) 73 | lr['unsup'].append(config['lr']) 74 | 75 | else: 76 | layer['sup'].append(id) 77 | lr['unsup'].append(config['lr']) 78 | lr['sup'].append(config['lr_sup']) 79 | if layer['unsup']: 80 | train_layer_order['t0'] = { 81 | 'blocks': layer['unsup'], 82 | 'mode': 'supervised', 83 | 'learning_mode': 'HB', 84 | 'batch_size': dataset_unsup_config['batch_size'], 85 | 'nb_epoch': dataset_unsup_config['nb_epoch'], 86 | 'print_freq': dataset_unsup_config['print_freq'], 87 | 'lr': min(lr['unsup']) 88 | } 89 | config = blocks['b0']['layer'] 90 | 91 | config['lr_scheduler'] = { 92 | 'lr': config['lr'], 93 | 'adaptive': config['adaptive'], 94 | 'nb_epochs': train_layer_order['t0']['nb_epoch'], 95 | 'ratio': train_layer_order['t0']['batch_size'] / dataset_unsup_config['training_sample'], 96 | 'speed': config['speed'], 97 | 'div': config['lr_div'], 98 | 'decay': config['lr_decay'], 99 | 'power_lr': config['power_lr'] 100 | } 101 | else: 102 | train_layer_order['t0'] = { 103 | 'blocks': layer['sup'], 104 | 'mode': 'supervised', 105 | 'learning_mode': 'BP', 106 | 'batch_size': dataset_unsup_config['batch_size'], 107 | 'nb_epoch': dataset_unsup_config['nb_epoch'], 108 | 'print_freq': dataset_unsup_config['print_freq'], 109 | 'lr': min(lr['unsup']) 110 | } 111 | train_layer_order['t2'] = train_layer_order['t0'].copy() 112 | train_layer_order['t2']['blocks'] = [0] 113 | 114 | train_layer_order['t1'] = { 115 | 'blocks': [1], 116 | 'mode': 'supervised', 117 | 'batch_size': dataset_sup_config['batch_size'], 118 | 'nb_epoch': dataset_sup_config['nb_epoch'], 119 | 'print_freq': dataset_sup_config['print_freq'], 120 | 'lr': min(lr['sup']) 121 | } 122 | 123 | return train_layer_order 124 | 125 | 126 | def main(blocks, name_model, dataset_sup_config, dataset_unsup_config, train_config, gpu_id): 127 | device = get_device(gpu_id) 128 | model = load_layers(blocks, name_model, resume=None) 129 | 130 | model = model.to(device) 131 | 132 | log = Log(train_config) 133 | 134 | # copy all the parameter of the first layer to overide in the future 135 | b0_copy = copy.deepcopy(model.get_block(0).state_dict()) 136 | 137 | config0 = train_config['t0'] 138 | config1 = train_config['t1'] 139 | config2 = train_config['t2'] 140 | 141 | if config0['learning_mode'] == 'HB': 142 | run_sup( 143 | final_epoch=config0['nb_epoch'], 144 | print_freq=config0['print_freq'], 145 | batch_size=config0['batch_size'], 146 | lr=config0['lr'], 147 | folder_name=name_model, 148 | dataset_config=dataset_unsup_config, 149 | model=model, 150 | device=device, 151 | log=log.sup['t0'], 152 | blocks=config0['blocks'], 153 | learning_mode=config0['learning_mode'] 154 | ) 155 | print('First layer trained') 156 | 157 | run_sup( 158 | final_epoch=config1['nb_epoch'], 159 | print_freq=config1['print_freq'], 160 | batch_size=config1['batch_size'], 161 | lr=config1['lr'], 162 | folder_name=name_model, 163 | dataset_config=dataset_sup_config, 164 | model=model, 165 | device=device, 166 | log=log.sup['t1'], 167 | blocks=config1['blocks'] 168 | ) 169 | 170 | 171 | print('Second layer trained') 172 | # overide parameters of the first layer 173 | 174 | print(model.get_block(0).state_dict()) 175 | print(b0_copy) 176 | model.get_block(0).load_state_dict(b0_copy) 177 | model.reset() 178 | 179 | 180 | run_sup( 181 | final_epoch=config2['nb_epoch'], 182 | print_freq=config2['print_freq'], 183 | batch_size=config2['batch_size'], 184 | lr=config2['lr'], 185 | folder_name=name_model, 186 | dataset_config=dataset_sup_config, 187 | model=model, 188 | device=device, 189 | log=log.sup['t2'], 190 | blocks=config2['blocks'], 191 | learning_mode=config2['learning_mode'], 192 | save_batch=True 193 | ) 194 | 195 | save_logs(log, name_model) 196 | 197 | 198 | if __name__ == '__main__': 199 | params = parser.parse_args() 200 | name_model = params.preset if params.model_name is None else params.model_name 201 | blocks = load_presets(params.preset) 202 | dataset_unsup_config = load_config_dataset(params.dataset_unsup, False) 203 | dataset_sup_config = load_config_dataset(params.dataset_sup, False) 204 | if params.seed is not None: 205 | dataset_sup_config['seed'] = params.seed 206 | 207 | if dataset_sup_config['seed'] is not None: 208 | seed_init_fn(dataset_sup_config['seed']) 209 | 210 | blocks = check_dimension(blocks, dataset_sup_config) 211 | 212 | 213 | train_config = training_config(blocks, dataset_sup_config, dataset_unsup_config) 214 | 215 | main(blocks, name_model, dataset_sup_config, dataset_unsup_config, train_config, params.gpu_id) 216 | -------------------------------------------------------------------------------- /ray_search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import pdb 4 | 5 | from utils import SEARCH, load_presets, get_device, load_config_dataset, merge_parameter, seed_init_fn, str2bool 6 | from model import load_layers 7 | import torch 8 | from train import run_sup, run_unsup, check_dimension, training_config, run_hybrid 9 | from log import Log 10 | import ray 11 | from ray import tune 12 | from ray.tune.suggest.basic_variant import BasicVariantGenerator 13 | from ray.tune import CLIReporter 14 | from functools import partial 15 | import warnings 16 | import numpy as np 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | metric_names = ['train_loss', 'train_acc', 'test_loss', 'test_acc', 'convergence', 'R1'] 21 | 22 | parser = argparse.ArgumentParser(description='Multi layer Hebbian Training') 23 | 24 | parser.add_argument('--preset', choices=load_presets(), default=None, 25 | type=str, help='Preset of hyper-parameters ' + 26 | ' | '.join(load_presets()) + 27 | ' (default: None)') 28 | 29 | parser.add_argument('--dataset-unsup', choices=load_config_dataset(), default='MNIST', 30 | type=str, help='Dataset possibilities ' + 31 | ' | '.join(load_config_dataset()) + 32 | ' (default: MNIST)') 33 | 34 | parser.add_argument('--dataset-sup', choices=load_config_dataset(), default='MNIST', 35 | type=str, help='Dataset possibilities ' + 36 | ' | '.join(load_config_dataset()) + 37 | ' (default: MNIST)') 38 | 39 | parser.add_argument('--training-mode', choices=['successive', 'consecutive', 'simultaneous'], default='successive', 40 | type=str, help='Training possibilities ' + 41 | ' | '.join(['successive', 'consecutive', 'simultaneous']) + 42 | ' (default: consecutive)') 43 | 44 | parser.add_argument('--resume', choices=[None, "all", "without_classifier"], default=None, 45 | type=str, help='Resume Model ' + 46 | ' | '.join(["all", "without_classifier"]) + 47 | ' (default: None)') 48 | 49 | parser.add_argument('--metric', choices=metric_names, default='test_acc', 50 | type=str, help='Primary Metric' + 51 | ' | '.join(metric_names) + 52 | ' (default: test_acc)') 53 | 54 | parser.add_argument('--training-blocks', default=None, nargs='+', type=int, 55 | help='Selection of the blocks that will be trained') 56 | 57 | parser.add_argument('--folder-name', default=None, type=str, 58 | help='Name of the experiment') 59 | 60 | parser.add_argument('--num-samples', default=1, type=int, 61 | help='number of search into the hparams space') 62 | 63 | parser.add_argument('--model-name', default=None, type=str, help='Model Name') 64 | 65 | parser.add_argument('--validation-sup', default=False, type=str2bool, metavar='N', 66 | help='') 67 | 68 | parser.add_argument('--validation-unsup', default=False, type=str2bool, metavar='N', 69 | help='') 70 | 71 | parser.add_argument('--config', default='seed', type=str, metavar='N', 72 | help='') 73 | 74 | parser.add_argument('--gpu-exp', default=1, type=int, metavar='N', 75 | help='') 76 | 77 | parser.add_argument('--save-model', default=False, action='store_true', 78 | help='Save model checkpoints, configs, etc') 79 | 80 | parser.add_argument('--debug', default=False, action='store_true', help='Debug mode (ray local)') 81 | 82 | 83 | def get_config(config_name): 84 | if config_name == 'regimes': 85 | t_invert_search = [1.25 ** (x - 50) for x in range(100)] 86 | softness_search = ["soft", "softkrotov"] 87 | seeds = [0, 1, 2] 88 | configs = [] 89 | for i_softness in softness_search: 90 | for i_t_invert in t_invert_search: 91 | for i_seed in seeds: 92 | i_config = { 93 | f'b{i_layer}': { 94 | "layer": { 95 | 't_invert': i_t_invert, 96 | "softness": i_softness, 97 | } 98 | } for i_layer in range(3)} 99 | i_config['dataset_unsup'] = { 100 | 'seed': i_seed, 101 | } 102 | configs.append(i_config) 103 | 104 | config = tune.grid_search(configs) 105 | 106 | elif config_name == 'radius': 107 | config = { 108 | 'b0': { 109 | "layer": { 110 | 'radius': tune.grid_search([1.25 ** (x - 10) for x in range(27)]), 111 | } 112 | }, 113 | 'dataset_unsup': { 114 | 'seed': tune.grid_search([0, 1, 2]), 115 | } 116 | } 117 | elif config_name == 'one_seed': 118 | config = { 119 | 'dataset_unsup': { 120 | 'seed': 0 121 | } 122 | } 123 | else: 124 | config = { 125 | 'dataset_unsup': { 126 | 'seed': tune.grid_search([0, 1, 2, 3]) 127 | } 128 | } 129 | print("config_name", config_name) 130 | print("config", config) 131 | return config 132 | 133 | 134 | def main(params, dataset_sup_config, dataset_unsup_config, blocks, config): 135 | for block_id, block in blocks.items(): 136 | if block_id in config: 137 | blocks[block_id] = merge_parameter(block.copy(), config[block_id]) 138 | print("blocks", blocks) 139 | 140 | if "dataset_unsup" in config: 141 | dataset_unsup_config = merge_parameter(dataset_unsup_config, config['dataset_unsup']) 142 | 143 | if "dataset_sup" in config: 144 | dataset_sup_config = merge_parameter(dataset_sup_config, config['dataset_sup']) 145 | 146 | if dataset_unsup_config['seed'] is not None: 147 | seed_init_fn(dataset_unsup_config['seed']) 148 | 149 | device = get_device() 150 | 151 | blocks = check_dimension(blocks, dataset_sup_config) 152 | 153 | print("dataset_sup_config, dataset_unsup_config", dataset_sup_config, dataset_unsup_config) 154 | train_config = training_config(blocks, dataset_sup_config, dataset_unsup_config, params.training_mode, 155 | params.training_blocks) 156 | 157 | print("train_config", train_config) 158 | 159 | model = load_layers(blocks, params.name_model, params.resume) 160 | 161 | model.reset() 162 | 163 | model = model.to(device) 164 | 165 | log = Log(train_config) 166 | for id, config in train_config.items(): 167 | if config['mode'] == 'unsupervised': 168 | run_unsup( 169 | config['nb_epoch'], 170 | config['print_freq'], 171 | config['batch_size'], 172 | params.name_model, 173 | dataset_unsup_config, 174 | model, 175 | device, 176 | log.unsup[id], 177 | blocks=config['blocks'], 178 | report=tune.report, 179 | save=params.save_model, 180 | reset=False, 181 | model_dir=tune.session.get_trial_dir(), 182 | ) 183 | elif config['mode'] == 'supervised': 184 | print('Running supervised') 185 | run_sup( 186 | config['nb_epoch'], 187 | config['print_freq'], 188 | config['batch_size'], 189 | config['lr'], 190 | params.name_model, 191 | dataset_sup_config, 192 | model, 193 | device, 194 | log.sup[id], 195 | blocks=config['blocks'], 196 | report=tune.report, 197 | save=params.save_model, 198 | model_dir=tune.session.get_trial_dir(), 199 | ) 200 | else: 201 | run_hybrid( 202 | config['nb_epoch'], 203 | config['print_freq'], 204 | config['batch_size'], 205 | config['lr'], 206 | params.name_model, 207 | dataset_sup_config, 208 | model, 209 | device, 210 | log.sup[id], 211 | blocks=config['blocks'], 212 | report=tune.report, 213 | save=params.save_model, 214 | model_dir=tune.session.get_trial_dir(), 215 | ) 216 | 217 | 218 | if __name__ == '__main__': 219 | params = parser.parse_args() 220 | 221 | config = get_config(params.config) 222 | 223 | params.name_model = params.preset if params.model_name is None else params.model_name # TODO change this for better model storage 224 | blocks = load_presets(params.preset) 225 | 226 | dataset_sup_config = load_config_dataset(params.dataset_sup, params.validation_sup) 227 | dataset_unsup_config = load_config_dataset(params.dataset_unsup, params.validation_unsup) 228 | 229 | if params.debug is True: 230 | # local_mode=True for debugging . It seems there's no need to init ray for these usecase 231 | ray.init(local_mode=True) 232 | 233 | reporter = CLIReporter(max_progress_rows=12) 234 | for metric in metric_names: 235 | reporter.add_metric_column(metric) 236 | 237 | algo_search = BasicVariantGenerator() 238 | 239 | trial_exp = partial( 240 | main, params, dataset_sup_config, dataset_unsup_config, blocks 241 | ) 242 | # TODO: use ray for model storing, as it is better aware of the different variants 243 | analysis = tune.run( 244 | trial_exp, 245 | resources_per_trial={ 246 | "cpu": 4, 247 | "gpu": max(1 / params.gpu_exp, torch.cuda.device_count() * 4 / 86) 248 | }, 249 | metric=params.metric, 250 | mode='min' if params.metric.endswith('loss') else 'max', 251 | search_alg=algo_search, 252 | config=config, 253 | progress_reporter=reporter, 254 | num_samples=params.num_samples, 255 | local_dir=SEARCH, 256 | name=params.folder_name) 257 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import CustomStepLR, double_factorial 2 | from model import save_layers, HebbianOptimizer, AggregateOptim 3 | from engine import train_sup, train_unsup, evaluate_unsup, evaluate_sup 4 | from dataset import make_data_loaders 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | 11 | def check_dimension(blocks, dataset_config): 12 | """ 13 | Make each block dimension of the model corresponds to the next one. 14 | Parameters 15 | ---------- 16 | blocks: dict 17 | configuration of every blocks in the model 18 | dataset_config: dict 19 | configuration of the dataset 20 | Returns 21 | ------- 22 | blocks: dict 23 | configuration of every blocks in the model with correct dimensionality 24 | 25 | """ 26 | 27 | in_channels, out_channels_final, in_width, in_height = dataset_config['channels'], \ 28 | dataset_config['out_channels'], \ 29 | dataset_config['width'], \ 30 | dataset_config['height'] 31 | 32 | for id in range(len(blocks)): 33 | block = blocks['b%s' % id] 34 | assert block['num'] == id, 'Block b%s has not the correct number %s ' % (id, block['num']) 35 | config = block['layer'] 36 | if id == len(blocks) - 1 and not config['hebbian']: 37 | config['out_channels'] = out_channels_final 38 | # assert out_channels_final == config['out_channels'], \ 39 | # 'Output channels %s is different than number of classes %s'%(config['out_channels'], out_channels_final) 40 | 41 | if 'operation' in block and 'flatten' in block['operation']: 42 | config['in_channels'] = int(in_channels * in_width * in_height) 43 | config['old_channels'] = in_channels 44 | else: 45 | config['in_channels'] = in_channels 46 | 47 | if block['arch'] == 'CNN': 48 | # config['padding'] = config['kernel_size']//2 49 | in_width = int((in_width + 2 * config['padding'] - config['dilation'] * ( 50 | config['kernel_size'] - 1) - 1) / config['stride']) + 1 51 | in_height = int((in_height + 2 * config['padding'] - config['dilation'] * ( 52 | config['kernel_size'] - 1) - 1) / config['stride']) + 1 53 | if block['pool'] is not None: 54 | # block['pool']['padding'] = int(int(block['pool']['kernel_size']) / 2 - 1) 55 | in_width = int((in_width - 1 * (block['pool']['kernel_size'] - 1) + 2 * block['pool']['padding'] - 1) / 56 | block['pool']['stride'] + 1) 57 | in_height = int( 58 | (in_height - 1 * (block['pool']['kernel_size'] - 1) + 2 * block['pool']['padding'] - 1) / 59 | block['pool']['stride'] + 1) 60 | print('block %s, size : %s %s %s' % (id, config['out_channels'], in_width, in_height)) 61 | in_channels = config['out_channels'] # prepare for next loop 62 | 63 | lp = blocks['b%s' % id]['layer']['lebesgue_p'] 64 | initial_r = blocks['b%s' % id]['layer']['radius'] ** (lp) 65 | 66 | if blocks['b%s' % id]['arch'] == 'CNN': 67 | kenel_size = blocks['b%s' % id]['layer']['kernel_size'] 68 | input_channel = blocks['b%s' % id]['layer']['in_channels'] 69 | groups = blocks['b%s' % id]['layer']['groups'] 70 | n_neurons = input_channel / groups * kenel_size ** 2 71 | else: 72 | n_neurons = blocks['b%s' % id]['layer']['in_channels'] 73 | if "operation" in block and "batchnorm" in block["operation"]: 74 | blocks['b%s' % id]['layer']['weight_init'] = 'normal' 75 | 76 | t = double_factorial(lp - 1) * (np.sqrt(2 / np.pi) if lp % 2 != 0 else 1) 77 | blocks['b%s' % id]['layer']['weight_init_range'] = np.power((initial_r / (n_neurons * t)), 1 / lp) 78 | else: 79 | blocks['b%s' % id]['layer']['weight_init'] = 'positive' 80 | blocks['b%s' % id]['layer']['weight_init_range'] = np.power(((lp + 1) * initial_r / (n_neurons)), 1 / lp) 81 | 82 | print('range = %s' % blocks['b%s' % id]['layer']['weight_init_range']) 83 | return blocks 84 | 85 | 86 | def training_config(blocks, dataset_sup_config, dataset_unsup_config, mode, blocks_train=None): 87 | """ 88 | Define the training order of blocks: 89 | -successive: one block after the other 90 | -consecutive: Hebbian blocks then BP blocks 91 | -simultaneous: All at once with an hybrid learning 92 | 93 | Parameters 94 | ---------- 95 | blocks: dict 96 | configuration of every blocks in the model 97 | dataset_config: dict 98 | configuration of the dataset 99 | mode: str 100 | 101 | Returns 102 | ------- 103 | train_layer_order: dict 104 | configuration of the training blocks order 105 | 106 | 107 | """ 108 | for id in range(len(blocks)): 109 | blocks['b%s' % id]['layer']['lr_scheduler'] = {'decay': 'cste', 'lr': 0.1} 110 | blocks_train = range(len(blocks)) if blocks_train is None else blocks_train 111 | if mode == 'successive': 112 | train_layer_order = {} 113 | train_id = 0 114 | for id in blocks_train: 115 | block = blocks['b%s' % id] 116 | config = block['layer'] 117 | if config['hebbian']: 118 | 119 | train_layer_order['t%s' % train_id] = { 120 | 'blocks': [id], 121 | 'mode': 'unsupervised', 122 | 'lr': config['lr'], 123 | 'nb_epoch': dataset_unsup_config['nb_epoch'], 124 | 'batch_size': dataset_unsup_config['batch_size'], 125 | 'print_freq': dataset_unsup_config['print_freq'] 126 | } 127 | config['lr_scheduler'] = { 128 | 'lr': config['lr'], 129 | 'adaptive': config['adaptive'], 130 | 'nb_epochs': dataset_unsup_config['nb_epoch'], 131 | 'ratio': dataset_unsup_config['batch_size'] / dataset_unsup_config['training_sample'], 132 | 'speed': config['speed'], 133 | 'div': config['lr_div'], 134 | 'decay': config['lr_decay'], 135 | 'power_lr': config['power_lr'] 136 | } 137 | last_hebbian = True 138 | train_id += 1 139 | else: 140 | train_layer_order['t%s' % train_id] = { 141 | 'blocks': [id], 142 | 'mode': 'supervised', 143 | 'lr': config['lr_sup'], 144 | 'nb_epoch': dataset_sup_config['nb_epoch'], 145 | 'batch_size': dataset_sup_config['batch_size'], 146 | 'print_freq': dataset_sup_config['print_freq'] 147 | } 148 | train_id += 1 149 | elif mode == 'consecutive': 150 | train_layer_order = {} 151 | layer = {'sup': [], 'unsup': []} 152 | lr = {'sup': [], 'unsup': []} 153 | 154 | for id in blocks_train: 155 | block = blocks['b%s' % id] 156 | config = block['layer'] 157 | # this allows to have supervised Hebbian 158 | is_unsup = config['hebbian'] and config.get('metric_mode', 'unsupervised') != 'supervised' 159 | if is_unsup: 160 | layer['unsup'].append(id) 161 | lr['unsup'].append(config['lr']) 162 | else: 163 | layer['sup'].append(id) 164 | lr['sup'].append(config['lr_sup']) 165 | if layer['unsup']: # if the list is not empty, i.e. we have unsup blocks 166 | train_layer_order['t0'] = { 167 | 'blocks': layer['unsup'], 168 | 'mode': 'unsupervised', 169 | 'batch_size': dataset_unsup_config['batch_size'], 170 | 'nb_epoch': dataset_unsup_config['nb_epoch'], 171 | 'print_freq': dataset_unsup_config['print_freq'], 172 | 'lr': min(lr['unsup']) 173 | } 174 | if layer['sup']: # if the list is not empty, i.e. we have sup blocks 175 | t_id = 't1' if layer['unsup'] else 't0' 176 | train_layer_order[t_id] = { 177 | 'blocks': layer['sup'], 178 | 'mode': 'supervised', 179 | 'batch_size': dataset_sup_config['batch_size'], 180 | 'nb_epoch': dataset_sup_config['nb_epoch'], 181 | 'print_freq': dataset_sup_config['print_freq'], 182 | 'lr': min(lr['sup']) 183 | } 184 | for id in range(len(blocks)): 185 | block = blocks['b%s' % id] 186 | config = block['layer'] 187 | if config['hebbian']: 188 | config['lr_scheduler'] = { 189 | 'lr': config['lr'], 190 | 'adaptive': config['adaptive'], 191 | 'nb_epochs': train_layer_order['t0']['nb_epoch'], 192 | 'ratio': train_layer_order['t0']['batch_size'] / dataset_unsup_config['training_sample'], 193 | 'speed': config['speed'], 194 | 'div': config['lr_div'], 195 | 'decay': config['lr_decay'], 196 | 'power_lr': config['power_lr'] 197 | } 198 | elif mode == 'simultaneous': 199 | train_layer_order = { 200 | 'blocks': [], 201 | 'lr': [], 202 | 'mode': 'hybrid', 203 | 'batch_size': dataset_sup_config['batch_size'], 204 | 'nb_epoch': dataset_sup_config['nb_epoch'], 205 | 'print_freq': dataset_sup_config['print_freq'], 206 | } 207 | 208 | for id in blocks_train: 209 | block = blocks['b%s' % id] 210 | config = block['layer'] 211 | train_layer_order['blocks'].append(id) 212 | if not config['hebbian']: 213 | train_layer_order['lr'].append(config['lr_sup']) 214 | 215 | train_layer_order['lr'] = min(train_layer_order['lr']) 216 | for id in range(len(blocks)): 217 | block = blocks['b%s' % id] 218 | config = block['layer'] 219 | if config['hebbian']: 220 | config['lr_scheduler'] = { 221 | 'lr': config['lr'], 222 | 'adaptive': config['adaptive'], 223 | 'nb_epochs': train_layer_order['nb_epoch'], 224 | 'ratio': train_layer_order['batch_size'] / dataset_sup_config['training_sample'], 225 | 'speed': config['speed'], 226 | 'div': config['lr_div'], 227 | 'decay': config['lr_decay'], 228 | 'power_lr': config['power_lr'] 229 | } 230 | train_layer_order = {'t1': train_layer_order} 231 | else: 232 | raise ValueError 233 | return train_layer_order 234 | 235 | 236 | def run_hybrid( 237 | final_epoch: int, 238 | print_freq: int, 239 | batch_size: int, 240 | lr: float, 241 | folder_name: str, 242 | dataset_config: dict, 243 | model, 244 | device, 245 | log, 246 | blocks, 247 | learning_mode: str = 'BP', 248 | save_batch: bool = True, 249 | save: bool = True, 250 | report=None, 251 | plot_fc=None, 252 | model_dir=None, 253 | ): 254 | """ 255 | Hybrid training of one model, happens during simultaneous training mode 256 | 257 | """ 258 | 259 | print('\n', '********** Hybrid learning of blocks %s **********' % blocks) 260 | 261 | train_loader, test_loader = make_data_loaders(dataset_config, batch_size, device) 262 | 263 | optimizer_sgd = optim.Adam( 264 | model.parameters(), lr=lr) # , weight_decay=1e-4) 265 | criterion = nn.CrossEntropyLoss() 266 | hebbian_optimizer = HebbianOptimizer(model) 267 | scheduler = CustomStepLR(optimizer_sgd, final_epoch) 268 | optimizer = AggregateOptim((hebbian_optimizer, optimizer_sgd)) 269 | log_batch = log.new_log_batch() 270 | for epoch in range(1, final_epoch + 1): 271 | measures, lr = train_sup(model, criterion, optimizer, train_loader, device, log_batch, learning_mode, blocks) 272 | 273 | if scheduler is not None: 274 | scheduler.step() 275 | 276 | if epoch % print_freq == 0 or epoch == final_epoch or epoch == 1: 277 | 278 | loss_test, acc_test = evaluate_sup(model, criterion, test_loader, device) 279 | 280 | log_batch = log.step(epoch, log_batch, loss_test, acc_test, lr, save=save_batch) 281 | 282 | if report is not None: 283 | _, train_loss, train_acc, test_loss, test_acc = log.data[-1] 284 | 285 | conv, R1 = model.convergence() 286 | report(train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, 287 | convergence=conv, R1=R1) 288 | 289 | else: 290 | log.verbose() 291 | if save: 292 | save_layers(model, folder_name, epoch, blocks, storing_path=model_dir) 293 | 294 | if plot_fc is not None: 295 | for block in blocks: 296 | plot_fc(model, block) 297 | 298 | 299 | def run_unsup( 300 | final_epoch: int, 301 | print_freq: int, 302 | batch_size: int, 303 | folder_name: str, 304 | dataset_config: dict, 305 | model, 306 | device, 307 | log, 308 | blocks, 309 | save: bool = True, 310 | report=None, 311 | plot_fc=None, 312 | reset=False, 313 | model_dir=None 314 | ): 315 | """ 316 | Unsupervised training of hebbians blocks of one model 317 | 318 | """ 319 | print('\n', '********** Hebbian Unsupervised learning of blocks %s **********' % blocks) 320 | 321 | train_loader, test_loader = make_data_loaders(dataset_config, batch_size, device) 322 | 323 | for epoch in range(1, final_epoch + 1): 324 | lr, info, convergence, R1 = train_unsup(model, train_loader, device, blocks) 325 | 326 | if epoch % print_freq == 0 or epoch == final_epoch or epoch == 1: 327 | 328 | acc_train, acc_test = evaluate_unsup(model, train_loader, test_loader, device, blocks) 329 | 330 | log.step(epoch, acc_train, acc_test, info, convergence, R1, lr) 331 | 332 | if report is not None: 333 | report(train_loss=0., train_acc=acc_train, test_loss=0., test_acc=acc_test, convergence=convergence, 334 | R1=R1) 335 | # else: 336 | log.verbose() 337 | 338 | if save: 339 | save_layers(model, folder_name, epoch, blocks, storing_path=model_dir) 340 | 341 | if plot_fc is not None: 342 | for block in blocks: 343 | plot_fc(model, block) 344 | if reset: 345 | model.reset() 346 | 347 | 348 | def run_sup( 349 | final_epoch: int, 350 | print_freq: int, 351 | batch_size: int, 352 | lr: float, 353 | folder_name: str, 354 | dataset_config: dict, 355 | model, 356 | device, 357 | log, 358 | blocks, 359 | learning_mode: str = 'BP', 360 | save_batch: bool = False, 361 | save: bool = True, 362 | report=None, 363 | plot_fc=None, 364 | model_dir=None 365 | ): 366 | """ 367 | Supervised training of BP blocks of one model 368 | 369 | """ 370 | 371 | print('\n', '********** Supervised learning of blocks %s **********' % blocks) 372 | 373 | train_loader, test_loader = make_data_loaders(dataset_config, batch_size, device) 374 | 375 | criterion = nn.CrossEntropyLoss() 376 | log_batch = log.new_log_batch() 377 | if all([model.get_block(b).is_hebbian() for b in blocks]): 378 | # optimizer, scheduler, log_batch = None, None, None 379 | optimizer, scheduler = None, None 380 | else: 381 | # criterion = nn.CrossEntropyLoss() 382 | optimizer = optim.Adam(model.parameters(), lr=lr) # , weight_decay=1e-6) 383 | scheduler = CustomStepLR(optimizer, final_epoch) 384 | 385 | for epoch in range(1, final_epoch + 1): 386 | measures, lr = train_sup(model, criterion, optimizer, train_loader, device, log_batch, learning_mode, blocks) 387 | 388 | if scheduler is not None: 389 | scheduler.step() 390 | 391 | if epoch % print_freq == 0 or epoch == final_epoch or epoch == 1: 392 | 393 | # so the diff between evaluate sup and unsup is that former calcs train and test acc, former test loss and acc 394 | loss_test, acc_test = evaluate_sup(model, criterion, test_loader, device) 395 | 396 | log_batch = log.step(epoch, log_batch, loss_test, acc_test, lr, save_batch) 397 | 398 | if report is not None: 399 | _, train_loss, train_acc, test_loss, test_acc = log.data[-1] 400 | conv, R1 = model.convergence() 401 | report(train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, 402 | convergence=conv, R1=R1) 403 | else: 404 | log.verbose() 405 | 406 | if save: 407 | save_layers(model, folder_name, epoch, blocks, storing_path=model_dir) 408 | 409 | if plot_fc is not None: 410 | for block in blocks: 411 | plot_fc(model, block) 412 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as op 3 | import json 4 | import random 5 | from math import e 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.optim.lr_scheduler import StepLR 11 | 12 | username = op.expanduser('~').split('/')[-1] 13 | data_candidate = ('/scratch' if 'hrodriguez' == username else '/home') + f'/{username}/workspace' 14 | DATA = op.realpath(op.expanduser(data_candidate)) 15 | RESULT = op.join(DATA, 'results', 'hebb', 'result') # everything from multi_layer.py 16 | SEARCH = op.join(DATA, 'results', 'hebb', 'search') # everything from ray_search 17 | DATASET = op.join(DATA, 'data') 18 | 19 | 20 | def get_folder_name(params): 21 | """ 22 | from a set of parameter, define the name of the model and thus its folder name 23 | 24 | 25 | Parameters 26 | ---------- 27 | params : namespace or dict 28 | hyperparameters. 29 | 30 | Returns 31 | ------- 32 | folder_name : str 33 | folder name or name of one model. 34 | 35 | """ 36 | if params.folder_name is not None: 37 | return params.folder_name 38 | if params.preset is not None: 39 | folder_name = params.preset 40 | 41 | if isinstance(params, dict): 42 | if params['folder_name'] is not None: 43 | return params.folder_name 44 | if params['preset'] is not None: 45 | folder_name = params.preset 46 | else: 47 | names = ['arch', 'n_neurons', 'lr', 't_invert'] 48 | folder_name = '_'.join([str(params[name]) for name in names]) 49 | if params['post_hoc_loss']: 50 | folder_name = 'post_hoc_loss_' + folder_name 51 | else: 52 | if params.folder_name is not None: 53 | return params.folder_name 54 | if params.preset is not None: 55 | folder_name = params.preset 56 | else: 57 | names = ['arch', 'n_neurons', 'lr', 't_invert'] 58 | folder_name = '_'.join([str(getattr(params, name)) for name in names]) 59 | if params.post_hoc_loss: 60 | folder_name = 'post_hoc_loss_' + folder_name 61 | 62 | return folder_name 63 | 64 | 65 | def activation(x, t_invert=e, activation_fn='exp', dim=1, power=15, beta=1, normalize=False): 66 | """ 67 | Applies the Softmax function to an n-dimensional input Tensor 68 | rescaling them so that the elements of the n-dimensional output Tensor 69 | lie in the range [0,1] and sum to 1. It using softmax function from pytorch. 70 | The t_invert parameter allows a range of softmax between WTA and AllTA. 71 | 72 | Parameters 73 | ---------- 74 | x : torch.tensor 75 | DESCRIPTION. 76 | t_invert : torch.tensor 77 | DESCRIPTION. The default is torch.tensor(e). 78 | activation_fn : str 79 | activation function name. The default is 'exp'. 80 | dim : int 81 | output dimension of the softmax. The default is 1. 82 | 83 | Returns 84 | ------- 85 | TYPE 86 | softmax compute as a torch.tensor. 87 | 88 | """ 89 | if (activation_fn == 'exp' and normalize) or activation_fn == 'softmax': 90 | # this can lead to erros when passed with -inf, which is a design choice of funcs that call this 91 | # it'd be good to use a custom softmax where we could pass a small value in the denominator 92 | # however it seems it is not trivial to construct a softmax that achieves similar performance as Pytorch's 93 | return torch.softmax(t_invert * x, dim) 94 | if activation_fn == 'exp': 95 | x = torch.exp(x * t_invert) 96 | elif activation_fn == 'relu': 97 | x = torch.relu(x) 98 | elif activation_fn == 'sigmoid': 99 | x = torch.sigmoid(x) 100 | elif activation_fn == 'repu': 101 | x = torch.relu(x) ** power 102 | elif activation_fn == 'repu_norm': 103 | x = torch.relu(x) ** power 104 | normalize = True 105 | elif activation_fn == 'tanh': 106 | x = torch.tanh(beta * x) 107 | if normalize and x.sum() != 0: 108 | return (x.t() / x.sum(dim=1)).t() 109 | else: 110 | return x 111 | 112 | 113 | def get_device(gpu_id=0): 114 | """ 115 | Get the correct device either cuda or cpu with the selected id. 116 | 117 | Parameters 118 | ---------- 119 | gpu_id : int 120 | Gpu id. The default is 0. 121 | 122 | Returns 123 | ------- 124 | device : torch.device 125 | torch device either gpu or cpu. 126 | 127 | """ 128 | use_cuda = torch.cuda.is_available() and gpu_id is not None 129 | device = torch.device('cuda:' + str(gpu_id) if use_cuda else 'cpu') 130 | return device 131 | 132 | 133 | def seed_init_fn(seed): 134 | """ 135 | Dataloader worker init function, if seed is not None every epoch and 136 | experiment will get the same data. 137 | 138 | Parameters 139 | ---------- 140 | seed : int 141 | seed Id. 142 | 143 | Returns 144 | ------- 145 | None. 146 | 147 | """ 148 | seed = seed % 2 ** 32 149 | np.random.seed(seed) 150 | random.seed(seed) 151 | torch.manual_seed(seed) 152 | return 153 | 154 | 155 | def merge_parameter(based_params, override_params): 156 | """ 157 | Update the parameters in ``t_invert_params`` with ``override_params``. 158 | Can be useful to override parsed command line arguments. 159 | 160 | 161 | Parameters 162 | ---------- 163 | params : namespace or dict 164 | t_invert parameters. A key-value mapping. 165 | override_params : dict or None 166 | Parameters to override. Usually the parameters got from ``get_next_parameters()``. 167 | When it is none, nothing will happen. 168 | 169 | Returns 170 | ------- 171 | params : namespace or dict 172 | The updated ``t_invert_params``. Note that ``t_invert_params`` will be updated inplace. The return value is 173 | only for convenience.. 174 | 175 | """ 176 | if override_params is None: 177 | return based_params 178 | is_dict = isinstance(based_params, dict) 179 | for k, v in override_params.items(): 180 | if is_dict: 181 | # if k not in params: 182 | # raise ValueError('Key \'%s\' not found in parameters.' % k) 183 | if k not in based_params: 184 | based_params[k] = v 185 | elif isinstance(based_params[k], dict): 186 | if isinstance(v, dict): 187 | based_params[k] = merge_parameter(based_params[k], v) 188 | else: 189 | based_params[k] = v 190 | else: 191 | # if not hasattr(params, k): 192 | # raise ValueError('Key \'%s\' not found in parameters.' % k) 193 | if not hasattr(based_params, k): 194 | setattr(based_params, k, v) 195 | elif isinstance(getattr(based_params, k), dict): 196 | if isinstance(v, dict): 197 | setattr(based_params, k, merge_parameter(based_params[k], v)) 198 | else: 199 | setattr(based_params, k, v) 200 | return based_params 201 | 202 | 203 | def str2bool(v): 204 | """ 205 | Return boolean form a string 206 | 207 | Parameters 208 | ---------- 209 | v : str 210 | argparse argument. 211 | 212 | Returns 213 | ------- 214 | bool 215 | DESCRIPTION. 216 | 217 | """ 218 | if isinstance(v, bool): 219 | return v 220 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 221 | return True 222 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 223 | return False 224 | else: 225 | raise argparse.ArgumentTypeError('Boolean value expected.') 226 | 227 | 228 | def init_weight(shape, weight_distribution, weight_range, weight_offset=0): 229 | """ 230 | Weight initialization from a distribution 231 | Parameters 232 | ---------- 233 | shape: tuple 234 | Expected shape of the Weight tensor 235 | weight_distribution: str 236 | Distribution 237 | weight_range: 238 | multiplier of the weight 239 | weight_offset: 240 | Value add to the weight 241 | 242 | Returns 243 | ------- 244 | weight: Tensor 245 | """ 246 | if weight_distribution == 'positive': 247 | return weight_range * torch.rand(shape) + weight_offset 248 | elif weight_distribution == 'negative': 249 | return -weight_range * torch.rand(shape) + weight_offset 250 | elif weight_distribution == 'zero_mean': 251 | return 2 * torch.rand(shape) + weight_offset 252 | elif weight_distribution == 'normal': 253 | return weight_range * torch.randn(shape) + weight_offset 254 | 255 | 256 | def double_factorial(x): 257 | if x <= 2: 258 | return x 259 | return x * double_factorial(x - 2) 260 | 261 | 262 | def LrLinearDecay(lr, nb_epoch, ratio): 263 | """ 264 | Linear decay Generator 265 | """ 266 | delta = lr * ratio / nb_epoch 267 | while True: 268 | yield max(0., lr) 269 | lr = lr - delta 270 | 271 | 272 | def LrExpDecay(lr, nb_epoch, ratio, lr_div=100, speed=10): 273 | """ 274 | Exponential decay Generator 275 | """ 276 | relative_speed = speed * ratio / nb_epoch 277 | # to guaranty that min value is indeed lr / lr_div 278 | min_lr = (lr / lr_div - lr * np.exp(-speed)) / (1 - np.exp(-speed)) 279 | 280 | while True: 281 | yield lr 282 | lr = (lr - min_lr) / np.exp(relative_speed) + min_lr 283 | 284 | 285 | def LrCste(lr): 286 | """ 287 | Constant decay Generator 288 | """ 289 | while True: 290 | yield lr 291 | 292 | 293 | def unsup_lr_scheduler(lr, nb_epochs=1, ratio=1, speed=1, div=150, decay: str = 'linear'): 294 | """ 295 | Selection of the lr scheduler, return a Generator 296 | 297 | """ 298 | if nb_epochs == 0 or decay == 'constant': 299 | return LrCste(lr) 300 | if decay == 'linear': 301 | return LrLinearDecay(lr, nb_epochs, ratio) 302 | if decay == 'exp': 303 | return LrExpDecay(lr, nb_epochs, ratio, speed=speed, lr_div=div) 304 | return LrCste(lr) 305 | 306 | 307 | def normalize(normalize_type): 308 | if normalize_type == 'norm': 309 | return lambda x: nn.functional.normalize(x) 310 | return lambda x: x 311 | 312 | 313 | def generate_config(preset, arch): 314 | """ 315 | Generate config from name of the layer 316 | Parameters 317 | ---------- 318 | preset: dict 319 | initial config 320 | arch: str 321 | Architecture 322 | 323 | Returns 324 | ------- 325 | config: dict 326 | """ 327 | config = {} 328 | preset = preset.split("-") 329 | if preset[0] == 'BP': 330 | config['hebbian'] = False 331 | else: 332 | config['hebbian'] = True 333 | config['softness'] = preset[0] 334 | 335 | for param in preset[1:]: 336 | if param.startswith('c'): 337 | config['out_channels'] = int(param[1:]) 338 | if param.startswith('lr'): 339 | config['lr'] = float(param[2:]) 340 | if param.startswith('ls'): 341 | config['lr_sup'] = float(param[2:]) 342 | if param.startswith('lb'): 343 | config['lebesgue_p'] = float(param[2:]) 344 | if param.startswith('lp'): 345 | config['power_lr'] = float(param[2:]) 346 | if param.startswith('t'): 347 | config['t_invert'] = float(param[1:]) 348 | if param.startswith('b'): 349 | config['add_bias'] = bool(int(param[1:])) 350 | if param.startswith('a'): 351 | config['delta'] = float(param[1:]) 352 | if param.startswith('r'): 353 | config['radius'] = float(param[1:]) 354 | if param.startswith('v'): 355 | config['adaptive'] = bool(int(param[1:])) 356 | 357 | if arch == 'CNN': 358 | for param in preset[1:]: 359 | if param.startswith('c'): 360 | config['out_channels'] = int(param[1:]) 361 | elif param.startswith('k'): 362 | config['kernel_size'] = int(param[1:]) 363 | elif param.startswith('d'): 364 | config['dilation'] = int(param[1:]) 365 | elif param.startswith('p'): 366 | config['padding'] = int(param[1:]) 367 | elif param.startswith('s'): 368 | config['stride'] = int(param[1:]) 369 | elif param.startswith('s'): 370 | config['stride'] = int(param[1:]) 371 | elif param.startswith('m'): 372 | config['mask_thsd'] = float(param[1:]) 373 | elif param.startswith('g'): 374 | config['groups'] = int(param[1:]) 375 | elif param.startswith('e'): 376 | config['pre_triangle'] = bool(int(param[1:])) 377 | 378 | return config 379 | 380 | 381 | def load_presets(name=None): 382 | """ 383 | Load blocks config from name of the models 384 | 385 | """ 386 | presets = json.load(open('presets.json')) 387 | if name is None: 388 | return list(presets['model'].keys()) 389 | blocks = presets['model'][name] 390 | for id, block in blocks.items(): 391 | if block['preset'] in presets['layer'][block['arch']]: 392 | over_config = presets['layer'][block['arch']][block['preset']].copy() 393 | else: 394 | over_config = generate_config(block['preset'], block['arch']) # an option is to pass the supervision here 395 | 396 | if 'layer' in blocks[id]: 397 | # had to add this to override parameters from the default layer (eg 'metric_mode' in MLP) without causing larger changes 398 | over_config = merge_parameter(over_config, blocks[id]['layer']) 399 | blocks[id]['layer'] = merge_parameter(presets['layer'][block['arch']]['default'].copy(), over_config) 400 | 401 | if 'pool' in block and block['pool'] is not None: 402 | type_, kernel_size, stride, padding = block['pool'].split('_') 403 | blocks[id]['pool'] = {'type': type_, 'kernel_size': int(kernel_size), 'stride': int(stride), 404 | 'padding': int(padding)} 405 | else: 406 | blocks[id]['pool'] = None 407 | 408 | if 'activation' in block and block['activation'] is not None: 409 | param = 1 410 | activation = block['activation'] 411 | activation_param = activation.split('_') 412 | if len(activation_param) == 2: 413 | activation = activation_param[0] 414 | param = float(activation_param[1]) 415 | blocks[id]['activation'] = {'function': activation, 'param': param} 416 | else: 417 | blocks[id]['activation'] = None 418 | 419 | return blocks 420 | 421 | 422 | def load_config_dataset(name=None, validation=True): 423 | """ 424 | Load dataset config from name of the dataset 425 | 426 | """ 427 | dataset = json.load(open('presets.json'))['dataset'] 428 | if name is None: 429 | lst_dataset = [] 430 | for key, value in dataset.items(): 431 | for prop in value.keys(): 432 | if prop == 'default': 433 | lst_dataset.append(key) 434 | else: 435 | lst_dataset.append(key + '_' + prop) 436 | 437 | return lst_dataset 438 | 439 | if '_' in name: 440 | dataset_name, dataset_prop = name.split('_') 441 | else: 442 | dataset_name = name 443 | dataset_prop = 'default' 444 | 445 | all_dataset_config = dataset[dataset_name] 446 | dataset_config = merge_parameter(dataset['default'], all_dataset_config['default']) 447 | dataset_config = merge_parameter(dataset_config, all_dataset_config[dataset_prop]) 448 | 449 | dataset_config['validation'] = validation 450 | if validation: 451 | dataset_config['val_sample'] = int( 452 | np.floor(dataset_config['training_sample'] * dataset_config['validation_split'])) 453 | dataset_config['training_sample'] = dataset_config['training_sample'] - dataset_config['val_sample'] 454 | return dataset_config 455 | 456 | 457 | class CustomStepLR(StepLR): 458 | def __init__(self, optimizer, nb_epochs): 459 | self.step_thresold = [] 460 | if nb_epochs < 20: 461 | self.step_thresold = [] 462 | elif nb_epochs < 50: 463 | self.step_thresold.append(int(nb_epochs * 0.5)) 464 | self.step_thresold.append(int(nb_epochs * 0.75)) 465 | else: 466 | self.step_thresold.append(int(nb_epochs * 0.2)) 467 | self.step_thresold.append(int(nb_epochs * 0.35)) 468 | self.step_thresold.append(int(nb_epochs * 0.5)) 469 | self.step_thresold.append(int(nb_epochs * 0.6)) 470 | self.step_thresold.append(int(nb_epochs * 0.7)) 471 | self.step_thresold.append(int(nb_epochs * 0.8)) 472 | self.step_thresold.append(int(nb_epochs * 0.9)) 473 | 474 | super().__init__(optimizer, -1, False) 475 | 476 | def get_lr(self): 477 | if self.last_epoch in self.step_thresold: 478 | return [group['lr'] * 0.5 479 | for group in self.optimizer.param_groups] 480 | return [group['lr'] for group in self.optimizer.param_groups] 481 | 482 | 483 | class PowerLoss(nn.Module): 484 | def __init__(self, nb_output=10, m=6): 485 | super().__init__() 486 | self.nb_output = nb_output 487 | self.m = m 488 | 489 | def forward(self, c, t): 490 | t = torch.eye(self.nb_output, dtype=torch.float, device=c.device)[t] 491 | t[t == 0] = -1. 492 | loss = (c - t).abs() ** self.m 493 | return loss.sum() 494 | --------------------------------------------------------------------------------