├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── image_dataset.py └── synthetic_dataset.py ├── exp_proc └── run_ntk.py ├── models └── __init__.py ├── optims └── __init__.py ├── result.txt ├── utils └── __init__.py └── vis ├── curve_10.png ├── curve_100.png ├── curve_1000.png ├── curve_10000.png ├── curve_100000.png ├── curve_50000.png ├── linear_regression_data.png └── sin_regression_data.png /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DL_theory_exp 2 | 3 | **Among all those mysterious findings and observations of deep learning, which are the real clues that can potentially open the black box of neural networks?** 4 | 5 | This repository is a sandbox for conducting concept validating experiments for **any** interesting and insightful observation in deep learning. The primary goal for this project is to gain comprehensive understanding about details, conditions, and potentially unexplored aspects of some well-known phenomena and theoretical claims. 6 | 7 | I personally believe current status of deep learning theory is like what physists were before General Relativity or Maxwell's equation being proposed. We need substantial novel and intricate observations as "clues" to intersect and combine to reveal the true story. Only through extensive empirical explorations can we finally propose an elegant and principled theory for deep learning. 8 | 9 | Therefore, this repository will go through interesting findings and observations researchers have discovered in the past years. It will delve deep into every detail and main claim for those papers, reproducing the results, in order to validate the claim and comprehensively understand the condition and implications. 10 | 11 | Maybe we can find some novel insights, who knows! 12 | 13 | ## The List of Phenomena To Be Explored 14 | 15 | | Name | Introduction | Paper | Author | 16 | | :-------:| :----: | :-----: | :-----: | 17 | | 1. Information Bottleneck | | Deep Learning and the Information Bottleneck Principle| Tishby *et al.* | 18 | | 2. Edge of Stability| | Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability| Cohen *et al.*| 19 | | 3. Lottery Ticket Hypothesis | | The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks | Frankle *et al.* | 20 | 21 | To be continued... 22 | 23 | ## Need Your Help! 24 | You can add `pull Request` to propose the mechanism and phenomena that you feel interesting and want to be explored! We will add detailed experimental log and report in this repo! 25 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.synthetic_dataset import * -------------------------------------------------------------------------------- /datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from PIL import Image 7 | import torch.nn as nn 8 | 9 | class Cifar10(Dataset): 10 | def __init__(self, datapath='./data', 11 | trunc_len=None, 12 | return_label=True, 13 | random_flip=False, 14 | random_crop=False, 15 | class_filter=None, 16 | is_train=True, 17 | transform=None 18 | ): 19 | 20 | self.data = [] 21 | self.targets = [] 22 | self.return_label = return_label 23 | self.transform = transforms.Compose([ 24 | transforms.RandomHorizontalFlip() if random_flip else nn.Identity(), 25 | transforms.RandomCrop(32, padding=4) if random_crop else nn.Identity(), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 28 | ]) if transform is None else transform 29 | 30 | # load torchvision cifar10 dataset 31 | raw_dataset = torchvision.datasets.CIFAR10(root=datapath, train=is_train, download=True, transform=self.transform) 32 | cnt = 0 33 | print("Loading CIFAR10 dataset...") 34 | for data, target in raw_dataset: 35 | if trunc_len is not None and cnt >= trunc_len: 36 | break 37 | if class_filter is None or target in class_filter: 38 | self.data.append(data) 39 | self.targets.append(target) 40 | cnt += 1 41 | print("Done!") 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, idx): 47 | if self.return_label: 48 | return self.data[idx], self.targets[idx] 49 | else: 50 | return self.data[idx] 51 | 52 | def get_data(self): 53 | """ return images as a tensor X and labels as a tensor y """ 54 | return torch.stack(self.data).reshape(len(self.data), -1), torch.tensor(self.targets) 55 | 56 | class ImageDataset(Dataset): 57 | def __init__(self, datapath, 58 | image_size=224, 59 | trunc_len=None, 60 | return_label=True, 61 | random_flip=False, 62 | random_crop=False, 63 | class_filter=None, 64 | is_train=True, 65 | transform=None): 66 | 67 | self.data = [] 68 | self.targets = [] 69 | self.return_label = return_label 70 | self.transform = transforms.Compose([ 71 | transforms.Resize((image_size, image_size)), 72 | transforms.RandomHorizontalFlip() if random_flip else nn.Identity(), 73 | transforms.RandomCrop(image_size, padding=4) if random_crop else nn.Identity(), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 76 | ]) if transform is None else transform 77 | 78 | print("Loading ImageNet dataset...") 79 | raw_dataset = torchvision.datasets.ImageFolder(root=datapath, transform=self.transform, train=is_train) 80 | cnt = 0 81 | for data, target in raw_dataset: 82 | if trunc_len is not None and cnt >= trunc_len: 83 | break 84 | if class_filter is None or target in class_filter: 85 | self.data.append(data) 86 | self.targets.append(target) 87 | cnt += 1 88 | print("Done!") 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | def __getitem__(self, idx): 94 | if self.return_label: 95 | return self.data[idx], self.targets[idx] 96 | else: 97 | return self.data[idx] 98 | 99 | def get_data(self): 100 | """ return images as a tensor X and labels as a tensor y """ 101 | return torch.stack(self.data).reshape(len(self.data), -1), torch.tensor(self.targets) 102 | 103 | class GeneralImageDataset(Dataset): 104 | def __init__(self, datapath, 105 | image_size=224, 106 | trunc_len=None, 107 | return_label=True, 108 | random_flip=False, 109 | random_crop=False, 110 | transform=None, 111 | ): 112 | 113 | self.data = [] 114 | self.targets = [] 115 | self.return_label = return_label 116 | self.transform = transforms.Compose([ 117 | transforms.Resize((image_size, image_size)), 118 | transforms.RandomHorizontalFlip() if random_flip else nn.Identity(), 119 | transforms.RandomCrop(image_size, padding=4) if random_crop else nn.Identity(), 120 | transforms.ToTensor(), 121 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 122 | ]) if transform is None else transform 123 | 124 | cnt = 0 125 | print(f"Loading general image dataset from {datapath}...") 126 | assert os.path.exists(datapath), f"Path {datapath} does not exist!" 127 | for root, dirs, files in os.walk(datapath): 128 | for file in files: 129 | if trunc_len is not None and cnt >= trunc_len: 130 | break 131 | img = Image.open(os.path.join(root, file)) 132 | img = self.transform(img) 133 | self.data.append(img) 134 | cnt += 1 135 | 136 | print("Done!") 137 | def __len__(self): 138 | return len(self.data) 139 | 140 | def __getitem__(self, idx): 141 | return self.data[idx] 142 | 143 | def get_data(self): 144 | """ return images as a tensor X and labels as a tensor y """ 145 | return torch.stack(self.data).reshape(len(self.data), -1), torch.tensor(self.targets) 146 | 147 | if __name__ == "__main__": 148 | # dataset = Cifar10() 149 | # print(len(dataset)) 150 | # print(dataset[0]) 151 | # print(dataset[0][0].shape) 152 | # print(dataset[0][1]) 153 | # print("Done!") 154 | dataset = GeneralImageDataset(datapath='/cluster/home1/lurui/ffhq128', 155 | trunc_len=1000, image_size=64, random_flip=True) 156 | print(len(dataset)) 157 | print(dataset[3].shape) 158 | print("Done!") 159 | 160 | print(dataset.get_data()[0].shape) -------------------------------------------------------------------------------- /datasets/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | class SyntheticDataset(Dataset): 6 | def __init__(self, 7 | dim_in, 8 | dim_out, 9 | n_samples, 10 | func, 11 | data_generator=None, 12 | noise_scale=0.1, 13 | **kwargs): 14 | super().__init__() 15 | self.dim_in = dim_in 16 | self.dim_out = dim_out 17 | self.n_samples = n_samples 18 | self.func = func 19 | self.noise_scale = noise_scale 20 | 21 | if data_generator is None: 22 | # create a function that returns Gaussian noise of shape (n_samples, dim_out) 23 | self.data_generator = lambda: np.random.normal(0, 1, (n_samples, dim_out)) 24 | else: 25 | self.data_generator = data_generator 26 | 27 | self.X = self.data_generator() 28 | self.y = [self.func(x) for x in self.X] 29 | self.y = np.array(self.y) 30 | self.y += np.random.normal(0, noise_scale, (n_samples, dim_out)) 31 | self.X = torch.tensor(self.X, dtype=torch.float32) 32 | self.y = torch.tensor(self.y, dtype=torch.float32) 33 | 34 | def __len__(self): 35 | return self.n_samples 36 | 37 | def __getitem__(self, idx): 38 | return self.X[idx], self.y[idx] 39 | 40 | def get_data(self): 41 | return self.X, self.y 42 | 43 | def get_true_data(self): 44 | return self.X, self.func(self.X) 45 | 46 | def get_true_predictions(self): 47 | return self.func(self.X) 48 | 49 | def get_noisy_predictions(self): 50 | return self.func(self.X) + np.random.normal(0, self.noise_scale, (self.n_samples, self.dim_out)) 51 | 52 | def get_params(self): 53 | pass 54 | 55 | class LinearRegressionDataset(SyntheticDataset): 56 | def __init__(self, 57 | n_samples=1000, 58 | n_features=1, 59 | dim_out=1, 60 | bias_scale = 0, 61 | noise_scale=0.1, 62 | data_distribution="Gaussian", 63 | set_weights=None, 64 | set_bias=None, 65 | ): 66 | 67 | assert set_weights is None or set_weights.shape == (n_features, dim_out) 68 | assert set_bias is None or set_bias.shape == (dim_out,) 69 | self.W = np.random.normal(0, 1, (n_features, dim_out)) if set_weights is None else set_weights 70 | self.b = np.random.normal(0, bias_scale, (dim_out,)) if set_bias is None else set_bias 71 | 72 | assert data_distribution in ["Gaussian", "Uniform"] # only support Gaussian and Uniform distribution 73 | if data_distribution == "Gaussian": 74 | self.data_generator = lambda: np.random.normal(0, 1, (n_samples, n_features)) 75 | elif data_distribution == "Uniform": 76 | self.data_generator = lambda: np.random.uniform(-1, 1, (n_samples, n_features)) 77 | 78 | def linearfunc(x, W=self.W, b=self.b): 79 | return np.dot(x, W) + b 80 | super().__init__(n_features, dim_out, n_samples, linearfunc, self.data_generator, noise_scale) 81 | 82 | def get_params(self): 83 | return self.W, self.b 84 | 85 | class SinusoidalDataset(SyntheticDataset): 86 | def __init__(self, 87 | n_samples=1000, 88 | n_features=1, 89 | dim_out=1, 90 | noise_scale=0.1, 91 | bias_scale = 0, 92 | data_distribution="Gaussian", 93 | set_weights=None, 94 | set_bias=None, 95 | ): 96 | self.noise_scale = noise_scale 97 | 98 | assert data_distribution in ["Gaussian", "Uniform"] # only support Gaussian and Uniform distribution 99 | if data_distribution == "Gaussian": 100 | data_generator = lambda: np.random.normal(0, 1, (n_samples, n_features)) 101 | elif data_distribution == "Uniform": 102 | data_generator = lambda: np.random.uniform(-1, 1, (n_samples, n_features)) 103 | 104 | assert set_weights is None or set_weights.shape == (n_features,) 105 | assert set_bias is None or set_bias.shape == (1,) 106 | self.W = np.random.normal(0, 1, (n_features, 1)) if set_weights is None else set_weights 107 | self.b = np.random.normal(0, bias_scale, (1,)) if set_bias is None else set_bias 108 | 109 | def sinusoidalfunc(x, W=self.W, b=self.b): 110 | # print(x.shape, W.shape, b.shape) 111 | return np.sin(np.dot(x, W) + b) 112 | 113 | super().__init__(n_features, dim_out, n_samples, sinusoidalfunc, data_generator) 114 | 115 | def get_params(self): 116 | return self.W, self.b 117 | 118 | if __name__ == "__main__": 119 | dataset = LinearRegressionDataset( 120 | n_samples=100, 121 | n_features=1, 122 | dim_out=1, 123 | bias_scale = 0, 124 | noise_scale=3e-2, 125 | ) 126 | # dataset = SinusoidalDataset( 127 | # n_samples=100, 128 | # n_features=1, 129 | # dim_out=1, 130 | # noise_scale=3e-2, 131 | # data_distribution="Uniform", 132 | # set_weights=np.array([3.0]), 133 | # ) 134 | print(dataset[0]) 135 | import matplotlib.pyplot as plt 136 | X, y = dataset.get_data() 137 | print(dataset.get_params()) 138 | # plt.scatter(X, y) 139 | # plt.savefig("vis/sin_regression_data.png") -------------------------------------------------------------------------------- /exp_proc/run_ntk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | torch.manual_seed(42) 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | def relative_change(initial_parameters, new_parameters): 15 | with torch.no_grad(): 16 | total_change = 0.0 17 | total_initial_norm = 0.0 18 | for initial, new in zip(initial_parameters, new_parameters): 19 | total_change += torch.norm(new - initial).item() 20 | total_initial_norm += torch.norm(initial).item() 21 | return total_change / total_initial_norm if total_initial_norm > 0 else 0 22 | 23 | 24 | def IsSimilar(x, y): 25 | if abs(x - y) < 1e-5: 26 | return 1 27 | else: 28 | return 0 29 | 30 | 31 | batch_size = 1024 32 | 33 | dataset = torchvision.datasets.MNIST('~/data/', train=True, download=True, 34 | transform=torchvision.transforms.ToTensor()) 35 | test_dataset = torchvision.datasets.MNIST('~/data/', train=False, download=True, 36 | transform=torchvision.transforms.ToTensor()) 37 | # Only extract one batch of data for faster training 38 | 39 | # Adjust learning rate and increase training steps 40 | learning_rate = 1e-2 41 | widths = [10, 100, 1000, 10000, 50000, 100000] 42 | 43 | results = [] 44 | 45 | for width in widths: 46 | print("Width: ", width) 47 | # Adjust learning rate inversely with width 48 | scaled_learning_rate = learning_rate / np.sqrt(width) 49 | epochs = 500 50 | model = nn.Sequential( 51 | nn.Linear(28 * 28, width), 52 | nn.ReLU(), 53 | nn.Linear(width, 10) 54 | ).to(device) 55 | 56 | optimizer = optim.SGD(model.parameters(), lr=scaled_learning_rate) 57 | loss_fn = nn.CrossEntropyLoss() 58 | 59 | initial_parameters = [p.clone() for p in model.parameters()] 60 | 61 | final_loss = 0.0 62 | 63 | losses = [] 64 | acc = 0 65 | 66 | # tqdm updating the loss 67 | 68 | with tqdm(total=epochs, unit='batch') as pbar: 69 | for ep in range(epochs): 70 | model.train() 71 | batch_X, batch_y = dataset.data[:batch_size], dataset.targets[:batch_size] 72 | optimizer.zero_grad() 73 | batch_X = batch_X.view(-1, 28 * 28).to(device).float() 74 | output = model(batch_X) 75 | loss = loss_fn(output, batch_y.to(device)) 76 | loss.backward() 77 | optimizer.step() 78 | final_loss = loss.item() 79 | losses.append(final_loss) 80 | pbar.set_description(f'epoch: {ep + 1}') 81 | pbar.set_postfix({'loss': final_loss}) 82 | pbar.update(1) 83 | 84 | with torch.no_grad(): 85 | total, correct = 0, 0 86 | model.eval() 87 | x, y = test_dataset.data, test_dataset.targets 88 | x = x.view(-1, 28 * 28).to(device).float() 89 | output = model(x) 90 | _, predicted = torch.max(output.data, dim=1) 91 | total += y.size(0) 92 | correct += (predicted.cpu() == y).sum().item() 93 | acc = correct / total 94 | print("acc: ", acc * 100) 95 | new_parameters = [p for p in model.parameters()] 96 | change = relative_change(initial_parameters, new_parameters) 97 | results.append((width, change, final_loss, acc)) 98 | # torch.save(model, str(width) + '.pth') 99 | 100 | plt.plot(range(len(losses)), losses) 101 | plt.ylabel("train_loss") 102 | plt.xlabel("epoch") 103 | plt.savefig("curve_" + str(width) + ".png") 104 | plt.clf() 105 | print(results) 106 | 107 | print(results) 108 | with open('result.txt', 'w') as f: 109 | for item in results: 110 | f.write(str(item) + '\n') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/models/__init__.py -------------------------------------------------------------------------------- /optims/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/optims/__init__.py -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | (10, 0.41282473981877554, 2.071605682373047, 0.1818) 2 | (100, 0.12273145267661775, 0.009883943013846874, 0.7627) 3 | (1000, 0.024971781517513322, 0.0019121915102005005, 0.7755) 4 | (10000, 0.005204312326438836, 0.0005671746912412345, 0.8366) 5 | (50000, 0.0019906164072203783, 0.0001354488922515884, 0.8449) 6 | (100000, 0.0015169701612683072, 6.648951966781169e-05, 0.8485) 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/utils/__init__.py -------------------------------------------------------------------------------- /vis/curve_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_10.png -------------------------------------------------------------------------------- /vis/curve_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_100.png -------------------------------------------------------------------------------- /vis/curve_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_1000.png -------------------------------------------------------------------------------- /vis/curve_10000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_10000.png -------------------------------------------------------------------------------- /vis/curve_100000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_100000.png -------------------------------------------------------------------------------- /vis/curve_50000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/curve_50000.png -------------------------------------------------------------------------------- /vis/linear_regression_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/linear_regression_data.png -------------------------------------------------------------------------------- /vis/sin_regression_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LR32768/DL_theory_exp/b7ade5fca2f83bd86ca8fbec5d9c4e1155dbe57d/vis/sin_regression_data.png --------------------------------------------------------------------------------