├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── base.py ├── cifar10.py ├── cifar100.py ├── imagenet.py ├── mnist.py ├── registry.py ├── test │ ├── test_cifar10.py │ ├── test_cifar100.py │ └── test_mnist.py └── tiny_imagenet.py ├── experiments ├── branch │ ├── base.py │ ├── desc.py │ ├── oneshot_prune.py │ ├── randomly_prune.py │ ├── randomly_reinitialize.py │ ├── registry.py │ ├── retrain.py │ └── runner.py ├── finetune │ ├── desc.py │ ├── runner.py │ └── test │ │ └── test_runner.py ├── lottery │ ├── desc.py │ └── runner.py ├── rewindLR │ ├── desc.py │ ├── runner.py │ └── test │ │ └── test_runner.py ├── runner_registry.py └── scratch │ ├── desc.py │ ├── runner.py │ └── test │ └── test_runner.py ├── figs └── sparsedd.png ├── foundations ├── __init__.py ├── desc.py ├── hparams.py ├── local.py ├── paths.py ├── runner.py └── step.py ├── main.py ├── models ├── __init__.py ├── base.py ├── bn_initializers.py ├── cifar_pytorch_resnet.py ├── cifar_vgg.py ├── imagenet_resnet.py ├── initializers.py ├── mnist_mlp.py ├── registry.py └── tinyimagenet_resnet.py ├── pruning ├── __init__.py ├── base.py ├── gradient.py ├── magnitude.py ├── mask.py ├── network_slimming.py ├── pruned_model.py ├── random.py ├── registry.py └── test │ ├── __init__.py │ ├── test_magnitude.py │ ├── test_mask.py │ ├── test_network_slimming.py │ ├── test_pruned_model.py │ └── test_random.py ├── show_result.py ├── testing ├── test_case.py └── toy_model.py ├── training ├── __init__.py ├── checkpointing.py ├── desc.py ├── metric_logger.py ├── optimizers.py ├── runner.py ├── standard_callbacks.py ├── test │ └── test_train.py └── train.py └── utils ├── __init__.py ├── arg_utils.py ├── shared_args.py └── tensor_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /**/__pycache__ 2 | /**/*.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zheng He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import sys 10 | import torchvision 11 | 12 | from datasets import base 13 | from foundations.local import Platform 14 | 15 | class CIFAR10(torchvision.datasets.CIFAR10): 16 | """A subclass to suppress an annoying print statement in the torchvision CIFAR-10 library. 17 | 18 | Not strictly necessary - you can just use `torchvision.datasets.CIFAR10 if the print 19 | message doesn't bother you. 20 | """ 21 | 22 | def download(self): 23 | with open(os.devnull, 'w') as fp: 24 | sys.stdout = fp 25 | super(CIFAR10, self).download() 26 | sys.stdout = sys.__stdout__ 27 | 28 | 29 | class Dataset(base.ImageDataset): 30 | """The CIFAR-10 dataset.""" 31 | 32 | @staticmethod 33 | def num_train_examples(): return 50000 34 | 35 | @staticmethod 36 | def num_test_examples(): return 10000 37 | 38 | @staticmethod 39 | def num_classes(): return 10 40 | 41 | @staticmethod 42 | def get_train_set(use_augmentation, num_workers): 43 | augment = [torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomCrop(32, 4)] 44 | train_set = CIFAR10(train=True, root=os.path.join(Platform().dataset_root, 'cifar10'), download=True) 45 | return Dataset(train_set.data, np.array(train_set.targets), augment if use_augmentation else []) 46 | 47 | @staticmethod 48 | def get_test_set(num_workers): 49 | test_set = CIFAR10(train=False, root=os.path.join(Platform().dataset_root, 'cifar10'), download=True) 50 | return Dataset(test_set.data, np.array(test_set.targets)) 51 | 52 | def __init__(self, examples, labels, image_transforms=None): 53 | super(Dataset, self).__init__(examples, labels, image_transforms or [], 54 | [torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 55 | 56 | def asymmetric_noisy_labels(self, seed:int, fraction: float) -> None: 57 | """Inject asymmetric label noise into the specified fraction of the dataset by pair flipping.""" 58 | # https://github.com/shengliu66/ELR/blob/master/ELR/data_loader/cifar10.py 59 | _labels = self._labels.copy() 60 | for i in range(self.num_classes()): 61 | indices = np.where(_labels == i)[0] 62 | num_to_noisify_label_i = np.ceil(len(indices) * fraction).astype(int) 63 | np.random.RandomState(seed=seed+i).shuffle(indices) 64 | for j, idx in enumerate(indices): 65 | if j < num_to_noisify_label_i: 66 | # self.noise_indx.append(idx) 67 | # truck -> automobile 68 | if i == 9: 69 | self._labels[idx] = 1 70 | # bird -> airplane 71 | elif i == 2: 72 | self._labels[idx] = 0 73 | # cat -> dog 74 | elif i == 3: 75 | self._labels[idx] = 5 76 | # dog -> cat 77 | elif i == 5: 78 | self._labels[idx] = 3 79 | # deer -> horse 80 | elif i == 4: 81 | self._labels[idx] = 7 82 | 83 | 84 | def example_to_image(self, example): 85 | return Image.fromarray(example) 86 | 87 | 88 | 89 | DataLoader = base.DataLoader 90 | -------------------------------------------------------------------------------- /datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import sys 10 | import torchvision 11 | 12 | from datasets import base 13 | from foundations.local import Platform 14 | from numpy.testing import assert_array_almost_equal 15 | 16 | class CIFAR100(torchvision.datasets.CIFAR100): 17 | """A subclass to suppress an annoying print statement in the torchvision CIFAR-100 library. 18 | 19 | Not strictly necessary - you can just use `torchvision.datasets.CIFAR100 if the print 20 | message doesn't bother you. 21 | """ 22 | 23 | def download(self): 24 | with open(os.devnull, 'w') as fp: 25 | sys.stdout = fp 26 | super(CIFAR100, self).download() 27 | sys.stdout = sys.__stdout__ 28 | 29 | 30 | class Dataset(base.ImageDataset): 31 | """The CIFAR-100 dataset.""" 32 | 33 | @staticmethod 34 | def num_train_examples(): return 50000 35 | 36 | @staticmethod 37 | def num_test_examples(): return 10000 38 | 39 | @staticmethod 40 | def num_classes(): return 100 41 | 42 | @staticmethod 43 | def get_train_set(use_augmentation, num_workers): 44 | augment = [torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomCrop(32, 4)] 45 | train_set = CIFAR100(train=True, root=os.path.join(Platform().dataset_root, 'cifar100'), download=True) 46 | return Dataset(train_set.data, np.array(train_set.targets), augment if use_augmentation else []) 47 | 48 | @staticmethod 49 | def get_test_set(num_workers): 50 | test_set = CIFAR100(train=False, root=os.path.join(Platform().dataset_root, 'cifar100'), download=True) 51 | return Dataset(test_set.data, np.array(test_set.targets)) 52 | 53 | def __init__(self, examples, labels, image_transforms=None): 54 | super(Dataset, self).__init__(examples, labels, image_transforms or [], 55 | [torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 56 | 57 | def asymmetric_noisy_labels(self, seed:int, fraction: float) -> None: 58 | """Inject asymmetric label noise into the specified fraction of the dataset by pair flipping, 59 | by flipping each class into the next class within the same super-class.""" 60 | P = np.eye(self.num_classes()) 61 | # n = self.cfg_trainer['percent'] 62 | nb_superclasses = 20 63 | nb_subclasses = 5 64 | 65 | def build_transition_matrix(size, fraction): 66 | trans_matrix = (1. - fraction) * np.eye(size) 67 | for i in np.arange(size - 1): 68 | trans_matrix[i, i + 1] = fraction 69 | 70 | # adjust last row 71 | trans_matrix[size - 1, 0] = fraction 72 | 73 | assert_array_almost_equal(trans_matrix.sum(axis=1), 1, 1) 74 | return trans_matrix 75 | 76 | # if n > 0.0: 77 | for i in np.arange(nb_superclasses): 78 | init, end = i * nb_subclasses, (i+1) * nb_subclasses 79 | P[init:end, init:end] = build_transition_matrix(nb_subclasses, fraction) 80 | 81 | # y_train_noisy = self.multiclass_noisify(self.train_labels, P=P, 82 | # random_state=0) 83 | # actual_noise = (y_train_noisy != self.train_labels).mean() 84 | # assert actual_noise > 0.0 85 | # self.train_labels = y_train_noisy 86 | self.multiclass_labels_noisify(seed=seed, trans_matrix=P) 87 | 88 | 89 | def example_to_image(self, example): 90 | return Image.fromarray(example) 91 | 92 | 93 | DataLoader = base.DataLoader 94 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import concurrent 7 | import numpy as np 8 | import os 9 | from PIL import Image 10 | import torchvision 11 | 12 | from datasets import base 13 | from foundations.local import Platform 14 | 15 | 16 | def _get_samples(root, y_name, y_num): 17 | y_dir = os.path.join(root, y_name) 18 | if not os.path.isdir(y_dir): return [] 19 | output = [(os.path.join(y_dir, f), y_num) for f in os.listdir(y_dir) if f.lower().endswith('jpeg')] 20 | return output 21 | 22 | 23 | class Dataset(base.ImageDataset): 24 | """ImageNet""" 25 | 26 | def __init__(self, loc: str, image_transforms, num_workers=0): 27 | # Load the data. 28 | classes = sorted(os.listdir(loc)) 29 | samples = [] 30 | 31 | if num_workers > 0: 32 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) 33 | futures = [executor.submit(_get_samples, loc, y_name, y_num) for y_num, y_name in enumerate(classes)] 34 | for d in concurrent.futures.wait(futures)[0]: samples += d.result() 35 | else: 36 | for y_num, y_name in enumerate(classes): 37 | samples += _get_samples(loc, y_name, y_num) 38 | 39 | examples, labels = zip(*samples) 40 | super(Dataset, self).__init__( 41 | np.array(examples), np.array(labels), image_transforms, 42 | [torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 43 | 44 | @staticmethod 45 | def num_train_examples(): return 1281167 46 | 47 | @staticmethod 48 | def num_test_examples(): return 50000 49 | 50 | @staticmethod 51 | def num_classes(): return 1000 52 | 53 | @staticmethod 54 | def _augment_transforms(): 55 | return [ 56 | torchvision.transforms.RandomResizedCrop(224, scale=(0.1, 1.0), ratio=(0.8, 1.25)), 57 | torchvision.transforms.RandomHorizontalFlip() 58 | ] 59 | 60 | @staticmethod 61 | def _transforms(): 62 | return [torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224)] 63 | 64 | @staticmethod 65 | def get_train_set(use_augmentation, num_workers): 66 | transforms = Dataset._augment_transforms() if use_augmentation else Dataset._transforms() 67 | return Dataset(os.path.join(Platform().imagenet_root, 'train'), transforms, num_workers) 68 | 69 | @staticmethod 70 | def get_test_set(num_workers): 71 | return Dataset(os.path.join(Platform().imagenet_root, 'val'), Dataset._transforms(), num_workers) 72 | 73 | @staticmethod 74 | def example_to_image(example): 75 | with open(example, 'rb') as fp: 76 | return Image.open(fp).convert('RGB') 77 | 78 | 79 | DataLoader = base.DataLoader 80 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import os 7 | from PIL import Image 8 | import numpy as np 9 | import torchvision 10 | 11 | from datasets import base 12 | from foundations.local import Platform 13 | 14 | 15 | class Dataset(base.ImageDataset): 16 | """The MNIST dataset.""" 17 | 18 | @staticmethod 19 | def num_train_examples(): return 60000 20 | 21 | @staticmethod 22 | def num_test_examples(): return 10000 23 | 24 | @staticmethod 25 | def num_classes(): return 10 26 | 27 | @staticmethod 28 | def get_train_set(use_augmentation, num_workers): 29 | # No augmentation for MNIST. 30 | train_set = torchvision.datasets.MNIST( 31 | train=True, root=os.path.join(Platform().dataset_root, 'mnist'), download=True) 32 | return Dataset(train_set.data, train_set.targets) 33 | 34 | @staticmethod 35 | def get_test_set(num_workers): 36 | test_set = torchvision.datasets.MNIST( 37 | train=False, root=os.path.join(Platform().dataset_root, 'mnist'), download=True) 38 | return Dataset(test_set.data, test_set.targets) 39 | 40 | def __init__(self, examples, labels): 41 | tensor_transforms = [torchvision.transforms.Normalize(mean=[0.1307], std=[0.3081])] 42 | super(Dataset, self).__init__(examples, labels, [], tensor_transforms) 43 | 44 | def asymmetric_noisy_labels(self, seed:int, fraction: float) -> None: 45 | """Inject asymmetric label noise into the specified fraction of the dataset by pair flipping.""" 46 | # https://github.com/xiaoboxia/CDR/blob/main/utils.py 47 | 48 | P = np.eye(10) 49 | # 2 -> 7 50 | P[2, 2], P[2, 7] = 1. - fraction, fraction 51 | # 5 <-> 6 52 | P[5, 5], P[5, 6] = 1. - fraction, fraction 53 | P[6, 6], P[6, 5] = 1. - fraction, fraction 54 | # 3 -> 8 55 | P[3, 3], P[3, 8] = 1. - fraction, fraction 56 | 57 | self.multiclass_labels_noisify(seed=seed, trans_matrix=P) 58 | 59 | 60 | def example_to_image(self, example): 61 | return Image.fromarray(example.numpy(), mode='L') 62 | 63 | 64 | DataLoader = base.DataLoader 65 | -------------------------------------------------------------------------------- /datasets/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | 8 | from datasets import base, mnist, cifar10, cifar100, tiny_imagenet 9 | from foundations.hparams import DatasetHparams 10 | 11 | 12 | registered_datasets = {'mnist': mnist,'cifar10': cifar10, 'cifar100': cifar100, 13 | 'tiny_imagenet': tiny_imagenet} 14 | 15 | 16 | def get(dataset_hparams: DatasetHparams, train: bool = True, subsample_labels_type: str = None): 17 | """Get the train or test set corresponding to the hyperparameters.""" 18 | 19 | seed = dataset_hparams.transformation_seed or 0 20 | 21 | # Get the dataset itself. 22 | if dataset_hparams.dataset_name in registered_datasets: 23 | use_augmentation = train and not dataset_hparams.do_not_augment 24 | if train: 25 | dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_train_set(use_augmentation, 26 | dataset_hparams.num_workers) 27 | else: 28 | dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_test_set(dataset_hparams.num_workers) 29 | else: 30 | raise ValueError('No such dataset: {}'.format(dataset_hparams.dataset_name)) 31 | 32 | # Transform the dataset. 33 | if dataset_hparams.random_labels_fraction is not None and dataset_hparams.noisy_labels_fraction is not None: 34 | raise ValueError('random_labels_fraction and noisy_labels_fraction cannot be assigned at the same time.') 35 | 36 | if train and dataset_hparams.random_labels_fraction is not None: 37 | dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction) 38 | 39 | if train and dataset_hparams.noisy_labels_fraction is not None: 40 | if dataset_hparams.noisy_labels_type == 'symmetric': 41 | dataset.symmetric_noisy_labels(seed=seed, fraction=dataset_hparams.noisy_labels_fraction) 42 | elif dataset_hparams.noisy_labels_type == 'asymmetric': 43 | dataset.asymmetric_noisy_labels(seed=seed, fraction=dataset_hparams.noisy_labels_fraction) 44 | elif dataset_hparams.noisy_labels_type == 'pairflip': 45 | dataset.pairflip_noisy_labels(seed=seed, fraction=dataset_hparams.noisy_labels_fraction) 46 | elif dataset_hparams.noisy_labels_type is None: 47 | raise ValueError('Please specify the type of noisy labels.') 48 | else: 49 | raise ValueError('Noisy label type of {} is not implemented.'.format(dataset_hparams.noisy_labels_type)) 50 | 51 | if train and dataset_hparams.subsample_fraction is not None: 52 | dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction) 53 | 54 | if train and dataset_hparams.blur_factor is not None: 55 | if not isinstance(dataset, base.ImageDataset): 56 | raise ValueError('Can blur images.') 57 | else: 58 | dataset.blur(seed=seed, blur_factor=dataset_hparams.blur_factor) 59 | 60 | if dataset_hparams.unsupervised_labels is not None: 61 | if dataset_hparams.unsupervised_labels != 'rotation': 62 | raise ValueError('Unknown unsupervised labels: {}'.format(dataset_hparams.unsupervised_labels)) 63 | elif not isinstance(dataset, base.ImageDataset): 64 | raise ValueError('Can only do unsupervised rotation to images.') 65 | else: 66 | dataset.unsupervised_rotation(seed=seed) 67 | 68 | # Create the loader. 69 | return registered_datasets[dataset_hparams.dataset_name].DataLoader( 70 | dataset, batch_size=dataset_hparams.batch_size, num_workers=dataset_hparams.num_workers) 71 | 72 | 73 | def iterations_per_epoch(dataset_hparams: DatasetHparams): 74 | """Get the number of iterations per training epoch.""" 75 | 76 | if dataset_hparams.dataset_name in registered_datasets: 77 | num_train_examples = registered_datasets[dataset_hparams.dataset_name].Dataset.num_train_examples() 78 | else: 79 | raise ValueError('No such dataset: {}'.format(dataset_hparams.dataset_name)) 80 | 81 | if dataset_hparams.subsample_fraction is not None: 82 | num_train_examples *= dataset_hparams.subsample_fraction 83 | 84 | return np.ceil(num_train_examples / dataset_hparams.batch_size).astype(int) 85 | 86 | 87 | def num_classes(dataset_hparams: DatasetHparams): 88 | """Get the number of classes.""" 89 | 90 | if dataset_hparams.dataset_name in registered_datasets: 91 | num_classes = registered_datasets[dataset_hparams.dataset_name].Dataset.num_classes() 92 | else: 93 | raise ValueError('No such dataset: {}'.format(dataset_hparams.dataset_name)) 94 | 95 | if dataset_hparams.unsupervised_labels is not None: 96 | if dataset_hparams.unsupervised_labels != 'rotation': 97 | raise ValueError('Unknown unsupervised labels: {}'.format(dataset_hparams.unsupervised_labels)) 98 | else: 99 | return 4 100 | 101 | return num_classes 102 | -------------------------------------------------------------------------------- /datasets/test/test_cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | 8 | from datasets import cifar100 9 | from testing import test_case 10 | 11 | 12 | class TestDataset(test_case.TestCase): 13 | def setUp(self): 14 | super(TestDataset, self).setUp() 15 | self.test_set = cifar100.Dataset.get_test_set(num_workers=0) 16 | self.train_set = cifar100.Dataset.get_train_set(use_augmentation=True, num_workers=0) 17 | self.train_set_noaugment = cifar100.Dataset.get_train_set(use_augmentation=False, num_workers=0) 18 | 19 | def test_not_none(self): 20 | self.assertIsNotNone(self.test_set) 21 | self.assertIsNotNone(self.train_set) 22 | self.assertIsNotNone(self.train_set_noaugment) 23 | 24 | def test_size(self): 25 | self.assertEqual(cifar100.Dataset.num_classes(), 100) 26 | self.assertEqual(cifar100.Dataset.num_train_examples(), 50000) 27 | self.assertEqual(cifar100.Dataset.num_test_examples(), 10000) 28 | 29 | # test random labels 30 | def test_randomize_labels_half(self): 31 | labels_before = self.test_set._labels.tolist() 32 | self.test_set.randomize_labels(0, 0.5) 33 | examples_match = np.sum(np.equal(labels_before, self.test_set._labels).astype(int)) 34 | self.assertEqual(examples_match, 5048) 35 | 36 | def test_randomize_labels_none(self): 37 | labels_before = self.test_set._labels.tolist() 38 | self.test_set.randomize_labels(0, 0) 39 | examples_match = np.sum(np.equal(labels_before, self.test_set._labels).astype(int)) 40 | self.assertEqual(examples_match, 10000) 41 | 42 | def test_randomize_labels_all(self): 43 | labels_before = self.test_set._labels.tolist() 44 | self.test_set.randomize_labels(0, 1) 45 | examples_match = np.sum(np.equal(labels_before, self.test_set._labels).astype(int)) 46 | self.assertEqual(examples_match, 97) 47 | 48 | # test symmetric noisy labels 49 | def test_symmetric_noisy_labels_half(self): 50 | # labels_before = self.test_set._labels.tolist() 51 | labels_before = self.test_set._labels.copy() 52 | self.test_set.symmetric_noisy_labels(0, 0.5) 53 | labels_after = self.test_set._labels 54 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 55 | self.assertEqual(examples_match, 4964) 56 | for i in range(0,100): 57 | i_labels_after = labels_after[labels_before==i] 58 | num_i_class = labels_before.tolist().count(i) 59 | for j in range(100): 60 | num_i_to_j_class = i_labels_after.tolist().count(j) 61 | frac = num_i_to_j_class/num_i_class 62 | if i == j : 63 | self.assertAlmostEqual( frac, 0.5, delta=0.11) 64 | else: 65 | self.assertAlmostEqual( frac, 0.5/99, delta=0.1) 66 | 67 | def test_symmetric_noisy_labels_none(self): 68 | labels_before = self.test_set._labels.copy() 69 | self.test_set.symmetric_noisy_labels(0, 0) 70 | labels_after = self.test_set._labels 71 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 72 | self.assertEqual(examples_match, 10000) 73 | 74 | def test_symmetric_noisy_labels_all(self): 75 | labels_before = self.test_set._labels.copy() 76 | self.test_set.symmetric_noisy_labels(0, 1) 77 | labels_after = self.test_set._labels 78 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 79 | self.assertEqual(examples_match, 0) 80 | 81 | # test asymmetric noisy labels 82 | def test_asymmetric_noisy_labels_half(self): 83 | # labels_before = self.test_set._labels.tolist() 84 | labels_before = self.test_set._labels.copy() 85 | self.test_set.asymmetric_noisy_labels(1, 0.5) 86 | labels_after = self.test_set._labels 87 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 88 | self.assertEqual(examples_match, 4954) 89 | nb_superclasses = 20 90 | nb_subclasses = 5 91 | i = 0 92 | for sup in range(nb_superclasses): 93 | for sub in range(nb_subclasses): 94 | num_i_class = labels_before.tolist().count(i) 95 | i_labels_after = labels_after[labels_before==i] 96 | if sub == nb_subclasses - 1: 97 | num_i_to_j_class = i_labels_after.tolist().count(sup * nb_subclasses) 98 | else: 99 | num_i_to_j_class = i_labels_after.tolist().count(i + 1) 100 | frac = num_i_to_j_class/num_i_class 101 | self.assertAlmostEqual( frac , 0.5, delta=0.2) 102 | i += 1 103 | 104 | 105 | def test_asymmetric_noisy_labels_none(self): 106 | labels_before = self.test_set._labels.copy() 107 | self.test_set.asymmetric_noisy_labels(0, 0) 108 | labels_after = self.test_set._labels 109 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 110 | self.assertEqual(examples_match, 10000) 111 | 112 | def test_asymmetric_noisy_labels_all(self): 113 | labels_before = self.test_set._labels.copy() 114 | self.test_set.asymmetric_noisy_labels(0, 1) 115 | labels_after = self.test_set._labels 116 | examples_match = np.sum(np.equal(labels_before, labels_after).astype(int)) 117 | self.assertEqual(examples_match, 0) 118 | 119 | def test_subsample(self): 120 | # Subsample the test set. 121 | 122 | self.test_set.subsample(0, 0.1) 123 | self.assertEqual(len(self.test_set), 1000) 124 | 125 | self.train_set.subsample(0, 0.1) 126 | self.assertEqual(len(self.train_set), 5000) 127 | 128 | def test_subsample_twice(self): 129 | self.train_set.subsample(1, 0.1) 130 | with self.assertRaises(ValueError): 131 | self.train_set.subsample(1, 0.1) 132 | 133 | 134 | test_case.main() -------------------------------------------------------------------------------- /datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | # Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`. 2 | 3 | import concurrent 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | import torchvision 8 | from datasets import base 9 | from foundations.local import Platform 10 | 11 | 12 | def _get_samples(root, y_name, y_num): 13 | y_dir = os.path.join(root, y_name) 14 | # y_dir = os.path.join(y_dir, 'images') 15 | if not os.path.isdir(y_dir): return [] 16 | output = [(os.path.join(y_dir, f), y_num) for f in os.listdir(y_dir) if f.lower().endswith('jpeg')] 17 | return output 18 | 19 | 20 | class Dataset(base.ImageDataset): 21 | """Tiny-ImageNet""" 22 | 23 | def __init__(self, loc: str, image_transforms, num_workers=0): 24 | # Load the data. 25 | classes = sorted(os.listdir(loc)) 26 | samples = [] 27 | 28 | if num_workers > 0: 29 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) 30 | futures = [executor.submit(_get_samples, loc, y_name, y_num) for y_num, y_name in enumerate(classes)] 31 | for d in concurrent.futures.wait(futures)[0]: samples += d.result() 32 | else: 33 | for y_num, y_name in enumerate(classes): 34 | samples += _get_samples(loc, y_name, y_num) 35 | 36 | examples, labels = zip(*samples) 37 | super(Dataset, self).__init__( 38 | np.array(examples), np.array(labels), image_transforms, 39 | [torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 40 | 41 | @staticmethod 42 | def num_train_examples(): return 100000 43 | 44 | @staticmethod 45 | def num_test_examples(): return 10000 46 | 47 | @staticmethod 48 | def num_classes(): return 200 49 | 50 | @staticmethod 51 | def _augment_transforms(): 52 | return [ 53 | torchvision.transforms.RandomResizedCrop(32), 54 | torchvision.transforms.RandomHorizontalFlip() 55 | ] 56 | 57 | @staticmethod 58 | def _transforms(): 59 | return [torchvision.transforms.Resize(45), torchvision.transforms.CenterCrop(32)] 60 | 61 | @staticmethod 62 | def get_train_set(use_augmentation, num_workers): 63 | transforms = Dataset._augment_transforms() if use_augmentation else Dataset._transforms() 64 | return Dataset(os.path.join(Platform().tiny_imagenet_root, 'train'), transforms, num_workers) 65 | 66 | @staticmethod 67 | def get_test_set(num_workers): 68 | return Dataset(os.path.join(Platform().tiny_imagenet_root, 'val'), Dataset._transforms(), num_workers) 69 | 70 | @staticmethod 71 | def example_to_image(example): 72 | with open(example, 'rb') as fp: 73 | return Image.open(fp).convert('RGB') 74 | 75 | 76 | DataLoader = base.DataLoader 77 | -------------------------------------------------------------------------------- /experiments/branch/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import abc 7 | import argparse 8 | import inspect 9 | from dataclasses import dataclass, field, fields, make_dataclass 10 | import sys 11 | from typing import List 12 | 13 | from experiments.branch.desc import make_BranchDesc 14 | 15 | from experiments.finetune.desc import FinetuningDesc 16 | from experiments.lottery.desc import LotteryDesc 17 | from experiments.rewindLR.desc import RewindingDesc 18 | from experiments.scratch.desc import ScratchDesc 19 | from foundations.desc import Desc 20 | from foundations.hparams import Hparams 21 | from foundations.runner import Runner 22 | from utils import arg_utils, shared_args 23 | 24 | main_descs = {'finetune': FinetuningDesc, 'lottery': LotteryDesc, 'rewindLR': RewindingDesc, 'scratch': ScratchDesc} 25 | 26 | @dataclass 27 | class Branch(Runner): 28 | """A branch. Implement `branch_function`, add a name and description, and add to the registry.""" 29 | replicate: int 30 | levels: str 31 | desc: Desc 32 | verbose: bool = False 33 | level: int = None 34 | 35 | # Interface that needs to be overriden for each branch. 36 | @staticmethod 37 | @abc.abstractmethod 38 | def description() -> str: 39 | """A description of this branch. Override this.""" 40 | pass 41 | 42 | @staticmethod 43 | @abc.abstractmethod 44 | def name() -> str: 45 | """The name of this branch. Override this.""" 46 | pass 47 | 48 | @abc.abstractmethod 49 | def branch_function(self) -> None: 50 | """The method that is called to execute the branch. 51 | 52 | Override this method with any additional arguments that the branch will need. 53 | These arguments will be converted into command-line arguments for the branch. 54 | Each argument MUST have a type annotation. The first argument must still be self. 55 | """ 56 | pass 57 | 58 | # Interface that is useful for writing branches. 59 | @property 60 | def main_experiment(self) -> str: 61 | """The main experiments on which the branch is based""" 62 | 63 | main_experiment = sys.argv[1].split('_')[0] 64 | return main_experiment 65 | 66 | @property 67 | def main_desc(self) -> Desc: 68 | """The main description of this experiment.""" 69 | 70 | return self.desc.main_desc 71 | 72 | @property 73 | def experiment_name(self) -> str: 74 | """The name of this experiment.""" 75 | 76 | return self.desc.hashname 77 | 78 | @property 79 | def branch_root(self) -> str: 80 | """The root for where branch results will be stored for a specific invocation of run().""" 81 | 82 | return self.main_desc.run_path(self.replicate, self.level, self.experiment_name) 83 | 84 | @property 85 | def zero_branch_root(self) -> str: 86 | """The level_0 folder root of the main experiment.""" 87 | 88 | return self.main_desc.run_path(self.replicate, 0, self.experiment_name) 89 | 90 | @property 91 | def level_root(self) -> str: 92 | """The root of the main experiment on which this branch is based.""" 93 | 94 | return self.main_desc.run_path(self.replicate, self.level) 95 | 96 | # Interface that deals with command line arguments. 97 | @dataclass 98 | class ArgHparams(Hparams): 99 | levels: str 100 | pretrain_training_steps: str = None 101 | 102 | _name: str = 'Experiments Hyperparameters' 103 | _description: str = 'Hyperparameters that control the pruning and retraining process.' 104 | _levels: str = \ 105 | 'The pruning levels on which to run this branch. Can include a comma-separate list of levels or ranges, '\ 106 | 'e.g., 1,2-4,9' 107 | _pretrain_training_steps: str = 'The number of steps to train the network prior to the branch process.' 108 | 109 | @classmethod 110 | def add_args(cls, parser: argparse.ArgumentParser): 111 | defaults = shared_args.maybe_get_default_hparams() 112 | shared_args.JobArgs.add_args(parser) 113 | Branch.ArgHparams.add_args(parser) 114 | cls.BranchDesc.add_args(parser, defaults) 115 | 116 | @staticmethod 117 | def level_str_to_int_list(levels: str): 118 | level_list = [] 119 | elements = levels.split(',') 120 | for element in elements: 121 | if element.isdigit(): 122 | level_list.append(int(element)) 123 | elif len(element.split('-')) == 2: 124 | level_list += list(range(int(element.split('-')[0]), int(element.split('-')[1]) + 1)) 125 | else: 126 | raise ValueError(f'Invalid level: {element}') 127 | return sorted(list(set(level_list))) 128 | 129 | @classmethod 130 | def create_from_args(cls, args: argparse.Namespace): 131 | levels = Branch.level_str_to_int_list(args.levels) 132 | 133 | return cls(args.replicate, levels, cls.BranchDesc.create_from_args(args), not args.quiet) 134 | 135 | @classmethod 136 | def create_from_hparams(cls, replicate, levels: List[int], desc: Desc, hparams: Hparams, verbose=False): 137 | return cls(replicate, levels, cls.BranchDesc(desc, hparams), verbose) 138 | 139 | def display_output_location(self): 140 | print(self.branch_root) 141 | 142 | def run(self): 143 | for self.level in self.levels: 144 | if self.verbose: 145 | print('='*82) 146 | print(f'Branch {self.name()} (Replicate {self.replicate}, Level {self.level})\n' + '-'*82) 147 | print(f'{self.main_desc.display}\n{self.desc.branch_hparams.display}') 148 | print(f'Output Location: {self.branch_root}\n' + '='*82 + '\n') 149 | 150 | args = {f.name: getattr(self.desc.branch_hparams, f.name) 151 | for f in fields(self.BranchHparams) if not f.name.startswith('_')} 152 | self.branch_function(**args) 153 | 154 | # Initialize instances and subclasses (metaprogramming). 155 | def __init_subclass__(cls): 156 | """Metaprogramming: modify the attributes of the subclass based on information in run(). 157 | 158 | The goal is to make it possible for users to simply write a single run() method and have 159 | as much functionality as possible occur automatically. Specifically, this function converts 160 | the annotations and defaults in run() into a `BranchHparams` property. 161 | """ 162 | 163 | fields = [] 164 | for arg_name, parameter in list(inspect.signature(cls.branch_function).parameters.items())[1:]: 165 | t = parameter.annotation 166 | if t == inspect._empty: raise ValueError(f'Argument {arg_name} needs a type annotation.') 167 | elif t in [str, float, int, bool] or (isinstance(t, type) and issubclass(t, Hparams)): 168 | if parameter.default != inspect._empty: fields.append((arg_name, t, field(default=parameter.default))) 169 | else: fields.append((arg_name, t)) 170 | else: 171 | raise ValueError('Invalid branch type: {}'.format(parameter.annotation)) 172 | 173 | main_experiment = sys.argv[1].split('_')[0] if len(sys.argv) > 1 else None 174 | if main_experiment is not None and len(sys.argv[1].split('_'))==2: 175 | if main_experiment not in main_descs.keys(): 176 | raise ValueError('{} has not been registered as a main experiment'.format(main_experiment)) 177 | 178 | fields += [('_name', str, 'Branch Arguments'), ('_description', str, 'Arguments specific to the branch.')] 179 | setattr(cls, 'BranchHparams', make_dataclass('BranchHparams', fields, bases=(Hparams,))) 180 | setattr(cls, 'BranchDesc', make_BranchDesc(cls.BranchHparams, main_descs[main_experiment], cls.name())) 181 | -------------------------------------------------------------------------------- /experiments/branch/desc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | from dataclasses import dataclass 8 | 9 | from foundations import desc 10 | 11 | from foundations.desc import Desc 12 | 13 | 14 | def make_BranchDesc(BranchHparams: type, MainDesc: Desc, name: str): 15 | @dataclass 16 | class BranchDesc(desc.Desc): 17 | main_desc: MainDesc 18 | branch_hparams: BranchHparams 19 | 20 | @staticmethod 21 | def name_prefix(): return 'branch_' + name 22 | 23 | @staticmethod 24 | def add_args(parser: argparse.ArgumentParser, defaults: Desc = None): 25 | MainDesc.add_args(parser, defaults) 26 | BranchHparams.add_args(parser) 27 | 28 | @classmethod 29 | def create_from_args(cls, args: argparse.Namespace): 30 | return BranchDesc(MainDesc.create_from_args(args), BranchHparams.create_from_args(args)) 31 | 32 | return BranchDesc 33 | -------------------------------------------------------------------------------- /experiments/branch/randomly_prune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch 7 | 8 | from experiments.branch import base 9 | import models.registry 10 | from pruning.mask import Mask 11 | from pruning.pruned_model import PrunedModel 12 | from training import train 13 | from utils.tensor_utils import vectorize, unvectorize, shuffle_tensor, shuffle_state_dict 14 | 15 | 16 | class Branch(base.Branch): 17 | def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind', 18 | layers_to_ignore: str = ''): 19 | # Randomize the mask. 20 | mask = Mask.load(self.level_root) 21 | 22 | # Randomize while keeping the same layerwise proportions as the original mask. 23 | if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed)) 24 | 25 | # Randomize globally throughout all prunable layers. 26 | elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask)) 27 | 28 | # Randomize evenly across all layers. 29 | elif strategy == 'even': 30 | sparsity = mask.sparsity 31 | for i, k in sorted(mask.keys()): 32 | layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size), 33 | torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size)) 34 | mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size) 35 | 36 | # Identity. 37 | elif strategy == 'identity': pass 38 | 39 | # Error. 40 | else: raise ValueError(f'Invalid strategy: {strategy}') 41 | 42 | # Reset the masks of any layers that shouldn't be pruned. 43 | if layers_to_ignore: 44 | for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) 45 | 46 | # Save the new mask. 47 | mask.save(self.branch_root) 48 | 49 | # Determine the start step. 50 | if start_at == 'init': 51 | start_step = self.main_desc.str_to_step('0ep') 52 | state_step = start_step 53 | elif start_at == 'end': 54 | start_step = self.main_desc.str_to_step('0ep') 55 | state_step = self.main_desc.train_end_step 56 | elif start_at == 'rewind': 57 | start_step = self.main_desc.train_start_step 58 | state_step = start_step 59 | else: 60 | raise ValueError(f'Invalid starting point {start_at}') 61 | 62 | # Train the model with the new mask. 63 | # model_for_reset = models.registry.load(self.zero_branch_root, self.main_desc.train_end_step, 64 | # self.main_desc.model_hparams, self.main_desc.train_outputs, 65 | # self.main_desc.pruning_hparams.pruning_strategy) 66 | # freeze_pruned_weights = self.main_desc.pruning_hparams.freeze_pruned_weights 67 | model = PrunedModel(models.registry.load(self.level_root, state_step, self.main_desc.model_hparams, 68 | self.main_desc.train_outputs, self.main_desc.pruning_hparams.pruning_strategy), mask) 69 | train.standard_train(model, self.branch_root, self.main_desc.dataset_hparams, 70 | self.main_desc.training_hparams, start_step=start_step, verbose=self.verbose) 71 | 72 | @staticmethod 73 | def description(): 74 | return "Randomly prune the model." 75 | 76 | @staticmethod 77 | def name(): 78 | return 'randomly_prune' 79 | -------------------------------------------------------------------------------- /experiments/branch/randomly_reinitialize.py: -------------------------------------------------------------------------------- 1 | from experiments.branch import base 2 | import models.registry 3 | from pruning.mask import Mask 4 | from pruning.pruned_model import PrunedModel 5 | from training import train 6 | 7 | 8 | class Branch(base.Branch): 9 | def branch_function(self, start_at_step_zero: bool = False): 10 | # model_for_reset = models.registry.load(self.zero_branch_root, self.main_desc.train_end_step, 11 | # self.main_desc.model_hparams, self.main_desc.train_outputs) 12 | # freeze_pruned_weights = self.main_desc.pruning_hparams.freeze_pruned_weights 13 | model = PrunedModel(models.registry.get(self.main_desc.model_hparams, outputs=self.main_desc.train_outputs, 14 | pruning_strategy = self.main_desc.pruning_hparams.pruning_strategy), Mask.load(self.level_root)) 15 | start_step = self.main_desc.str_to_step('0it') if start_at_step_zero else self.main_desc.train_start_step 16 | Mask.load(self.level_root).save(self.branch_root) 17 | train.standard_train(model, self.branch_root, self.main_desc.dataset_hparams, 18 | self.main_desc.training_hparams, start_step=start_step, verbose=self.verbose) 19 | 20 | @staticmethod 21 | def description(): 22 | return "Randomly reinitialize the model." 23 | 24 | @staticmethod 25 | def name(): 26 | return 'randomly_reinitialize' -------------------------------------------------------------------------------- /experiments/branch/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from experiments.branch.base import Branch 7 | from experiments.branch import retrain, randomly_prune, randomly_reinitialize, oneshot_prune 8 | 9 | registered_branches = { 10 | 'randomly_prune': randomly_prune.Branch, 11 | 'randomly_reinitialize': randomly_reinitialize.Branch, 12 | 'retrain': retrain.Branch, 13 | 'oneshot_prune': oneshot_prune.Branch 14 | } 15 | 16 | 17 | def get(branch_name: str) -> Branch: 18 | if branch_name not in registered_branches: 19 | raise ValueError('No such branch: {}'.format(branch_name)) 20 | else: 21 | return registered_branches[branch_name] 22 | -------------------------------------------------------------------------------- /experiments/branch/retrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import copy 7 | import os 8 | from training import checkpointing, standard_callbacks 9 | import datasets.registry 10 | from foundations import hparams, paths 11 | from foundations.step import Step 12 | from experiments.branch import base 13 | import models.registry 14 | from pruning.mask import Mask 15 | from pruning.pruned_model import PrunedModel 16 | from training import train 17 | 18 | 19 | class Branch(base.Branch): 20 | def branch_function( 21 | self, 22 | retrain_d: hparams.DatasetHparams, 23 | retrain_t: hparams.TrainingHparams, 24 | start_at_step_zero: bool = False, 25 | ): 26 | 27 | evaluate_every_epoch: bool = True 28 | 29 | # Get the mask and model. 30 | m = models.registry.load(self.level_root, self.main_desc.train_start_step, self.main_desc.model_hparams, self.main_desc.train_outputs, 31 | self.main_desc.pruning_hparams.pruning_strategy) 32 | freeze_pruned_weights = self.main_desc.pruning_hparams.freeze_pruned_weights 33 | if freeze_pruned_weights == 'init': 34 | model_for_reset = copy.deepcopy(m) 35 | elif (freeze_pruned_weights == 'final' or freeze_pruned_weights == 'permuted') and self.level != 0: 36 | model_for_reset = models.registry.load(self.main_desc.run_path(self.replicate, self.level-1), self.main_desc.train_end_step, 37 | self.main_desc.model_hparams, self.main_desc.train_outputs, 38 | self.main_desc.pruning_hparams.pruning_strategy) 39 | else: 40 | model_for_reset = None 41 | 42 | m = PrunedModel(m, Mask.load(self.level_root), model_for_reset, freeze_pruned_weights) 43 | 44 | start_step = Step.from_iteration(0 if start_at_step_zero else self.main_desc.train_start_step.iteration, 45 | datasets.registry.iterations_per_epoch(retrain_d)) 46 | 47 | # If the model file for the end of training already exists in this location, do not train. 48 | iterations_per_epoch = datasets.registry.iterations_per_epoch(retrain_d) 49 | end_step = Step.from_str(retrain_t.training_steps, iterations_per_epoch) 50 | if (models.registry.exists(self.branch_root, end_step) and 51 | os.path.exists(paths.logger(self.branch_root))): return 52 | 53 | train_loader = datasets.registry.get(retrain_d, train=True) 54 | test_loader = datasets.registry.get(retrain_d, train=False) 55 | test_eval_callback = standard_callbacks.create_eval_callback('test', test_loader, verbose=self.verbose) 56 | train_eval_callback = standard_callbacks.create_eval_callback('train', train_loader, verbose=self.verbose) 57 | 58 | # Basic checkpointing and state saving at the beginning and end. 59 | result = [ 60 | standard_callbacks.run_at_step(start_step, standard_callbacks.save_model), 61 | standard_callbacks.run_at_step(end_step, standard_callbacks.save_model), 62 | standard_callbacks.run_at_step(end_step, standard_callbacks.save_logger), 63 | standard_callbacks.run_every_epoch(checkpointing.save_checkpoint_callback), 64 | ] 65 | 66 | # Test every epoch if requested. 67 | if self.verbose: result.append(standard_callbacks.run_every_epoch(standard_callbacks.create_timekeeper_callback())) 68 | 69 | # Ensure that testing occurs at least at the beginning and end of training. 70 | if start_step.it != 0 or not evaluate_every_epoch: result = [standard_callbacks.run_at_step(start_step, test_eval_callback)] + result 71 | if end_step.it != 0 or not evaluate_every_epoch: result = [standard_callbacks.run_at_step(end_step, test_eval_callback)] + result 72 | 73 | # Do the same for the train set if requested. 74 | if evaluate_every_epoch: result = [standard_callbacks.run_every_epoch(train_eval_callback)] + result 75 | 76 | if start_step.it != 0 or not evaluate_every_epoch: result = [standard_callbacks.run_at_step(start_step, train_eval_callback)] + result 77 | 78 | if end_step.it != 0 or not evaluate_every_epoch: result = [standard_callbacks.run_at_step(end_step, train_eval_callback)] + result 79 | 80 | train.train(retrain_t, m, train_loader, self.branch_root, result, start_step=start_step) 81 | 82 | 83 | 84 | @staticmethod 85 | def description(): 86 | return "Retrain the model with different hyperparameters." 87 | 88 | @staticmethod 89 | def name(): 90 | return 'retrain' 91 | 92 | 93 | -------------------------------------------------------------------------------- /experiments/branch/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | from dataclasses import dataclass 8 | import sys 9 | 10 | from utils import arg_utils 11 | from foundations.runner import Runner 12 | from experiments.branch import registry 13 | 14 | @dataclass 15 | class BranchRunner(Runner): 16 | """A meta-runner that calls the branch-specific runner.""" 17 | 18 | runner: Runner 19 | 20 | @staticmethod 21 | def description(): 22 | return "Run a branch of the main experiment." 23 | 24 | @staticmethod 25 | def add_args(parser): 26 | # Produce help text for selecting the branch. 27 | branch_names = sorted(registry.registered_branches.keys()) 28 | helptext = '='*82 + '\nA Framework on Pruning Robustness Based on open-lth\n' + '-'*82 29 | helptext += '\nChoose a branch to run:' 30 | for branch_name in branch_names: 31 | helptext += "\n * {} [main experiment]_branch {} [...] => {}".format( 32 | sys.argv[0], branch_name, registry.get(branch_name).description()) 33 | helptext += '\n' + '='*82 34 | 35 | # Print an error message if appropriate. 36 | runner_name = arg_utils.maybe_get_arg('subcommand', positional=True) 37 | branch_name = arg_utils.maybe_get_arg('subcommand', positional=True, position=1) 38 | if len(runner_name.split('_')) != 2 or branch_name not in branch_names: 39 | print(helptext) 40 | sys.exit(1) 41 | 42 | # Add the arguments for the branch. 43 | parser.add_argument('branch_name', type=str) 44 | registry.get(branch_name).add_args(parser) 45 | 46 | @staticmethod 47 | def create_from_args(args: argparse.Namespace): 48 | return BranchRunner(registry.get(sys.argv[2]).create_from_args(args)) 49 | 50 | def display_output_location(self): 51 | self.runner.display_output_location() 52 | 53 | def run(self) -> None: 54 | self.runner.run() 55 | -------------------------------------------------------------------------------- /experiments/finetune/desc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | from dataclasses import dataclass, replace 8 | import os 9 | from typing import Union 10 | 11 | from utils import arg_utils 12 | from datasets import registry as datasets_registry 13 | from foundations.desc import Desc 14 | from foundations import hparams 15 | from foundations.step import Step 16 | import pruning.registry 17 | from foundations.local import Platform 18 | 19 | @dataclass 20 | class FinetuningDesc(Desc): 21 | """The hyperparameters necessary to describe a pruning and finetuning backbone.""" 22 | 23 | model_hparams: hparams.ModelHparams 24 | dataset_hparams: hparams.DatasetHparams 25 | training_hparams: hparams.TrainingHparams 26 | pruning_hparams: hparams.PruningHparams 27 | finetuning_hparams: hparams.FinetuningHparams 28 | pretrain_dataset_hparams: hparams.DatasetHparams = None 29 | pretrain_training_hparams: hparams.TrainingHparams = None 30 | 31 | @staticmethod 32 | def name_prefix(): return 'finetune' 33 | 34 | @staticmethod 35 | def _add_pretrain_argument(parser): 36 | help_text = \ 37 | 'Perform a pre-training phase prior to running the main pruning and finetuning process. Setting this argument '\ 38 | 'will enable arguments to control how the dataset and training during this pre-training phase. ' 39 | parser.add_argument('--pretrain', action='store_true', help=help_text) 40 | 41 | @staticmethod 42 | def add_args(parser: argparse.ArgumentParser, defaults: 'FinetuningDesc' = None): 43 | 44 | # Add the finetuning/pretraining arguments. 45 | pretrain = arg_utils.maybe_get_arg('pretrain', boolean_arg=True) 46 | 47 | pretraining_parser = parser.add_argument_group( 48 | 'Pretraining Arguments', 'Arguments that control how the network is pre-trained') 49 | FinetuningDesc._add_pretrain_argument(pretraining_parser) 50 | 51 | # Get the proper pruning hparams. 52 | pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy') 53 | if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy 54 | if pruning_strategy: 55 | pruning_hparams = pruning.registry.get_pruning_hparams(pruning_strategy) 56 | if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy: 57 | def_ph = defaults.pruning_hparams 58 | else: 59 | def_ph = None 60 | else: 61 | pruning_hparams = hparams.PruningHparams 62 | def_ph = None 63 | 64 | # Add the main arguments. 65 | hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None) 66 | hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None) 67 | hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None) 68 | hparams.FinetuningHparams.add_args(parser, defaults=defaults.finetuning_hparams if defaults else None, prefix='finetune') 69 | pruning_hparams.add_args(parser, defaults=def_ph if defaults else None) 70 | # Set the finetuning arguments 71 | # def_ft = replace(defaults.training_hparams, **defaults.finetuning_hparams.__dict__) 72 | # hparams.TrainingHparams.add_args(parser, defaults=def_ft if defaults else None, prefix='finetune') 73 | 74 | # Handle pretraining. 75 | if pretrain: 76 | if defaults: def_th = replace(defaults.training_hparams, training_steps='0ep') 77 | hparams.TrainingHparams.add_args(parser, defaults=def_th if defaults else None, 78 | name='Training Hyperparameters for Pretraining', prefix='pretrain') 79 | hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None, 80 | name='Dataset Hyperparameters for Pretraining', prefix='pretrain') 81 | 82 | @classmethod 83 | def create_from_args(cls, args: argparse.Namespace) -> 'FinetuningDesc': 84 | # Get the main arguments. 85 | dataset_hparams = hparams.DatasetHparams.create_from_args(args) 86 | model_hparams = hparams.ModelHparams.create_from_args(args) 87 | training_hparams = hparams.TrainingHparams.create_from_args(args) 88 | pruning_hparams = pruning.registry.get_pruning_hparams(args.pruning_strategy).create_from_args(args) 89 | ft_hparams = hparams.FinetuningHparams.create_from_args(args, prefix='finetune') 90 | finetuning_hparams = replace(training_hparams, **ft_hparams.__dict__) 91 | 92 | # Create the desc. 93 | desc = cls(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 94 | 95 | # Handle pretraining. 96 | if args.pretrain and not Step.str_is_zero(args.pretrain_training_steps): 97 | desc.pretrain_dataset_hparams = hparams.DatasetHparams.create_from_args(args, prefix='pretrain') 98 | desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name 99 | desc.pretrain_training_hparams = hparams.TrainingHparams.create_from_args(args, prefix='pretrain') 100 | desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name 101 | 102 | return desc 103 | 104 | def str_to_step(self, s: str, pretrain: bool = False) -> Step: 105 | dataset_hparams = self.pretrain_dataset_hparams if pretrain else self.dataset_hparams 106 | iterations_per_epoch = datasets_registry.iterations_per_epoch(dataset_hparams) 107 | return Step.from_str(s, iterations_per_epoch) 108 | 109 | @property 110 | def pretrain_end_step(self): 111 | return self.str_to_step(self.pretrain_training_hparams.training_steps, True) 112 | 113 | @property 114 | def train_start_step(self): 115 | if self.pretrain_training_hparams: return self.str_to_step(self.pretrain_training_hparams.training_steps) 116 | else: return self.str_to_step('0it') 117 | 118 | @property 119 | def train_end_step(self): 120 | return self.str_to_step(self.training_hparams.training_steps) if self.training_hparams._convergence_training_steps is None \ 121 | else self.str_to_step(self.training_hparams._convergence_training_steps) 122 | 123 | @property 124 | def finetune_start_step(self): 125 | return self.str_to_step('0it') 126 | 127 | @property 128 | def finetune_end_step(self): 129 | return self.str_to_step(self.finetuning_hparams.training_steps) if self.finetuning_hparams._convergence_training_steps is None \ 130 | else self.str_to_step(self.finetuning_hparams._convergence_training_steps) 131 | 132 | @property 133 | def pretrain_outputs(self): 134 | return datasets_registry.num_classes(self.pretrain_dataset_hparams) 135 | 136 | @property 137 | def train_outputs(self): 138 | return datasets_registry.num_classes(self.dataset_hparams) 139 | 140 | def run_path(self, replicate: int, pruning_level: Union[str, int], experiment: str = 'main'): 141 | """The location where any run is stored.""" 142 | 143 | if not isinstance(replicate, int) or replicate <= 0: 144 | raise ValueError('Bad replicate: {}'.format(replicate)) 145 | 146 | return os.path.join(Platform().root, self.hashname, 147 | f'replicate_{replicate}', f'level_{pruning_level}', experiment) 148 | 149 | @property 150 | def display(self): 151 | ls = [self.dataset_hparams.display, self.model_hparams.display, 152 | self.training_hparams.display, self.pruning_hparams.display, 153 | self.finetuning_hparams.display] 154 | if self.pretrain_training_hparams: 155 | ls += [self.pretrain_dataset_hparams.display, self.pretrain_training_hparams.display] 156 | return '\n'.join(ls) 157 | -------------------------------------------------------------------------------- /experiments/lottery/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | import copy 8 | 9 | from utils import shared_args 10 | from dataclasses import dataclass 11 | from foundations.runner import Runner 12 | import models.registry 13 | from experiments.lottery.desc import LotteryDesc 14 | import pruning.registry 15 | from pruning.mask import Mask 16 | from pruning.pruned_model import PrunedModel 17 | from training import train 18 | 19 | 20 | @dataclass 21 | class LotteryRunner(Runner): 22 | replicate: int 23 | levels: int 24 | desc: LotteryDesc 25 | verbose: bool = True 26 | evaluate_every_epoch: bool = True 27 | 28 | @staticmethod 29 | def description(): 30 | return 'Run a lottery ticket hypothesis experiment.' 31 | 32 | @staticmethod 33 | def _add_levels_argument(parser): 34 | help_text = \ 35 | 'The number of levels of iterative pruning to perform. At each level, the network is trained to ' \ 36 | 'completion, pruned, and rewound, preparing it for the next iteration. The full network is trained ' \ 37 | 'at level 0, and level 1 is the first level at which pruning occurs. Set this argument to 0 to ' \ 38 | 'just train the full network or to N to prune the network N times.' 39 | parser.add_argument('--levels', required=True, type=int, help=help_text) 40 | 41 | @staticmethod 42 | def add_args(parser: argparse.ArgumentParser) -> None: 43 | # Get preliminary information. 44 | defaults = shared_args.maybe_get_default_hparams() 45 | 46 | # Add the job arguments. 47 | shared_args.JobArgs.add_args(parser) 48 | lottery_parser = parser.add_argument_group( 49 | 'Lottery Ticket Hyperparameters', 'Hyperparameters that control the lottery ticket process.') 50 | LotteryRunner._add_levels_argument(lottery_parser) 51 | LotteryDesc.add_args(parser, defaults) 52 | 53 | @staticmethod 54 | def create_from_args(args: argparse.Namespace) -> 'LotteryRunner': 55 | return LotteryRunner(args.replicate, args.levels, LotteryDesc.create_from_args(args), 56 | not args.quiet, not args.evaluate_only_at_end) 57 | 58 | def display_output_location(self): 59 | print(self.desc.run_path(self.replicate, 0)) 60 | 61 | def run(self) -> None: 62 | if self.verbose: 63 | print('='*82 + f'\nLottery Ticket Experiment (Replicate {self.replicate})\n' + '-'*82) 64 | print(self.desc.display) 65 | print(f'Output Location: {self.desc.run_path(self.replicate, 0)}' + '\n' + '='*82 + '\n') 66 | 67 | self.desc.save(self.desc.run_path(self.replicate, 0)) 68 | if self.desc.pretrain_training_hparams: self._pretrain() 69 | self._establish_initial_weights() 70 | 71 | for level in range(self.levels+1): 72 | self._prune_level(level) 73 | self._train_level(level) 74 | 75 | # Helper methods for running the lottery. 76 | def _pretrain(self): 77 | location = self.desc.run_path(self.replicate, 'pretrain') 78 | if models.registry.exists(location, self.desc.pretrain_end_step): return 79 | 80 | if self.verbose: print('-'*82 + '\nPretraining\n' + '-'*82) 81 | model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs, 82 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 83 | train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams, 84 | verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) 85 | 86 | def _establish_initial_weights(self): 87 | location = self.desc.run_path(self.replicate, 0) 88 | if models.registry.exists(location, self.desc.train_start_step): return 89 | 90 | new_model = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs, 91 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 92 | 93 | # If there was a pretrained model, retrieve its final weights and adapt them for training. 94 | if self.desc.pretrain_training_hparams is not None: 95 | pretrain_loc = self.desc.run_path(self.replicate, 'pretrain') 96 | old = models.registry.load(pretrain_loc, self.desc.pretrain_end_step, 97 | self.desc.model_hparams, self.desc.pretrain_outputs, 98 | self.desc.pruning_hparams.pruning_strategy) 99 | state_dict = {k: v for k, v in old.state_dict().items()} 100 | 101 | # Select a new output layer if number of classes differs. 102 | if self.desc.train_outputs != self.desc.pretrain_outputs: 103 | state_dict.update({k: new_model.state_dict()[k] for k in new_model.output_layer_names}) 104 | 105 | new_model.load_state_dict(state_dict) 106 | 107 | new_model.save(location, self.desc.train_start_step) 108 | 109 | def _train_level(self, level: int): 110 | location = self.desc.run_path(self.replicate, level) 111 | if models.registry.exists(location, self.desc.train_end_step): return 112 | 113 | model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, 114 | self.desc.model_hparams, self.desc.train_outputs, 115 | self.desc.pruning_hparams.pruning_strategy) 116 | 117 | freeze_pruned_weights = self.desc.pruning_hparams.freeze_pruned_weights 118 | if freeze_pruned_weights == 'init' and level != 0: 119 | model_for_reset = copy.deepcopy(model) 120 | elif (freeze_pruned_weights == 'final' or freeze_pruned_weights == 'permuted') and level != 0: 121 | model_for_reset = models.registry.load(self.desc.run_path(self.replicate, level-1), self.desc.train_end_step, 122 | self.desc.model_hparams, self.desc.train_outputs, 123 | self.desc.pruning_hparams.pruning_strategy) 124 | else: 125 | model_for_reset = None 126 | pruned_model = PrunedModel(model, Mask.load(location),model_for_reset, freeze_pruned_weights) 127 | pruned_model.save(location, self.desc.train_start_step) 128 | if self.verbose: 129 | print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82) 130 | train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, 131 | start_step=self.desc.train_start_step, verbose=self.verbose, 132 | evaluate_every_epoch=self.evaluate_every_epoch) 133 | 134 | def _prune_level(self, level: int): 135 | new_location = self.desc.run_path(self.replicate, level) 136 | if Mask.exists(new_location): return 137 | 138 | if level == 0: 139 | Mask.ones_like(models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs, 140 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy)).save(new_location) 141 | else: 142 | old_location = self.desc.run_path(self.replicate, level-1) 143 | model = models.registry.load(old_location, self.desc.train_end_step, 144 | self.desc.model_hparams, self.desc.train_outputs, 145 | self.desc.pruning_hparams.pruning_strategy) 146 | pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location), self.desc.dataset_hparams).save(new_location) 147 | -------------------------------------------------------------------------------- /experiments/rewindLR/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | import copy 8 | 9 | from numpy import mod 10 | 11 | from utils import shared_args 12 | from dataclasses import dataclass 13 | from foundations.runner import Runner 14 | import models.registry 15 | from experiments.rewindLR.desc import RewindingDesc 16 | import pruning.registry 17 | from pruning.mask import Mask 18 | from pruning.pruned_model import PrunedModel 19 | from training import train 20 | 21 | 22 | @dataclass 23 | class RewindingRunner(Runner): 24 | replicate: int 25 | levels: int 26 | desc: RewindingDesc 27 | verbose: bool = True 28 | evaluate_every_epoch: bool = True 29 | 30 | @staticmethod 31 | def description(): 32 | return 'Run a pruning and rewinding experiment.' 33 | 34 | @staticmethod 35 | def _add_levels_argument(parser): 36 | help_text = \ 37 | 'The number of levels of iterative pruning to perform. At each level, the network is trained to ' \ 38 | 'completion, pruned, and rewound, preparing it for the next iteration. The full network is trained ' \ 39 | 'at level 0, and level 1 is the first level at which pruning occurs. Set this argument to 0 to ' \ 40 | 'just train the full network or to N to prune the network N times.' 41 | parser.add_argument('--levels', required=True, type=int, help=help_text) 42 | 43 | @staticmethod 44 | def add_args(parser: argparse.ArgumentParser) -> None: 45 | # Get preliminary information. 46 | defaults = shared_args.maybe_get_default_hparams() 47 | 48 | # Add the job arguments. 49 | shared_args.JobArgs.add_args(parser) 50 | rewinding_parser = parser.add_argument_group( 51 | 'Rewinding Hyperparameters', 'Hyperparameters that control the pruning and rewinding process.') 52 | RewindingRunner._add_levels_argument(rewinding_parser) 53 | RewindingDesc.add_args(parser, defaults) 54 | 55 | @staticmethod 56 | def create_from_args(args: argparse.Namespace) -> 'RewindingRunner': 57 | return RewindingRunner(args.replicate, args.levels, RewindingDesc.create_from_args(args), 58 | not args.quiet, not args.evaluate_only_at_end) 59 | 60 | def display_output_location(self): 61 | print(self.desc.run_path(self.replicate, 0)) 62 | 63 | def run(self) -> None: 64 | if self.verbose: 65 | print('='*82 + f'\nThe learning rate rewinding Experiment (Replicate {self.replicate})\n' + '-'*82) 66 | print(self.desc.display) 67 | print(f'Output Location: {self.desc.run_path(self.replicate, 0)}' + '\n' + '='*82 + '\n') 68 | 69 | self.desc.save(self.desc.run_path(self.replicate, 0)) 70 | if self.desc.pretrain_training_hparams: self._pretrain() 71 | self._establish_initial_weights() 72 | 73 | for level in range(self.levels+1): 74 | self._prune_level(level) 75 | self._train_level(level) 76 | 77 | # Helper methods for running the pruning and rewinding process. 78 | def _pretrain(self): 79 | location = self.desc.run_path(self.replicate, 'pretrain') 80 | if models.registry.exists(location, self.desc.pretrain_end_step): return 81 | 82 | if self.verbose: print('-'*82 + '\nPretraining\n' + '-'*82) 83 | model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs, 84 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 85 | train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams, 86 | verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) 87 | 88 | def _establish_initial_weights(self): 89 | location = self.desc.run_path(self.replicate, 0) 90 | if models.registry.exists(location, self.desc.train_start_step): return 91 | 92 | new_model = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs, 93 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 94 | 95 | # If there was a pretrained model, retrieve its final weights and adapt them for training. 96 | if self.desc.pretrain_training_hparams is not None: 97 | pretrain_loc = self.desc.run_path(self.replicate, 'pretrain') 98 | old = models.registry.load(pretrain_loc, self.desc.pretrain_end_step, 99 | self.desc.model_hparams, self.desc.pretrain_outputs, 100 | self.desc.pruning_hparams.pruning_strategy) 101 | state_dict = {k: v for k, v in old.state_dict().items()} 102 | 103 | # Select a new output layer if number of classes differs. 104 | if self.desc.train_outputs != self.desc.pretrain_outputs: 105 | state_dict.update({k: new_model.state_dict()[k] for k in new_model.output_layer_names}) 106 | 107 | new_model.load_state_dict(state_dict) 108 | 109 | new_model.save(location, self.desc.train_start_step) 110 | 111 | def _train_level(self, level: int): 112 | location = self.desc.run_path(self.replicate, level) 113 | if models.registry.exists(location, self.desc.train_end_step): return 114 | 115 | # use the final weight values from the end of training from the previous level 116 | if level != 0: 117 | model = models.registry.load(self.desc.run_path(self.replicate, level-1), self.desc.train_end_step, 118 | self.desc.model_hparams, self.desc.train_outputs, 119 | self.desc.pruning_hparams.pruning_strategy) 120 | # else use the initial weight values from level 0 121 | else: 122 | model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, 123 | self.desc.model_hparams, self.desc.train_outputs, 124 | self.desc.pruning_hparams.pruning_strategy) 125 | 126 | freeze_pruned_weights = self.desc.pruning_hparams.freeze_pruned_weights 127 | if freeze_pruned_weights == 'init' and level != 0: 128 | model_for_reset = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, 129 | self.desc.model_hparams, self.desc.train_outputs, 130 | self.desc.pruning_hparams.pruning_strategy) 131 | elif (freeze_pruned_weights == 'final' or freeze_pruned_weights == 'permuted') and level != 0: 132 | model_for_reset = copy.deepcopy(model) 133 | else: 134 | model_for_reset = None 135 | pruned_model = PrunedModel(model, Mask.load(location), model_for_reset, freeze_pruned_weights) 136 | pruned_model.save(location, self.desc.train_start_step) 137 | if self.verbose: 138 | print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82) 139 | train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, 140 | start_step=self.desc.train_start_step, verbose=self.verbose, 141 | evaluate_every_epoch=self.evaluate_every_epoch) 142 | 143 | def _prune_level(self, level: int): 144 | new_location = self.desc.run_path(self.replicate, level) 145 | if Mask.exists(new_location): return 146 | 147 | if level == 0: 148 | Mask.ones_like(models.registry.get(self.desc.model_hparams, self.desc.train_outputs, 149 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy)).save(new_location) 150 | else: 151 | old_location = self.desc.run_path(self.replicate, level-1) 152 | model = models.registry.load(old_location, self.desc.train_end_step, 153 | self.desc.model_hparams, self.desc.train_outputs, 154 | self.desc.pruning_hparams.pruning_strategy) 155 | pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location), self.desc.dataset_hparams).save(new_location) 156 | -------------------------------------------------------------------------------- /experiments/runner_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from foundations.runner import Runner 7 | from training.runner import TrainingRunner 8 | from experiments.rewindLR.runner import RewindingRunner 9 | from experiments.lottery.runner import LotteryRunner 10 | from experiments.finetune.runner import FinetuningRunner 11 | from experiments.scratch.runner import ScratchRunner 12 | from experiments.branch.runner import BranchRunner 13 | 14 | registered_runners = {'train': TrainingRunner, 'rewindLR': RewindingRunner, 'lottery': LotteryRunner, 'finetune': FinetuningRunner, 15 | 'scratch': ScratchRunner, 'branch': BranchRunner} 16 | 17 | 18 | def get(runner_name: str) -> Runner: 19 | if runner_name not in registered_runners: 20 | raise ValueError('No such runner: {}'.format(runner_name)) 21 | else: 22 | return registered_runners[runner_name] -------------------------------------------------------------------------------- /experiments/scratch/desc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | import copy 8 | from dataclasses import dataclass, replace 9 | import os 10 | from typing import Union 11 | 12 | from utils import arg_utils 13 | from datasets import registry as datasets_registry 14 | from foundations.desc import Desc 15 | from foundations import hparams 16 | from foundations.step import Step 17 | import pruning.registry 18 | from foundations.local import Platform 19 | @dataclass 20 | class ScratchDesc(Desc): 21 | """The hyperparameters necessary to describe a pruning and retraining from re-initialized scratch backbone.""" 22 | 23 | model_hparams: hparams.ModelHparams 24 | dataset_hparams: hparams.DatasetHparams 25 | training_hparams: hparams.TrainingHparams 26 | pruning_hparams: hparams.PruningHparams 27 | pretrain_dataset_hparams: hparams.DatasetHparams = None 28 | pretrain_training_hparams: hparams.TrainingHparams = None 29 | 30 | @staticmethod 31 | def name_prefix(): return 'scratch' 32 | 33 | @staticmethod 34 | def _add_pretrain_argument(parser): 35 | help_text = \ 36 | 'Perform a pre-training phase prior to running the main process. Setting this argument will enable '\ 37 | 'arguments to control how the dataset and training during this pre-training phase. Rewinding '\ 38 | 'is a specific case of pre-training where rewinding uses the same dataset and training procedure '\ 39 | 'as the main training run.' 40 | parser.add_argument('--pretrain', action='store_true', help=help_text) 41 | 42 | @staticmethod 43 | def add_args(parser: argparse.ArgumentParser, defaults: 'ScratchDesc' = None): 44 | pretrain = arg_utils.maybe_get_arg('pretrain', boolean_arg=True) 45 | pretraining_parser = parser.add_argument_group( 46 | 'Pretraining Arguments', 'Arguments that control how the network is pre-trained') 47 | ScratchDesc._add_pretrain_argument(pretraining_parser) 48 | # Get the proper pruning hparams. 49 | pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy') 50 | if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy 51 | if pruning_strategy: 52 | pruning_hparams = pruning.registry.get_pruning_hparams(pruning_strategy) 53 | if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy: 54 | def_ph = defaults.pruning_hparams 55 | else: 56 | def_ph = None 57 | else: 58 | pruning_hparams = hparams.PruningHparams 59 | def_ph = None 60 | 61 | # Add the main arguments. 62 | hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None) 63 | hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None) 64 | hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None) 65 | pruning_hparams.add_args(parser, defaults=def_ph if defaults else None) 66 | 67 | # Handle pretraining. 68 | if pretrain: 69 | if defaults: def_th = replace(defaults.training_hparams, training_steps='0ep') 70 | hparams.TrainingHparams.add_args(parser, defaults=def_th if defaults else None, 71 | name='Training Hyperparameters for Pretraining', prefix='pretrain') 72 | hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None, 73 | name='Dataset Hyperparameters for Pretraining', prefix='pretrain') 74 | 75 | @classmethod 76 | def create_from_args(cls, args: argparse.Namespace) -> 'ScratchDesc': 77 | # Get the main arguments. 78 | dataset_hparams = hparams.DatasetHparams.create_from_args(args) 79 | model_hparams = hparams.ModelHparams.create_from_args(args) 80 | training_hparams = hparams.TrainingHparams.create_from_args(args) 81 | pruning_hparams = pruning.registry.get_pruning_hparams(args.pruning_strategy).create_from_args(args) 82 | 83 | # Create the desc. 84 | desc = cls(model_hparams, dataset_hparams, training_hparams, pruning_hparams) 85 | 86 | # Handle pretraining. 87 | if args.pretrain and not Step.str_is_zero(args.pretrain_training_steps): 88 | desc.pretrain_dataset_hparams = hparams.DatasetHparams.create_from_args(args, prefix='pretrain') 89 | desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name 90 | desc.pretrain_training_hparams = hparams.TrainingHparams.create_from_args(args, prefix='pretrain') 91 | desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name 92 | 93 | return desc 94 | 95 | def str_to_step(self, s: str, pretrain: bool = False) -> Step: 96 | dataset_hparams = self.pretrain_dataset_hparams if pretrain else self.dataset_hparams 97 | iterations_per_epoch = datasets_registry.iterations_per_epoch(dataset_hparams) 98 | return Step.from_str(s, iterations_per_epoch) 99 | 100 | @property 101 | def pretrain_end_step(self): 102 | return self.str_to_step(self.pretrain_training_hparams.training_steps, True) 103 | 104 | @property 105 | def train_start_step(self): 106 | if self.pretrain_training_hparams: return self.str_to_step(self.pretrain_training_hparams.training_steps) 107 | else: return self.str_to_step('0it') 108 | 109 | @property 110 | def train_end_step(self): 111 | return self.str_to_step(self.training_hparams.training_steps) if self.training_hparams._convergence_training_steps is None \ 112 | else self.str_to_step(self.training_hparams._convergence_training_steps) 113 | 114 | @property 115 | def pretrain_outputs(self): 116 | return datasets_registry.num_classes(self.pretrain_dataset_hparams) 117 | 118 | @property 119 | def train_outputs(self): 120 | return datasets_registry.num_classes(self.dataset_hparams) 121 | 122 | def run_path(self, replicate: int, pruning_level: Union[str, int], experiment: str = 'main'): 123 | """The location where any run is stored.""" 124 | 125 | if not isinstance(replicate, int) or replicate <= 0: 126 | raise ValueError('Bad replicate: {}'.format(replicate)) 127 | 128 | return os.path.join(Platform().root, self.hashname, 129 | f'replicate_{replicate}', f'level_{pruning_level}', experiment) 130 | 131 | @property 132 | def display(self): 133 | ls = [self.dataset_hparams.display, self.model_hparams.display, 134 | self.training_hparams.display, self.pruning_hparams.display] 135 | if self.pretrain_training_hparams: 136 | ls += [self.pretrain_dataset_hparams.display, self.pretrain_training_hparams.display] 137 | return '\n'.join(ls) 138 | -------------------------------------------------------------------------------- /experiments/scratch/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | import copy 8 | 9 | from utils import shared_args 10 | from dataclasses import dataclass 11 | from foundations.runner import Runner 12 | import models.registry 13 | from experiments.scratch.desc import ScratchDesc 14 | import pruning.registry 15 | from pruning.mask import Mask 16 | from pruning.pruned_model import PrunedModel 17 | from training import train 18 | 19 | 20 | @dataclass 21 | class ScratchRunner(Runner): 22 | replicate: int 23 | levels: int 24 | desc: ScratchDesc 25 | verbose: bool = True 26 | evaluate_every_epoch: bool = True 27 | 28 | @staticmethod 29 | def description(): 30 | return 'Run a pruning and retraining from re-initialized scratch experiment.' 31 | 32 | @staticmethod 33 | def _add_levels_argument(parser): 34 | help_text = \ 35 | 'The number of levels of iterative pruning to perform. At each level, the network is trained to ' \ 36 | 'completion, pruned, and retrained, preparing it for the next iteration. The full network is trained ' \ 37 | 'at level 0, and level 1 is the first level at which pruning occurs. Set this argument to 0 to ' \ 38 | 'just train the full network or to N to prune the network N times.' 39 | parser.add_argument('--levels', required=True, type=int, help=help_text) 40 | 41 | @staticmethod 42 | def add_args(parser: argparse.ArgumentParser) -> None: 43 | # Get preliminary information. 44 | defaults = shared_args.maybe_get_default_hparams() 45 | 46 | # Add the job arguments. 47 | shared_args.JobArgs.add_args(parser) 48 | scratch_parser = parser.add_argument_group( 49 | 'Scratch Hyperparameters', 'Hyperparameters that control the pruning and retraining from scratch process.') 50 | ScratchRunner._add_levels_argument(scratch_parser) 51 | ScratchDesc.add_args(parser, defaults) 52 | 53 | @staticmethod 54 | def create_from_args(args: argparse.Namespace) -> 'ScratchRunner': 55 | return ScratchRunner(args.replicate, args.levels, ScratchDesc.create_from_args(args), 56 | not args.quiet, not args.evaluate_only_at_end) 57 | 58 | def display_output_location(self): 59 | print(self.desc.run_path(self.replicate, 0)) 60 | 61 | def run(self) -> None: 62 | if self.verbose: 63 | print('='*82 + f'\nThe Pruning and Retraining from Scratch Experiment (Replicate {self.replicate})\n' + '-'*82) 64 | print(self.desc.display) 65 | print(f'Output Location: {self.desc.run_path(self.replicate, 0)}' + '\n' + '='*82 + '\n') 66 | 67 | self.desc.save(self.desc.run_path(self.replicate, 0)) 68 | if self.desc.pretrain_training_hparams: self._pretrain() 69 | 70 | for level in range(self.levels+1): 71 | self._establish_initial_weights(level) 72 | self._prune_level(level) 73 | self._train_level(level) 74 | 75 | # Helper methods for running the pruning and rewinding process. 76 | def _pretrain(self): 77 | location = self.desc.run_path(self.replicate, 'pretrain') 78 | if models.registry.exists(location, self.desc.pretrain_end_step): return 79 | 80 | if self.verbose: print('-'*82 + '\nPretraining\n' + '-'*82) 81 | model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs, 82 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 83 | train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams, 84 | verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) 85 | 86 | def _establish_initial_weights(self, level): 87 | location = self.desc.run_path(self.replicate, level) 88 | if models.registry.exists(location, self.desc.train_start_step): return 89 | 90 | new_model = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs, 91 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy) 92 | 93 | # If there was a pretrained model, retrieve its final weights and adapt them for training (only for level 0 94 | # models of other levels will be loaded with re-initialization). 95 | if self.desc.pretrain_training_hparams is not None and level == 0: 96 | pretrain_loc = self.desc.run_path(self.replicate, 'pretrain') 97 | old = models.registry.load(pretrain_loc, self.desc.pretrain_end_step, 98 | self.desc.model_hparams, self.desc.pretrain_outputs, 99 | self.desc.pruning_hparams.pruning_strategy) 100 | state_dict = {k: v for k, v in old.state_dict().items()} 101 | 102 | # Select a new output layer if number of classes differs. 103 | if self.desc.train_outputs != self.desc.pretrain_outputs: 104 | state_dict.update({k: new_model.state_dict()[k] for k in new_model.output_layer_names}) 105 | 106 | new_model.load_state_dict(state_dict) 107 | 108 | new_model.save(location, self.desc.train_start_step) 109 | 110 | def _train_level(self, level: int): 111 | location = self.desc.run_path(self.replicate, level) 112 | if models.registry.exists(location, self.desc.train_end_step): return 113 | # use the randomly re-initialized weights 114 | model = models.registry.load(location, self.desc.train_start_step, 115 | self.desc.model_hparams, self.desc.train_outputs, 116 | self.desc.pruning_hparams.pruning_strategy) 117 | 118 | freeze_pruned_weights = self.desc.pruning_hparams.freeze_pruned_weights 119 | if freeze_pruned_weights == 'init' and level != 0: 120 | model_for_reset = copy.deepcopy(model) 121 | elif (freeze_pruned_weights == 'final' or freeze_pruned_weights == 'permuted') and level != 0: 122 | model_for_reset = models.registry.load(self.desc.run_path(self.replicate, level-1), self.desc.train_end_step, 123 | self.desc.model_hparams, self.desc.train_outputs, 124 | self.desc.pruning_hparams.pruning_strategy) 125 | else: 126 | model_for_reset = None 127 | 128 | pruned_model = PrunedModel(model, Mask.load(location), model_for_reset, freeze_pruned_weights) 129 | pruned_model.save(location, self.desc.train_start_step) 130 | if self.verbose: 131 | print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82) 132 | train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, 133 | start_step=self.desc.train_start_step, verbose=self.verbose, 134 | evaluate_every_epoch=self.evaluate_every_epoch) 135 | 136 | def _prune_level(self, level: int): 137 | new_location = self.desc.run_path(self.replicate, level) 138 | if Mask.exists(new_location): return 139 | 140 | if level == 0: 141 | Mask.ones_like(models.registry.get(self.desc.model_hparams, self.desc.train_outputs, 142 | pruning_strategy = self.desc.pruning_hparams.pruning_strategy)).save(new_location) 143 | else: 144 | old_location = self.desc.run_path(self.replicate, level-1) 145 | model = models.registry.load(old_location, self.desc.train_end_step, 146 | self.desc.model_hparams, self.desc.train_outputs, 147 | self.desc.pruning_hparams.pruning_strategy) 148 | pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location), self.desc.dataset_hparams).save(new_location) 149 | -------------------------------------------------------------------------------- /experiments/scratch/test/test_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import copy 7 | import json 8 | import numpy as np 9 | import os 10 | import torch 11 | 12 | import datasets.registry 13 | from foundations import paths 14 | from foundations.step import Step 15 | from experiments.scratch.runner import ScratchRunner 16 | from experiments.scratch.desc import ScratchDesc 17 | import models.registry 18 | from pruning.mask import Mask 19 | from testing import test_case 20 | 21 | 22 | class TestRunner(test_case.TestCase): 23 | def setUp(self): 24 | super(TestRunner, self).setUp() 25 | desc = models.registry.get_default_hparams('cifar_resnet_8_2') 26 | self.desc = ScratchDesc(desc.model_hparams, desc.dataset_hparams, desc.training_hparams, desc.pruning_hparams) 27 | 28 | def to_step(self, s): 29 | return Step.from_str(s, datasets.registry.iterations_per_epoch(self.desc.dataset_hparams)) 30 | 31 | def assertLevelFilesPresent(self, level_root, start_step, end_step, masks=False): 32 | with self.subTest(level_root=level_root): 33 | self.assertTrue(os.path.exists(paths.model(level_root, start_step))) 34 | self.assertTrue(os.path.exists(paths.model(level_root, end_step))) 35 | self.assertTrue(os.path.exists(paths.logger(level_root))) 36 | if masks: 37 | self.assertTrue(os.path.exists(paths.mask(level_root))) 38 | self.assertTrue(os.path.exists(paths.sparsity_report(level_root))) 39 | 40 | def test_level0_2it(self): 41 | self.desc.training_hparams.training_steps = '2it' 42 | ScratchRunner(replicate=2, levels=0, desc=self.desc, verbose=False).run() 43 | level_root = self.desc.run_path(2, 0) 44 | 45 | # Ensure the important files are there. 46 | self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) 47 | 48 | # Ensure that the mask is all 1's. 49 | mask = Mask.load(level_root) 50 | for v in mask.numpy().values(): self.assertTrue(np.all(np.equal(v, 1))) 51 | with open(paths.sparsity_report(level_root)) as fp: 52 | sparsity_report = json.loads(fp.read()) 53 | self.assertEqual(sparsity_report['unpruned'] / sparsity_report['total'], 1) 54 | 55 | 56 | def test_level3_2it(self): 57 | self.desc.training_hparams.training_steps = '2it' 58 | ScratchRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() 59 | 60 | level0_weights = paths.model(self.desc.run_path(2, 0), self.to_step('0it')) 61 | level0_weights = {k: v.cpu().numpy() for k, v in torch.load(level0_weights).items()} 62 | 63 | for level in range(0, 4): 64 | level_root = self.desc.run_path(2, level) 65 | self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) 66 | 67 | # Check the mask. 68 | pct = 0.8**level 69 | mask = Mask.load(level_root).numpy() 70 | 71 | # Check the mask itself. 72 | total, total_present = 0.0, 0.0 73 | for v in mask.values(): 74 | total += v.size 75 | total_present += np.sum(v) 76 | self.assertTrue(np.allclose(pct, total_present / total, atol=0.01)) 77 | 78 | # Check the sparsity report. 79 | with open(paths.sparsity_report(level_root)) as fp: 80 | sparsity_report = json.loads(fp.read()) 81 | self.assertTrue(np.allclose(pct, sparsity_report['unpruned'] / sparsity_report['total'], atol=0.01)) 82 | 83 | # Ensure that initial weights of each level are different from the original initialization. 84 | if level != 0: 85 | level_weights = paths.model(level_root, self.to_step('0it')) 86 | level_weights = {k: v.cpu().numpy() for k, v in torch.load(level_weights).items()} 87 | for k in level0_weights: 88 | if 'weight' in k: 89 | self.assertFalse((level_weights[k]==level0_weights[k] * mask.get(k, 1)).all()) 90 | 91 | # Ensure that initial weights of each level are different from the initial weights of the previous level. 92 | if level != 0: 93 | previous_level_start_weights = paths.model(self.desc.run_path(2, level-1), self.to_step('0it')) 94 | previous_level_start_weights = {k: v.cpu().numpy() for k, v in torch.load(previous_level_start_weights).items()} 95 | for k in level_weights: 96 | if 'weight' in k: 97 | self.assertFalse((level_weights[k]==previous_level_start_weights[k] * mask.get(k, 1)).all()) 98 | # self.assertStateAllNotEqual(level_weights, {k: v * mask.get(k, 1) for k, v in previous_level_end_weights.items()}) 99 | 100 | 101 | 102 | test_case.main() 103 | -------------------------------------------------------------------------------- /figs/sparsedd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/he-zh/sparse-double-descent/f9bbdccab9d520ec40b8b9f051a7269b2cca8caa/figs/sparsedd.png -------------------------------------------------------------------------------- /foundations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /foundations/desc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import abc 7 | import argparse 8 | from dataclasses import dataclass, fields 9 | import hashlib 10 | import os 11 | from foundations.hparams import Hparams 12 | from foundations import paths 13 | 14 | 15 | @dataclass 16 | class Desc(abc.ABC): 17 | """The bundle of hyperparameters necessary for a particular kind of job. Contains many hparams objects. 18 | 19 | Each hparams object should be a field of this dataclass. 20 | """ 21 | 22 | @staticmethod 23 | @abc.abstractmethod 24 | def name_prefix() -> str: 25 | """The name to prefix saved runs with.""" 26 | 27 | pass 28 | 29 | @property 30 | def hashname(self) -> str: 31 | """The name under which experiments with these hyperparameters will be stored.""" 32 | 33 | fields_dict = {f.name: getattr(self, f.name) for f in fields(self)} 34 | hparams_strs = [str(fields_dict[k]) for k in sorted(fields_dict) if isinstance(fields_dict[k], Hparams)] 35 | hash_str = hashlib.md5(';'.join(hparams_strs).encode('utf-8')).hexdigest() 36 | return f'{self.name_prefix()}_{hash_str}' 37 | 38 | @staticmethod 39 | @abc.abstractmethod 40 | def add_args(parser: argparse.ArgumentParser, defaults: 'Desc' = None) -> None: 41 | """Add the necessary command-line arguments.""" 42 | 43 | pass 44 | 45 | @staticmethod 46 | @abc.abstractmethod 47 | def create_from_args(args: argparse.Namespace) -> 'Desc': 48 | """Create from command line arguments.""" 49 | 50 | pass 51 | 52 | def save(self, output_location): 53 | if not os.path.exists(output_location): os.makedirs(output_location) 54 | if os.path.exists(paths.hparams(output_location)): 55 | return 56 | fields_dict = {f.name: getattr(self, f.name) for f in fields(self)} 57 | hparams_strs = [fields_dict[k].display for k in sorted(fields_dict) if isinstance(fields_dict[k], Hparams)] 58 | with open(paths.hparams(output_location), 'w') as fp: 59 | fp.write('\n'.join(hparams_strs)) 60 | -------------------------------------------------------------------------------- /foundations/local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from dataclasses import dataclass 7 | import foundations 8 | import os 9 | import torch 10 | import pathlib 11 | 12 | from foundations.hparams import Hparams 13 | 14 | @dataclass 15 | class Platform(Hparams): 16 | fix_all_random_seeds: int = None 17 | torch_seed: int = None 18 | 19 | _name: str = 'Platform Hyperparameters' 20 | _description: str = 'Hyperparameters that control the plaform on which the job is run.' 21 | _fix_all_random_seeds: int = 'The random seed to control cpu, gpu, data loader and random mask, this will make reproducibility possible' 22 | _torch_seed: str = 'The pytorch random seed that controls the randomness for cpu and cuda, like model initialization' 23 | 24 | # Manage the available devices. 25 | 26 | @property 27 | def device_str(self): 28 | # GPU device. 29 | if torch.cuda.is_available() and torch.cuda.device_count() > 0: 30 | return 'cuda' 31 | # CPU device. 32 | else: 33 | return 'cpu' 34 | @property 35 | def device_ids(self): 36 | if torch.cuda.is_available() and torch.cuda.device_count() > 0: 37 | device_ids = [int(x) for x in range(torch.cuda.device_count())] 38 | return device_ids 39 | else: return None 40 | 41 | @property 42 | def torch_device(self): 43 | return torch.device(self.device_str) 44 | 45 | @property 46 | def is_parallel(self): 47 | return torch.cuda.is_available() and torch.cuda.device_count() > 1 48 | 49 | # important root for datasets and stored files 50 | 51 | @property 52 | def root(self): 53 | return os.path.join(pathlib.Path.home(), '/data/hezheng/result') 54 | 55 | @property 56 | def dataset_root(self): 57 | return os.path.join(pathlib.Path.home(), '/data/hezheng/datasets/') 58 | 59 | @property 60 | def tiny_imagenet_root(self): 61 | return os.path.join(pathlib.Path.home(), '/data/hezheng/datasets/tiny-imagenet-200') 62 | -------------------------------------------------------------------------------- /foundations/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import os 7 | 8 | 9 | def checkpoint(root): return os.path.join(root, 'checkpoint.pth') 10 | 11 | 12 | def logger(root): return os.path.join(root, 'logger') 13 | 14 | 15 | def mask(root): return os.path.join(root, 'mask.pth') 16 | 17 | 18 | def sparsity_report(root): return os.path.join(root, 'sparsity_report.json') 19 | 20 | 21 | def model(root, step): return os.path.join(root, 'model_ep{}_it{}.pth'.format(step.ep, step.it)) 22 | 23 | 24 | def hparams(root): return os.path.join(root, 'hparams.log') 25 | -------------------------------------------------------------------------------- /foundations/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import abc 7 | import argparse 8 | 9 | 10 | class Runner(abc.ABC): 11 | """An instance of a training run of some kind.""" 12 | 13 | @staticmethod 14 | @abc.abstractmethod 15 | def description() -> str: 16 | """A description of this runner.""" 17 | 18 | pass 19 | 20 | @staticmethod 21 | @abc.abstractmethod 22 | def add_args(parser: argparse.ArgumentParser) -> None: 23 | """Add all command line flags necessary for this runner.""" 24 | 25 | pass 26 | 27 | @staticmethod 28 | @abc.abstractmethod 29 | def create_from_args(args: argparse.Namespace) -> 'Runner': 30 | """Create a runner from command line arguments.""" 31 | 32 | pass 33 | 34 | @abc.abstractmethod 35 | def display_output_location(self) -> None: 36 | """Print the output location for the job.""" 37 | 38 | pass 39 | 40 | @abc.abstractmethod 41 | def run(self) -> None: 42 | """Run the job.""" 43 | 44 | pass 45 | -------------------------------------------------------------------------------- /foundations/step.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | 7 | class Step: 8 | """Represents a particular step of training. 9 | 10 | A step can be represented as either an iteration or a pair of an epoch and an iteration within that epoch. 11 | This class encapsulates a step of training such that it can be freely converted between the two representations. 12 | """ 13 | 14 | def __init__(self, iteration: int, iterations_per_epoch: int) -> 'Step': 15 | if iteration < 0: raise ValueError('iteration must >= 0.') 16 | if iterations_per_epoch <= 0: raise ValueError('iterations_per_epoch must be > 0.') 17 | self._iteration = iteration 18 | self._iterations_per_epoch = iterations_per_epoch 19 | 20 | @staticmethod 21 | def str_is_zero(s: str): 22 | return s in ['0ep', '0it', '0ep0it'] 23 | 24 | @staticmethod 25 | def from_iteration(iteration: int, iterations_per_epoch: int) -> 'Step': 26 | return Step(iteration, iterations_per_epoch) 27 | 28 | @staticmethod 29 | def from_epoch(epoch: int, iteration: int, iterations_per_epoch: int) -> 'Step': 30 | return Step(epoch * iterations_per_epoch + iteration, iterations_per_epoch) 31 | 32 | @staticmethod 33 | def from_str(s: str, iterations_per_epoch: int) -> 'Step': 34 | """Creates a step from a string that describes the number of epochs, iterations, or both. 35 | 36 | Epochs: '120ep' 37 | Iterations: '2000it' 38 | Both: '120ep50it'""" 39 | 40 | if 'ep' in s and 'it' in s: 41 | ep = int(s.split('ep')[0]) 42 | it = int(s.split('ep')[1].split('it')[0]) 43 | if s != '{}ep{}it'.format(ep, it): raise ValueError('Malformed string step: {}'.format(s)) 44 | return Step.from_epoch(ep, it, iterations_per_epoch) 45 | elif 'ep' in s: 46 | ep = int(s.split('ep')[0]) 47 | if s != '{}ep'.format(ep): raise ValueError('Malformed string step: {}'.format(s)) 48 | return Step.from_epoch(ep, 0, iterations_per_epoch) 49 | elif 'it' in s: 50 | it = int(s.split('it')[0]) 51 | if s != '{}it'.format(it): raise ValueError('Malformed string step: {}'.format(s)) 52 | return Step.from_iteration(it, iterations_per_epoch) 53 | else: 54 | raise ValueError('Malformed string step: {}'.format(s)) 55 | 56 | @staticmethod 57 | def zero(iterations_per_epoch: int) -> 'Step': 58 | return Step(0, iterations_per_epoch) 59 | 60 | @property 61 | def iteration(self): 62 | """The overall number of steps of training completed so far.""" 63 | return self._iteration 64 | 65 | @property 66 | def ep(self): 67 | """The current epoch of training.""" 68 | return self._iteration // self._iterations_per_epoch 69 | 70 | @property 71 | def it(self): 72 | """The iteration within the current epoch of training.""" 73 | return self._iteration % self._iterations_per_epoch 74 | 75 | def _check(self, other): 76 | if not isinstance(other, Step): 77 | raise ValueError('Invalid type for other: {}.'.format(type(other))) 78 | if self._iterations_per_epoch != other._iterations_per_epoch: 79 | raise ValueError('Cannot compare steps when epochs are of different lengths.') 80 | 81 | def __lt__(self, other): 82 | self._check(other) 83 | return self._iteration < other._iteration 84 | 85 | def __le__(self, other): 86 | self._check(other) 87 | return self._iteration <= other._iteration 88 | 89 | def __eq__(self, other): 90 | self._check(other) 91 | return self._iteration == other._iteration 92 | 93 | def __ne__(self, other): 94 | self._check(other) 95 | return self._iteration != other._iteration 96 | 97 | def __gt__(self, other): 98 | self._check(other) 99 | return self._iteration > other._iteration 100 | 101 | def __ge__(self, other): 102 | self._check(other) 103 | return self._iteration >= other._iteration 104 | 105 | def __str__(self): 106 | return '(Iteration {}; Iterations per Epoch: {})'.format(self._iteration, self._iterations_per_epoch) 107 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | import sys 8 | import numpy as np 9 | from numpy import random 10 | import os 11 | import torch 12 | 13 | from experiments import runner_registry 14 | from foundations.local import Platform 15 | from utils import arg_utils 16 | 17 | 18 | def main(): 19 | # The welcome message. 20 | welcome = '='*82 + '\nA Framework on Sparse Double Descent Based on open-lth\n' + '-'*82 21 | 22 | # Choose an initial command. 23 | helptext = welcome + "\nChoose a command to run:" 24 | startup_path = sys.argv[0].split('/')[-1] 25 | for name, runner in runner_registry.registered_runners.items(): 26 | if name != 'branch': 27 | helptext += "\n * {} {} [...] => {}".format(startup_path, name, runner.description()) 28 | else: 29 | for _name, _runner in runner_registry.registered_runners.items(): 30 | if _name == name: continue 31 | helptext += "\n * {} {}_{} [...] => {}".format(startup_path, _name, name, runner.description()) 32 | helptext += '\n' + '='*82 33 | 34 | runner_name = arg_utils.maybe_get_arg('subcommand', positional=True) 35 | if runner_name is None or runner_name.split('_')[-1] not in runner_registry.registered_runners: 36 | print(helptext) 37 | sys.exit(1) 38 | 39 | runner_name = runner_name.split('_')[-1] 40 | # Add the arguments for that command. 41 | usage = '\n' + welcome + '\n' 42 | usage += 'main.py {} [...] => {}'.format(runner_name, runner_registry.get(runner_name).description()) 43 | usage += '\n' + '='*82 + '\n' 44 | 45 | parser = argparse.ArgumentParser(usage=usage, conflict_handler='resolve') 46 | parser.add_argument('subcommand') 47 | parser.add_argument('--display_output_location','-d', action='store_true', 48 | help='Display the output location for this job.') 49 | 50 | parser.add_argument('--gpu', type = str, default='3', help='The GPU devices to run the job') 51 | 52 | # Get the platform arguments. 53 | Platform.add_args(parser) 54 | 55 | # Add arguments for the various runners. 56 | runner_registry.get(runner_name).add_args(parser) 57 | 58 | args = parser.parse_args() 59 | 60 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 61 | 62 | if args.fix_all_random_seeds: 63 | for key in args.__dict__.keys(): 64 | if 'torch_seed' in key or 'data_order_seed' in key or 'transformation_seed' in key or 'random_mask_seed' in key: 65 | setattr(args, key, args.fix_all_random_seeds) 66 | # args.torch_seed = args.fix_all_random_seeds 67 | # args.data_order_seed = args.fix_all_random_seeds 68 | # args.transformation_seed = args.fix_all_random_seeds 69 | # if hasattr(args, 'random_mask_seed'): args.random_mask_seed = args.fix_all_random_seeds 70 | 71 | platform = Platform.create_from_args(args) 72 | torch_seed = platform.torch_seed 73 | if torch_seed is not None: 74 | torch.manual_seed(torch_seed) 75 | torch.backends.cudnn.deterministic = True 76 | torch.backends.cudnn.benchmark = False 77 | 78 | if args.display_output_location: 79 | runner_registry.get(runner_name).create_from_args(args).display_output_location() 80 | sys.exit(0) 81 | 82 | experiment_runner = runner_registry.get(runner_name).create_from_args(args) 83 | experiment_runner.run() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import abc 7 | import torch 8 | import typing 9 | import os 10 | from foundations.step import Step 11 | 12 | 13 | 14 | class Model(torch.nn.Module, abc.ABC): 15 | """The base class used by all models in this codebase.""" 16 | 17 | _prunable_layer_type: str = 'default' 18 | 19 | @staticmethod 20 | @abc.abstractmethod 21 | def is_valid_model_name(model_name: str) -> bool: 22 | """Is the model name string a valid name for models in this class?""" 23 | 24 | pass 25 | 26 | @staticmethod 27 | @abc.abstractmethod 28 | def get_model_from_name( 29 | model_name: str, 30 | outputs: int, 31 | initializer: typing.Callable[[torch.nn.Module], None] 32 | ) -> 'Model': 33 | """Returns an instance of this class as described by the model_name string.""" 34 | 35 | pass 36 | 37 | @property 38 | def prunable_layer_type(self) -> str: 39 | """The type of nn.module that is valid for pruning. 40 | 41 | By default, only the weights of convolutional and linear layers are prunable. 42 | If network-slimming pruning is utilized, then BN will be set as the prunable type. 43 | """ 44 | return self._prunable_layer_type 45 | 46 | @prunable_layer_type.setter 47 | def prunable_layer_type(self, type: str): 48 | if type in ['default', 'BN']: 49 | self._prunable_layer_type = type 50 | else: 51 | raise ValueError('Not recognized prunabel_layer_type: {}'.format(type)) 52 | 53 | @property 54 | def prunable_layer_names(self) -> typing.List[str]: 55 | """A list of the names of Tensors of this model that are valid for pruning. 56 | 57 | By default, only the weights of convolutional and linear layers are prunable. 58 | If network-slimming pruning is utilized, then the weights and biases of batch normlization 59 | layers will be set as prunable. 60 | """ 61 | if self.prunable_layer_type == 'BN': 62 | return [name + m for name, module in self.named_modules() if 63 | isinstance(module, torch.nn.modules.BatchNorm2d) for m in ['.weight', '.bias']] 64 | else: 65 | return [name + '.weight' for name, module in self.named_modules() if 66 | isinstance(module, torch.nn.modules.conv.Conv2d) or 67 | isinstance(module, torch.nn.modules.linear.Linear)] 68 | 69 | @property 70 | @abc.abstractmethod 71 | def output_layer_names(self) -> typing.List[str]: 72 | """A list of the names of the Tensors of the output layer of this model.""" 73 | 74 | pass 75 | 76 | @property 77 | @abc.abstractmethod 78 | def loss_criterion(self) -> torch.nn.Module: 79 | """The loss criterion to use for this model.""" 80 | 81 | pass 82 | 83 | def updateBN(self): 84 | """ 85 | Add additional subgradient descent of batch normalization weights on the sparsity-induced penalty term 86 | for network-slimming pruning 87 | """ 88 | pass 89 | 90 | def save(self, save_location: str, save_step: Step): 91 | if not os.path.exists(save_location): os.makedirs(save_location) 92 | torch.save(self.state_dict(), os.path.join(save_location, 'model_ep{}_it{}.pth'.format(save_step.ep, save_step.it))) 93 | 94 | 95 | class DataParallel(Model, torch.nn.DataParallel): 96 | def __init__(self, module: Model): 97 | super(DataParallel, self).__init__(module=module) 98 | 99 | @property 100 | def prunable_layer_type(self): return self.module.prunable_layer_type 101 | 102 | @property 103 | def prunable_layer_names(self): return self.module.prunable_layer_names 104 | 105 | @property 106 | def output_layer_names(self): return self.module.output_layer_names 107 | 108 | @property 109 | def loss_criterion(self): return self.module.loss_criterion 110 | 111 | @staticmethod 112 | def get_model_from_name(model_name, outputs, initializer): raise NotImplementedError 113 | 114 | @staticmethod 115 | def is_valid_model_name(model_name): raise NotImplementedError 116 | 117 | @staticmethod 118 | def default_hparams(): raise NotImplementedError 119 | 120 | def updateBN(self): 121 | return self.module.updateBN() 122 | 123 | def save(self, save_location: str, save_step: Step): 124 | self.module.save(save_location, save_step) 125 | 126 | 127 | class DistributedDataParallel(Model, torch.nn.parallel.DistributedDataParallel): 128 | def __init__(self, module: Model, device_ids): 129 | super(DistributedDataParallel, self).__init__(module=module, device_ids=device_ids) 130 | 131 | @property 132 | def prunable_layer_type(self): return self.module.prunable_layer_type 133 | 134 | @property 135 | def prunable_layer_names(self): return self.module.prunable_layer_names 136 | 137 | @property 138 | def output_layer_names(self): return self.module.output_layer_names 139 | 140 | @property 141 | def loss_criterion(self): return self.module.loss_criterion 142 | 143 | @staticmethod 144 | def get_model_from_name(model_name, outputs, initializer): raise NotImplementedError 145 | 146 | @staticmethod 147 | def is_valid_model_name(model_name): raise NotImplementedError 148 | 149 | @staticmethod 150 | def default_hparams(): raise NotImplementedError 151 | 152 | def updateBN(self): 153 | return self.module.updateBN() 154 | 155 | def save(self, save_location: str, save_step: Step): 156 | self.module.save(save_location, save_step) 157 | -------------------------------------------------------------------------------- /models/bn_initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch 7 | 8 | 9 | def uniform(w): 10 | if isinstance(w, torch.nn.BatchNorm2d): 11 | w.weight.data = torch.rand(w.weight.data.shape) 12 | w.bias.data = torch.zeros_like(w.bias.data) 13 | 14 | 15 | def fixed(w): 16 | if isinstance(w, torch.nn.BatchNorm2d): 17 | w.weight.data = torch.ones_like(w.weight.data) 18 | w.bias.data = torch.zeros_like(w.bias.data) 19 | 20 | 21 | def oneone(w): 22 | if isinstance(w, torch.nn.BatchNorm2d): 23 | w.weight.data = torch.ones_like(w.weight.data) 24 | w.bias.data = torch.ones_like(w.bias.data) 25 | 26 | 27 | def positivenegative(w): 28 | if isinstance(w, torch.nn.BatchNorm2d): 29 | uniform(w) 30 | w.weight.data = w.weight.data * 2 - 1 31 | w.bias.data = torch.zeros_like(w.bias.data) 32 | -------------------------------------------------------------------------------- /models/cifar_pytorch_resnet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnet.py 2 | 3 | from functools import partial 4 | import torch 5 | import torchvision 6 | 7 | from foundations import hparams 8 | from models import base 9 | from pruning import magnitude 10 | from experiments.finetune.desc import FinetuningDesc 11 | 12 | class ResNet(torchvision.models.ResNet): 13 | def __init__(self, block, layers, num_classes=100, width=64): 14 | """To make it possible to vary the width, we need to override the constructor of the torchvision resnet.""" 15 | 16 | torch.nn.Module.__init__(self) # Skip the parent constructor. This replaces it. 17 | self._norm_layer = torch.nn.BatchNorm2d 18 | self.inplanes = width 19 | self.dilation = 1 20 | self.groups = 1 21 | self.base_width = 64 22 | 23 | # The initial convolutional layer. 24 | self.conv1 = torch.nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False) 25 | self.bn1 = self._norm_layer(self.inplanes) 26 | self.relu = torch.nn.ReLU(inplace=True) 27 | # self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | 29 | # The subsequent blocks. 30 | self.layer1 = self._make_layer(block, width, layers[0], stride=1) 31 | self.layer2 = self._make_layer(block, width*2, layers[1], stride=2, dilate=False) 32 | self.layer3 = self._make_layer(block, width*4, layers[2], stride=2, dilate=False) 33 | self.layer4 = self._make_layer(block, width*8, layers[3], stride=2, dilate=False) 34 | 35 | # The last layers. 36 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 37 | self.fc = torch.nn.Linear(width*8*block.expansion, num_classes) 38 | 39 | # Default init. 40 | for m in self.modules(): 41 | if isinstance(m, torch.nn.Conv2d): 42 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 43 | elif isinstance(m, torch.nn.BatchNorm2d): 44 | torch.nn.init.constant_(m.weight, 1) 45 | torch.nn.init.constant_(m.bias, 0) 46 | 47 | def _forward_impl(self, x): 48 | x = self.conv1(x) 49 | x = self.bn1(x) 50 | x = self.relu(x) 51 | # x = self.maxpool(x) 52 | 53 | x = self.layer1(x) 54 | x = self.layer2(x) 55 | x = self.layer3(x) 56 | x = self.layer4(x) 57 | 58 | x = self.avgpool(x) 59 | x = torch.flatten(x, 1) 60 | x = self.fc(x) 61 | 62 | return x 63 | 64 | 65 | class Model(base.Model): 66 | """A residual neural network as originally designed for ImageNet.""" 67 | 68 | def __init__(self, model_fn, initializer, outputs=None): 69 | super(Model, self).__init__() 70 | 71 | self.model = model_fn(num_classes=outputs or 100) 72 | self.criterion = torch.nn.CrossEntropyLoss() 73 | self.apply(initializer) 74 | 75 | def forward(self, x): 76 | return self.model(x) 77 | 78 | @property 79 | def output_layer_names(self): 80 | return ['model.fc.weight', 'model.fc.bias'] 81 | 82 | @staticmethod 83 | def is_valid_model_name(model_name): 84 | return (model_name.startswith('cifar_pytorch_resnet_') and 85 | 5 >= len(model_name.split('_')) >= 4 and 86 | model_name.split('_')[3].isdigit() and 87 | int(model_name.split('_')[3]) in [18, 34, 50, 101, 152, 200]) 88 | 89 | @staticmethod 90 | def get_model_from_name(model_name, initializer, outputs=100): 91 | """Name: cifar_pytorch_resnet_D[_W]. 92 | 93 | D is the model depth (e.g., 50 for ResNet-50). W is the model width - the number of filters in the first 94 | residual layers. By default, this number is 64.""" 95 | 96 | if not Model.is_valid_model_name(model_name): 97 | raise ValueError('Invalid model name: {}'.format(model_name)) 98 | 99 | num = int(model_name.split('_')[3]) 100 | if num == 18: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [2, 2, 2, 2]) 101 | elif num == 34: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [3, 4, 6, 3]) 102 | elif num == 50: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 6, 3]) 103 | elif num == 101: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 23, 3]) 104 | elif num == 152: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 8, 36, 3]) 105 | elif num == 200: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 24, 36, 3]) 106 | elif num == 269: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 30, 48, 8]) 107 | 108 | if len(model_name.split('_')) == 5: 109 | width = int(model_name.split('_')[4]) 110 | model_fn = partial(model_fn, width=width) 111 | 112 | return Model(model_fn, initializer, outputs) 113 | 114 | @property 115 | def loss_criterion(self): 116 | return self.criterion 117 | 118 | @staticmethod 119 | def default_hparams(): 120 | model_hparams = hparams.ModelHparams( 121 | model_name='cifar_pytorch_resnet_18', 122 | model_init='kaiming_normal', 123 | batchnorm_init='uniform', 124 | ) 125 | 126 | dataset_hparams = hparams.DatasetHparams( 127 | dataset_name='cifar100', 128 | batch_size=128, 129 | ) 130 | 131 | training_hparams = hparams.TrainingHparams( 132 | optimizer_name='sgd', 133 | momentum=0.9, 134 | milestone_steps='80ep,120ep', 135 | lr=0.1, 136 | gamma=0.1, 137 | weight_decay=1e-4, 138 | training_steps='160ep', 139 | ) 140 | 141 | pruning_hparams = magnitude.PruningHparams( 142 | pruning_strategy='magnitude', 143 | pruning_fraction=0.2, 144 | pruning_scope='global', 145 | ) 146 | 147 | finetuning_hparams = hparams.FinetuningHparams( 148 | lr=0.001, 149 | training_steps='160ep', 150 | ) 151 | 152 | return FinetuningDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 153 | 154 | -------------------------------------------------------------------------------- /models/cifar_vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from foundations import hparams 10 | from pruning import magnitude 11 | from experiments.finetune.desc import FinetuningDesc 12 | from models import base 13 | 14 | 15 | class Model(base.Model): 16 | """A VGG-style neural network designed for CIFAR-10.""" 17 | 18 | class ConvModule(nn.Module): 19 | """A single convolutional module in a VGG network.""" 20 | 21 | def __init__(self, in_filters, out_filters): 22 | super(Model.ConvModule, self).__init__() 23 | self.conv = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1) 24 | self.bn = nn.BatchNorm2d(out_filters) 25 | 26 | def forward(self, x): 27 | return F.relu(self.bn(self.conv(x))) 28 | 29 | def __init__(self, plan, initializer, outputs=10): 30 | super(Model, self).__init__() 31 | 32 | layers = [] 33 | filters = 3 34 | 35 | for spec in plan: 36 | if spec == 'M': 37 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 38 | else: 39 | layers.append(Model.ConvModule(filters, spec)) 40 | filters = spec 41 | 42 | self.layers = nn.Sequential(*layers) 43 | self.fc = nn.Linear(512, outputs) 44 | self.criterion = nn.CrossEntropyLoss() 45 | 46 | self.apply(initializer) 47 | 48 | def forward(self, x): 49 | x = self.layers(x) 50 | x = nn.AvgPool2d(2)(x) 51 | x = x.view(x.size(0), -1) 52 | x = self.fc(x) 53 | return x 54 | 55 | @property 56 | def output_layer_names(self): 57 | return ['fc.weight', 'fc.bias'] 58 | 59 | @staticmethod 60 | def is_valid_model_name(model_name): 61 | return (model_name.startswith('cifar_vgg_') and 62 | len(model_name.split('_')) == 3 and 63 | model_name.split('_')[2].isdigit() and 64 | int(model_name.split('_')[2]) in [11, 13, 16, 19]) 65 | 66 | @staticmethod 67 | def get_model_from_name(model_name, initializer, outputs=10): 68 | if not Model.is_valid_model_name(model_name): 69 | raise ValueError('Invalid model name: {}'.format(model_name)) 70 | 71 | outputs = outputs or 10 72 | 73 | num = int(model_name.split('_')[2]) 74 | if num == 11: 75 | plan = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512] 76 | elif num == 13: 77 | plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512] 78 | elif num == 16: 79 | plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] 80 | elif num == 19: 81 | plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512] 82 | else: 83 | raise ValueError('Unknown VGG model: {}'.format(model_name)) 84 | 85 | return Model(plan, initializer, outputs) 86 | 87 | @property 88 | def loss_criterion(self): 89 | return self.criterion 90 | 91 | @staticmethod 92 | def default_hparams(): 93 | model_hparams = hparams.ModelHparams( 94 | model_name='cifar_vgg_16', 95 | model_init='kaiming_normal', 96 | batchnorm_init='uniform', 97 | ) 98 | 99 | dataset_hparams = hparams.DatasetHparams( 100 | dataset_name='cifar10', 101 | batch_size=128 102 | ) 103 | 104 | training_hparams = hparams.TrainingHparams( 105 | optimizer_name='sgd', 106 | momentum=0.9, 107 | milestone_steps='80ep,120ep', 108 | lr=0.1, 109 | gamma=0.1, 110 | weight_decay=1e-4, 111 | training_steps='160ep' 112 | ) 113 | 114 | pruning_hparams = magnitude.PruningHparams( 115 | pruning_strategy='magnitude', 116 | pruning_fraction=0.2, 117 | pruning_scope='global', 118 | pruning_layers_to_ignore='fc.weight' 119 | ) 120 | 121 | finetuning_hparams = hparams.FinetuningHparams( 122 | lr=0.001, 123 | training_steps='160ep', 124 | ) 125 | 126 | return FinetuningDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 127 | -------------------------------------------------------------------------------- /models/imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from functools import partial 7 | import torch 8 | import torchvision 9 | 10 | from foundations import hparams 11 | from models import base 12 | from pruning import magnitude 13 | from experiments.finetune.desc import FinetuningDesc 14 | 15 | class ResNet(torchvision.models.ResNet): 16 | def __init__(self, block, layers, num_classes=1000, width=64): 17 | """To make it possible to vary the width, we need to override the constructor of the torchvision resnet.""" 18 | 19 | torch.nn.Module.__init__(self) # Skip the parent constructor. This replaces it. 20 | self._norm_layer = torch.nn.BatchNorm2d 21 | self.inplanes = width 22 | self.dilation = 1 23 | self.groups = 1 24 | self.base_width = 64 25 | 26 | # The initial convolutional layer. 27 | self.conv1 = torch.nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 28 | self.bn1 = self._norm_layer(self.inplanes) 29 | self.relu = torch.nn.ReLU(inplace=True) 30 | self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 31 | 32 | # The subsequent blocks. 33 | self.layer1 = self._make_layer(block, width, layers[0]) 34 | self.layer2 = self._make_layer(block, width*2, layers[1], stride=2, dilate=False) 35 | self.layer3 = self._make_layer(block, width*4, layers[2], stride=2, dilate=False) 36 | self.layer4 = self._make_layer(block, width*8, layers[3], stride=2, dilate=False) 37 | 38 | # The last layers. 39 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 40 | self.fc = torch.nn.Linear(width*8*block.expansion, num_classes) 41 | 42 | # Default init. 43 | for m in self.modules(): 44 | if isinstance(m, torch.nn.Conv2d): 45 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 46 | elif isinstance(m, torch.nn.BatchNorm2d): 47 | torch.nn.init.constant_(m.weight, 1) 48 | torch.nn.init.constant_(m.bias, 0) 49 | 50 | 51 | class Model(base.Model): 52 | """A residual neural network as originally designed for ImageNet.""" 53 | 54 | def __init__(self, model_fn, initializer, outputs=None): 55 | super(Model, self).__init__() 56 | 57 | self.model = model_fn(num_classes=outputs or 1000) 58 | self.criterion = torch.nn.CrossEntropyLoss() 59 | self.apply(initializer) 60 | 61 | def forward(self, x): 62 | return self.model(x) 63 | 64 | @property 65 | def output_layer_names(self): 66 | return ['model.fc.weight', 'model.fc.bias'] 67 | 68 | @staticmethod 69 | def is_valid_model_name(model_name): 70 | return (model_name.startswith('imagenet_resnet_') and 71 | 4 >= len(model_name.split('_')) >= 3 and 72 | model_name.split('_')[2].isdigit() and 73 | int(model_name.split('_')[2]) in [18, 34, 50, 101, 152, 200]) 74 | 75 | @staticmethod 76 | def get_model_from_name(model_name, initializer, outputs=1000): 77 | """Name: imagenet_resnet_D[_W]. 78 | 79 | D is the model depth (e.g., 50 for ResNet-50). W is the model width - the number of filters in the first 80 | residual layers. By default, this number is 64.""" 81 | 82 | if not Model.is_valid_model_name(model_name): 83 | raise ValueError('Invalid model name: {}'.format(model_name)) 84 | 85 | num = int(model_name.split('_')[2]) 86 | if num == 18: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [2, 2, 2, 2]) 87 | elif num == 34: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [3, 4, 6, 3]) 88 | elif num == 50: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 6, 3]) 89 | elif num == 101: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 23, 3]) 90 | elif num == 152: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 8, 36, 3]) 91 | elif num == 200: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 24, 36, 3]) 92 | elif num == 269: model_fn = partial(ResNet, torchvision.moedls.resnet.Bottleneck, [3, 30, 48, 8]) 93 | 94 | if len(model_name.split('_')) == 4: 95 | width = int(model_name.split('_')[3]) 96 | model_fn = partial(model_fn, width=width) 97 | 98 | return Model(model_fn, initializer, outputs) 99 | 100 | @property 101 | def loss_criterion(self): 102 | return self.criterion 103 | 104 | @staticmethod 105 | def default_hparams(): 106 | """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet. 107 | 108 | To get these results with a smaller batch size, scale the batch size linearly. 109 | That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc. 110 | """ 111 | 112 | model_hparams = hparams.ModelHparams( 113 | model_name='imagenet_resnet_50', 114 | model_init='kaiming_normal', 115 | batchnorm_init='uniform', 116 | ) 117 | 118 | dataset_hparams = hparams.DatasetHparams( 119 | dataset_name='imagenet', 120 | batch_size=1024, 121 | ) 122 | 123 | training_hparams = hparams.TrainingHparams( 124 | optimizer_name='sgd', 125 | momentum=0.9, 126 | milestone_steps='30ep,60ep,80ep', 127 | lr=0.4, 128 | gamma=0.1, 129 | weight_decay=1e-4, 130 | training_steps='90ep', 131 | warmup_steps='5ep', 132 | ) 133 | 134 | pruning_hparams = magnitude.PruningHparams( 135 | pruning_strategy='magnitude', 136 | pruning_fraction=0.2, 137 | pruning_scope='global', 138 | ) 139 | 140 | finetuning_hparams = hparams.FinetuningHparams( 141 | # optimizer_name='sgd', 142 | # momentum=0.9, 143 | lr=0.0001, 144 | # weight_decay=1e-4, 145 | training_steps='30ep', 146 | ) 147 | 148 | return FinetuningDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 149 | 150 | -------------------------------------------------------------------------------- /models/initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch 7 | 8 | 9 | def binary(w): 10 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 11 | torch.nn.init.kaiming_normal_(w.weight) 12 | sigma = w.weight.data.std() 13 | w.weight.data = torch.sign(w.weight.data) * sigma 14 | 15 | 16 | def kaiming_normal(w): 17 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 18 | torch.nn.init.kaiming_normal_(w.weight) 19 | 20 | 21 | def kaiming_uniform(w): 22 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 23 | torch.nn.init.kaiming_uniform_(w.weight) 24 | 25 | 26 | def orthogonal(w): 27 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 28 | torch.nn.init.orthogonal_(w.weight) 29 | -------------------------------------------------------------------------------- /models/mnist_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from foundations import hparams 10 | from models import base 11 | from pruning import magnitude 12 | from experiments.finetune.desc import FinetuningDesc 13 | 14 | class Model(base.Model): 15 | '''A fully-connected model for mnist''' 16 | 17 | def __init__(self, plan, initializer, outputs=10): 18 | super(Model, self).__init__() 19 | 20 | layers = [] 21 | current_size = 784 # 28 * 28 = number of pixels in MNIST image. 22 | for size in plan: 23 | layers.append(nn.Linear(current_size, size)) 24 | current_size = size 25 | 26 | self.fc_layers = nn.ModuleList(layers) 27 | self.fc = nn.Linear(current_size, outputs) 28 | self.criterion = nn.CrossEntropyLoss() 29 | 30 | self.apply(initializer) 31 | 32 | def forward(self, x): 33 | x = x.view(x.size(0), -1) # Flatten. 34 | for layer in self.fc_layers: 35 | x = F.relu(layer(x)) 36 | 37 | return self.fc(x) 38 | 39 | @property 40 | def output_layer_names(self): 41 | return ['fc.weight', 'fc.bias'] 42 | 43 | @staticmethod 44 | def is_valid_model_name(model_name): 45 | return (model_name.startswith('mnist_mlp') and 46 | len(model_name.split('_')) > 2 and 47 | all([x.isdigit() and int(x) > 0 for x in model_name.split('_')[2:]])) 48 | 49 | @staticmethod 50 | def get_model_from_name(model_name, initializer, outputs=None): 51 | """The name of a model is mnist_mlp_N1[_N2...]. 52 | 53 | N1, N2, etc. are the number of neurons in each fully-connected layer excluding the 54 | output layer (10 neurons by default). A MLP with 300 neurons in the first hidden layer, 55 | 100 neurons in the second hidden layer, and 10 output neurons is 'mnist_mlp_300_100'. 56 | """ 57 | 58 | outputs = outputs or 10 59 | 60 | if not Model.is_valid_model_name(model_name): 61 | raise ValueError('Invalid model name: {}'.format(model_name)) 62 | 63 | plan = [int(n) for n in model_name.split('_')[2:]] 64 | return Model(plan, initializer, outputs) 65 | 66 | @property 67 | def loss_criterion(self): 68 | return self.criterion 69 | 70 | @staticmethod 71 | def default_hparams(): 72 | model_hparams = hparams.ModelHparams( 73 | model_name='mnist_mlp_300_100', 74 | model_init='kaiming_normal', 75 | batchnorm_init='uniform' 76 | ) 77 | 78 | dataset_hparams = hparams.DatasetHparams( 79 | dataset_name='mnist', 80 | batch_size=128 81 | ) 82 | 83 | training_hparams = hparams.TrainingHparams( 84 | optimizer_name='sgd', 85 | lr=0.1, 86 | training_steps='160ep', 87 | ) 88 | 89 | pruning_hparams = magnitude.PruningHparams( 90 | pruning_strategy='magnitude', 91 | pruning_fraction=0.2, 92 | pruning_scope='global', 93 | pruning_layers_to_ignore='fc.weight' 94 | ) 95 | 96 | finetuning_hparams = hparams.FinetuningHparams( 97 | # optimizer_name='sgd', 98 | lr=0.1, 99 | training_steps='160ep' 100 | ) 101 | 102 | return FinetuningDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 103 | -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import pruning 7 | import torch 8 | import os 9 | from foundations import paths 10 | from foundations.hparams import ModelHparams 11 | from foundations.step import Step 12 | from models import cifar_vgg, mnist_mlp, imagenet_resnet, cifar_pytorch_resnet, tinyimagenet_resnet 13 | from models import bn_initializers, initializers 14 | 15 | registered_models = [mnist_mlp.Model, cifar_vgg.Model, imagenet_resnet.Model, 16 | cifar_pytorch_resnet.Model, tinyimagenet_resnet.Model] 17 | 18 | 19 | def get(model_hparams: ModelHparams, outputs=None, pruning_strategy = None): 20 | """Get the model for the corresponding hyperparameters.""" 21 | 22 | # Select the initializer. 23 | if hasattr(initializers, model_hparams.model_init): 24 | initializer = getattr(initializers, model_hparams.model_init) 25 | else: 26 | raise ValueError('No initializer: {}'.format(model_hparams.model_init)) 27 | 28 | # Select the BatchNorm initializer. 29 | if hasattr(bn_initializers, model_hparams.batchnorm_init): 30 | bn_initializer = getattr(bn_initializers, model_hparams.batchnorm_init) 31 | else: 32 | raise ValueError('No batchnorm initializer: {}'.format(model_hparams.batchnorm_init)) 33 | 34 | # Create the overall initializer function. 35 | def init_fn(w): 36 | initializer(w) 37 | bn_initializer(w) 38 | 39 | # Select the model. 40 | model = None 41 | for registered_model in registered_models: 42 | if registered_model.is_valid_model_name(model_hparams.model_name): 43 | model = registered_model.get_model_from_name(model_hparams.model_name, init_fn, outputs) 44 | break 45 | 46 | if model is None: 47 | raise ValueError('No such model: {}'.format(model_hparams.model_name)) 48 | 49 | # Set prunable layers type 50 | model.prunable_layer_type = 'BN' if pruning_strategy == 'network_slimming' else 'default' 51 | 52 | # Freeze various subsets of the network. 53 | bn_names = [] 54 | for k, v in model.named_modules(): 55 | if isinstance(v, torch.nn.BatchNorm2d): 56 | bn_names += [k + '.weight', k + '.bias'] 57 | 58 | if model_hparams.others_frozen_exceptions: 59 | others_exception_names = model_hparams.others_frozen_exceptions.split(',') 60 | for name in others_exception_names: 61 | if name not in model.state_dict(): 62 | raise ValueError(f'Invalid name to except: {name}') 63 | else: 64 | others_exception_names = [] 65 | 66 | for k, v in model.named_parameters(): 67 | if k in bn_names and model_hparams.batchnorm_frozen: 68 | v.requires_grad = False 69 | elif k in model.output_layer_names and model_hparams.output_frozen: 70 | v.requires_grad = False 71 | elif k not in bn_names and k not in model.output_layer_names and model_hparams.others_frozen: 72 | if k in others_exception_names: continue 73 | v.requires_grad = False 74 | 75 | return model 76 | 77 | 78 | def load(save_location: str, save_step: Step, model_hparams, outputs=None, pruning_strategy = None): 79 | state_dict = torch.load(paths.model(save_location, save_step)) 80 | model = get(model_hparams, outputs, pruning_strategy) 81 | model.load_state_dict(state_dict) 82 | return model 83 | 84 | 85 | def exists(save_location, save_step): 86 | return os.path.exists(paths.model(save_location, save_step)) 87 | 88 | 89 | def get_default_hparams(model_name): 90 | """Get the default hyperparameters for a particular model.""" 91 | 92 | for registered_model in registered_models: 93 | if registered_model.is_valid_model_name(model_name): 94 | params = registered_model.default_hparams() 95 | params.model_hparams.model_name = model_name 96 | return params 97 | 98 | raise ValueError('No such model: {}'.format(model_name)) 99 | -------------------------------------------------------------------------------- /models/tinyimagenet_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from functools import partial 7 | import torch 8 | import torchvision 9 | 10 | from foundations import hparams 11 | from models import base 12 | from pruning import magnitude 13 | from experiments.finetune.desc import FinetuningDesc 14 | 15 | 16 | class ResNet(torchvision.models.ResNet): 17 | def __init__(self, block, layers, num_classes=1000, width=64): 18 | """To make it possible to vary the width, we need to override the constructor of the torchvision resnet.""" 19 | 20 | torch.nn.Module.__init__(self) # Skip the parent constructor. This replaces it. 21 | self._norm_layer = torch.nn.BatchNorm2d 22 | self.inplanes = width 23 | self.dilation = 1 24 | self.groups = 1 25 | self.base_width = 64 26 | 27 | # The initial convolutional layer. 28 | self.conv1 = torch.nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 29 | self.bn1 = self._norm_layer(self.inplanes) 30 | self.relu = torch.nn.ReLU(inplace=True) 31 | self.maxpool = (lambda x : x) 32 | # The subsequent blocks. 33 | self.layer1 = self._make_layer(block, width, layers[0]) 34 | self.layer2 = self._make_layer(block, width*2, layers[1], stride=2) 35 | self.layer3 = self._make_layer(block, width*4, layers[2], stride=2) 36 | self.layer4 = self._make_layer(block, width*8, layers[3], stride=2) 37 | 38 | # The last layers. 39 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 40 | self.fc = torch.nn.Linear(width*8*block.expansion, num_classes) 41 | 42 | # Default init. 43 | for m in self.modules(): 44 | if isinstance(m, torch.nn.Conv2d): 45 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 46 | elif isinstance(m, torch.nn.BatchNorm2d): 47 | torch.nn.init.constant_(m.weight, 1) 48 | torch.nn.init.constant_(m.bias, 0) 49 | 50 | 51 | class Model(base.Model): 52 | """A residual neural network as originally designed for ImageNet.""" 53 | 54 | def __init__(self, model_fn, initializer, outputs=None): 55 | super(Model, self).__init__() 56 | 57 | self.model = model_fn(num_classes=outputs or 1000) 58 | self.criterion = torch.nn.CrossEntropyLoss() 59 | self.apply(initializer) 60 | 61 | def forward(self, x): 62 | return self.model(x) 63 | 64 | @property 65 | def output_layer_names(self): 66 | return ['model.fc.weight', 'model.fc.bias'] 67 | 68 | @staticmethod 69 | def is_valid_model_name(model_name): 70 | valid_start = model_name.startswith('tinyimagenet_resnet_') 71 | valid_length = 4 >= len(model_name.split('_')) >= 3 72 | valid_depth = model_name.split('_')[2].isdigit() and int(model_name.split('_')[2]) in [18, 34, 50, 101, 152, 200] 73 | return valid_start and valid_length and valid_depth 74 | 75 | @staticmethod 76 | def get_model_from_name(model_name, initializer, outputs=1000): 77 | """Name: imagenet_resnet_D[_W]. 78 | 79 | D is the model depth (e.g., 50 for ResNet-50). W is the model width - the number of filters in the first 80 | residual layers. By default, this number is 64.""" 81 | 82 | if not Model.is_valid_model_name(model_name): 83 | raise ValueError('Invalid model name: {}'.format(model_name)) 84 | 85 | num = int(model_name.split('_')[2]) 86 | if num == 18: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [2, 2, 2, 2]) 87 | elif num == 34: model_fn = partial(ResNet, torchvision.models.resnet.BasicBlock, [3, 4, 6, 3]) 88 | elif num == 50: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 6, 3]) 89 | elif num == 101: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 4, 23, 3]) 90 | elif num == 152: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 8, 36, 3]) 91 | elif num == 200: model_fn = partial(ResNet, torchvision.models.resnet.Bottleneck, [3, 24, 36, 3]) 92 | elif num == 269: model_fn = partial(ResNet, torchvision.moedls.resnet.Bottleneck, [3, 30, 48, 8]) 93 | 94 | if len(model_name.split('_')) == 4: 95 | width = int(model_name.split('_')[3]) 96 | model_fn = partial(model_fn, width=width) 97 | 98 | return Model(model_fn, initializer, outputs) 99 | 100 | @property 101 | def loss_criterion(self): 102 | return self.criterion 103 | 104 | @staticmethod 105 | def default_hparams(): 106 | """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet and XX.X% top-1 accuracy on TinyImageNet. 107 | 108 | To get these results with a smaller batch size, scale the learning rate linearly. 109 | That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc. 110 | """ 111 | 112 | # Model hyperparameters. 113 | model_hparams = hparams.ModelHparams( 114 | model_name='tinyimagenet_resnet_50', 115 | model_init='kaiming_normal', 116 | batchnorm_init='uniform', 117 | ) 118 | 119 | # Dataset hyperparameters. 120 | dataset_hparams = hparams.DatasetHparams( 121 | dataset_name='tinyimagenet', batch_size=256 122 | ) 123 | 124 | # Training hyperparameters. 125 | training_hparams = hparams.TrainingHparams( 126 | optimizer_name='sgd', 127 | momentum=0.9, 128 | milestone_steps='100ep,150ep', 129 | lr=0.2, 130 | gamma=0.1, 131 | weight_decay=1e-4, 132 | training_steps='200ep', 133 | warmup_steps='5ep', 134 | ) 135 | 136 | # Pruning hyperparameters. 137 | pruning_hparams = magnitude.PruningHparams( 138 | pruning_strategy='magnitude', 139 | pruning_fraction=0.2, 140 | pruning_scope='global', 141 | ) 142 | 143 | finetuning_hparams = hparams.FinetuningHparams( 144 | # optimizer_name='sgd', 145 | # momentum=0.9, 146 | lr=0.0001, 147 | # weight_decay=1e-4, 148 | training_steps='30ep', 149 | ) 150 | 151 | return FinetuningDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams, finetuning_hparams) 152 | -------------------------------------------------------------------------------- /pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /pruning/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import abc 7 | 8 | from foundations.hparams import PruningHparams 9 | from models import base 10 | from pruning.mask import Mask 11 | 12 | 13 | class Strategy(abc.ABC): 14 | @staticmethod 15 | @abc.abstractmethod 16 | def get_pruning_hparams() -> type: 17 | pass 18 | 19 | @staticmethod 20 | @abc.abstractmethod 21 | def prune(pruning_hparams: PruningHparams, trained_model: base.Model, current_mask: Mask = None) -> Mask: 22 | pass 23 | -------------------------------------------------------------------------------- /pruning/gradient.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # https://github.com/JJGO/shrinkbench/blob/master/strategies/magnitude.py 4 | 5 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 6 | # licensed under the MIT license 7 | 8 | import copy 9 | import dataclasses 10 | from foundations.local import Platform 11 | import datasets 12 | from training.train import train 13 | import numpy as np 14 | 15 | from foundations import hparams 16 | import models.base 17 | from pruning import base 18 | from pruning.mask import Mask 19 | 20 | 21 | @dataclasses.dataclass 22 | class PruningHparams(hparams.PruningHparams): 23 | pruning_fraction: float = 0.2 24 | pruning_scope: str = 'global' 25 | pruning_layers_to_ignore: str = None 26 | layers_to_prune: str = None 27 | 28 | _name = 'Hyperparameters for Unstructured Magnitude-Gradient Pruning' 29 | _description = 'Hyperparameters that modify the way pruning occurs.' 30 | _pruning_fraction = 'The fraction of additional weights to prune from the network.' 31 | _pruning_scope = 'A paramter that enables global pruning or layer-wise pruning, choose from global/layer' 32 | _pruning_layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.' 33 | _layers_to_prune = 'Specify the layers that should be pruned, to prune first/last nth layers in all prunable layers, use first_n / last_n .' 34 | 35 | class Strategy(base.Strategy): 36 | @staticmethod 37 | def get_pruning_hparams() -> type: 38 | return PruningHparams 39 | 40 | @staticmethod 41 | def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, dataset_hparams: hparams.DatasetHparams = None): 42 | 43 | current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() 44 | 45 | model = copy.deepcopy(trained_model) 46 | 47 | prunable_tensors = set(model.prunable_layer_names) 48 | if pruning_hparams.pruning_layers_to_ignore: 49 | prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) 50 | 51 | # Get the model weights. 52 | weights = {k: v.clone().cpu().detach().numpy() 53 | for k, v in model.state_dict().items() 54 | if k in prunable_tensors} 55 | 56 | # Compute gradients of parameters 57 | train_loader = datasets.registry.get(dataset_hparams, train=True) 58 | train_loader.shuffle(None) 59 | examples, labels = next(iter(train_loader)) 60 | model.zero_grad() 61 | model.train() 62 | loss = model.loss_criterion(model(examples), labels) 63 | loss.backward() 64 | # Get the model gradients. 65 | gradients = {k: v.grad.clone().cpu().detach().numpy() 66 | for k, v in model.named_parameters() 67 | if k in prunable_tensors and v.grad is not None} 68 | 69 | if pruning_hparams.pruning_scope == 'global': 70 | 71 | # Determine the number of weights that need to be pruned. 72 | number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()]) 73 | number_of_weights_to_prune = np.ceil( 74 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 75 | 76 | # Compute importance scores of all the unpruned weights in the model, which is weight*gradient. 77 | importance_scores = {k: np.abs(v * gradients[k]) for k, v in weights.items()} 78 | importance_vector = np.concatenate([v[current_mask[k] == 1] for k, v in importance_scores.items()]) 79 | threshold = np.sort(importance_vector)[number_of_weights_to_prune] 80 | 81 | new_mask = Mask({k: np.where(v > threshold, current_mask[k], np.zeros_like(v)) 82 | for k, v in importance_scores.items()}) 83 | 84 | elif pruning_hparams.pruning_scope == 'layer': 85 | new_mask_dict = {} 86 | for k, v in weights.items(): 87 | # Determine the number of weights that need to be pruned. 88 | number_of_remaining_weights = np.sum(current_mask[k]) 89 | number_of_weights_to_prune = np.ceil( 90 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 91 | 92 | # Compute importance scores of all the unpruned weights in the layer, which is weight*gradient. 93 | importance_scores = np.abs(v * gradients[k]) 94 | importance_vector = importance_scores[current_mask[k] == 1] 95 | threshold = np.sort(importance_vector)[number_of_weights_to_prune] 96 | 97 | new_mask_dict[k] = np.where(np.abs(importance_scores[k]) > threshold, current_mask[k], np.zeros_like(v)) 98 | new_mask = Mask(new_mask_dict) 99 | else: 100 | raise ValueError('No such pruning scope: {}'.format(pruning_hparams.pruning_scope)) 101 | 102 | for k in current_mask: 103 | if k not in new_mask: 104 | new_mask[k] = current_mask[k] 105 | 106 | return new_mask 107 | 108 | -------------------------------------------------------------------------------- /pruning/magnitude.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import dataclasses 7 | import numpy as np 8 | 9 | from foundations import hparams 10 | import models.base 11 | from pruning import base 12 | from pruning.mask import Mask 13 | 14 | 15 | @dataclasses.dataclass 16 | class PruningHparams(hparams.PruningHparams): 17 | pruning_fraction: float = 0.2 18 | pruning_scope: str = 'global' 19 | pruning_layers_to_ignore: str = None 20 | layers_to_prune: str = None 21 | prune_max_magnitude: bool = False 22 | 23 | _name = 'Hyperparameters for Unstructured Magnitude Pruning' 24 | _description = 'Hyperparameters that modify the way pruning occurs.' 25 | _pruning_fraction = 'The fraction of additional weights to prune from the network.' 26 | _pruning_scope = 'A paramter that enables global pruning or layer-wise pruning, choose from global/layer' 27 | _pruning_layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.' 28 | _layers_to_prune = 'Specify the layers that should be pruned, to prune first/last nth layers in all prunable layers, use first_n / last_n .' 29 | _prune_max_magnitude = 'An order that control pruner to prune the weights with max magnitude, or min magnitude' 30 | 31 | class Strategy(base.Strategy): 32 | @staticmethod 33 | def get_pruning_hparams() -> type: 34 | return PruningHparams 35 | 36 | @staticmethod 37 | def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, dataset_hparams: hparams.DatasetHparams = None): 38 | current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() 39 | 40 | # Determine which layers can be pruned. 41 | def get_pruning_layers_sequence(layers_to_prune) -> int: 42 | if len(layers_to_prune.split('_'))==2 and layers_to_prune.split('_')[-1].isdigit() and int(layers_to_prune.split('_')[-1])>0: 43 | return int(layers_to_prune.split('_')[-1]) 44 | else: 45 | raise ValueError('unrecognized pruning hparameters: {}'.format(layers_to_prune)) 46 | 47 | if pruning_hparams.layers_to_prune is not None and pruning_hparams.pruning_scope == 'layer': 48 | if pruning_hparams.layers_to_prune.startswith('first'): 49 | pruning_layers_sequence = get_pruning_layers_sequence(pruning_hparams.layers_to_prune) 50 | prunable_tensors = set(trained_model.prunable_layer_names[:pruning_layers_sequence]) 51 | elif pruning_hparams.layers_to_prune.startswith('last'): 52 | pruning_layers_sequence = get_pruning_layers_sequence(pruning_hparams.layers_to_prune) 53 | prunable_tensors = set(trained_model.prunable_layer_names[-pruning_layers_sequence:]) 54 | else: raise ValueError('unrecognized pruning hparameters: {}'.format(pruning_hparams.layers_to_prune)) 55 | elif pruning_hparams.layers_to_prune is not None and pruning_hparams.pruning_scope != 'layer': 56 | raise ValueError('pruning hparameters: layers_to_prune={} should be associated with pruning_cope=layer'.format(pruning_hparams.layers_to_prune)) 57 | else: 58 | prunable_tensors = set(trained_model.prunable_layer_names) 59 | 60 | if pruning_hparams.pruning_layers_to_ignore: 61 | prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) 62 | 63 | # Get the model weights. 64 | weights = {k: v.clone().cpu().detach().numpy() 65 | for k, v in trained_model.state_dict().items() 66 | if k in prunable_tensors} 67 | 68 | if pruning_hparams.pruning_scope == 'global': 69 | 70 | # Determine the number of weights that need to be pruned. 71 | number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()]) 72 | number_of_weights_to_prune = np.ceil( 73 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 74 | 75 | # Create a vector of all the unpruned weights in the model. 76 | weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()]) 77 | if pruning_hparams.prune_max_magnitude: 78 | abs_weight_vector = np.flip(np.sort(np.abs(weight_vector))) 79 | else: 80 | abs_weight_vector = np.sort(np.abs(weight_vector)) 81 | 82 | threshold = abs_weight_vector[number_of_weights_to_prune] 83 | 84 | new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) 85 | for k, v in weights.items()}) 86 | 87 | elif pruning_hparams.pruning_scope == 'layer': 88 | new_mask_dict = {} 89 | for k, v in weights.items(): 90 | # Determine the number of weights that need to be pruned. 91 | number_of_remaining_weights = np.sum(current_mask[k]) 92 | number_of_weights_to_prune = np.ceil( 93 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 94 | 95 | # Create a vector of all the unpruned weights in the particular layer. 96 | weight_vector = v[current_mask[k] == 1] 97 | if pruning_hparams.prune_max_magnitude: 98 | abs_weight_vector = np.flip(np.sort(np.abs(weight_vector))) 99 | else: 100 | abs_weight_vector = np.sort(np.abs(weight_vector)) 101 | threshold = abs_weight_vector[number_of_weights_to_prune] 102 | 103 | new_mask_dict[k] = np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) 104 | new_mask = Mask(new_mask_dict) 105 | else: 106 | raise ValueError('No such pruning scope: {}'.format(pruning_hparams.pruning_scope)) 107 | 108 | for k in current_mask: 109 | if k not in new_mask: 110 | new_mask[k] = current_mask[k] 111 | 112 | return new_mask 113 | -------------------------------------------------------------------------------- /pruning/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import json 7 | import numpy as np 8 | import torch 9 | import os 10 | from foundations import paths 11 | from models import base 12 | 13 | 14 | class Mask(dict): 15 | def __init__(self, other_dict=None): 16 | super(Mask, self).__init__() 17 | if other_dict is not None: 18 | for k, v in other_dict.items(): self[k] = v 19 | 20 | def __setitem__(self, key, value): 21 | if not isinstance(key, str) or len(key) == 0: 22 | raise ValueError('Invalid tensor name: {}'.format(key)) 23 | if isinstance(value, np.ndarray): 24 | value = torch.as_tensor(value) 25 | if not isinstance(value, torch.Tensor): 26 | raise ValueError('value for key {} must be torch Tensor or numpy ndarray.'.format(key)) 27 | if ((value != 0) & (value != 1)).any(): raise ValueError('All entries must be 0 or 1.') 28 | 29 | super(Mask, self).__setitem__(key, value) 30 | 31 | @staticmethod 32 | def ones_like(model: base.Model) -> 'Mask': 33 | mask = Mask() 34 | for name in model.prunable_layer_names: 35 | mask[name] = torch.ones(list(model.state_dict()[name].shape)) 36 | return mask 37 | 38 | def save(self, output_location): 39 | if not os.path.exists(output_location): os.makedirs(output_location) 40 | torch.save({k: v.cpu().int() for k, v in self.items()}, paths.mask(output_location)) 41 | 42 | # Create a sparsity report. 43 | total_weights = np.sum([v.size for v in self.numpy().values()]).item() 44 | total_unpruned = np.sum([np.sum(v) for v in self.numpy().values()]).item() 45 | with open(paths.sparsity_report(output_location), 'w') as fp: 46 | fp.write(json.dumps({'total': float(total_weights), 'unpruned': float(total_unpruned)}, indent=4)) 47 | 48 | @staticmethod 49 | def load(output_location): 50 | if not Mask.exists(output_location): 51 | raise ValueError('Mask not found at {}'.format(output_location)) 52 | return Mask(torch.load(paths.mask(output_location))) 53 | 54 | @staticmethod 55 | def exists(output_location): 56 | return os.path.exists(paths.mask(output_location)) 57 | 58 | def numpy(self): 59 | return {k: v.cpu().numpy() for k, v in self.items()} 60 | 61 | @property 62 | def sparsity(self): 63 | """Return the percent of weights that have been pruned as a decimal.""" 64 | 65 | unpruned = torch.sum(torch.tensor([torch.sum(v) for v in self.values()])) 66 | total = torch.sum(torch.tensor([torch.sum(torch.ones_like(v)) for v in self.values()])) 67 | return 1 - unpruned.float() / total.float() 68 | 69 | @property 70 | def density(self): 71 | return 1 - self.sparsity 72 | -------------------------------------------------------------------------------- /pruning/network_slimming.py: -------------------------------------------------------------------------------- 1 | # Learning Efficient Convolutional Networks Through Network Slimming 2 | # reference: 3 | # https://github.com/Eric-mingjie/network-slimming/blob/master/mask-impl/prune_mask.py 4 | 5 | import dataclasses 6 | import numpy as np 7 | import torch 8 | 9 | from foundations import hparams 10 | import models.base 11 | from pruning import base 12 | from pruning.mask import Mask 13 | 14 | 15 | @dataclasses.dataclass 16 | class PruningHparams(hparams.PruningHparams): 17 | pruning_fraction: float = 0.2 18 | pruning_layers_to_ignore: str = None 19 | # pruning_scope: str = 'layer' 20 | 21 | _name = 'Hyperparameters for Structured Network-Slimming Pruning' 22 | _description = 'Hyperparameters that modify the way pruning occurs.' 23 | _pruning_fraction = 'The fraction of additional weights to prune from the network.' 24 | _pruning_layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.' 25 | # _pruning_scope = 'A paramter that enables global pruning or layer-wise pruning, choose from global/layer' 26 | 27 | class Strategy(base.Strategy): 28 | @staticmethod 29 | def get_pruning_hparams() -> type: 30 | return PruningHparams 31 | 32 | @staticmethod 33 | def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, dataset_hparams: hparams.DatasetHparams = None): 34 | current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() 35 | 36 | # Determine which layers can be pruned. 37 | prunable_tensors = set(trained_model.prunable_layer_names) 38 | 39 | if pruning_hparams.pruning_layers_to_ignore: 40 | prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) 41 | 42 | # Get the weights and biases of batch normalization layers. 43 | bn = {k: v.clone().cpu().detach().numpy() for k, v in trained_model.state_dict().items() 44 | if k in prunable_tensors} 45 | 46 | # if pruning_hparams.pruning_scope == 'global': 47 | 48 | # Determine the number of weights that need to be pruned. 49 | number_of_remaining_weights = np.sum([np.sum(v) for k, v in current_mask.items() 50 | if 'weight' in k]) 51 | number_of_weights_to_prune = np.ceil( 52 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 53 | 54 | # Create a vector of all the unpruned weights in the model. 55 | weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in bn.items() 56 | if 'weight' in k]) 57 | threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune] 58 | 59 | weight_mask_dict = {k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) 60 | for k, v in bn.items() if 'weight' in k} 61 | bias_mask_dict = {k: np.where(np.abs(bn[k.replace('bias','weight')]) > threshold, current_mask[k], np.zeros_like(v)) 62 | for k, v in bn.items() if 'bias' in k} 63 | mask_dict = {k: v for d in [weight_mask_dict, bias_mask_dict] for k, v in d.items()} 64 | new_mask = Mask(mask_dict) 65 | 66 | # elif pruning_hparams.pruning_scope == 'layer': 67 | # for k, v in weights.items(): 68 | # # Determine the number of weights that need to be pruned. 69 | # number_of_remaining_weights = np.sum(current_mask[k]) 70 | # number_of_weights_to_prune = np.ceil( 71 | # pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 72 | 73 | # # Create a vector of all the unpruned weights in the particular layer. 74 | # weight_vector = v[current_mask[k] == 1] 75 | # threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune] 76 | 77 | # new_mask= Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v))}) 78 | 79 | # else: 80 | # raise ValueError('No such pruning scope: {}'.format(pruning_hparams.pruning_scope)) 81 | 82 | for k in current_mask: 83 | if k not in new_mask: 84 | new_mask[k] = current_mask[k] 85 | 86 | return new_mask 87 | -------------------------------------------------------------------------------- /pruning/pruned_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from utils.tensor_utils import shuffle_model_params 7 | import torch 8 | from models.base import Model 9 | from pruning.mask import Mask 10 | 11 | import numpy as np 12 | 13 | 14 | class PrunedModel(Model): 15 | @staticmethod 16 | def to_mask_name(name): 17 | return 'mask_' + name.replace('.', '___') 18 | 19 | def __init__(self, model: Model, mask: Mask, model_for_reset: Model = None, freeze_pruned_weights: str = 'zero'): 20 | if isinstance(model, PrunedModel): raise ValueError('Cannot nest pruned models.') 21 | super(PrunedModel, self).__init__() 22 | self.model = model 23 | self.model_for_reset = shuffle_model_params(model_for_reset, mask, seed=0) if freeze_pruned_weights=='permuted' and \ 24 | model_for_reset is not None else model_for_reset 25 | self.freeze_type = freeze_pruned_weights 26 | 27 | for k in self.model.prunable_layer_names: 28 | if k not in mask: raise ValueError('Missing mask value {}.'.format(k)) 29 | if not np.array_equal(mask[k].shape, np.array(self.model.state_dict()[k].shape)): 30 | raise ValueError('Incorrect mask shape {} for tensor {}.'.format(mask[k].shape, k)) 31 | 32 | for k in mask: 33 | if k not in self.model.prunable_layer_names: 34 | raise ValueError('Key {} found in mask but is not a valid model tensor.'.format(k)) 35 | 36 | for k, v in mask.items(): self.register_buffer(PrunedModel.to_mask_name(k), v.float()) 37 | self._apply_mask() # reset the parameters 38 | 39 | def _apply_mask(self): 40 | for name, param in self.model.named_parameters(): 41 | if hasattr(self, PrunedModel.to_mask_name(name)): 42 | if self.freeze_type == 'zero' or self.model_for_reset is None: 43 | param.data *= getattr(self, PrunedModel.to_mask_name(name)) 44 | elif self.freeze_type == 'init' or self.freeze_type == 'final' or self.freeze_type == 'permuted': 45 | value = self.model_for_reset.state_dict()[name] 46 | mask_reverse = torch.abs(getattr(self, PrunedModel.to_mask_name(name)) - 1) 47 | param.data = param.data * getattr(self, PrunedModel.to_mask_name(name)) + value * mask_reverse 48 | elif self.freeze_type == 'gaussian': 49 | gen = torch.Generator() 50 | gen.manual_seed(seed=0) 51 | value = torch.normal(mean=0, std=0.01, size=param.data.size, generator=gen) 52 | mask_reverse = torch.abs(getattr(self, PrunedModel.to_mask_name(name)) - 1) 53 | param.data = param.data * getattr(self, PrunedModel.to_mask_name(name)) + value * mask_reverse 54 | else: 55 | raise ValueError('Freezing pruned weights as type {} is not supported.'.format(self.freeze_type)) 56 | 57 | def updateBN(self): 58 | # https://github.com/Eric-mingjie/network-slimming/blob/b395dc07521cbc38f741d971a18fe3f6423c9ab1/main.py#L126 59 | if self.model.prunable_layer_type == 'BN': 60 | scale_sparse_rate = 0.0001 61 | for m in self.model.modules(): 62 | if isinstance(m, torch.nn.BatchNorm2d): 63 | m.weight.grad.data.add_(scale_sparse_rate * torch.sign(m.weight.data)) # L1 Norm 64 | 65 | def forward(self, x): 66 | self._apply_mask() 67 | return self.model.forward(x) 68 | 69 | @property 70 | def prunable_layer_type(self): 71 | return self.model.prunable_layer_type 72 | 73 | @property 74 | def prunable_layer_names(self): 75 | return self.model.prunable_layer_names 76 | 77 | @property 78 | def output_layer_names(self): 79 | return self.model.output_layer_names 80 | 81 | @property 82 | def loss_criterion(self): 83 | return self.model.loss_criterion 84 | 85 | 86 | def save(self, save_location, save_step): 87 | self.model.save(save_location, save_step) 88 | 89 | @staticmethod 90 | def default_hparams(): raise NotImplementedError() 91 | @staticmethod 92 | def is_valid_model_name(model_name): raise NotImplementedError() 93 | @staticmethod 94 | def get_model_from_name(model_name, outputs, initializer): raise NotImplementedError() 95 | -------------------------------------------------------------------------------- /pruning/random.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import dataclasses 7 | import numpy as np 8 | 9 | from foundations import hparams 10 | import models.base 11 | from pruning import base 12 | from pruning.mask import Mask 13 | 14 | 15 | @dataclasses.dataclass 16 | class PruningHparams(hparams.PruningHparams): 17 | pruning_fraction: float = 0.2 18 | pruning_layers_to_ignore: str = None 19 | pruning_scope: str = 'global' 20 | random_mask_seed: int = None 21 | 22 | _name = 'Hyperparameters for Unstructured Random Pruning' 23 | _description = 'Hyperparameters that modify the way pruning occurs.' 24 | _pruning_fraction = 'The fraction of additional weights to prune from the network.' 25 | _pruning_layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.' 26 | _pruning_scope = 'A paramter that enables global pruning or layer-wise pruning' 27 | _random_mask_seed = 'The random seed for generating a random mask' 28 | 29 | class Strategy(base.Strategy): 30 | @staticmethod 31 | def get_pruning_hparams() -> type: 32 | return PruningHparams 33 | 34 | @staticmethod 35 | def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, dataset_hparams: hparams.DatasetHparams = None): 36 | current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() 37 | 38 | # Determine which layers can be pruned. 39 | prunable_tensors = set(trained_model.prunable_layer_names) 40 | if pruning_hparams.pruning_layers_to_ignore: 41 | prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) 42 | 43 | # Get the model weights. 44 | weights = {k: v.clone().cpu().detach().numpy() 45 | for k, v in trained_model.state_dict().items() 46 | if k in prunable_tensors} 47 | 48 | if pruning_hparams.pruning_scope == 'global': 49 | 50 | # Determine the number of weights that need to be pruned. 51 | number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()]) 52 | number_of_weights_to_prune = np.ceil( 53 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 54 | # create the random scores of weights 55 | random_scores = {k: np.random.RandomState(pruning_hparams.random_mask_seed).rand(*v.shape) for k, v in weights.items()} 56 | # Create a vector of all the unpruned weights in the model. 57 | random_vector = np.concatenate([v[current_mask[k] == 1] for k, v in random_scores.items()]) 58 | threshold = np.sort(np.abs(random_vector))[number_of_weights_to_prune] 59 | 60 | new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) 61 | for k, v in random_scores.items()}) 62 | 63 | elif pruning_hparams.pruning_scope == 'layer': 64 | new_mask_dict = {} 65 | # create the random score of weights 66 | np.random.seed(pruning_hparams.random_mask_seed) 67 | random_scores = {k: np.random.RandomState(pruning_hparams.random_mask_seed).rand(*v.shape) for k, v in weights.items()} 68 | for k, v in weights.items(): 69 | # Determine the number of weights that need to be pruned. 70 | number_of_remaining_weights = np.sum(current_mask[k]) 71 | number_of_weights_to_prune = np.ceil( 72 | pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) 73 | # Create a vector of all the unpruned weights in the particular layer. 74 | random_vector = random_scores[k][current_mask[k] == 1] 75 | threshold = np.sort(random_vector)[number_of_weights_to_prune] 76 | 77 | new_mask_dict[k] = np.where(random_scores[k] > threshold, current_mask[k], np.zeros_like(v)) 78 | new_mask = Mask(new_mask_dict) 79 | 80 | else: 81 | raise ValueError('No such pruning scope: {}'.format(pruning_hparams.pruning_scope)) 82 | 83 | for k in current_mask: 84 | if k not in new_mask: 85 | new_mask[k] = current_mask[k] 86 | 87 | return new_mask 88 | -------------------------------------------------------------------------------- /pruning/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import copy 7 | from functools import partial 8 | 9 | from foundations.hparams import PruningHparams 10 | from pruning import magnitude, network_slimming, random, gradient 11 | 12 | registered_strategies = {'magnitude': magnitude.Strategy, 'random': random.Strategy, 'gradient': gradient.Strategy, 13 | 'network_slimming': network_slimming.Strategy} 14 | 15 | 16 | def get(pruning_hparams: PruningHparams): 17 | """Get the pruning function.""" 18 | 19 | return partial(registered_strategies[pruning_hparams.pruning_strategy].prune, 20 | copy.deepcopy(pruning_hparams)) 21 | 22 | 23 | def get_pruning_hparams(pruning_strategy: str) -> type: 24 | """Get a complete lottery schema as specialized for a particular pruning strategy.""" 25 | 26 | if pruning_strategy not in registered_strategies: 27 | raise ValueError('Pruning strategy {} not found.'.format(pruning_strategy)) 28 | 29 | return registered_strategies[pruning_strategy].get_pruning_hparams() 30 | -------------------------------------------------------------------------------- /pruning/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /pruning/test/test_magnitude.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | import sys 8 | from torch.cuda import device, init 9 | sys.path.append("./") 10 | import models.registry 11 | from pruning.magnitude import Strategy 12 | from pruning.magnitude import PruningHparams 13 | from testing import test_case 14 | 15 | 16 | class TestMagnitude(test_case.TestCase): 17 | def setUp(self): 18 | super(TestMagnitude, self).setUp() 19 | self.hparams_global = PruningHparams('magnitude', pruning_fraction=0.2) 20 | self.hparams_layer = PruningHparams('magnitude', pruning_fraction=0.2, pruning_scope='layer') 21 | model_hparams = models.registry.get_default_hparams('cifar_resnet_20').model_hparams 22 | self.model = models.registry.get(model_hparams) 23 | 24 | def test_get_pruning_hparams(self): 25 | self.assertTrue(issubclass(Strategy.get_pruning_hparams(), PruningHparams)) 26 | 27 | def test_globally_prune(self): 28 | 29 | m = Strategy.prune(self.hparams_global, self.model) 30 | 31 | # Check that the mask only contains entries for the prunable layers. 32 | self.assertEqual(set(m.keys()), set(self.model.prunable_layer_names)) 33 | 34 | # Check that the masks are the same sizes as the tensors. 35 | for k in self.model.prunable_layer_names: 36 | self.assertEqual(list(m[k].shape), list(self.model.state_dict()[k].shape)) 37 | 38 | # Check that the right fraction of weights was pruned among prunable layers. 39 | m = m.numpy() 40 | total_pruned = np.sum([np.sum(1 - v) for v in m.values()]) 41 | total_weights = np.sum([v.size for v in m.values()]) 42 | actual_fraction = float(total_pruned) / total_weights 43 | self.assertGreaterEqual(actual_fraction, self.hparams_global.pruning_fraction) 44 | self.assertGreater(self.hparams_global.pruning_fraction + 0.0001, actual_fraction) 45 | 46 | # Ensure that the right threshold was chosen. 47 | pruned_weights = [self.model.state_dict()[k].numpy()[m[k] == 0] for k in m] 48 | threshold = np.max(np.abs(np.concatenate(pruned_weights))) 49 | unpruned_weights = [self.model.state_dict()[k].numpy()[m[k] == 1] for k in m] 50 | self.assertTrue(np.all(np.abs(np.concatenate(unpruned_weights)) > threshold)) 51 | 52 | def test_globally_iterative_pruning(self): 53 | m = Strategy.prune(self.hparams_global, self.model) 54 | m2 = Strategy.prune(self.hparams_global, self.model, m) 55 | 56 | # Ensure that all weights pruned before are still pruned here. 57 | m, m2 = m.numpy(), m2.numpy() 58 | self.assertEqual(set(m.keys()), set(m2.keys())) 59 | for k in m: 60 | self.assertTrue(np.all(m[k] >= m2[k])) 61 | 62 | total_pruned = np.sum([np.sum(1 - v) for v in m2.values()]) 63 | total_weights = np.sum([v.size for v in m2.values()]) 64 | actual_fraction = float(total_pruned) / total_weights 65 | expected_fraction = 1 - (1 - self.hparams_global.pruning_fraction) ** 2 66 | self.assertGreaterEqual(actual_fraction, expected_fraction) 67 | self.assertGreater(expected_fraction + 0.0001, actual_fraction) 68 | 69 | 70 | def test_layer_wise_prune(self): 71 | m = Strategy.prune(self.hparams_layer, self.model) 72 | 73 | # Check that the mask only contains entries for the prunable layers. 74 | self.assertEqual(set(m.keys()), set(self.model.prunable_layer_names)) 75 | 76 | # Check that the masks are the same sizes as the tensors. 77 | for k in self.model.prunable_layer_names: 78 | self.assertEqual(list(m[k].shape), list(self.model.state_dict()[k].shape)) 79 | 80 | # Check that the right fraction of weights was pruned among each prunable layer. 81 | m = m.numpy() 82 | for k in m: 83 | layer_pruned = np.sum(1 - m[k]) 84 | layer_weights = np.sum(m[k].size) 85 | layer_fraction = float(layer_pruned) / layer_weights 86 | self.assertGreaterEqual(layer_fraction, self.hparams_layer.pruning_fraction) 87 | self.assertGreater(self.hparams_layer.pruning_fraction + 0.1, layer_fraction) 88 | 89 | # Ensure that the right threshold was chosen. 90 | pruned_weights = self.model.state_dict()[k].numpy()[m[k] == 0] 91 | threshold = np.max(np.abs(pruned_weights)) 92 | unpruned_weights = self.model.state_dict()[k].numpy()[m[k] == 1] 93 | self.assertTrue(np.all(np.abs(unpruned_weights) > threshold)) 94 | 95 | # Check that the right fraction of weights was pruned among all prunable layers. 96 | total_pruned = np.sum([np.sum(1 - v) for v in m.values()]) 97 | total_weights = np.sum([v.size for v in m.values()]) 98 | actual_fraction = float(total_pruned) / total_weights 99 | self.assertGreaterEqual(actual_fraction, self.hparams_layer.pruning_fraction) 100 | self.assertGreater(self.hparams_layer.pruning_fraction + 0.001, actual_fraction) 101 | 102 | def test_layer_wise_iterative_pruning(self): 103 | m = Strategy.prune(self.hparams_layer, self.model) 104 | m2 = Strategy.prune(self.hparams_layer, self.model, m) 105 | 106 | # Ensure that all weights pruned before are still pruned here. 107 | m, m2 = m.numpy(), m2.numpy() 108 | self.assertEqual(set(m.keys()), set(m2.keys())) 109 | for k in m: 110 | self.assertTrue(np.all(m[k] >= m2[k])) 111 | 112 | for k in m: 113 | layer_pruned = np.sum(1 - m2[k]) 114 | layer_weights = np.sum(m2[k].size) 115 | layer_fraction = float(layer_pruned) / layer_weights 116 | expected_fraction = 1 - (1 - self.hparams_layer.pruning_fraction) ** 2 117 | self.assertGreaterEqual(layer_fraction, expected_fraction) 118 | # self.assertGreater(expected_fraction + 0.0001, layer_fraction) 119 | total_pruned = np.sum([np.sum(1 - v) for v in m2.values()]) 120 | total_weights = np.sum([v.size for v in m2.values()]) 121 | actual_fraction = float(total_pruned) / total_weights 122 | expected_fraction = 1 - (1 - self.hparams_layer.pruning_fraction) ** 2 123 | self.assertGreaterEqual(actual_fraction, expected_fraction) 124 | self.assertGreater(expected_fraction + 0.001, actual_fraction) 125 | 126 | def test_globally_prune_layers_to_ignore(self): 127 | layers_to_ignore = sorted(self.model.prunable_layer_names)[:5] 128 | self.hparams_global.pruning_layers_to_ignore = ','.join(layers_to_ignore) 129 | 130 | m = Strategy.prune(self.hparams_global, self.model).numpy() 131 | 132 | # Ensure that the ignored layers were, indeed, ignored. 133 | for k in layers_to_ignore: 134 | self.assertTrue(np.all(m[k] == 1)) 135 | 136 | # Ensure that the expected fraction of parameters was still pruned. 137 | total_pruned = np.sum([np.sum(1 - v) for v in m.values()]) 138 | total_weights = np.sum([v.size for v in m.values()]) 139 | actual_fraction = float(total_pruned) / total_weights 140 | self.assertGreaterEqual(actual_fraction, self.hparams_global.pruning_fraction) 141 | self.assertGreater(self.hparams_global.pruning_fraction + 0.0001, actual_fraction) 142 | 143 | def test_layer_wise_prune_layers_to_ignore(self): 144 | layers_to_ignore = sorted(self.model.prunable_layer_names)[:5] 145 | self.hparams_layer.pruning_layers_to_ignore = ','.join(layers_to_ignore) 146 | 147 | m = Strategy.prune(self.hparams_layer, self.model).numpy() 148 | 149 | # Ensure that the ignored layers were, indeed, ignored. 150 | for k in layers_to_ignore: 151 | self.assertTrue(np.all(m[k] == 1)) 152 | 153 | # Ensure that the expected fraction of parameters was still pruned. 154 | total_pruned = np.sum([np.sum(1 - v) for v in m.values()]) 155 | total_weights = np.sum([v.size for v in m.values()]) 156 | actual_fraction = float(total_pruned) / total_weights 157 | self.assertGreaterEqual(self.hparams_layer.pruning_fraction, actual_fraction ) 158 | 159 | 160 | 161 | 162 | 163 | test_case.main() 164 | -------------------------------------------------------------------------------- /pruning/test/test_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | import os 8 | import torch 9 | 10 | from foundations import paths 11 | import models.registry 12 | from pruning.mask import Mask 13 | from testing import test_case 14 | 15 | 16 | class TestMask(test_case.TestCase): 17 | def test_dict_behavior(self): 18 | m = Mask() 19 | self.assertEqual(len(m), 0) 20 | self.assertEqual(len(m.keys()), 0) 21 | self.assertEqual(len(m.values()), 0) 22 | 23 | m['hello'] = np.ones([2, 3]) 24 | m['world'] = np.zeros([5, 6]) 25 | self.assertEqual(len(m), 2) 26 | self.assertEqual(len(m.keys()), 2) 27 | self.assertEqual(len(m.values()), 2) 28 | self.assertEqual(set(m.keys()), set(['hello', 'world'])) 29 | self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) 30 | self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) 31 | 32 | del m['hello'] 33 | self.assertEqual(len(m), 1) 34 | self.assertEqual(len(m.keys()), 1) 35 | self.assertEqual(len(m.values()), 1) 36 | self.assertEqual(set(m.keys()), set(['world'])) 37 | self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) 38 | 39 | def test_create_mask_from_dict(self): 40 | m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])}) 41 | self.assertEqual(len(m), 2) 42 | self.assertEqual(len(m.keys()), 2) 43 | self.assertEqual(len(m.values()), 2) 44 | self.assertEqual(set(m.keys()), set(['hello', 'world'])) 45 | self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) 46 | self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) 47 | 48 | def test_create_from_tensor(self): 49 | m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])}) 50 | self.assertEqual(len(m), 2) 51 | self.assertEqual(len(m.keys()), 2) 52 | self.assertEqual(len(m.values()), 2) 53 | self.assertEqual(set(m.keys()), set(['hello', 'world'])) 54 | self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) 55 | self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) 56 | 57 | def test_bad_inputs(self): 58 | m = Mask() 59 | 60 | with self.assertRaises(ValueError): 61 | m[''] = np.ones([2, 3]) 62 | 63 | with self.assertRaises(ValueError): 64 | m[6] = np.ones([2, 3]) 65 | 66 | with self.assertRaises(ValueError): 67 | m['hello'] = [[0, 1], [1, 0]] 68 | 69 | with self.assertRaises(ValueError): 70 | m['hello'] = np.array([[0, 1], [2, 0]]) 71 | 72 | def test_ones_like(self): 73 | model = models.registry.get(models.registry.get_default_hparams('cifar_resnet_20').model_hparams) 74 | m = Mask.ones_like(model) 75 | 76 | for k, v in model.state_dict().items(): 77 | if k in model.prunable_layer_names: 78 | self.assertIn(k, m) 79 | self.assertEqual(list(m[k].shape), list(v.shape)) 80 | self.assertTrue((m[k] == 1).all()) 81 | else: 82 | self.assertNotIn(k, m) 83 | 84 | def test_save_load_exists(self): 85 | self.assertFalse(Mask.exists(self.root)) 86 | self.assertFalse(os.path.exists(paths.mask(self.root))) 87 | 88 | m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])}) 89 | m.save(self.root) 90 | self.assertTrue(os.path.exists(paths.mask(self.root))) 91 | self.assertTrue(Mask.exists(self.root)) 92 | 93 | m2 = Mask.load(self.root) 94 | self.assertEqual(len(m2), 2) 95 | self.assertEqual(len(m2.keys()), 2) 96 | self.assertEqual(len(m2.values()), 2) 97 | self.assertEqual(set(m2.keys()), set(['hello', 'world'])) 98 | self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello'])) 99 | self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world'])) 100 | 101 | 102 | test_case.main() 103 | -------------------------------------------------------------------------------- /show_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sys import platform 3 | import numpy as np 4 | import os 5 | from foundations import local 6 | 7 | def read_logger(filename): 8 | 9 | with open(filename, 'r+') as f: 10 | 11 | train_acc = [] 12 | train_loss = [] 13 | train_iter = [] 14 | test_acc = [] 15 | test_loss = [] 16 | test_iter =[] 17 | 18 | while True: 19 | line = f.readline() 20 | if not line: 21 | break 22 | item = line.split(',') 23 | if item[0] == 'train_loss': 24 | train_loss.append(float(item[2])) 25 | train_iter.append(item[1]) 26 | elif item[0] == 'train_accuracy': 27 | train_acc.append(float(item[2])) 28 | elif item[0] == 'test_loss': 29 | test_loss.append(float(item[2])) 30 | test_iter.append(int(item[1])) 31 | elif item[0] == 'test_accuracy': 32 | test_acc.append(float(item[2])) 33 | 34 | train_loss = np.array(train_loss) 35 | train_acc = np.array(train_acc) 36 | test_loss = np.array(test_loss) 37 | test_acc = np.array(test_acc) 38 | train_iter = np.array(train_iter) 39 | test_iter = np.array(test_iter) 40 | 41 | return train_iter, train_loss, train_acc, test_iter, test_loss, test_acc 42 | 43 | 44 | def level_continuous_reader(filepath, file='main'): 45 | replicate = [i for i in os.listdir(filepath) if '.' not in i] 46 | replicate = ['replicate_1'] 47 | last_loss, last_acc, best_acc, train_acc_last= [],[],[], [] 48 | for num in replicate: 49 | replicate_path = os.path.join(filepath, num) 50 | level = os.listdir(replicate_path) 51 | _level = [i for i in level if 'pretrain' not in i and '.' not in i] 52 | _level.sort(key = lambda x: int(x[6:])) 53 | l = 0 54 | for dir in _level: 55 | level_path = os.path.join(replicate_path, dir) 56 | filename = os.path.join(level_path, file) +'/logger' 57 | 58 | if os.path.exists(filename): 59 | train_iter, train_loss, train_acc, test_iter, test_loss, test_acc = read_logger(filename) 60 | if l == 0: 61 | loss = test_loss 62 | acc = test_acc 63 | iter = test_iter 64 | else: 65 | loss =np.append(loss, test_loss) 66 | acc = np.append(acc, test_acc) 67 | iter = np.append(iter, test_iter+1+iter[-1]) 68 | last_loss.append(test_loss[-1]) 69 | last_acc.append(test_acc[-1]) 70 | best_acc.append(np.max(test_acc)) 71 | if len(train_acc) != 0: 72 | train_acc_last.append(train_acc[-1]) 73 | else: 74 | if l == 0: 75 | filename_main = os.path.join(level_path, 'main') +'/logger' 76 | train_iter, train_loss, train_acc, test_iter, test_loss, test_acc = read_logger(filename_main) 77 | loss = test_loss 78 | acc = test_acc 79 | iter = test_iter 80 | 81 | l += 1 82 | loss = loss.reshape(len(replicate), -1) 83 | acc = acc.reshape(len(replicate), -1) 84 | iter = iter.reshape(len(replicate), -1)[0] 85 | loss = np.mean(loss, 0) 86 | acc = np.mean(acc, 0) 87 | 88 | last_loss = np.array(last_loss).reshape(len(replicate), -1) 89 | last_acc = np.array(last_acc).reshape(len(replicate), -1) 90 | best_acc = np.array(best_acc).reshape(len(replicate), -1) 91 | last_loss = np.mean(last_loss, 0) 92 | last_acc = np.mean(last_acc, 0) 93 | best_acc = np.mean(best_acc, 0) 94 | if len(train_acc_last) != 0: 95 | train_acc_last = np.array(train_acc_last).reshape(len(replicate), -1) 96 | train_acc_last = np.mean(train_acc_last, 0) 97 | print('\n') 98 | for i in range(last_acc.size): 99 | print('%.2f'%(last_acc[i]*100), end=' ') 100 | print('') 101 | for i in range(best_acc.size): 102 | print('%.2f'%(best_acc[i]*100), end=' ') 103 | print('') 104 | if len(train_acc_last) != 0: 105 | for i in range(train_acc_last.size): 106 | print('%.2f'%(train_acc_last[i]*100), end=' ') 107 | print('\n') 108 | 109 | 110 | def single_level_reader(filepath, file='main'): 111 | replicate = [i for i in os.listdir(filepath) if '.' not in i] 112 | replicate = ['replicate_1'] 113 | last_loss, last_acc, best_acc, train_acc_last= [],[],[], [] 114 | for num in replicate: 115 | replicate_path = os.path.join(filepath, num) 116 | filename = os.path.join(replicate_path, file) +'/logger' 117 | 118 | if os.path.exists(filename): 119 | train_iter, train_loss, train_acc, test_iter, test_loss, test_acc = read_logger(filename) 120 | 121 | test_loss = test_loss.reshape(len(replicate), -1) 122 | train_loss = train_loss.reshape(len(replicate), -1) 123 | test_acc = test_acc.reshape(len(replicate), -1) 124 | train_acc = train_acc.reshape(len(replicate), -1) 125 | # iter = iter.reshape(len(replicate), -1)[0] 126 | test_loss = np.mean(test_loss, 0) 127 | train_loss = np.mean(train_loss, 0) 128 | test_acc = np.mean(test_acc, 0) 129 | train_acc = np.mean(train_acc, 0) 130 | 131 | print('last train loss %.2f'%(train_loss[-1]*100)) 132 | print('last train acc %.2f'%(train_acc[-1]*100)) 133 | print('last test acc %.2f'%(test_acc[-1]*100)) 134 | print('best test acc %.2f'%(np.max(test_acc)*100)) 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--name', type = str, 140 | help='The name of file.') 141 | args = parser.parse_args() 142 | platform = local.Platform() 143 | file_path = os.path.join(platform.root,args.name) 144 | if 'train' in args.name: 145 | file = 'main' 146 | single_level_reader(file_path, ) 147 | else: 148 | file = 'main' 149 | level_continuous_reader(file_path, file) -------------------------------------------------------------------------------- /testing/test_case.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import numpy as np 7 | import os 8 | import unittest 9 | import shutil 10 | from foundations.local import Platform 11 | 12 | class Platform(Platform): 13 | @property 14 | def device_str(self): 15 | return 'cpu' 16 | 17 | @property 18 | def is_parallel(self): 19 | return False 20 | 21 | @property 22 | def root(self): 23 | return '/data/hezheng/pruning-robustness/TESTING' 24 | 25 | 26 | class TestCase(unittest.TestCase): 27 | def setUp(self): 28 | platform = Platform() 29 | self.root = platform.root 30 | 31 | # def tearDown(self): 32 | # if os.path.exists(self.root): shutil.rmtree(self.root) 33 | # platforms.platform._PLATFORM = self.saved_platform 34 | 35 | @staticmethod 36 | def get_state(model): 37 | """Get a copy of the state of a model.""" 38 | 39 | return {k: v.clone().detach().cpu().numpy() for k, v in model.state_dict().items()} 40 | 41 | def assertStateEqual(self, state1, state2): 42 | """Assert that two models states are equal.""" 43 | 44 | self.assertEqual(set(state1.keys()), set(state2.keys())) 45 | for k in state1: 46 | self.assertTrue(np.array_equal(state1[k], state2[k])) 47 | 48 | def assertStateAllNotEqual(self, state1, state2): 49 | """Assert that two models states are not equal in any tensor.""" 50 | 51 | self.assertEqual(set(state1.keys()), set(state2.keys())) 52 | for k in state1: 53 | self.assertFalse(np.array_equal(state1[k], state2[k])) 54 | 55 | 56 | def main(): 57 | if __name__ == '__main__': 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /testing/toy_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import torch 7 | 8 | from models.base import Model 9 | 10 | 11 | class InnerProductModel(Model): 12 | @staticmethod 13 | def default_hparams(): raise NotImplementedError 14 | 15 | @staticmethod 16 | def is_valid_model_name(model_name): raise NotImplementedError 17 | 18 | @staticmethod 19 | def get_model_from_name(model_name): raise NotImplementedError 20 | 21 | @property 22 | def output_layer_names(self): raise NotImplementedError 23 | 24 | @property 25 | def loss_criterion(self): return torch.nn.MSELoss() 26 | 27 | def __init__(self, n): 28 | super(Model, self).__init__() 29 | self.layer = torch.nn.Linear(n, 1, bias=False) 30 | self.layer.weight.data = torch.arange(n, dtype=torch.float32) 31 | 32 | def forward(self, x): 33 | return self.layer(x) 34 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /training/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from foundations.local import Platform 7 | import torch 8 | import os 9 | from foundations import paths 10 | from foundations.step import Step 11 | from training.metric_logger import MetricLogger 12 | 13 | 14 | def save_checkpoint_callback(output_location, step, model, optimizer, logger): 15 | torch.save({ 16 | 'ep': step.ep, 17 | 'it': step.it, 18 | 'model_state_dict': model.state_dict(), 19 | 'optimizer_state_dict': optimizer.state_dict(), 20 | 'logger': str(logger), 21 | }, paths.checkpoint(output_location)) 22 | 23 | def restore_checkpoint(output_location, model, optimizer, iterations_per_epoch): 24 | checkpoint_location = paths.checkpoint(output_location) 25 | if not os.path.exists(checkpoint_location): 26 | return None, None 27 | checkpoint = torch.load(checkpoint_location, map_location=torch.device('cpu')) 28 | 29 | # Handle DataParallel. 30 | module_in_name = Platform().is_parallel 31 | if module_in_name and not all(k.startswith('module.') for k in checkpoint['model_state_dict']): 32 | checkpoint['model_state_dict'] = {'module.' + k: v for k, v in checkpoint['model_state_dict'].items()} 33 | elif all(k.startswith('module.') for k in checkpoint['model_state_dict']) and not module_in_name: 34 | checkpoint['model_state_dict'] = {k[len('module.'):]: v for k, v in checkpoint['model_state_dict'].items()} 35 | 36 | model.load_state_dict(checkpoint['model_state_dict']) 37 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 38 | step = Step.from_epoch(checkpoint['ep'], checkpoint['it'], iterations_per_epoch) 39 | logger = MetricLogger.create_from_string(checkpoint['logger']) 40 | 41 | return step, logger 42 | -------------------------------------------------------------------------------- /training/desc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | from sys import platform 8 | from dataclasses import dataclass 9 | import os 10 | 11 | from datasets import registry as datasets_registry 12 | from foundations import desc 13 | from foundations import hparams 14 | from foundations.step import Step 15 | # from lottery.desc import LotteryDesc 16 | from foundations.local import Platform 17 | 18 | @dataclass 19 | class TrainingDesc(desc.Desc): 20 | """The hyperparameters necessary to describe a training run.""" 21 | 22 | model_hparams: hparams.ModelHparams 23 | dataset_hparams: hparams.DatasetHparams 24 | training_hparams: hparams.TrainingHparams 25 | 26 | @staticmethod 27 | def name_prefix(): return 'train' 28 | 29 | @staticmethod 30 | def add_args(parser: argparse.ArgumentParser, defaults: 'TrainingDesc' = None): 31 | hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None) 32 | hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None) 33 | hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None) 34 | 35 | @staticmethod 36 | def create_from_args(args: argparse.Namespace) -> 'TrainingDesc': 37 | dataset_hparams = hparams.DatasetHparams.create_from_args(args) 38 | model_hparams = hparams.ModelHparams.create_from_args(args) 39 | training_hparams = hparams.TrainingHparams.create_from_args(args) 40 | return TrainingDesc(model_hparams, dataset_hparams, training_hparams) 41 | 42 | @property 43 | def end_step(self): 44 | iterations_per_epoch = datasets_registry.iterations_per_epoch(self.dataset_hparams) 45 | return Step.from_str(self.training_hparams.training_steps, iterations_per_epoch) 46 | 47 | @property 48 | def train_outputs(self): 49 | return datasets_registry.num_classes(self.dataset_hparams) 50 | 51 | def run_path(self, replicate, experiment='main'): 52 | return os.path.join(Platform().root, self.hashname, f'replicate_{replicate}', experiment) 53 | 54 | @property 55 | def display(self): 56 | return '\n'.join([self.dataset_hparams.display, self.model_hparams.display, self.training_hparams.display]) 57 | -------------------------------------------------------------------------------- /training/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | import os 6 | from foundations import paths 7 | from foundations.step import Step 8 | 9 | class MetricLogger: 10 | def __init__(self): 11 | self.log = {} 12 | 13 | def add(self, name: str, step: Step, value: float): 14 | self.log[(name, step.iteration)] = value 15 | 16 | def __str__(self): 17 | return '\n'.join(['{},{},{}'.format(k[0], k[1], v) for k, v in self.log.items()]) 18 | 19 | @staticmethod 20 | def create_from_string(as_str): 21 | logger = MetricLogger() 22 | if len(as_str.strip()) == 0: 23 | return logger 24 | 25 | rows = [row.split(',') for row in as_str.strip().split('\n')] 26 | logger.log = {(name, int(iteration)): float(value) for name, iteration, value in rows} 27 | return logger 28 | 29 | @staticmethod 30 | def create_from_file(filename): 31 | with open(paths.logger(filename)) as fp: 32 | as_str = fp.read() 33 | return MetricLogger.create_from_string(as_str) 34 | 35 | def save(self, location): 36 | if not os.path.exists(location): 37 | os.makedirs(location) 38 | with open(paths.logger(location), 'w') as fp: 39 | fp.write(str(self)) 40 | 41 | def get_data(self, desired_name): 42 | d = {k[1]: v for k, v in self.log.items() if k[0] == desired_name} 43 | return [(k, d[k]) for k in sorted(d.keys())] 44 | -------------------------------------------------------------------------------- /training/optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import bisect 7 | import numpy as np 8 | import torch 9 | 10 | from foundations.hparams import TrainingHparams 11 | from foundations.step import Step 12 | from models.base import Model 13 | 14 | 15 | def get_optimizer(training_hparams: TrainingHparams, model: Model) -> torch.optim.Optimizer: 16 | if training_hparams.optimizer_name == 'sgd': 17 | return torch.optim.SGD( 18 | model.parameters(), 19 | lr=training_hparams.lr, 20 | momentum=training_hparams.momentum or training_hparams.nesterov_momentum or 0, 21 | weight_decay=training_hparams.weight_decay or 0, 22 | nesterov=training_hparams.nesterov_momentum is not None and training_hparams.nesterov_momentum > 0 23 | ) 24 | elif training_hparams.optimizer_name == 'adam': 25 | return torch.optim.Adam( 26 | model.parameters(), 27 | lr=training_hparams.lr, 28 | weight_decay=training_hparams.weight_decay or 0 29 | ) 30 | 31 | raise ValueError('No such optimizer: {}'.format(training_hparams.optimizer_name)) 32 | 33 | 34 | def get_lr_schedule(training_hparams: TrainingHparams, optimizer: torch.optim.Optimizer, iterations_per_epoch: int): 35 | lambdas = [lambda it: 1.0] 36 | 37 | # Drop the learning rate according to gamma at the specified milestones. 38 | if bool(training_hparams.gamma) != bool(training_hparams.milestone_steps): 39 | raise ValueError('milestones and gamma hyperparameters must both be set or not at all.') 40 | if training_hparams.milestone_steps: 41 | milestones = [Step.from_str(x, iterations_per_epoch).iteration 42 | for x in training_hparams.milestone_steps.split(',')] 43 | lambdas.append(lambda it: training_hparams.gamma ** bisect.bisect(milestones, it)) 44 | 45 | # Add linear learning rate warmup if specified. 46 | if training_hparams.warmup_steps: 47 | warmup_iters = Step.from_str(training_hparams.warmup_steps, iterations_per_epoch).iteration 48 | lambdas.append(lambda it: min(1.0, it / warmup_iters)) 49 | 50 | # Combine the lambdas. 51 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda it: np.product([l(it) for l in lambdas])) 52 | 53 | def distance_reguralization(model, init_state): 54 | distance = 0 55 | for name, param in model.model.state_dict().items(): 56 | if 'weight' in name or 'bias' in name: 57 | distance += torch.norm(param-init_state[name], p='fro')**2 58 | 59 | return distance 60 | -------------------------------------------------------------------------------- /training/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | from dataclasses import dataclass 8 | 9 | from utils import shared_args 10 | from foundations.runner import Runner 11 | import models.registry 12 | from training import train 13 | from training.desc import TrainingDesc 14 | 15 | 16 | @dataclass 17 | class TrainingRunner(Runner): 18 | replicate: int 19 | desc: TrainingDesc 20 | verbose: bool = True 21 | evaluate_every_epoch: bool = True 22 | 23 | @staticmethod 24 | def description(): 25 | return "Train a model." 26 | 27 | @staticmethod 28 | def add_args(parser: argparse.ArgumentParser) -> None: 29 | shared_args.JobArgs.add_args(parser) 30 | TrainingDesc.add_args(parser, shared_args.maybe_get_default_hparams()) 31 | 32 | @staticmethod 33 | def create_from_args(args: argparse.Namespace) -> 'TrainingRunner': 34 | return TrainingRunner(args.replicate, TrainingDesc.create_from_args(args), 35 | not args.quiet, not args.evaluate_only_at_end) 36 | 37 | def display_output_location(self): 38 | print(self.desc.run_path(self.replicate)) 39 | 40 | def run(self): 41 | if self.verbose: 42 | print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82) 43 | print(self.desc.display) 44 | print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n') 45 | self.desc.save(self.desc.run_path(self.replicate)) 46 | train.standard_train( 47 | models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs), self.desc.run_path(self.replicate), 48 | self.desc.dataset_hparams, self.desc.training_hparams, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) 49 | -------------------------------------------------------------------------------- /training/standard_callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import time 7 | import torch 8 | 9 | from datasets.base import DataLoader 10 | from foundations import hparams 11 | from foundations.step import Step 12 | from training import checkpointing 13 | from foundations.local import Platform 14 | 15 | # Standard callbacks. 16 | def save_model(output_location, step, model, optimizer, logger): 17 | model.save(output_location, step) 18 | 19 | 20 | def save_logger(output_location, step, model, optimizer, logger): 21 | logger.save(output_location) 22 | 23 | 24 | def create_timekeeper_callback(): 25 | time_of_last_call = None 26 | 27 | def callback(output_location, step, model, optimizer, logger): 28 | nonlocal time_of_last_call 29 | t = 0.0 if time_of_last_call is None else time.time() - time_of_last_call 30 | print(f'Ep {step.ep}\tIt {step.it}\tTime Elapsed {t:.2f}') 31 | time_of_last_call = time.time() 32 | 33 | return callback 34 | 35 | 36 | def create_eval_callback(eval_name: str, loader: DataLoader, verbose=False): 37 | """This function returns a callback.""" 38 | 39 | time_of_last_call = None 40 | 41 | def eval_callback(output_location, step, model, optimizer, logger): 42 | example_count = torch.tensor(0.0).to(Platform().torch_device) 43 | total_loss = torch.tensor(0.0).to(Platform().torch_device) 44 | total_correct = torch.tensor(0.0).to(Platform().torch_device) 45 | 46 | def correct(labels, output): 47 | return torch.sum(torch.eq(labels, output.argmax(dim=1))) 48 | 49 | model.eval() 50 | 51 | with torch.no_grad(): 52 | for examples, labels in loader: 53 | examples = examples.to(Platform().torch_device) 54 | labels = labels.squeeze().to(Platform().torch_device) 55 | output = model(examples) 56 | 57 | labels_size = torch.tensor(len(labels), device=Platform().torch_device) 58 | example_count += labels_size 59 | total_loss += model.loss_criterion(output, labels) * labels_size 60 | total_correct += correct(labels, output) 61 | 62 | total_loss = total_loss.cpu().item() 63 | total_correct = total_correct.cpu().item() 64 | example_count = example_count.cpu().item() 65 | 66 | logger.add('{}_loss'.format(eval_name), step, total_loss / example_count) 67 | logger.add('{}_accuracy'.format(eval_name), step, total_correct / example_count) 68 | logger.add('{}_examples'.format(eval_name), step, example_count) 69 | 70 | if verbose: 71 | nonlocal time_of_last_call 72 | elapsed = 0 if time_of_last_call is None else time.time() - time_of_last_call 73 | print('{}\tep {:03d}\tit {:03d}\tloss {:.3f}\tacc {:.2f}%\tex {:d}\ttime {:.2f}s'.format( 74 | eval_name, step.ep, step.it, total_loss/example_count, 100 * total_correct/example_count, 75 | int(example_count), elapsed)) 76 | time_of_last_call = time.time() 77 | 78 | return eval_callback 79 | 80 | 81 | # Callback frequencies. Each takes a callback as an argument and returns a new callback 82 | # that runs only at the specified frequency. 83 | def run_every_epoch(callback): 84 | def modified_callback(output_location, step, model, optimizer, logger): 85 | if step.it != 0: 86 | return 87 | callback(output_location, step, model, optimizer, logger) 88 | return modified_callback 89 | 90 | 91 | def run_every_step(callback): 92 | return callback 93 | 94 | 95 | def run_at_step(step1, callback): 96 | def modified_callback(output_location, step, model, optimizer, logger): 97 | if step != step1: 98 | return 99 | callback(output_location, step, model, optimizer, logger) 100 | return modified_callback 101 | 102 | 103 | # The standard set of callbacks that should be used for a normal training run. 104 | def standard_callbacks(training_hparams: hparams.TrainingHparams, train_set_loader: DataLoader, 105 | test_set_loader: DataLoader, eval_on_train: bool = True, verbose: bool = True, 106 | start_step: Step = None, evaluate_every_epoch: bool = True): 107 | start = start_step or Step.zero(train_set_loader.iterations_per_epoch) 108 | end = Step.from_str(training_hparams.training_steps, train_set_loader.iterations_per_epoch) 109 | test_eval_callback = create_eval_callback('test', test_set_loader, verbose=verbose) 110 | train_eval_callback = create_eval_callback('train', train_set_loader, verbose=verbose) 111 | 112 | # Basic checkpointing and state saving at the beginning and end. 113 | result = [ 114 | run_at_step(start, save_model), 115 | run_at_step(end, save_model), 116 | run_at_step(end, save_logger), 117 | run_every_epoch(checkpointing.save_checkpoint_callback), 118 | ] 119 | 120 | # Test every epoch if requested. 121 | if evaluate_every_epoch: result = [run_every_epoch(test_eval_callback)] + result 122 | elif verbose: result.append(run_every_epoch(create_timekeeper_callback())) 123 | 124 | # Ensure that testing occurs at least at the beginning and end of training. 125 | if start.it != 0 or not evaluate_every_epoch: result = [run_at_step(start, test_eval_callback)] + result 126 | if end.it != 0 or not evaluate_every_epoch: result = [run_at_step(end, test_eval_callback)] + result 127 | 128 | # Do the same for the train set if requested. 129 | if eval_on_train: 130 | if evaluate_every_epoch: result = [run_every_epoch(train_eval_callback)] + result 131 | if start.it != 0 or not evaluate_every_epoch: result = [run_at_step(start, train_eval_callback)] + result 132 | if end.it != 0 or not evaluate_every_epoch: result = [run_at_step(end, train_eval_callback)] + result 133 | 134 | return result 135 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | -------------------------------------------------------------------------------- /utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import argparse 7 | 8 | 9 | def maybe_get_arg(arg_name, positional=False, position=0, boolean_arg=False): 10 | parser = argparse.ArgumentParser(add_help=False) 11 | prefix = '' if positional else '--' 12 | if positional: 13 | for i in range(position): 14 | parser.add_argument(f'arg{i}') 15 | if boolean_arg: 16 | parser.add_argument(prefix + arg_name, action='store_true') 17 | else: 18 | parser.add_argument(prefix + arg_name, type=str, default=None) 19 | try: 20 | args = parser.parse_known_args()[0] 21 | return getattr(args, arg_name) if arg_name in args else None 22 | except: 23 | return None 24 | -------------------------------------------------------------------------------- /utils/shared_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | from dataclasses import dataclass 7 | 8 | from utils import arg_utils 9 | from foundations.hparams import Hparams 10 | import models.registry 11 | 12 | 13 | @dataclass 14 | class JobArgs(Hparams): 15 | """Arguments shared across lottery ticket jobs.""" 16 | 17 | replicate: int = 1 18 | default_hparams: str = None 19 | quiet: bool = False 20 | evaluate_only_at_end: bool = False 21 | 22 | _name: str = 'High-Level Arguments' 23 | _description: str = 'Arguments that determine how the job is run and where it is stored.' 24 | _replicate: str = 'The index of this particular replicate. ' \ 25 | 'Use a different replicate number to run another copy of the same experiment.' 26 | _default_hparams: str = 'Populate all arguments with the default hyperparameters for this model.' 27 | _quiet: str = 'Suppress output logging about the training status.' 28 | _evaluate_only_at_end: str = 'Run the test set only before and after training. Otherwise, will run every epoch.' 29 | 30 | 31 | def maybe_get_default_hparams(): 32 | default_hparams = arg_utils.maybe_get_arg('default_hparams') 33 | return models.registry.get_default_hparams(default_hparams) if default_hparams else None 34 | -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This code is from OpenLTH repository https://github.com/facebookresearch/open_lth 4 | # licensed under the MIT license 5 | 6 | import random 7 | from pruning.mask import Mask 8 | from models.base import Model 9 | import torch 10 | 11 | import typing 12 | 13 | 14 | def vectorize(state_dict: typing.Dict[str, torch.Tensor]): 15 | """Convert a state dict into a single column Tensor in a repeatable way.""" 16 | 17 | return torch.cat([state_dict[k].reshape(-1) for k in sorted(state_dict.keys())]) 18 | 19 | 20 | def unvectorize(vector: torch.Tensor, reference_state_dict: typing.Dict[str, torch.Tensor]): 21 | """Convert a vector back into a state dict with the same shapes as reference state_dict.""" 22 | 23 | if len(vector.shape) > 1: raise ValueError('vector has more than one dimension.') 24 | 25 | state_dict = {} 26 | for k in sorted(reference_state_dict.keys()): 27 | if vector.nelement() == 0: raise ValueError('Ran out of values.') 28 | 29 | size, shape = reference_state_dict[k].nelement(), reference_state_dict[k].shape 30 | this, vector = vector[:size], vector[size:] 31 | state_dict[k] = this.reshape(shape) 32 | 33 | if vector.nelement() > 0: raise ValueError('Excess values.') 34 | return state_dict 35 | 36 | 37 | def perm(N, seed: int = None): 38 | """Generate a tensor with the numbers 0 through N-1 ordered randomly.""" 39 | 40 | gen = torch.Generator() 41 | if seed is not None: gen.manual_seed(seed) 42 | perm = torch.normal(torch.zeros(N), torch.ones(N), generator=gen) 43 | return torch.argsort(perm) 44 | 45 | 46 | def shuffle_tensor(tensor: torch.Tensor, seed: int = None): 47 | """Randomly shuffle the elements of a tensor.""" 48 | 49 | shape = tensor.shape 50 | return tensor.reshape(-1)[perm(tensor.nelement(), seed=seed)].reshape(shape) 51 | 52 | 53 | def shuffle_state_dict(state_dict: typing.Dict[str, torch.Tensor], seed: int = None): 54 | """Randomly shuffle each of the tensors in a state_dict.""" 55 | 56 | output = {} 57 | for i, k in enumerate(sorted(state_dict.keys())): 58 | output[k] = shuffle_tensor(state_dict[k], seed=None if seed is None else seed+i) 59 | return output 60 | 61 | def shuffle_model_params(model: Model, mask: Mask, seed: int = None): 62 | """ 63 | Shuffle the pruned parameters in a model layerwise 64 | """ 65 | for name, param in model.named_parameters(): 66 | if name in model.prunable_layer_names: 67 | data_numpy = param.data.cpu().numpy() 68 | data_shuffle = data_numpy[mask[name]==0] 69 | random.Random(seed).shuffle(data_shuffle) 70 | data_numpy[mask[name]==0] = data_shuffle 71 | return model 72 | --------------------------------------------------------------------------------