├── README.md ├── .gitignore ├── data.py ├── mnist_demo.py ├── novelty_policy.py ├── gym_data.py └── mnist_novelty_sampling.py /README.md: -------------------------------------------------------------------------------- 1 | # Exploration by Random Network Distillation 2 | 3 | Reproduction of some of the results from the aforementioned paper. 4 | 5 | https://arxiv.org/abs/1810.12894 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | mnistdata 109 | runs -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler, DataLoader 3 | from tqdm import tqdm 4 | 5 | 6 | class MNISTZeroSampler(Sampler): 7 | """ 8 | Samples only zeros from MNIST 9 | """ 10 | 11 | def __init__(self, datasource): 12 | super().__init__(datasource) 13 | self.datasource = datasource 14 | self.zero_elems = [] 15 | len_datasource_tq = tqdm(range(len(self.datasource))) 16 | len_datasource_tq.set_description('building index of zeroes') 17 | for index in len_datasource_tq: 18 | image, label = self.datasource[index] 19 | if label.item() == 0: 20 | self.zero_elems.append(index) 21 | 22 | def __iter__(self): 23 | return iter(self.zero_elems) 24 | 25 | def __len__(self): 26 | return len(self.zero_elems) 27 | 28 | 29 | class MNISTNonZeroSampler(Sampler): 30 | """ 31 | Samples everything but zeros from MNIST 32 | """ 33 | 34 | def __init__(self, datasource): 35 | super().__init__(datasource) 36 | self.datasource = datasource 37 | self.zero_elems = [] 38 | len_datasource_tq = tqdm(range(len(self.datasource))) 39 | len_datasource_tq.set_description('building index of non zeroes') 40 | for index in len_datasource_tq: 41 | image, label = self.datasource[index] 42 | if label.item() != 0: 43 | self.zero_elems.append(index) 44 | 45 | def __iter__(self): 46 | return iter(self.zero_elems) 47 | 48 | def __len__(self): 49 | return len(self.zero_elems) 50 | 51 | 52 | def compute_covar(dataset, index, batch_size=200): 53 | """ 54 | Compute covariance matrix 55 | :param dataset: a dataset to compute over 56 | :param index: index of item in dataset tuple to compute covar for 57 | :param batch_size: batch size to use for the computation 58 | :return: (mean, stdev) of covariance matrix 59 | """ 60 | 61 | dataloader = DataLoader(dataset, batch_size=batch_size) 62 | 63 | example = dataset[0][index] 64 | sum_image = torch.zeros_like(example).double() 65 | sum_squares_dev = torch.zeros_like(example).double() 66 | 67 | mnist_normal_tq = tqdm(dataloader) 68 | mnist_normal_tq.set_description('computing covariance matrix - mean') 69 | for data in mnist_normal_tq: 70 | sum_image = sum_image + torch.sum(data[index].double(), dim=0) 71 | 72 | mean = sum_image / len(dataset) 73 | 74 | mnist_normal_tq = tqdm(dataloader) 75 | mnist_normal_tq.set_description('computing covariance matrix - stdev') 76 | for data in mnist_normal_tq: 77 | squared_dev = (data[index].double() - mean) ** 2 78 | sum_squares_dev = sum_squares_dev + torch.sum(squared_dev ** 2, dim=0) 79 | 80 | stdev = torch.sqrt((sum_squares_dev / (len(dataset) - 1))) 81 | stdev[stdev == 0.0] = 1e-12 82 | 83 | return mean.float(), stdev.float() 84 | 85 | 86 | class Clip(object): 87 | def __init__(self, min, max): 88 | self.min = min 89 | self.max = max 90 | 91 | def __call__(self, x): 92 | return x.clamp(self.min, self.max) 93 | -------------------------------------------------------------------------------- /mnist_demo.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.modules.loss import MSELoss 3 | from torch.optim import Adam 4 | from torch.utils.data.dataloader import DataLoader 5 | from torchvision.datasets import MNIST 6 | import torchvision.transforms as transforms 7 | 8 | from data import MNISTZeroSampler, MNISTNonZeroSampler, compute_covar 9 | from storm.vis import UniImageViewer 10 | import statistics 11 | 12 | if __name__ == '__main__': 13 | in_features = 28 * 28 14 | h_size = 20 15 | epochs = 20 16 | batch_size = 200 17 | 18 | rand_net = nn.Sequential(nn.Linear(in_features, h_size), nn.BatchNorm1d(h_size), nn.ReLU(), nn.Linear(h_size, 1), 19 | nn.ReLU()) 20 | dist_net = nn.Sequential(nn.Linear(in_features, h_size), nn.BatchNorm1d(h_size), nn.ReLU(), nn.Linear(h_size, 1), 21 | nn.ReLU()) 22 | 23 | mnist = MNIST('mnistdata', download=True, 24 | transform=transforms.Compose([ 25 | transforms.ToTensor() 26 | ])) 27 | 28 | mean, stdev = compute_covar(mnist, index=0, batch_size=200) 29 | 30 | mnist_white = MNIST('mnistdata', download=True, 31 | transform=transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean, stdev) 34 | ])) 35 | 36 | mnist_normal = DataLoader(mnist_white, batch_size=batch_size) 37 | mnist_zeros = DataLoader(mnist_white, batch_size=batch_size, sampler=MNISTZeroSampler(mnist)) 38 | mnist_non_zeros = DataLoader(mnist_white, batch_size=batch_size, sampler=MNISTNonZeroSampler(mnist)) 39 | 40 | viewer = UniImageViewer('mnist') 41 | 42 | criterion = MSELoss() 43 | optimizer = Adam(lr=1e-4, params=dist_net.parameters()) 44 | 45 | for epoch in range(epochs): 46 | train_losses = [] 47 | for image, label in mnist_zeros: 48 | # viewer.render(image, block=True) 49 | flat_image = image.squeeze().view(-1, in_features) 50 | target = rand_net(flat_image) 51 | target = target.detach() 52 | optimizer.zero_grad() 53 | prediction = dist_net(flat_image) 54 | loss = criterion(prediction, target) 55 | loss.backward() 56 | optimizer.step() 57 | train_losses.append(loss.item()) 58 | 59 | test_normal_losses = [] 60 | for image, label in mnist_normal: 61 | flat_image = image.squeeze().view(-1, in_features) 62 | target = rand_net(flat_image) 63 | target = target.detach() 64 | prediction = dist_net(flat_image) 65 | loss = criterion(prediction, target) 66 | test_normal_losses.append(loss.item()) 67 | 68 | test_zero_losses = [] 69 | for image, label in mnist_zeros: 70 | flat_image = image.squeeze().view(-1, in_features) 71 | target = rand_net(flat_image) 72 | target = target.detach() 73 | prediction = dist_net(flat_image) 74 | loss = criterion(prediction, target) 75 | test_zero_losses.append(loss.item()) 76 | 77 | test_nonzero_losses = [] 78 | for image, label in mnist_non_zeros: 79 | flat_image = image.squeeze().view(-1, in_features) 80 | target = rand_net(flat_image) 81 | target = target.detach() 82 | prediction = dist_net(flat_image) 83 | loss = criterion(prediction, target) 84 | test_nonzero_losses.append(loss.item()) 85 | 86 | mean_training_loss = statistics.mean(train_losses) 87 | 88 | mean_normal_loss = statistics.mean(test_normal_losses) 89 | sigma_normal_loss = statistics.stdev(test_normal_losses) 90 | 91 | mean_test_zero_loss = statistics.mean(test_zero_losses) 92 | mean_test_nonzero_loss = statistics.mean(test_nonzero_losses) 93 | 94 | 95 | def normalize(value, mean, sigma): 96 | return (value - mean) + 1e-10 / (sigma + 1e-10) 97 | 98 | 99 | normalized_zero = normalize(mean_test_zero_loss, mean_normal_loss, sigma_normal_loss) 100 | normalized_nonzero = normalize(mean_test_nonzero_loss, mean_normal_loss, sigma_normal_loss) 101 | 102 | print(f'epoch: {epoch} ' 103 | f'train:{mean_training_loss:.6f} ' 104 | f'test_normal_mean:{mean_normal_loss:.6f} ' 105 | f'test_normal_sigma:{sigma_normal_loss:.6f} ' 106 | f'test_zero:{mean_test_zero_loss:.6f} ' 107 | f'test_nonzero:{mean_test_nonzero_loss:.6f} ' 108 | f'normalized_zero: {normalized_zero:.6f} ' 109 | f'normalized_nonzero: {normalized_nonzero:.6f} ') 110 | -------------------------------------------------------------------------------- /novelty_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Adam 5 | from torch.nn.modules.loss import MSELoss 6 | from torch.utils.data import DataLoader 7 | import gym 8 | from gym_data import Policy, ActionEmbedding, GymSimulatorDataset 9 | 10 | 11 | def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): 12 | """ 13 | Utility function for computing output of convolutions. 14 | 15 | :param tuple of (h,w) 16 | :returns tuple of (h,w) 17 | """ 18 | from math import floor 19 | if type(kernel_size) is not tuple: 20 | kernel_size = (kernel_size, kernel_size) 21 | 22 | if type(pad) is not tuple: 23 | pad = (pad, pad) 24 | 25 | h = floor(((h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1) 26 | w = floor(((h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1) 27 | return h, w 28 | 29 | 30 | class AtariNoveltyNet(nn.Module): 31 | """Encoder.""" 32 | 33 | def __init__(self, input_shape, action_shape, first_kernel=5, first_stride=2, second_kernel=5, second_stride=2): 34 | nn.Module.__init__(self) 35 | 36 | self.e_conv1 = nn.Conv2d(3, 32, kernel_size=first_kernel, stride=first_stride) 37 | self.e_bn1 = nn.BatchNorm2d(32) 38 | output_shape = conv_output_shape(input_shape, kernel_size=first_kernel, stride=first_stride) 39 | 40 | self.e_conv2 = nn.Conv2d(32, 128, kernel_size=second_kernel, stride=second_stride) 41 | self.e_bn2 = nn.BatchNorm2d(128) 42 | output_shape = conv_output_shape(output_shape, kernel_size=second_kernel, stride=second_stride) 43 | 44 | self.e_conv3 = nn.Conv2d(128, 128, kernel_size=second_kernel, stride=second_stride) 45 | self.e_bn3 = nn.BatchNorm2d(128) 46 | self.z_shape = conv_output_shape(output_shape, kernel_size=second_kernel, stride=second_stride) 47 | 48 | self.e_squeeze1 = nn.Conv2d(128, 32, 1, 1) 49 | self.e_sq_bn1 = nn.BatchNorm2d(32) 50 | 51 | self.num_features = (32 + action_shape) * self.z_shape[0] * self.z_shape[1] 52 | self.linear = nn.Linear(self.num_features, 1) 53 | 54 | def forward(self, observation, action): 55 | """Forward pass. 56 | :param observation (batch, channels, height, width) 57 | :param action (batch, embedding) 58 | """ 59 | encoded = F.relu(self.e_bn1(self.e_conv1(observation))) 60 | encoded = F.relu(self.e_bn2(self.e_conv2(encoded))) 61 | encoded = F.relu(self.e_bn3(self.e_conv3(encoded))) 62 | encoded = F.relu(self.e_sq_bn1(self.e_squeeze1(encoded))) 63 | action = action.unsqueeze(-1).unsqueeze(-1) 64 | action_exp = action.expand(-1, -1, self.z_shape[0], self.z_shape[1]) 65 | encoded = torch.cat((encoded, action_exp), dim=1) 66 | encoded = encoded.view(-1, self.num_features) 67 | novelty = F.sigmoid(self.linear(encoded)) 68 | return novelty 69 | 70 | 71 | class NoveltyPolicy(Policy): 72 | def __init__(self, env): 73 | self.env = env 74 | action_shape = env.action_space.n 75 | self.action_embed = ActionEmbedding(env) 76 | self.rand_net = AtariNoveltyNet((210, 160), action_shape) 77 | self.dist_net = AtariNoveltyNet((210, 160), action_shape) 78 | self.optimizer = Adam(lr=1e-3, params=self.dist_net.parameters()) 79 | self.criterion = MSELoss() 80 | self.device = 'cpu' 81 | 82 | def to(self, device): 83 | self.device = device 84 | self.rand_net = self.rand_net.to(device) 85 | self.dist_net = self.dist_net.to(device) 86 | return self 87 | 88 | def action(self, screen, observation): 89 | screen = screen.to(self.device) 90 | observation = observation.to(self.device) 91 | error = torch.zeros(env.action_space.n).to(self.device) 92 | batch_size = screen.size(0) 93 | 94 | # compute the novelty of each action in this state 95 | #for action_i in range(env.action_space.n): 96 | 97 | actions = torch.eye(env.action_space.n).to(self.device) 98 | #action = self.action_embed.tensor(action_i).to(self.device) 99 | 100 | screen_exp = screen.expand(env.action_space.n, -1, -1, -1) 101 | 102 | target = self.rand_net(screen_exp, actions) 103 | prediction = self.dist_net(screen_exp, actions) 104 | error = (prediction - target) 105 | print(error.transpose(0,1)) 106 | 107 | #print(prediction, target) 108 | # select the most novel action 109 | error[error < 0] = error[error < 0] * -1 110 | #print(error) 111 | weighted_novelty = F.softmax(error.squeeze(), dim=0) 112 | #print(weighted_novelty) 113 | values, indices = torch.topk(weighted_novelty, 1) 114 | 115 | most_novel_action = indices.item() 116 | #print(most_novel_action) 117 | 118 | # train distillation network on selected action/state 119 | action = self.action_embed.tensor(most_novel_action).to(self.device).unsqueeze(0) 120 | novel_targets = self.rand_net(screen, action) 121 | novel_targets.detach() 122 | 123 | self.optimizer.zero_grad() 124 | prediction = self.dist_net(screen, action) 125 | loss = self.criterion(prediction, novel_targets) 126 | loss.backward() 127 | self.optimizer.step() 128 | 129 | return most_novel_action 130 | 131 | 132 | if __name__ == '__main__': 133 | 134 | #env = gym.make('MontezumaRevenge-v0') 135 | env = gym.make('SpaceInvaders-v0') 136 | policy = NoveltyPolicy(env).to('cuda') 137 | action_embedding = ActionEmbedding(env) 138 | gym_dataset = GymSimulatorDataset(env, policy, 10000, action_embedding, render_to_window=True) 139 | gym_loader = DataLoader(gym_dataset) 140 | for frame in gym_loader: 141 | pass 142 | -------------------------------------------------------------------------------- /gym_data.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from torchvision.transforms import functional as TVF 5 | import numpy as np 6 | 7 | 8 | class Policy(ABC): 9 | @abstractmethod 10 | def action(self, screen, observation): 11 | """ 12 | 13 | :param screen: batch, channels, height, width 14 | :param observation: batch, channels, height, width 15 | :return: an action in the embedding space, will need to be converted to the simulator space 16 | """ 17 | raise NotImplementedError 18 | 19 | 20 | class RandomPolicy(Policy): 21 | def __init__(self, env): 22 | self.env = env 23 | 24 | # todo decide if action will be in embedded or simulator space 25 | # todo if so then embedding should be part of the policy 26 | def action(self, screen, observation): 27 | return self.env.action_space.sample() 28 | 29 | 30 | class ActionEmbedding: 31 | """ 32 | Simple one-hot embedding of the action space 33 | """ 34 | 35 | def __init__(self, env): 36 | self.env = env 37 | 38 | def tensor(self, action): 39 | action_t = torch.zeros(self.env.action_space.n) 40 | action_t[action] = 1.0 41 | return action_t 42 | 43 | def numpy(self, action): 44 | action_n = np.zeros(self.env.action_space.n) 45 | action_n[action] = 1.0 46 | return action_n 47 | 48 | def embedding_to_action(self, index): 49 | return index 50 | 51 | def start_tensor(self): 52 | return torch.zeros(self.env.action_space.n) 53 | 54 | def start_numpy(self): 55 | return np.zeros(self.env.action_space.n) 56 | 57 | 58 | class ToTensor(object): 59 | def __init__(self, action_embedding): 60 | self.embed_action = action_embedding 61 | 62 | def __call__(self, screen, observation, reward, done, info, action): 63 | screen_t = TVF.to_tensor(screen) 64 | observation_t = torch.Tensor(observation) 65 | reward_t = torch.Tensor([reward]) 66 | done_t = torch.Tensor([done]) 67 | action_t = self.embed_action.tensor(action) 68 | return screen_t, observation_t, reward_t, done_t, info, action_t 69 | 70 | 71 | class Rollout: 72 | def __init__(self, env): 73 | self.env = env 74 | 75 | def rollout(self, policy, episode, max_timesteps=100): 76 | observation = self.env.reset() 77 | screen = self.env.render(mode='rgb_array') 78 | 79 | for t in range(max_timesteps): 80 | 81 | action = policy.action(screen, observation) 82 | observation, reward, done, info = self.env.step(action) 83 | screen = self.env.render(mode='rgb_array') 84 | 85 | if done: 86 | print("Episode finished after {} timesteps".format(t + 1)) 87 | break 88 | 89 | 90 | class RolloutGen(object): 91 | """ 92 | Wrap gym in a generator object 93 | """ 94 | 95 | def __init__(self, env, policy, action_embedding, populate_screen=True, render_to_window=False): 96 | """ 97 | 98 | :param env: gym environment 99 | :param policy: policy to select actions in the environment 100 | :param populate_screen: populates the screen return parameter with numpy array of RGB data 101 | :param render_to_window: render the output to a window 102 | """ 103 | self.env = env 104 | self.policy = policy 105 | self.done = True 106 | self.action = None 107 | self.populate_screen = populate_screen 108 | self.render_to_window = render_to_window 109 | self.to_tensor = ToTensor(action_embedding) 110 | 111 | def __iter__(self): 112 | return self 113 | 114 | # Python 3 compatibility 115 | def __next__(self): 116 | return self.next() 117 | 118 | def render(self): 119 | screen = None 120 | if self.populate_screen: 121 | screen = self.env.render(mode='rgb_array') 122 | if self.render_to_window: 123 | self.env.render() 124 | return screen 125 | 126 | def step(self, action): 127 | observation, reward, done, info = self.env.step(action) 128 | screen = self.render() 129 | screen_t, observation_t, reward_t, done_t, info, action_t = \ 130 | self.to_tensor(screen, observation, reward, done, info, action) 131 | action = self.policy.action(screen_t.unsqueeze(0), observation_t.unsqueeze(0)) 132 | return screen, observation, reward, done, info, action 133 | 134 | def next(self): 135 | 136 | if self.done: 137 | observation = self.env.reset() 138 | screen = self.render() 139 | reward = 0 140 | self.done = False 141 | info = {} 142 | screen_t, observation_t, reward_t, done, info, action = \ 143 | self.to_tensor(screen, observation, reward, self.done, info, 0) 144 | self.action = self.policy.action(screen_t.unsqueeze(0), observation_t.unsqueeze(0)) 145 | return screen, observation, reward, self.done, info, self.action 146 | 147 | else: 148 | screen, observation, reward, done, info, action = self.step(self.action) 149 | self.action = action 150 | self.done = done 151 | return screen, observation, reward, done, info, action 152 | 153 | 154 | class GymSimulatorDataset(torch.utils.data.Dataset): 155 | def __init__(self, env, policy, length, action_embedding, output_in_numpy_format=False, render_to_window=False): 156 | torch.utils.data.Dataset.__init__(self) 157 | self.length = length 158 | self.count = 0 159 | self.policy = policy 160 | self.rollout = RolloutGen(env, policy, action_embedding, render_to_window=render_to_window).__iter__() 161 | self.output_in_numpy_format = output_in_numpy_format 162 | self.to_tensor = ToTensor(action_embedding=action_embedding) 163 | 164 | def __getitem__(self, index): 165 | screen, observation, reward, done, info, action = self.rollout.next() 166 | 167 | if not self.output_in_numpy_format: 168 | screen, observation, action, reward, info, done = \ 169 | self.to_tensor(screen, observation, reward, done, info, action) 170 | 171 | self.count += 1 172 | 173 | return screen, observation, action, reward, done 174 | 175 | def __len__(self): 176 | return self.length 177 | -------------------------------------------------------------------------------- /mnist_novelty_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.loss import MSELoss 4 | from torch.optim import Adam 5 | from torch.utils.data.dataloader import DataLoader 6 | from torchvision.datasets import MNIST 7 | import torchvision.transforms as transforms 8 | from torch.nn.functional import softmax 9 | import statistics 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data.sampler import Sampler 12 | import random 13 | from data import compute_covar 14 | 15 | if __name__ == '__main__': 16 | in_features = 28 * 28 17 | h_size = 50 18 | epochs = 40 19 | batch_size = 1000 20 | novel_batch_size = 200 21 | sample = False 22 | magnitude_norm = False 23 | L1_distance = True 24 | 25 | #todo run test with a very very small learning rate 26 | 27 | class MNISTBiasedSampler(Sampler): 28 | """ 29 | Samples unevenly from MNIST 30 | 1, 2, and 3 are reduced in frequency 31 | the idea is this will make them more "novel" 32 | """ 33 | 34 | def __init__(self, datasource, index, drop_freq): 35 | super().__init__(datasource) 36 | self.datasource = datasource 37 | self.sampled_elems = [] 38 | for index in range(len(self.datasource)): 39 | image, label = self.datasource[index] 40 | drop = random.random() < drop_freq 41 | if label.item() in {1, 2, 3}: 42 | if not drop: 43 | self.sampled_elems.append(index) 44 | else: 45 | self.sampled_elems.append(index) 46 | 47 | def __iter__(self): 48 | return iter(self.sampled_elems) 49 | 50 | def __len__(self): 51 | return len(self.sampled_elems) 52 | 53 | 54 | mnist = MNIST('mnistdata', download=True, 55 | transform=transforms.Compose([ 56 | transforms.ToTensor() 57 | ])) 58 | 59 | mean, stdev = compute_covar(mnist, index=0, batch_size=200) 60 | 61 | mnist = MNIST('mnistdata', download=True, 62 | transform=transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize(mean, stdev) 65 | ])) 66 | mnist_biased = DataLoader(mnist, batch_size=batch_size, sampler=MNISTBiasedSampler(mnist, {1,2,3}, 0.2)) 67 | mnist = DataLoader(mnist, batch_size=batch_size) 68 | 69 | for run in range(10): 70 | tb = SummaryWriter(f'runs/normal_init/{run:2d}/') 71 | 72 | # a simple network 73 | rand_net = nn.Sequential(nn.Linear(in_features, h_size), 74 | nn.BatchNorm1d(h_size), 75 | nn.ReLU(), 76 | nn.Linear(h_size, h_size), 77 | nn.BatchNorm1d(h_size), 78 | nn.ReLU(), 79 | nn.Linear(h_size, 1), 80 | nn.ReLU()) 81 | 82 | # initialization function, first checks the module type 83 | def init_normal(m): 84 | if type(m) == nn.Linear: 85 | nn.init.uniform_(m.weight) 86 | 87 | # use the modules apply function to recursively apply the initialization 88 | rand_net.apply(init_normal) 89 | 90 | dist_net = nn.Sequential(nn.Linear(in_features, h_size), 91 | nn.BatchNorm1d(h_size), 92 | nn.ReLU(), 93 | nn.Linear(h_size, h_size), 94 | nn.BatchNorm1d(h_size), 95 | nn.ReLU(), 96 | nn.Linear(h_size, 1), 97 | nn.ReLU()) 98 | 99 | optimizer = Adam(lr=1e-5, params=dist_net.parameters()) 100 | 101 | criterion = MSELoss() 102 | 103 | base_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 104 | base_freq = None 105 | novelty_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 106 | novelty_freq = None 107 | global_step = 0 108 | 109 | def init_normal(m): 110 | if type(m) == nn.Linear: 111 | nn.init.uniform_(m.weight) 112 | 113 | for epoch in range(epochs): 114 | train_losses = [] 115 | for image, labels in mnist_biased: 116 | flat_image = image.squeeze().view(-1, in_features) 117 | 118 | for label in labels: 119 | base_counts[label.item()] += 1 120 | total = sum(base_counts) 121 | base_freq = [x / total for x in base_counts] 122 | 123 | # compute the novelty value of the image 124 | target = rand_net(flat_image) 125 | prediction = dist_net(flat_image) 126 | if L1_distance: 127 | error = (prediction - target) 128 | error[error < 0] = error[error < 0] * -1 129 | else: 130 | error = (prediction - target) ** 2 131 | weighted_novelty = softmax(error.squeeze(), dim=0) 132 | 133 | if sample: 134 | # sample based on novelty score (more novel = more likely) 135 | indices = torch.multinomial(weighted_novelty, novel_batch_size, replacement=False) 136 | else: 137 | values, indices = torch.topk(weighted_novelty, novel_batch_size) 138 | 139 | novel_images = flat_image[indices] 140 | novel_labels = labels[indices] 141 | 142 | # train distillation network on selected images 143 | novel_targets = rand_net(novel_images) 144 | novel_targets.detach() 145 | 146 | optimizer.zero_grad() 147 | prediction = dist_net(novel_images) 148 | loss = criterion(prediction, novel_targets) 149 | loss.backward() 150 | optimizer.step() 151 | train_losses.append(loss.item()) 152 | 153 | for label in novel_labels: 154 | novelty_counts[label.item()] += 1 155 | novelty_total = sum(novelty_counts) 156 | novelty_freq = [x / novelty_total for x in novelty_counts] 157 | 158 | freq_diff = [base - novelty for base, novelty in zip(base_freq, novelty_freq)] 159 | #print(freq_diff, loss.item()) 160 | 161 | for i, freq in enumerate(freq_diff): 162 | tb.add_scalar(f'freq_diff_{i}', freq, global_step) 163 | tb.add_scalar('loss', loss.item(), global_step) 164 | tb.add_histogram('labels', labels.cpu().numpy(), epoch) 165 | tb.add_histogram('novel_labels', novel_labels.cpu().numpy(), epoch) 166 | global_step += 1 167 | --------------------------------------------------------------------------------