├── LICENSE ├── Readme.md ├── losses.py ├── networks.py ├── requirements.txt ├── train_student_teacher.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Michael Arbel 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ## Table of contents 2 | 3 | * [Introduction](#introduction) 4 | * [Requirements](#requirements) 5 | * [How to use](#how-to-use) 6 | * [Student Teacher network](#student-teacher-network) 7 | * [Resources](#resources) 8 | * [Hardware](#hardware) 9 | * [Full documentation](#full-documentation) 10 | * [Reference](#reference) 11 | * [License](#license) 12 | 13 | ## Introduction 14 | 15 | This repository contains an implementation of the Wasserstein gradient flow of the Maximum Mean Discrepancy from [Maxmimum Mean Discrepancy Gradient Flow paper](https://arxiv.org/abs/1906.04370) published at Neurips 2019. It allows to reproduce the experiments in the paper. 16 | 17 | 18 | ## Requirements 19 | 20 | 21 | This a Pytorch implementation which requires the follwoing packages: 22 | 23 | ``` 24 | python==3.6.2 or newer 25 | torch==1.2.0 or newer 26 | torchvision==0.4.0 or newer 27 | numpy==1.17.2 or newer 28 | ``` 29 | 30 | All dependencies can be installed using: 31 | 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | 38 | 39 | ## How to use 40 | 41 | ### Student Teacher network: 42 | ``` 43 | python train_student_teacher.py --device=-1 44 | ``` 45 | 46 | ## Resources 47 | 48 | ### Hardware 49 | 50 | To use a particular GPU, set —device=#gpu_id 51 | To use GPU without specifying a particular one, set —device=-1 52 | To use CPU set —device=-2 53 | 54 | 55 | ## Full documentation 56 | 57 | ``` 58 | # Optimizer parameters 59 | --lr learning rate [1.] 60 | --batch_size batch size [100] 61 | --total_epochs total number of epochs [10000] 62 | --Optimizer Optimizer ['SGD'] 63 | --use_scheduler By default uses the ReduceLROnPlateau scheduler [False] 64 | 65 | # Loss paramters 66 | --loss loss to optimize: mmd_noise_injection, mmd_diffusion, sobolev ['mmd_noise_injection'] 67 | --with_noise to use noise injection set to true [True] 68 | --noise_level variance of the injected noise [1.] 69 | --noise_decay_freq decays the variance of the injected every 1000 epochs by a factor "noise_decay" [1000] 70 | --noise_decay factor for decreasing the variance of the injected noise [0.5] 71 | 72 | # Hardware parameters 73 | --device gpu device, set -1 for cpu [0] 74 | --dtype precision: single: float32 or double: float64 ['float32'] 75 | 76 | # Reproducibility parameters 77 | --seed seed for the random number generator on pytorch [1] 78 | --log_dir log directory [''] 79 | --log_name log name ['mmd'] 80 | --log_in_file to log output on a file [False] 81 | 82 | 83 | --bias ste to include bias in the network parameters [False] 84 | --teacher_net teacher network ['OneHidden'] 85 | --student_net student network ['NoisyOneHidden'] 86 | --d_int dim input data [50] 87 | --d_out dim out feature [1] 88 | --H num of hidden layers in the teacher network [3] 89 | --num_particles num_particles*H = number of hidden units in the student network [1000] 90 | 91 | # Initialization parameters 92 | --mean_student mean initial value for the student weights [0.001] 93 | --std_student std initial value for the student weights [1.] 94 | --mean_teacher mean initial value for the teacher weights [0.] 95 | --std_teacher std initial value for the teacher weights [1.] 96 | 97 | # Data parameters 98 | --input_data input data distribution ['Spherical'] 99 | --N_train num samples for training [1000] 100 | --N_valid num samples for validation [1000] 101 | 102 | --config config file for non default parameters [''] 103 | 104 | ``` 105 | 106 | ## Reference 107 | 108 | If using this code for research purposes, please cite: 109 | 110 | [1] M. Arbel, A. Korba, A. Salim, A. Gretton [*Maximum Mean Discrepancy Gradient Flow*](https://arxiv.org/abs/1906.04370) 111 | 112 | ``` 113 | @article{Arbel:2018, 114 | author = {Michael Arbel, Anna Korba, Adil Salim, Arthur Gretton}, 115 | title = {Maximum Mean Discrepancy Gradient Flow}, 116 | journal = {NeurIPS}, 117 | year = {2019}, 118 | url = {https://arxiv.org/abs/1906.04370}, 119 | } 120 | ``` 121 | 122 | 123 | ## License 124 | 125 | This code is under a BSD license. 126 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch as tr 2 | import torch.nn as nn 3 | from torch import autograd 4 | 5 | 6 | class mmd2_noise_injection(autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx,true_feature,fake_feature,noisy_feature): 10 | b_size,d, n_particles = noisy_feature.shape 11 | with tr.enable_grad(): 12 | 13 | mmd2 = tr.mean((true_feature-fake_feature)**2) 14 | mean_noisy_feature = tr.mean(noisy_feature,dim = -1 ) 15 | 16 | mmd2_for_grad = (n_particles/b_size)*(tr.einsum('nd,nd->',fake_feature,mean_noisy_feature) - tr.einsum('nd,nd->',true_feature,mean_noisy_feature)) 17 | 18 | ctx.save_for_backward(mmd2_for_grad,noisy_feature) 19 | 20 | return mmd2 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | mmd2_for_grad, noisy_feature = ctx.saved_tensors 25 | with tr.enable_grad(): 26 | gradients = autograd.grad(outputs=mmd2_for_grad, inputs=noisy_feature, 27 | grad_outputs=grad_output, 28 | create_graph=True, only_inputs=True)[0] 29 | 30 | return None, None, gradients 31 | 32 | 33 | class mmd2_func(autograd.Function): 34 | 35 | @staticmethod 36 | def forward(ctx,true_feature,fake_feature): 37 | 38 | b_size,d, n_particles = fake_feature.shape 39 | 40 | with tr.enable_grad(): 41 | 42 | mmd2 = (n_particles/b_size)*tr.sum((true_feature-tr.mean(fake_feature,dim=-1))**2) 43 | 44 | ctx.save_for_backward(mmd2,fake_feature) 45 | 46 | return (1./n_particles)*mmd2 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | 51 | mmd2, fake_feature = ctx.saved_tensors 52 | with tr.enable_grad(): 53 | gradients = autograd.grad(outputs=mmd2, inputs=fake_feature, 54 | grad_outputs=grad_output, 55 | create_graph=True, only_inputs=True)[0] 56 | 57 | return None, gradients 58 | 59 | 60 | class sobolev(autograd.Function): 61 | @staticmethod 62 | def forward(ctx,true_feature,fake_feature,matrix): 63 | 64 | b_size,_, n_particles = fake_feature.shape 65 | 66 | m = tr.mean(fake_feature,dim=-1) - true_feature 67 | 68 | alpha = tr.solve(m,matrix)[0].clone().detach() 69 | 70 | with tr.enable_grad(): 71 | 72 | mmd2 = (0.5*n_particles/b_size)*tr.sum((true_feature-tr.mean(fake_feature,dim=-1))**2) 73 | mmd2_for_grad = (1./b_size)*tr.einsum('id,idm->',alpha,fake_feature) 74 | 75 | ctx.save_for_backward(mmd2_for_grad,fake_feature) 76 | 77 | return (1./n_particles)*mmd2 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | mmd2, fake_feature = ctx.saved_tensors 82 | with tr.enable_grad(): 83 | gradients = autograd.grad(outputs=mmd2, inputs=fake_feature, 84 | grad_outputs=grad_output, 85 | create_graph=True, only_inputs=True)[0] 86 | 87 | return None, gradients,None 88 | 89 | 90 | class MMD(nn.Module): 91 | def __init__(self,student,with_noise): 92 | super(MMD, self).__init__() 93 | self.student = student 94 | self.mmd2 = mmd2_noise_injection.apply 95 | self.with_noise=with_noise 96 | def forward(self,x,y): 97 | if self.with_noise: 98 | out = tr.mean(self.student(x),dim = -1).clone().detach() 99 | self.student.set_noisy_mode(True) 100 | noisy_out = self.student(x) 101 | loss = 0.5*self.mmd2(y,out,noisy_out) 102 | else: 103 | fake_feature = tr.mean(self.student(x),dim=-1) 104 | loss = 0.5*tr.mean((y-fake_feature)**2) 105 | return loss 106 | 107 | class MMD_Diffusion(nn.Module): 108 | def __init__(self,student): 109 | super(MMD_Diffusion, self).__init__() 110 | self.student = student 111 | self.mmd2 = mmd2_func.apply 112 | def forward(self,x,y): 113 | self.student.add_noise() 114 | noisy_out = self.student(x) 115 | 116 | loss = 0.5*self.mmd2(y,noisy_out) 117 | return loss 118 | 119 | class Sobolev(nn.Module): 120 | def __init__(self,student): 121 | super(Sobolev, self).__init__() 122 | self.student = student 123 | self.sobolev = sobolev.apply 124 | self.lmbda = 1e-6 125 | def forward(self,x,y): 126 | self.student.zero_grad() 127 | out = self.student(x) 128 | b_size,_,num_particles = out.shape 129 | grad_out = compute_grad(self.student,x) 130 | matrix = (1./(num_particles*b_size))*tr.einsum('im,jm->ij',grad_out,grad_out)+self.lmbda*tr.eye(b_size, dtype= x.dtype, device=x.device) 131 | matrix = matrix.clone().detach() 132 | loss = self.sobolev(y,out,matrix) 133 | return loss 134 | 135 | def compute_grad(net,x): 136 | J = [] 137 | F = net(x) 138 | F = tr.einsum('idm->i',F) 139 | b_size = F.shape[0] 140 | for i in range(b_size): 141 | if i==b_size-1: 142 | grads = autograd.grad(F[i], net.parameters(),retain_graph=False) 143 | else: 144 | grads = autograd.grad(F[i], net.parameters(),retain_graph=True) 145 | grads = [x.view(-1) for x in grads] 146 | grads = tr.cat(grads) 147 | J.append(grads) 148 | 149 | return tr.stack(J,dim=0) 150 | 151 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch as tr 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | class quadexp(nn.Module): 6 | def __init__(self, sigma = 2.): 7 | super(quadexp,self).__init__() 8 | self.sigma = sigma 9 | def forward(self,x): 10 | return tr.exp(-x**2/(self.sigma**2)) 11 | 12 | class NoisyLinear(nn.Linear): 13 | def __init__(self, in_features, out_features, noise_level=1., noise_decay = 0.1, bias=False): 14 | super(NoisyLinear, self).__init__(in_features, out_features, bias=bias) 15 | self.noise_level = noise_level 16 | self.register_buffer("epsilon_weight", tr.zeros(out_features, in_features)) 17 | if bias: 18 | self.register_buffer("epsilon_bias", tr.zeros(out_features)) 19 | self.noisy_mode = False 20 | self.noise_decay = noise_decay 21 | 22 | def update_noise_level(self): 23 | self.noise_level = self.noise_decay * self.noise_level 24 | def set_noisy_mode(self,is_noisy): 25 | self.noisy_mode = is_noisy 26 | 27 | def forward(self, input): 28 | if self.noisy_mode: 29 | tr.randn(self.epsilon_weight.size(), out=self.epsilon_weight) 30 | bias = self.bias 31 | if bias is not None: 32 | tr.randn(self.epsilon_bias.size(), out=self.epsilon_bias) 33 | bias = bias + self.noise_level * Variable(self.epsilon_bias, requires_grad = False) 34 | self.noisy_mode = False 35 | return F.linear(input, self.weight + self.noise_level * Variable(self.epsilon_weight, requires_grad=False), bias) 36 | else: 37 | return F.linear(input, self.weight , self.bias) 38 | def add_noise(self): 39 | tr.randn(self.epsilon_weight.size(), out=self.epsilon_weight) 40 | self.weight.data += self.noise_level * Variable(self.epsilon_weight, requires_grad=False) 41 | bias = self.bias 42 | if bias is not None: 43 | tr.randn(self.epsilon_bias.size(), out=self.epsilon_bias) 44 | self.bias.data += self.noise_level * Variable(self.epsilon_bias, requires_grad = False) 45 | 46 | class OneHiddenLayer(nn.Module): 47 | def __init__(self,d_int, H, d_out,non_linearity = quadexp(),bias=False): 48 | super(OneHiddenLayer,self).__init__() 49 | self.linear1 = tr.nn.Linear(d_int, H,bias=bias) 50 | self.linear2 = tr.nn.Linear(H, d_out,bias=bias) 51 | self.non_linearity = non_linearity 52 | self.d_int = d_int 53 | self.d_out = d_out 54 | 55 | def weights_init(self,center, std): 56 | self.linear1.weights_init(center,std) 57 | self.linear2.weights_init(center,std) 58 | 59 | 60 | def forward(self, x): 61 | h1_relu = self.linear1(x).clamp(min=0) 62 | h2_relu = self.linear2(h1_relu) 63 | h2_relu = self.non_linearity(h2_relu) 64 | 65 | return h2_relu 66 | 67 | 68 | class NoisyOneHiddenLayer(nn.Module): 69 | def __init__(self,d_int, H, d_out, n_particles,non_linearity = quadexp(),noise_level=1., noise_decay = 0.1,bias=False): 70 | super(NoisyOneHiddenLayer,self).__init__() 71 | 72 | self.linear1 = NoisyLinear(d_int, H*n_particles,noise_level = noise_level,noise_decay=noise_decay,bias=bias) 73 | self.linear2 = NoisyLinear(H*n_particles, n_particles*d_out,noise_level = noise_level,noise_decay=noise_decay,bias= bias) 74 | 75 | self.non_linearity = non_linearity 76 | self.n_particles = n_particles 77 | self.d_out = d_out 78 | 79 | def set_noisy_mode(self,is_noisy): 80 | self.linear1.set_noisy_mode(is_noisy) 81 | self.linear2.set_noisy_mode(is_noisy) 82 | 83 | def update_noise_level(self): 84 | self.linear1.update_noise_level() 85 | self.linear2.update_noise_level() 86 | 87 | def weights_init(self,center, std): 88 | self.linear1.weights_init(center,std) 89 | self.linear2.weights_init(center,std) 90 | 91 | def forward(self, x): 92 | h1_relu = self.linear1(x).clamp(min=0) 93 | h2_relu = self.linear2(h1_relu) 94 | h2_relu = h2_relu.view(-1,self.d_out, self.n_particles) 95 | h2_relu = self.non_linearity(h2_relu) 96 | 97 | return h2_relu 98 | def add_noise(self): 99 | self.linear1.add_noise() 100 | self.linear2.add_noise() 101 | 102 | class SphericalTeacher(tr.utils.data.Dataset): 103 | 104 | def __init__(self,network, N_samples, dtype, device): 105 | D = network.d_int 106 | self.device = device 107 | self.source = tr.distributions.multivariate_normal.MultivariateNormal(tr.zeros(D ,dtype=dtype,device=device), tr.eye(D,dtype=dtype,device=device)) 108 | source_samples = self.source.sample([N_samples]) 109 | inv_norm = 1./tr.norm(source_samples,dim=1) 110 | self.X = tr.einsum('nd,n->nd',source_samples,inv_norm) 111 | self.total_size = N_samples 112 | self.network = network 113 | 114 | with tr.no_grad(): 115 | self.Y = self.network(self.X) 116 | 117 | def __len__(self): 118 | return self.total_size 119 | def __getitem__(self,index): 120 | return self.X[index,:],self.Y[index,:] 121 | 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.2 2 | torch==1.2.0 3 | torchvision==0.4.0 4 | numpy==1.17.2 5 | yalm==0.15.46 6 | tensorboardX==1.8 -------------------------------------------------------------------------------- /train_student_teacher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | 5 | from trainer import Trainer 6 | 7 | torch.backends.cudnn.benchmark = True 8 | 9 | def make_flags(args,config_file): 10 | if config_file: 11 | config = yaml.load(open(config_file)) 12 | dic = vars(args) 13 | all(map(dic.pop, config)) 14 | dic.update(config) 15 | return args 16 | 17 | 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 20 | 21 | # Optimizer parameters 22 | parser.add_argument('--lr', default=.1, type=float, help='learning rate') 23 | parser.add_argument('--batch_size', default = 100 ,type= int, help='batch size') 24 | parser.add_argument('--total_epochs', default=10000, type=int, help='total number of epochs') 25 | parser.add_argument('--optimizer', default = 'SGD' ,type= str, help='Optimizer') 26 | parser.add_argument('--use_scheduler', action='store_true', help=' By default uses the ReduceLROnPlateau scheduler ') 27 | 28 | # Loss paramters 29 | parser.add_argument('--loss', default = 'mmd_noise_injection',type= str, help='loss to optimize: mmd_noise_injection, mmd_diffusion, sobolev') 30 | parser.add_argument('--with_noise', default = True ,type= bool, help='to use noise injection set to true') 31 | parser.add_argument('--noise_level', default = 1. ,type= float, help=' variance of the injected noise ') 32 | parser.add_argument('--noise_decay_freq', default = 1000 ,type= int, help='decays the variance of the injected every 1000 epochs by a factor "noise_decay"') 33 | parser.add_argument('--noise_decay', default = 0.5 ,type= float, help='factor for decreasing the variance of the injected noise') 34 | 35 | # Hardware parameters 36 | parser.add_argument('--device', default = 0 ,type= int, help='gpu device, set -1 for cpu') 37 | parser.add_argument('--dtype', default = 'float32' ,type= str, help='precision: single: float32 or double: float64') 38 | 39 | # Reproducibility parameters 40 | parser.add_argument('--seed', default = 1 ,type= int, help='seed for the random number generator on pytorch') 41 | parser.add_argument('--log_dir', default = '',type= str, help='log directory ') 42 | parser.add_argument('--log_name', default = 'mmd',type= str, help='log name') 43 | parser.add_argument('--log_in_file', action='store_true', help='to log output on a file') 44 | 45 | # Network parameters 46 | parser.add_argument('--bias', action='store_true', help='ste to include bias in the network parameters') 47 | parser.add_argument('--teacher_net', default = 'OneHidden' ,type= str, help='teacher network') 48 | parser.add_argument('--student_net', default = 'NoisyOneHidden' ,type= str, help='student network') 49 | parser.add_argument('--d_int', default = 50 ,type= int, help='dim input data') 50 | parser.add_argument('--d_out', default = 1 ,type= int, help='dim out feature') 51 | parser.add_argument('--H', default = 3 ,type= int, help='num of hidden layers in the teacher network') 52 | parser.add_argument('--num_particles', default = 1000 ,type= int, help='num_particles*H = number of hidden units in the student network ') 53 | 54 | # Initialization parameters 55 | parser.add_argument('--mean_student', default = 0.001 ,type= float, help='mean initial value for the student weights') 56 | parser.add_argument('--std_student', default = 1. ,type= float, help='std initial value for the student weights') 57 | parser.add_argument('--mean_teacher', default = 0. ,type= float, help='mean initial value for the teacher weights') 58 | parser.add_argument('--std_teacher', default = 1. ,type= float, help='std initial value for the teacher weights') 59 | 60 | # Data parameters 61 | parser.add_argument('--input_data', default = 'Spherical' ,type= str, help='input data distribution') 62 | parser.add_argument('--N_train', default = 1000 ,type= int, help='num samples for training') 63 | parser.add_argument('--N_valid', default = 1000 ,type= int, help='num samples for validation') 64 | 65 | parser.add_argument('--config', default = '' ,type= str, help='config file for non default parameters') 66 | 67 | args = parser.parse_args() 68 | args = make_flags(args,args.config) 69 | 70 | 71 | exp = Trainer(args) 72 | exp.train() 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils import data 7 | 8 | 9 | 10 | import numpy as np 11 | from functools import partial 12 | import sys,os,time,itertools 13 | 14 | from networks import * 15 | from losses import * 16 | 17 | class Trainer(object): 18 | def __init__(self,args): 19 | self.args = args 20 | self.device = 'cuda:'+str(args.device) if torch.cuda.is_available() and args.device>-1 else 'cpu' 21 | 22 | 23 | self.dtype = get_dtype(args) 24 | self.log_dir = os.path.join(args.log_dir, args.log_name+'_loss_' + args.loss +'_noise_level_'+str(args.noise_level) ) 25 | 26 | if not os.path.isdir(self.log_dir): 27 | os.mkdir(self.log_dir) 28 | 29 | if args.log_in_file: 30 | self.log_file = open(os.path.join(self.log_dir, 'log.txt'), 'w', buffering=1) 31 | sys.stdout = self.log_file 32 | print('==> Building model..') 33 | self.build_model() 34 | 35 | 36 | def build_model(self): 37 | torch.manual_seed(self.args.seed) 38 | if not self.args.with_noise: 39 | self.args.noise_level = 0. 40 | self.teacherNet = get_net(self.args,self.dtype,self.device,'teacher') 41 | self.student = get_net(self.args,self.dtype,self.device,'student') 42 | self.data_train = get_data_gen(self.teacherNet,self.args,self.dtype,self.device) 43 | self.data_valid = get_data_gen(self.teacherNet,self.args,self.dtype,self.device) 44 | 45 | self.loss = self.get_loss() 46 | 47 | self.optimizer = self.get_optimizer(self.args.lr) 48 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min',patience = 50,verbose=True, factor = 0.9) 49 | #self.get_reg_dist() 50 | 51 | def get_loss(self): 52 | if self.args.loss=='mmd_noise_injection': 53 | return MMD(self.student,self.args.with_noise) 54 | elif self.args.loss=='mmd_diffusion': 55 | return MMD_Diffusion(self.student) 56 | elif self.args.loss=='sobolev': 57 | return Sobolev(self.student) 58 | def get_optimizer(self,lr): 59 | if self.args.optimizer=='SGD': 60 | return optim.SGD(self.student.parameters(), lr=lr) 61 | 62 | def init_student(self,mean,std): 63 | weights_init_student = partial(weights_init,{'mean':mean,'std':std}) 64 | self.student.apply(weights_init_student) 65 | 66 | def train(self,start_epoch=0,total_iters=0): 67 | print("Starting Training Loop...") 68 | start_time = time.time() 69 | best_valid_loss = np.inf 70 | for epoch in range(start_epoch, start_epoch+self.args.total_epochs): 71 | total_iters,train_loss = train_epoch(epoch,total_iters,self.loss,self.data_train,self.optimizer,'train', device=self.device) 72 | total_iters,valid_loss = train_epoch(epoch, total_iters, self.loss,self.data_valid,self.optimizer,'valid', device=self.device) 73 | if not np.isfinite(train_loss): 74 | break 75 | 76 | if valid_loss < best_valid_loss: 77 | best_valid_loss = valid_loss 78 | if self.args.use_scheduler: 79 | self.scheduler.step(train_loss) 80 | if np.mod(epoch,self.args.noise_decay_freq)==0 and epoch>0: 81 | self.loss.student.update_noise_level() 82 | if np.mod(epoch,10)==0: 83 | new_time = time.time() 84 | 85 | start_time = new_time 86 | return train_loss,valid_loss,best_valid_loss 87 | 88 | 89 | def get_data_gen(net,args,dtype,device): 90 | params = {'batch_size': args.batch_size, 91 | 'shuffle': True, 92 | 'num_workers': 0} 93 | if args.input_data=='Spherical': 94 | teacher = SphericalTeacher(net,args.N_train,dtype,device) 95 | return data.DataLoader(teacher, **params) 96 | 97 | def get_net(args,dtype,device,net_type): 98 | non_linearity = quadexp() 99 | if net_type=='teacher': 100 | weights_init_net = partial(weights_init,{'mean':args.mean_teacher,'std':args.std_teacher}) 101 | if args.teacher_net=='OneHidden': 102 | Net = OneHiddenLayer(args.d_int,args.H,args.d_out,non_linearity = non_linearity,bias=args.bias) 103 | if net_type=='student': 104 | weights_init_net = partial(weights_init,{'mean':args.mean_student,'std':args.std_student}) 105 | if args.student_net=='NoisyOneHidden': 106 | Net = NoisyOneHiddenLayer(args.d_int, args.H, args.d_out, args.num_particles, non_linearity = non_linearity, noise_level = args.noise_level,noise_decay=args.noise_decay,bias=args.bias) 107 | 108 | Net.to(device) 109 | if args.dtype=='float64': 110 | Net.double() 111 | 112 | Net.apply(weights_init_net) 113 | return Net 114 | 115 | def get_dtype(args): 116 | if args.dtype=='float32': 117 | return torch.float32 118 | else: 119 | return torch.float64 120 | 121 | 122 | def weights_init(args,m): 123 | if isinstance(m, nn.Linear): 124 | m.weight.data.normal_(mean=args['mean'],std=args['std']) 125 | if m.bias: 126 | m.bias.data.normal_(mean=args['mean'],std=args['std']) 127 | 128 | def train_epoch(epoch,total_iters,Loss,data_loader, optimizer,phase, device="cuda"): 129 | 130 | # Training Loop 131 | # Lists to keep track of progress 132 | 133 | if phase == 'train': 134 | Loss.student.train(True) # Set model to training mode 135 | else: 136 | Loss.student.train(False) # Set model to evaluate mode 137 | 138 | cum_loss = 0 139 | # For each epoch 140 | 141 | # For each batch in the dataloader 142 | for batch_idx, (inputs, targets) in enumerate(data_loader): 143 | if phase=="train": 144 | total_iters += 1 145 | Loss.student.zero_grad() 146 | loss = Loss(inputs, targets) 147 | # Calculate the gradients for this batch 148 | loss.backward() 149 | optimizer.step() 150 | loss = loss.item() 151 | cum_loss += loss 152 | 153 | elif phase=='valid': 154 | loss = Loss(inputs, targets).item() 155 | cum_loss += loss 156 | total_loss = cum_loss/(batch_idx+1) 157 | if np.mod(epoch,10)==0: 158 | 159 | print('Epoch: '+ str(epoch) + ' | ' + phase + ' loss: ' + str(round(total_loss,3)) ) 160 | return total_iters, total_loss 161 | --------------------------------------------------------------------------------