├── LICENSE ├── README.md ├── compareTrace.py ├── config.py ├── data └── miniboone │ └── data.npy ├── datasets ├── __init__.py ├── bsds300.py ├── gas.py ├── hepmass.py ├── miniboone.py ├── mnist.py └── power.py ├── detailedSetup.md ├── evaluateLargeOTflow.py ├── evaluateToyOTflow.py ├── experiments └── cnf │ ├── large │ └── pretrained │ │ ├── pretrained_interp_mnist_checkpt.pth │ │ └── pretrained_miniboone_checkpt.pth │ └── toy │ └── pretrained │ └── pretrained_swissroll_alph30_15_m32_checkpt.pth ├── interpMnist.py ├── lib ├── dataloader.py ├── toy_data.py ├── transform.py └── utils.py ├── requirements.txt ├── src ├── Autoencoder.py ├── OTFlowProblem.py ├── Phi.py ├── PhiHC.py ├── mmd.py ├── plotTraceComparison.py └── plotter.py ├── test ├── gradTestOTFlowProblem.py ├── gradTestTrHess.py ├── testPhiGradx.py └── testPhiOpt.py ├── trainLargeOTflow.py ├── trainMnistOTflow.py └── trainToyOTflow.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EmoryMLIP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OT-Flow 2 | Pytorch implementation of our continuous normalizing flows regularized with optimal transport. 3 | 4 | ## Associated Publication 5 | 6 | OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport 7 | 8 | Paper: https://ojs.aaai.org/index.php/AAAI/article/view/17113 9 | 10 | Supplemental: https://arxiv.org/abs/2006.00104 11 | 12 | Please cite as 13 | 14 | @inproceedings{onken2021otflow, 15 | title={{OT-Flow}: Fast and Accurate Continuous Normalizing Flows via Optimal Transport}, 16 | author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto}, 17 | volume={35}, 18 | number={10}, 19 | booktitle={AAAI Conference on Artificial Intelligence}, 20 | year={2021}, 21 | month={May}, 22 | pages={9223--9232}, 23 | url={https://ojs.aaai.org/index.php/AAAI/article/view/17113}, 24 | } 25 | 26 | ## Set-up 27 | 28 | Install all the requirements: 29 | ``` 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | For the large data sets, you'll need to download the preprocessed data from Papamakarios's MAF paper found at https://zenodo.org/record/1161203#.XbiVGUVKhgi. Place the data in the data folder. We've done miniboone for you since it's small (and provide a pre-trained miniboone model). 34 | 35 | To run some files (e.g. the tests), you may need to add them to the path via 36 | ``` 37 | export PYTHONPATH="${PYTHONPATH}:." 38 | ``` 39 | 40 | A more in-depth setup is provided in [detailedSetup.md](detailedSetup.md). 41 | 42 | ## Trace Comparison 43 | 44 | Compare our trace with the AD estimation of the trace 45 | ``` 46 | python compareTrace.py 47 | ``` 48 | 49 | For Figure 2, we averaged over 20 runs with the following results 50 | ``` 51 | python src/plotTraceComparison.py 52 | ``` 53 | 54 | 55 | 56 | ## Toy problems 57 | 58 | Train a toy example 59 | ``` 60 | python trainToyOTflow.py 61 | ``` 62 | 63 | Plot results of a pre-trained example 64 | ``` 65 | python evaluateToyOTflow.py 66 | ``` 67 | 68 | 69 | ## Large CNFs 70 | 71 | ``` 72 | python trainLargeOTflow.py 73 | ``` 74 | 75 | Evaluate a pre-trained model 76 | ``` 77 | python evaluateLargeOTflow.py 78 | ``` 79 | 80 | 81 | 82 | #### Hyperparameters 83 | Train and Evaluate using our hyperparameters ([see detailedSetup.md](detailedSetup.md)) 84 | 85 | | Data set | Train Time Steps | Val Time Steps | Batch Size | Hidden Dim | alpha on C term | alpha on R term | Test Time Steps | Test Batch Size | 86 | |------------------- |----------------- |--------------- |----------- |----------- |---------------- |---------------- |---------------- |---------------- | 87 | | Power | 10 | 22 | 10,000 | 128 | 500 | 5 | 24 | 120,000 | 88 | | Gas | 10 | 24 | 2,000 | 350 | 1,200 | 40 | 30 | 55,000 | 89 | | Hepmass | 12 | 24 | 2,000 | 256 | 500 | 40 | 24 | 50,000 | 90 | | Miniboone | 6 | 10 | 2,000 | 256 | 100 | 15 | 18 | 5,000 | 91 | | BSDS300 | 14 | 30 | 300 | 512 | 2,000 | 800 | 40 | 10,000 | 92 | 93 | 94 | 95 | 96 | ### MNIST 97 | 98 | Train an MNIST model 99 | ``` 100 | python trainMnistOTflow.py 101 | ``` 102 | 103 | Run a pre-trained MNIST 104 | ``` 105 | python interpMnist.py 106 | ``` 107 | 108 | ## Acknowledgements 109 | 110 | This material is in part based upon work supported by the US National Science Foundation Grant DMS-1751636, the US AFOSR Grants 20RT0237 and FA9550-18-1-0167, AFOSR 111 | MURI FA9550-18-1-050, and ONR Grant No. N00014-18-1- 112 | 2527. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the funding agencies. 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /compareTrace.py: -------------------------------------------------------------------------------- 1 | # compareTrace.py 2 | # compare the exact trace in Phi with the hutchinsons estimator using atomatic differentiation 3 | 4 | import math 5 | from src.OTFlowProblem import * 6 | 7 | gpu = 0 8 | if not torch.cuda.is_available(): 9 | print("No gpu found. If you wish to run on a CPU, remove the cuda specific lines, then run again.") 10 | exit(1) 11 | 12 | 13 | device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu') 14 | 15 | # ---------------------------------------------------------------------------------------------------------------------- 16 | # compare timings with AD 17 | # ---------------------------------------------------------------------------------------------------------------------- 18 | 19 | def compareTrace(domain,d, seed=0): 20 | """ 21 | domain: list of integers specificying the number of hutchinson vectors to use 22 | d: dimensionality of the problem 23 | :param domain: 24 | :return: 25 | """ 26 | 27 | start = torch.cuda.Event(enable_timing=True) 28 | end = torch.cuda.Event(enable_timing=True) 29 | 30 | torch.manual_seed(seed) # for reproducibility 31 | 32 | # set up model 33 | m = 64 34 | alph = [1.0,1.0,1.0,1.0,1.0] 35 | nTh = 2 36 | net = Phi(nTh=nTh, m=m, d=d, alph=alph) 37 | net = net.to(device) 38 | 39 | n_samples = 512 40 | x = torch.randn(n_samples, d+1).to(device) 41 | x.requires_grad = True 42 | 43 | 44 | # dry-run / warm-up 45 | start.record() 46 | a = torch.Tensor(2000,3000) 47 | b = torch.Tensor(3000,4000) 48 | c = torch.mm(a,b) 49 | end.record() 50 | torch.cuda.synchronize() 51 | _ = start.elapsed_time(end) / 1000.0 # convert to seconds 52 | 53 | 54 | # --------------------------------------------- 55 | # time the exact trace 56 | # --------------------------------------------- 57 | 58 | start.record() 59 | grad, exact_trace = net.trHess(x) 60 | end.record() 61 | torch.cuda.synchronize() 62 | exact_time = start.elapsed_time(end) / 1000.0 # convert to seconds 63 | print("Exact Trace Computation time= {:9.6f}".format(exact_time)) 64 | 65 | # --------------------------------------------- 66 | # time hutchinson's estimator using AD 67 | # compute an estimate for each value in domain 68 | # aka domain=[1,10,20] will run an estiamte with 1 hutch vector, one with 10 hutch vectors, and one with 20 hutch vectors 69 | # --------------------------------------------- 70 | 71 | # where to hold results 72 | resTime = torch.zeros(1,len(domain)) 73 | resErr = torch.zeros(1,len(domain)) 74 | 75 | for iDomain, num_hutchinsons in enumerate(domain): 76 | torch.manual_seed(seed+1) 77 | trace_acc = torch.zeros(n_samples).to(device) # accumulated trace 78 | 79 | 80 | # create the num_hutchinsons rademacher vectors...these "vectors" are each stored as a matrix 81 | # we have num_hutchinsons of them, so that makes a tensor called rad 82 | # compute vector-Jacobian Product using AD with the rademacher vector 83 | 84 | 85 | start.record() 86 | rad = (1 / math.sqrt(num_hutchinsons)) * ((torch.rand(n_samples, d+1, num_hutchinsons,device=device) < 0.5).float() * 2 - 1) 87 | rad[:,d,:] = 0 # set time position to 0, leave space values as rademacher 88 | # rad = rad.to(device) 89 | for i in range(num_hutchinsons): 90 | e = rad[:,:,i] # the "random vector" 91 | grad = net.trHess(x, justGrad=True) 92 | trace_est = torch.autograd.grad(outputs=grad, inputs=x, create_graph=False,retain_graph=False, grad_outputs=e)[0] 93 | trace_est = trace_est * e 94 | trace_est = trace_est.view(grad.shape[0], -1).sum(dim=1) 95 | trace_acc += trace_est 96 | end.record() 97 | torch.cuda.synchronize() 98 | ad_time = start.elapsed_time(end) / 1000.0 # convert to seconds 99 | 100 | trace_error = torch.norm(exact_trace-trace_acc)/torch.norm(exact_trace) # compute error 101 | print("{:4d} hutchinson vectors. time= {:9.6f} , rel. error = {:9.7f}".format(num_hutchinsons, ad_time, trace_error )) 102 | resTime[0, iDomain] = ad_time 103 | resErr[0, iDomain] = trace_error 104 | 105 | # return timings nad errors for plotting/analysis 106 | return resTime, resErr, exact_time 107 | 108 | if __name__ == '__main__': 109 | 110 | from src.plotTraceComparison import * 111 | 112 | domainMini = [1, 10, 20, 30, 43] 113 | domainBSDS = [1, 10, 20, 30, 40, 50, 63] 114 | domainMNIST = [1, 100, 200, 300, 400, 500, 600, 700, 784] 115 | 116 | nRepeats = 2 # average over 2 runs. For publication figure, we set this to 20 117 | 118 | # arrays to hold all the results...in case we want to use error bounds 119 | resTimeBSDSArray = torch.zeros(nRepeats, len(domainBSDS)) 120 | traceErrorBSDSArray = torch.zeros(nRepeats, len(domainBSDS)) 121 | exactTimingBSDSArray = torch.zeros(nRepeats, 1) 122 | resTimeMiniArray = torch.zeros(nRepeats, len(domainMini)) 123 | traceErrorMiniArray = torch.zeros(nRepeats, len(domainMini)) 124 | exactTimingMiniArray = torch.zeros(nRepeats, 1) 125 | resTimeMNISTArray = torch.zeros(nRepeats, len(domainMNIST)) 126 | traceErrorMNISTArray = torch.zeros(nRepeats, len(domainMNIST)) 127 | exactTimingMNISTArray = torch.zeros(nRepeats, 1) 128 | 129 | for i in range(nRepeats): 130 | print('\n\n ITER ', i) 131 | _ = compareTrace(domainMini, 50) # dry-run 132 | a , b, c = compareTrace(domainBSDS, 63, seed=i); 133 | resTimeBSDSArray[i,:] = a 134 | traceErrorBSDSArray[i,:] = b 135 | exactTimingBSDSArray[i] = c 136 | _ = compareTrace(domainMini, 50) # dry-run 137 | a , b, c = compareTrace(domainMini, 43, seed=i) 138 | resTimeMiniArray[i, :] = a 139 | traceErrorMiniArray[i, :] = b 140 | exactTimingMiniArray[i] = c 141 | _ = compareTrace(domainMini, 50) # dry-run 142 | a , b, c = compareTrace(domainMNIST, 784, seed=i) 143 | resTimeMNISTArray[i, :] = a 144 | traceErrorMNISTArray[i, :] = b 145 | exactTimingMNISTArray[i] = c 146 | 147 | 148 | lTimeExact = [ exactTimingMiniArray , exactTimingBSDSArray , exactTimingMNISTArray] 149 | plotTraceCompare(domainMini , domainBSDS , domainMNIST, 150 | resTimeMiniArray , resTimeBSDSArray , resTimeMNISTArray, 151 | traceErrorMiniArray, traceErrorBSDSArray, traceErrorMNISTArray, 152 | lTimeExact, 'image/traceComparison/') 153 | 154 | 155 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | # simplistic config file to make code platform-agnostic 3 | 4 | def getconfig(): 5 | return ConfigOT() 6 | 7 | class ConfigOT: 8 | """ 9 | gpu - True means GPU available on plaform , False means it's not; this is used for default values 10 | os - 'mac' , 'linux' 11 | """ 12 | gpu = True 13 | os = 'linux' 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /data/miniboone/data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/OT-Flow/4d66618e2a2f4d8e8ce080cf1b3c769c78b2590d/data/miniboone/data.npy -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | root = 'data/' 2 | 3 | from .power import POWER 4 | from .gas import GAS 5 | from .hepmass import HEPMASS 6 | from .miniboone import MINIBOONE 7 | from .bsds300 import BSDS300 8 | 9 | -------------------------------------------------------------------------------- /datasets/bsds300.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | import datasets 5 | 6 | 7 | class BSDS300: 8 | """ 9 | A dataset of patches from BSDS300. 10 | """ 11 | 12 | class Data: 13 | """ 14 | Constructs the dataset. 15 | """ 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data[:] 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | # load dataset 25 | f = h5py.File(datasets.root + 'BSDS300/BSDS300.hdf5', 'r') 26 | 27 | self.trn = self.Data(f['train']) 28 | self.val = self.Data(f['validation']) 29 | self.tst = self.Data(f['test']) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 33 | 34 | f.close() 35 | -------------------------------------------------------------------------------- /datasets/gas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import datasets 5 | 6 | 7 | class GAS: 8 | 9 | class Data: 10 | 11 | def __init__(self, data): 12 | 13 | self.x = data.astype(np.float32) 14 | self.N = self.x.shape[0] 15 | 16 | def __init__(self): 17 | 18 | file = datasets.root + 'gas/ethylene_CO.pickle' 19 | trn, val, tst = load_data_and_clean_and_split(file) 20 | 21 | self.trn = self.Data(trn) 22 | self.val = self.Data(val) 23 | self.tst = self.Data(tst) 24 | 25 | self.n_dims = self.trn.x.shape[1] 26 | 27 | 28 | def load_data(file): 29 | 30 | data = pd.read_pickle(file) 31 | # data = pd.read_pickle(file).sample(frac=0.25) 32 | # data.to_pickle(file) 33 | data.drop("Meth", axis=1, inplace=True) 34 | data.drop("Eth", axis=1, inplace=True) 35 | data.drop("Time", axis=1, inplace=True) 36 | return data 37 | 38 | 39 | def get_correlation_numbers(data): 40 | C = data.corr() 41 | A = C > 0.98 42 | # B = A.as_matrix().sum(axis=1) 43 | B = A.values.sum(axis=1) 44 | return B 45 | 46 | 47 | def load_data_and_clean(file): 48 | 49 | data = load_data(file) 50 | B = get_correlation_numbers(data) 51 | 52 | while np.any(B > 1): 53 | col_to_remove = np.where(B > 1)[0][0] 54 | col_name = data.columns[col_to_remove] 55 | data.drop(col_name, axis=1, inplace=True) 56 | B = get_correlation_numbers(data) 57 | # print(data.corr()) 58 | data = (data - data.mean()) / data.std() 59 | 60 | return data 61 | 62 | 63 | def load_data_and_clean_and_split(file): 64 | 65 | # data = load_data_and_clean(file).as_matrix() # OUTDATED 66 | data = load_data_and_clean(file).values 67 | N_test = int(0.1 * data.shape[0]) 68 | data_test = data[-N_test:] 69 | data_train = data[0:-N_test] 70 | N_validate = int(0.1 * data_train.shape[0]) 71 | data_validate = data_train[-N_validate:] 72 | data_train = data_train[0:-N_validate] 73 | 74 | return data_train, data_validate, data_test 75 | -------------------------------------------------------------------------------- /datasets/hepmass.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | from os.path import join 5 | 6 | import datasets 7 | 8 | 9 | class HEPMASS: 10 | """ 11 | The HEPMASS data set. 12 | http://archive.ics.uci.edu/ml/datasets/HEPMASS 13 | """ 14 | 15 | class Data: 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data.astype(np.float32) 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | path = datasets.root + 'hepmass/' 25 | trn, val, tst = load_data_no_discrete_normalised_as_array(path) 26 | 27 | self.trn = self.Data(trn) 28 | self.val = self.Data(val) 29 | self.tst = self.Data(tst) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | 33 | 34 | def load_data(path): 35 | 36 | data_train = pd.read_csv(filepath_or_buffer=join(path, "1000_train.csv"), index_col=False) 37 | data_test = pd.read_csv(filepath_or_buffer=join(path, "1000_test.csv"), index_col=False) 38 | 39 | return data_train, data_test 40 | 41 | 42 | def load_data_no_discrete(path): 43 | """ 44 | Loads the positive class examples from the first 10 percent of the dataset. 45 | """ 46 | data_train, data_test = load_data(path) 47 | 48 | # Gets rid of any background noise examples i.e. class label 0. 49 | data_train = data_train[data_train[data_train.columns[0]] == 1] 50 | data_train = data_train.drop(data_train.columns[0], axis=1) 51 | data_test = data_test[data_test[data_test.columns[0]] == 1] 52 | data_test = data_test.drop(data_test.columns[0], axis=1) 53 | # Because the data set is messed up! 54 | data_test = data_test.drop(data_test.columns[-1], axis=1) 55 | 56 | return data_train, data_test 57 | 58 | 59 | def load_data_no_discrete_normalised(path): 60 | 61 | data_train, data_test = load_data_no_discrete(path) 62 | mu = data_train.mean() 63 | s = data_train.std() 64 | data_train = (data_train - mu) / s 65 | data_test = (data_test - mu) / s 66 | 67 | return data_train, data_test 68 | 69 | 70 | def load_data_no_discrete_normalised_as_array(path): 71 | 72 | data_train, data_test = load_data_no_discrete_normalised(path) 73 | data_train, data_test = data_train.values, data_test.values 74 | 75 | i = 0 76 | # Remove any features that have too many re-occurring real values. 77 | features_to_remove = [] 78 | for feature in data_train.T: 79 | c = Counter(feature) 80 | max_count = np.array([v for k, v in sorted(c.items())])[0] 81 | if max_count > 5: 82 | features_to_remove.append(i) 83 | i += 1 84 | data_train = data_train[:, np.array([i for i in range(data_train.shape[1]) if i not in features_to_remove])] 85 | data_test = data_test[:, np.array([i for i in range(data_test.shape[1]) if i not in features_to_remove])] 86 | 87 | N = data_train.shape[0] 88 | N_validate = int(N * 0.1) 89 | data_validate = data_train[-N_validate:] 90 | data_train = data_train[0:-N_validate] 91 | 92 | return data_train, data_validate, data_test 93 | -------------------------------------------------------------------------------- /datasets/miniboone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import datasets 4 | 5 | 6 | class MINIBOONE: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | file = datasets.root + 'miniboone/data.npy' 18 | trn, val, tst = load_data_normalised(file) 19 | 20 | self.trn = self.Data(trn) 21 | self.val = self.Data(val) 22 | self.tst = self.Data(tst) 23 | 24 | self.n_dims = self.trn.x.shape[1] 25 | 26 | 27 | def load_data(root_path): 28 | # NOTE: To remember how the pre-processing was done. 29 | # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) 30 | # print data.head() 31 | # data = data.as_matrix() 32 | # # Remove some random outliers 33 | # indices = (data[:, 0] < -100) 34 | # data = data[~indices] 35 | # 36 | # i = 0 37 | # # Remove any features that have too many re-occuring real values. 38 | # features_to_remove = [] 39 | # for feature in data.T: 40 | # c = Counter(feature) 41 | # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] 42 | # if max_count > 5: 43 | # features_to_remove.append(i) 44 | # i += 1 45 | # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] 46 | # np.save("~/data/miniboone/data.npy", data) 47 | 48 | data = np.load(root_path) 49 | N_test = int(0.1 * data.shape[0]) 50 | data_test = data[-N_test:] 51 | data = data[0:-N_test] 52 | N_validate = int(0.1 * data.shape[0]) 53 | data_validate = data[-N_validate:] 54 | data_train = data[0:-N_validate] 55 | 56 | return data_train, data_validate, data_test 57 | 58 | 59 | def load_data_normalised(root_path): 60 | 61 | data_train, data_validate, data_test = load_data(root_path) 62 | data = np.vstack((data_train, data_validate)) 63 | mu = data.mean(axis=0) 64 | s = data.std(axis=0) 65 | data_train = (data_train - mu) / s 66 | data_validate = (data_validate - mu) / s 67 | data_test = (data_test - mu) / s 68 | 69 | return data_train, data_validate, data_test 70 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # mnist.py 2 | 3 | from torchvision import transforms, datasets 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | 8 | def getLoader(name, batch, test_batch, augment=False, hasGPU=False, conditional=-1): 9 | 10 | if name == 'mnist': 11 | val_size = 1.0/6.0 12 | random_seed = 0 13 | 14 | # define transforms 15 | normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST 16 | 17 | val_transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | normalize 20 | ]) 21 | if augment: 22 | train_transform = transforms.Compose([ 23 | transforms.RandomCrop(32, padding=4), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | normalize 27 | ]) 28 | else: 29 | train_transform = transforms.Compose([ 30 | transforms.ToTensor(), 31 | normalize 32 | ]) 33 | 34 | kwargs = {'num_workers': 0, 'pin_memory': True} if hasGPU else {} 35 | 36 | # load the dataset 37 | # from https://gist.github.com/MattKleinsmith/5226a94bad5dd12ed0b871aed98cb123 38 | data = datasets.MNIST(root='../data', train=True, 39 | download=True, transform=train_transform) 40 | 41 | # val_dataset = datasets.MNIST(root='../data', train=True, 42 | # download=True, transform=val_transform) 43 | 44 | test_data = datasets.MNIST(root='../data', 45 | train=False,download=True,transform=val_transform) 46 | 47 | 48 | if conditional >= 0 and conditional <= 9: 49 | idx = data.targets == conditional 50 | data.data = data.data[idx, :] 51 | data.targets = data.targets[idx] 52 | nTot = torch.sum(idx).item() 53 | nTrain = int((5.0 / 6.0) * nTot) 54 | nVal = nTot - nTrain 55 | train_data, valid_data = torch.utils.data.random_split(data, [nTrain, nVal]) 56 | 57 | idx = test_data.targets == conditional 58 | test_data.data = test_data.data[idx,:] 59 | test_data.targets = test_data.targets[idx] 60 | else: 61 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 62 | 63 | 64 | # num_train = len(train_dataset) 65 | # indices = list(range(num_train)) 66 | # split = int(np.floor(val_size * num_train)) 67 | 68 | # set up random samplers 69 | # np.random.seed(random_seed) 70 | # np.random.shuffle(indices) 71 | # train_idx, val_idx = indices[split:], indices[:split] 72 | # train_sampler = SubsetRandomSampler(train_idx) 73 | # val_sampler = SubsetRandomSampler(val_idx) 74 | 75 | train_loader = torch.utils.data.DataLoader(train_data, 76 | batch_size=batch, shuffle=True, **kwargs) 77 | 78 | val_loader = torch.utils.data.DataLoader(valid_data, 79 | batch_size=test_batch, shuffle=False, **kwargs) 80 | 81 | 82 | test_loader = torch.utils.data.DataLoader(test_data, shuffle=False, 83 | batch_size=test_batch, **kwargs) 84 | 85 | return train_loader, val_loader, test_loader 86 | 87 | 88 | 89 | 90 | else: 91 | raise ValueError('Unknown dataset') 92 | exit(1) 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /datasets/power.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import datasets 4 | 5 | 6 | class POWER: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | trn, val, tst = load_data_normalised() 18 | 19 | self.trn = self.Data(trn) 20 | self.val = self.Data(val) 21 | self.tst = self.Data(tst) 22 | 23 | self.n_dims = self.trn.x.shape[1] 24 | 25 | 26 | def load_data(): 27 | return np.load(datasets.root + 'power/data.npy') 28 | 29 | 30 | def load_data_split_with_noise(): 31 | 32 | rng = np.random.RandomState(42) 33 | 34 | data = load_data() 35 | rng.shuffle(data) 36 | N = data.shape[0] 37 | 38 | data = np.delete(data, 3, axis=1) 39 | data = np.delete(data, 1, axis=1) 40 | ############################ 41 | # Add noise 42 | ############################ 43 | # global_intensity_noise = 0.1*rng.rand(N, 1) 44 | voltage_noise = 0.01 * rng.rand(N, 1) 45 | # grp_noise = 0.001*rng.rand(N, 1) 46 | gap_noise = 0.001 * rng.rand(N, 1) 47 | sm_noise = rng.rand(N, 3) 48 | time_noise = np.zeros((N, 1)) 49 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) 50 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) 51 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 52 | data = data + noise 53 | 54 | N_test = int(0.1 * data.shape[0]) 55 | data_test = data[-N_test:] 56 | data = data[0:-N_test] 57 | N_validate = int(0.1 * data.shape[0]) 58 | data_validate = data[-N_validate:] 59 | data_train = data[0:-N_validate] 60 | 61 | return data_train, data_validate, data_test 62 | 63 | 64 | def load_data_normalised(): 65 | 66 | data_train, data_validate, data_test = load_data_split_with_noise() 67 | data = np.vstack((data_train, data_validate)) 68 | mu = data.mean(axis=0) 69 | s = data.std(axis=0) 70 | data_train = (data_train - mu) / s 71 | data_validate = (data_validate - mu) / s 72 | data_test = (data_test - mu) / s 73 | 74 | return data_train, data_validate, data_test 75 | -------------------------------------------------------------------------------- /detailedSetup.md: -------------------------------------------------------------------------------- 1 | # Detailed Set-up and Running Instructions 2 | 3 | ## Set-up 4 | 5 | go to local folder where you want the files and type 6 | ``` 7 | git init 8 | ``` 9 | 10 | copy all files from the remote repository into this local folder: 11 | ``` 12 | git pull git@github.com:EmoryMLIP/OT-Flow 13 | ``` 14 | 15 | Set vim as the default editor (this step often just helps on linux): 16 | ``` 17 | git config --global core.editor "vim" 18 | ``` 19 | 20 | Create a virtual environment (may need to install virtualenv command) to hold all the python package versions for this project: 21 | ``` 22 | virtualenv -p python3 otEnv 23 | ``` 24 | 25 | Start up the virtual environment: 26 | ``` 27 | source otEnv/bin/activate 28 | ``` 29 | 30 | If you're running python 3.7 (and pip 21.3.1), install all the requirements: 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | We used Python 3.5 and CUDA 9.2, so we installed pytorch separately via 36 | ``` 37 | pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html 38 | ``` 39 | 40 | For full capabilities, set the values in config.py to match your architecture. 41 | 42 | 43 | ### Training and Evaluating Toy Data Sets 44 | commands with hyperparameters 45 | 46 | 47 | ``` 48 | python trainToyOTflow.py --data 8gaussians --nt 8 --nt_val 12 --batch_size 5000 --prec double --alph 1.0,30.0,1 --niters 5000 --lr 1e-1 --val_freq 50 --drop_freq 500 --sample_freq 25 --m 32 49 | 50 | python trainToyOTflow.py --data checkerboard --nt 12 --nt_val 16 --batch_size 10000 --prec double --alph 1.0,15.0,2.0 --niters 20000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 51 | 52 | python trainToyOTflow.py --data swissroll --nt 8 --nt_val 16 --batch_size 5000 --prec double --alph 1.0,30.0,15.0 --niters 5000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 53 | 54 | python trainToyOTflow.py --data circles --nt 8 --nt_val 12 --batch_size 5000 --prec double --alph 1.0,5.0,1.0 --niters 5000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 55 | 56 | python trainToyOTflow.py --data moons --nt 8 --nt_val 12 --batch_size 5000 --prec double --alph 1.0,8.0,1.0 --niters 5000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 57 | 58 | python trainToyOTflow.py --data pinwheel --nt 8 --nt_val 12 --batch_size 5000 --prec double --alph 1.0,30.0,15.0 --niters 5000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 59 | 60 | python trainToyOTflow.py --data 2spirals --nt 8 --nt_val 12 --batch_size 5000 --prec double --alph 1.0,10.0,1.0 --niters 5000 --lr 5e-2 --val_freq 50 --drop_freq 1000 --sample_freq 25 --m 32 61 | ``` 62 | 63 | 64 | ### Training and Evaluating Large Data Sets 65 | commands with hyperparameters 66 | 67 | ``` 68 | python trainLargeOTflow.py --data power --niters 36000 --alph 1.0,500.0,5.0 --m 128 --batch_size 10000 --lr 0.03 --nt 10 --nt_val 22 --test_batch_size 120000 --val_freq 30 --weight_decay 0.0 --drop_freq 0 69 | 70 | python evaluateLargeOTflow.py --data power --nt 24 --batch_size 120000 --resume yourPowerCheckpt.pth 71 | 72 | 73 | python trainLargeOTflow.py --data gas --niters 60000 --alph 1.0,1200.0,40.0 --m 350 --batch_size 2000 --drop_freq 0 --lr 0.01 --nt 10 --nt_val 28 --test_batch_size 55000 --val_freq 50 --weight_decay 0.0 --viz_freq 1000 --prec single --early_stopping 20 74 | 75 | python evaluateLargeOTflow.py --data gas --nt 30 --batch_size 55000 --resume youGasCheckpt.pth 76 | 77 | 78 | python trainLargeOTflow.py --data hepmass --niters 40000 --alph 1.0,500.0,40.0 --m 256 --nTh 2 --batch_size 2000 --drop_freq 0 --lr 0.02 --nt 12 --nt_val 24 --test_batch_size 20000 --val_freq 50 --weight_decay 0.0 --viz_freq 500 --prec single --early_stopping 15 79 | 80 | python evaluateLargeOTflow.py --data hepmass --nt 24 --batch_size 50000 --resume yourHepmassCheckpt.pth 81 | 82 | 83 | python trainLargeOTflow.py --data miniboone --niters 8000 --alph 1.0,100.0,15.0 --batch_size 2000 --nt 6 --nt_val 10 --lr 0.02 --val_freq 20 --drop_freq 0 --weight_decay 0.0 --m 256 --viz_freq 500 --test_batch_size 5000 --early_stopping 15 84 | 85 | python evaluateLargeOTflow.py --data miniboone --nt 18 --batch_size 5000 --resume yourMinibooneCheckpt.pth 86 | 87 | 88 | python trainLargeOTflow.py --data bsds300 --niters 120000 --alph 1.0,2000.0,800.0 --batch_size 300 --nt 14 --nt_val 30 --lr 0.001 --val_freq 100 --drop_freq 0 --weight_decay 0.0 --m 512 --lr_drop 3.3 --viz_freq 500 --test_batch_size 1000 --prec single --early_stopping 15 89 | 90 | python evaluateLargeOTflow.py --data bsds300 --nt 40 --batch_size 10000 --resume yourBSDSCheckpt.pth 91 | 92 | ``` 93 | 94 | -------------------------------------------------------------------------------- /evaluateLargeOTflow.py: -------------------------------------------------------------------------------- 1 | # evaluateLargeOTflow.py 2 | # run model on testing set, calculate MMD, and plot 3 | import argparse 4 | import os 5 | import time 6 | import numpy as np 7 | import lib.utils as utils 8 | from lib.utils import count_parameters 9 | from src.plotter import * 10 | from src.OTFlowProblem import * 11 | import h5py 12 | import datasets 13 | from src.mmd import mmd 14 | import config 15 | 16 | cf = config.getconfig() 17 | plt.rcParams.update({'font.size': 22}) 18 | 19 | parser = argparse.ArgumentParser('OT-Flow') 20 | parser.add_argument( 21 | '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300'], type=str, default='miniboone' 22 | ) 23 | parser.add_argument('--resume', type=str, default="experiments/cnf/large/pretrained/pretrained_miniboone_checkpt.pth") 24 | 25 | parser.add_argument("--nt" , type=int, default=18, help="number of integration time steps") 26 | parser.add_argument('--batch_size', type=int, default=5000) 27 | parser.add_argument('--prec', type=str, default='single', choices=['None', 'single','double'], help="overwrite trained precision") 28 | parser.add_argument('--gpu' , type=int, default=0) 29 | parser.add_argument('--long_version' , action='store_true') 30 | # default is: args.long_version=False , passing --long_version will take a long time to run to get values for paper 31 | args = parser.parse_args() 32 | 33 | # logger 34 | args.save, sPath = os.path.split(args.resume) 35 | utils.makedirs(args.save) 36 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 37 | logger.info(args) 38 | 39 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 40 | 41 | 42 | def batch_iter(X, batch_size=args.batch_size, shuffle=False): 43 | """ 44 | X: feature tensor (shape: num_instances x num_features) 45 | """ 46 | if shuffle: 47 | idxs = torch.randperm(X.shape[0]) 48 | else: 49 | idxs = torch.arange(X.shape[0]) 50 | if X.is_cuda: 51 | idxs = idxs.cuda() 52 | for batch_idxs in idxs.split(batch_size): 53 | yield X[batch_idxs] 54 | 55 | def load_data(name): 56 | 57 | if name == 'bsds300': 58 | return datasets.BSDS300() 59 | 60 | elif name == 'power': 61 | return datasets.POWER() 62 | 63 | elif name == 'gas': 64 | return datasets.GAS() 65 | 66 | elif name == 'hepmass': 67 | return datasets.HEPMASS() 68 | 69 | elif name == 'miniboone': 70 | return datasets.MINIBOONE() 71 | 72 | else: 73 | raise ValueError('Unknown dataset') 74 | 75 | def compute_loss(net, x, nt): 76 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 77 | return Jc, cs 78 | 79 | 80 | if __name__ == '__main__': 81 | 82 | if args.long_version: 83 | sH5ffjord = 'ffjordResults/' + args.data + 'TestFFJORD.h5' 84 | hf = h5py.File(sH5ffjord, 'r') # open FFJORD results for plotting 85 | """ 86 | FFJORD results were saved in an h5 file with initial data (copied so that ordering is preserved) 87 | hf.keys() 88 | x - the test data from dataset (miniboone, power, etc.) 89 | fx - f(x) , FFJORD's forward transformation of x to the standard normal 90 | finvfx - f^{-1} (f(x)) , FFJORD's backward transformation of fx 91 | invErr - inverse error, avg. norm of difference between x and finvfx ; computed using a weighted avg 92 | nWeights - number of weights in the FFJORD model 93 | testTime - how long FFJORD took to compute the testing loss on 1 gpu for the dataset's testing data 94 | testBatchSize - the batch size used to achieve testTime 95 | normSamples - 100K samples drawn from the standard normal 96 | genSamples - f^{-1} (normSamples) , generated points by applying FFJORD backward transformation to the normal dist. pts 97 | """ 98 | 99 | testData = torch.from_numpy(np.array(hf['x'])) 100 | ffjordFx = np.array(hf['fx']) 101 | ffjordFinvfx = np.array(hf['finvfx']) 102 | ffjordTime = np.array(hf['testTime']).item() 103 | ffjordWeights = np.array(hf['nWeights']).item() 104 | normSamples = torch.from_numpy(np.array(hf['normSamples'])) # 10^5 samples 105 | ffjordGen = np.array(hf['genSamples']) 106 | 107 | else: 108 | logger.info("\nABBREVIATED VERSION\n") 109 | data = load_data(args.data) 110 | testData = torch.from_numpy(data.tst.x) # x sampled from unknown rho_0 111 | nSamples = 3000 # 100000 112 | normSamples = torch.randn(nSamples, testData.shape[1]) # y sampled from rho_1 113 | 114 | logger.info("test data shape: {:}".format(testData.shape)) 115 | 116 | nex = testData.shape[0] 117 | d = testData.shape[1] 118 | nt_test = args.nt 119 | 120 | # reload model 121 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 122 | print(checkpt['args']) 123 | m = checkpt['args'].m 124 | alph = checkpt['args'].alph 125 | nTh = checkpt['args'].nTh 126 | net = Phi(nTh=nTh, m=m, d=d, alph=alph) 127 | argPrec = checkpt['state_dict']['A'].dtype 128 | net = net.to(argPrec) 129 | net.load_state_dict(checkpt["state_dict"]) 130 | net = net.to(device) 131 | 132 | # if specified precision supplied, override the loaded precision 133 | if args.prec != 'None': 134 | if args.prec == 'single': 135 | argPrec = torch.float32 136 | if args.prec == 'double': 137 | argPrec = torch.float64 138 | net = net.to(argPrec) 139 | 140 | cvt = lambda x: x.type(argPrec).to(device, non_blocking=True) 141 | 142 | logger.info(net) 143 | logger.info("----------TESTING---------------") 144 | logger.info("DIMENSION={:} m={:} nTh={:} alpha={:}".format(d,m,nTh,net.alph)) 145 | logger.info("nt_test={:}".format(nt_test)) 146 | logger.info("Number of trainable parameters: {}".format(count_parameters(net))) 147 | logger.info("Number of testing examples: {}".format(nex)) 148 | logger.info("-------------------------") 149 | logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu)) 150 | logger.info("saveLocation = {:}".format(args.save)) 151 | logger.info("-------------------------\n") 152 | 153 | end = time.time() 154 | 155 | log_msg = ( 156 | '{:4s} {:9s} {:9s} {:11s} {:9s}'.format( 157 | 'itr', 'loss', 'L (L_2)', 'C (loss)', 'R (HJB)' 158 | ) 159 | ) 160 | logger.info(log_msg) 161 | 162 | if not cf.gpu: 163 | # assume debugging and run a subset 164 | nSamples = 1000 165 | testData = testData[:nSamples, :] 166 | normSamples = normSamples[:nSamples, :] 167 | if args.long_version: 168 | ffjordFx = ffjordFx[:nSamples, :] 169 | ffjordFinvfx = ffjordFinvfx[:nSamples, :] 170 | ffjordGen = ffjordGen[:nSamples, :] 171 | 172 | net.eval() 173 | with torch.no_grad(): 174 | 175 | # meters to hold testing results 176 | testLossMeter = utils.AverageMeter() 177 | testAlphMeterL = utils.AverageMeter() 178 | testAlphMeterC = utils.AverageMeter() 179 | testAlphMeterR = utils.AverageMeter() 180 | 181 | 182 | itr = 1 183 | for x0 in batch_iter(testData, batch_size=args.batch_size): 184 | 185 | x0 = cvt(x0) 186 | nex = x0.shape[0] 187 | test_loss, test_cs = compute_loss(net, x0, nt=nt_test) 188 | testLossMeter.update(test_loss.item(), nex) 189 | testAlphMeterL.update(test_cs[0].item(), nex) 190 | testAlphMeterC.update(test_cs[1].item(), nex) 191 | testAlphMeterR.update(test_cs[2].item(), nex) 192 | log_message = 'batch {:4d}: {:9.3e} {:9.3e} {:11.5e} {:9.3e}'.format( 193 | itr, test_loss, test_cs[0], test_cs[1], test_cs[2] 194 | ) 195 | logger.info(log_message) # print batch 196 | itr+=1 197 | 198 | # add to print message 199 | log_message = '[TEST] {:9.3e} {:9.3e} {:11.5e} {:9.3e} '.format( 200 | testLossMeter.avg, testAlphMeterL.avg, testAlphMeterC.avg, testAlphMeterR.avg 201 | ) 202 | 203 | logger.info(log_message) # print total 204 | logger.info("Testing Time: {:.2f} seconds with {:} parameters".format( time.time() - end, count_parameters(net) )) 205 | if args.long_version: 206 | logger.info("FFJORD's Testing Time: {:.2f} seconds with {:} parameters".format( ffjordTime , ffjordWeights )) 207 | 208 | 209 | # computing inverse 210 | logger.info("computing inverse...") 211 | nGen = normSamples.shape[0] 212 | 213 | modelFx = np.zeros(testData.shape) 214 | modelFinvfx = np.zeros(testData.shape) 215 | modelGen = np.zeros(normSamples.shape) 216 | 217 | idx = 0 218 | for i , x0 in enumerate(batch_iter(testData, batch_size=args.batch_size)): 219 | x0 = cvt(x0) 220 | fx = integrate(x0[:, 0:d], net, [0.0, 1.0], nt_test, stepper="rk4", alph=net.alph) 221 | finvfx = integrate(fx[:, 0:d], net, [1.0, 0.0], nt_test, stepper="rk4", alph=net.alph) 222 | 223 | # consolidate fx and finvfx into one spot 224 | batchSz = x0.shape[0] 225 | modelFx[ idx:idx+batchSz , 0:d ] = fx[:,0:d].detach().cpu().numpy() 226 | modelFinvfx[ idx:idx+batchSz , 0:d ] = finvfx[:,0:d].detach().cpu().numpy() 227 | idx = idx + batchSz 228 | 229 | # logger.info("model inv error: {:.3e}".format(np.linalg.norm(testData.numpy() - modelFinvfx) / nex)) # initial bug 230 | logger.info("model inv error: {:.3e}".format( np.mean(np.linalg.norm(testData.numpy() - modelFinvfx, ord=2, axis=1)))) 231 | if args.long_version: 232 | logger.info("FFJORD inv error: {:.3e}".format( np.array(hf['invErr']).item() )) 233 | 234 | 235 | # this portion can take a long time 236 | # generate samples 237 | logger.info("generating samples...") 238 | idx = 0 239 | for i, y in enumerate(batch_iter(normSamples, batch_size=args.batch_size)): 240 | y = cvt(y) 241 | finvy = integrate(y[:, 0:d], net, [1.0, 0.0], nt_test, stepper="rk4",alph=net.alph) 242 | 243 | batchSz = y.shape[0] 244 | modelGen[ idx:idx+batchSz , 0:d ] = finvy[:,0:d].detach().cpu().numpy() 245 | idx = idx + batchSz 246 | 247 | # plotting 248 | sPath = os.path.join(args.save, 'figs', sPath[:-12] + '_test') 249 | if not os.path.exists(os.path.dirname(sPath)): 250 | os.makedirs(os.path.dirname(sPath)) 251 | 252 | testData = testData.detach().cpu().numpy() # make to numpy 253 | normSamples = normSamples.detach().cpu().numpy() 254 | 255 | if not args.long_version: # when running abbreviated style, use smaller sample sizes to compute mmd so its quicker 256 | nSamples = min(testData.shape[0], modelGen.shape[0], 3000) # number of samples for the MMD 257 | testSamps = testData[0:nSamples, :] 258 | modelSamps = modelGen[0:nSamples, 0:d] 259 | else: 260 | testSamps = testData 261 | modelSamps = modelGen[:,0:d] 262 | ffjordSamps = ffjordGen 263 | 264 | print("MMD( ourGen , rho_0 ), num(ourGen)={:d} , num(rho_0)={:d} : {:.5e}".format( modelSamps.shape[0] , testSamps.shape[0] , mmd(modelSamps , testSamps ))) 265 | if args.long_version: 266 | ffjordSamps = ffjordGen 267 | print("MMD( FFJORDGen, rho_0 ), num(FFJORDGen)={:d} , num(rho_0)={:d} : {:.5e}".format( ffjordSamps.shape[0] , testSamps.shape[0] , mmd(ffjordSamps , testSamps ))) 268 | 269 | logger.info("plotting...") 270 | nBins = 33 271 | LOW = -4 272 | HIGH = 4 273 | 274 | if args.data == 'gas': 275 | # the gas data set has different bounds 276 | LOWrho0 = -2 277 | HIGHrho0 = 2 278 | nBins = 33 279 | else: 280 | LOWrho0 = LOW 281 | HIGHrho0 = HIGH 282 | 283 | bounds = [[LOW, HIGH], [LOW, HIGH]] 284 | boundsRho0 = [[LOWrho0, HIGHrho0], [LOWrho0, HIGHrho0]] 285 | 286 | for d1 in range(0, d-1, 2): # plot 2-D slices of the multivariate distribution 287 | d2 = d1 + 1 288 | fig, axs = plt.subplots(2,3) # (2, 2) 289 | fig.set_size_inches(20,12) # (14,10) 290 | fig.suptitle(args.data + " dims: {:d} vs {:d}".format(d1, d2)) 291 | 292 | # hist, xbins, ybins, im = axs[0, 0].hist2d(x.numpy()[:,0],x.numpy()[:,1], range=[[LOW, HIGH], [LOW, HIGH]], bins = nBins) 293 | im1, _, _, map1 = axs[0, 0].hist2d(testData[:, d1], testData[:, d2], range=boundsRho0, bins=nBins) 294 | axs[0, 0].set_title(r'$x \sim \rho_0(x)$') 295 | 296 | im2, _, _, map2 = axs[0, 1].hist2d(modelFx[:, d1], modelFx[:, d2], range=bounds, bins=nBins) 297 | axs[0, 1].set_title(r'$f(x)$') 298 | 299 | im3, _, _, map3 = axs[1, 0].hist2d(normSamples[:, d1], normSamples[:, d2], range=bounds, bins=nBins) 300 | axs[1, 0].set_title(r'$y \sim \rho_1(y)$') 301 | 302 | im4, _, _, map4 = axs[1, 1].hist2d(modelGen[:, d1],modelGen[:, d2], range=boundsRho0, bins=nBins) 303 | axs[1, 1].set_title(r'$f^{-1}(y)$') 304 | 305 | if args.long_version: 306 | im5, _, _, map5 = axs[0, 2].hist2d(ffjordFx[:, d1], ffjordFx[:, d2], range=bounds, bins=nBins) 307 | axs[0, 2].set_title(r'FFJORD $f(x)$') 308 | 309 | im6, _, _, map6 = axs[1, 2].hist2d(ffjordGen[:, d1], ffjordGen[:, d2], range=boundsRho0, bins=nBins) 310 | axs[1, 2].set_title(r'FFJORD $f^{-1}(y)$') 311 | else: 312 | placeholder = 100.0*np.ones_like(testData[:, 0]) 313 | im5, _, _, map5 = axs[0, 2].hist2d(placeholder, placeholder, range=bounds, bins=nBins) 314 | axs[0, 2].set_title('placeholder') 315 | 316 | im6, _, _, map6 = axs[1, 2].hist2d(placeholder, placeholder, range=boundsRho0, bins=nBins) 317 | axs[1, 2].set_title('placeholder') 318 | 319 | 320 | # each has its own colorbar 321 | fig.colorbar(map1, cax=fig.add_axes([0.35 , 0.53, 0.01, 0.35]) ) 322 | fig.colorbar(map2, cax=fig.add_axes([0.625, 0.53, 0.01, 0.35]) ) 323 | fig.colorbar(map3, cax=fig.add_axes([0.35 , 0.11, 0.01, 0.35]) ) 324 | fig.colorbar(map4, cax=fig.add_axes([0.625, 0.11, 0.01, 0.35]) ) 325 | fig.colorbar(map5, cax=fig.add_axes([0.90 , 0.53, 0.01, 0.35]) ) 326 | fig.colorbar(map6, cax=fig.add_axes([0.90 , 0.11, 0.01, 0.35]) ) 327 | 328 | 329 | for i in range(axs.shape[0]): 330 | for j in range(axs.shape[1]): 331 | axs[i, j].get_yaxis().set_visible(False) 332 | axs[i, j].get_xaxis().set_visible(False) 333 | axs[i, j].set_aspect('equal') 334 | 335 | # plt.show() 336 | plt.savefig(sPath + "_{:d}v{:d}.pdf".format(d1, d2), dpi=400) 337 | plt.close() 338 | 339 | if args.long_version: 340 | hf.close() # close the h5 file 341 | logger.info('Testing has finished. ' + sPath) 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | -------------------------------------------------------------------------------- /evaluateToyOTflow.py: -------------------------------------------------------------------------------- 1 | # evaluateToyOTflow.py 2 | # plotting toy CNF results 3 | try: 4 | import matplotlib 5 | matplotlib.use('TkAgg') 6 | import matplotlib.pyplot as plt 7 | except: 8 | import matplotlib 9 | matplotlib.use('agg') # for linux server with no tkinter 10 | import matplotlib.pyplot as plt 11 | plt.rcParams['image.cmap'] = 'inferno' 12 | plt.rcParams.update({'font.size': 22}) 13 | 14 | import argparse 15 | import os 16 | import time 17 | import datetime 18 | import numpy as np 19 | import math 20 | import lib.toy_data as toy_data 21 | import lib.utils as utils 22 | from src.OTFlowProblem import * 23 | from src.mmd import * 24 | 25 | 26 | def_resume = 'experiments/cnf/toy/pretrained/pretrained_swissroll_alph30_15_m32_checkpt.pth' 27 | 28 | parser = argparse.ArgumentParser('Continuous Normalizing Flow') 29 | parser.add_argument( 30 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], 31 | type=str, default='swissroll' 32 | ) 33 | parser.add_argument("--nt" , type=int, default=12, help="number of time steps") 34 | parser.add_argument('--batch_size', type=int, default=20000) 35 | parser.add_argument('--resume' , type=str, default=def_resume) 36 | parser.add_argument('--save' , type=str, default='image/') 37 | parser.add_argument('--gpu' , type=int, default=0) 38 | args = parser.parse_args() 39 | 40 | # logger 41 | _ , sPath = os.path.split(args.resume) 42 | utils.makedirs(args.save) 43 | 44 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 45 | 46 | # loss function 47 | def compute_loss(net, x, nt): 48 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 49 | return Jc, cs 50 | 51 | if __name__ == '__main__': 52 | 53 | # reload model 54 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 55 | m = checkpt['args'].m 56 | alph = checkpt['args'].alph 57 | nTh = checkpt['args'].nTh 58 | d = checkpt['state_dict']['A'].size(1) - 1 59 | net = Phi(nTh=nTh, m=m, d=d, alph=alph) 60 | prec = checkpt['state_dict']['A'].dtype 61 | net = net.to(prec) 62 | net.load_state_dict(checkpt['state_dict']) 63 | net = net.to(device) 64 | 65 | args.data = checkpt['args'].data 66 | 67 | torch.set_default_dtype(prec) 68 | cvt = lambda x: x.type(prec).to(device, non_blocking=True) 69 | 70 | nSamples = args.batch_size 71 | p_samples = cvt(torch.Tensor(toy_data.inf_train_gen(args.data, batch_size=nSamples))) 72 | y = cvt(torch.randn(nSamples,d)) 73 | 74 | net.eval() 75 | with torch.no_grad(): 76 | 77 | test_loss, test_cs = compute_loss(net, p_samples, args.nt) 78 | 79 | # sample_fn, density_fn = get_transforms(model) 80 | modelFx = integrate(p_samples[:, 0:d], net, [0.0, 1.0], args.nt, stepper="rk4", alph=net.alph) 81 | modelFinvfx = integrate(modelFx[:, 0:d] , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph) 82 | modelGen = integrate(y[:, 0:d] , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph) 83 | 84 | print(" {:9s} {:9s} {:11s} {:9s}".format( "loss", "L (L_2)", "C (loss)", "R (HJB)")) 85 | print("[TEST]: {:9.3e} {:9.3e} {:11.5e} {:9.3e}".format(test_loss, test_cs[0], test_cs[1], test_cs[2])) 86 | 87 | print("Using ", utils.count_parameters(net), " parameters") 88 | invErr = np.mean(np.linalg.norm(p_samples.detach().cpu().numpy() - modelFinvfx[:,:d].detach().cpu().numpy(), ord=2, axis=1)) 89 | # invErr = (torch.norm(p_samples-modelFinvfx[:,:d]) / p_samples.size(0)).item() 90 | print("inv error: ", invErr ) 91 | 92 | modelGen = modelGen[:, 0:d].detach().cpu().numpy() 93 | p_samples = p_samples.detach().cpu().numpy() 94 | 95 | nBins = 80 96 | LOW = -4 97 | HIGH = 4 98 | extent = [[LOW, HIGH], [LOW, HIGH]] 99 | 100 | d1 = 0 101 | d2 = 1 102 | 103 | # density function of the standard normal 104 | def normpdf(x): 105 | mu = torch.zeros(1, d, device=x.device, dtype=x.dtype) 106 | cov = torch.ones(1, d, device=x.device, dtype = x.dtype) # diagonal of the covariance matrix 107 | 108 | denom = (2 * math.pi) ** (0.5 * d) * torch.sqrt(torch.prod(cov)) 109 | num = torch.exp(-0.5 * torch.sum((x - mu) ** 2 / cov, 1, keepdims=True)) 110 | return num / denom 111 | 112 | print("plotting...") 113 | # ---------------------------------------------------------------------------------------------------------- 114 | # Plot Density 115 | # ---------------------------------------------------------------------------------------------------------- 116 | title = "$density$" 117 | 118 | fig = plt.figure(figsize=(7, 7)) 119 | ax = plt.subplot(1, 1, 1, aspect="equal") 120 | 121 | npts = 100 122 | 123 | side = np.linspace(LOW, HIGH, npts) 124 | xx, yy = np.meshgrid(side, side) 125 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 126 | with torch.no_grad(): 127 | x = cvt(torch.from_numpy(x)) 128 | nt_val = args.nt 129 | z = integrate(x, net, [0.0, 1.0], nt_val, stepper="rk4", alph=net.alph) 130 | logqx = z[:, d] 131 | z = z[:, 0:d] 132 | 133 | qz = np.exp(logqx.cpu().numpy()).reshape(npts, npts) 134 | normpdfz = normpdf(z) 135 | rho0 = normpdfz.cpu().numpy().reshape(npts, npts) * qz 136 | 137 | im = plt.pcolormesh(xx, yy, rho0) 138 | vmin = np.min(rho0) 139 | vmax = np.max(rho0) 140 | im.set_clim(vmin, vmax) 141 | ax.axis('off') 142 | 143 | sSaveLoc = os.path.join(args.save, sPath[:-12] + '_density.png') 144 | plt.savefig(sSaveLoc,bbox_inches='tight') 145 | plt.close(fig) 146 | 147 | # ---------------------------------------------------------------------------------------------------------- 148 | # Plot Original Samples 149 | # ---------------------------------------------------------------------------------------------------------- 150 | 151 | x0 = toy_data.inf_train_gen(args.data, batch_size=nSamples) # load data batch 152 | fig = plt.figure(figsize=(7, 7)) 153 | ax = plt.subplot(1, 1, 1, aspect="equal") 154 | h2, _, _, map2 = ax.hist2d(x0[:, d1], x0[:, d2], range=extent, bins=nBins) 155 | # vmax: 15 for swissroll, 8gaussians, moons, 20 for pinwheel, 10 for circles, 8 for checkerboards 156 | h2 = h2 / (nSamples) 157 | im2 = ax.imshow(h2); 158 | ax.axis('off') 159 | im2.set_clim(vmin, vmax) 160 | sSaveLoc = os.path.join(args.save, sPath[:-12] + '_rho0Samples.png') 161 | plt.savefig(sSaveLoc,bbox_inches='tight') 162 | plt.close(fig) 163 | 164 | # ---------------------------------------------------------------------------------------------------------- 165 | # Plot Generated Samples 166 | # ---------------------------------------------------------------------------------------------------------- 167 | fig = plt.figure(figsize=(7, 7)) 168 | ax = plt.subplot(1, 1, 1, aspect="equal") 169 | y = cvt(torch.randn(nSamples, d)) 170 | genModel = integrate(y[:, 0:d], net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph) 171 | h3, _, _, map3 = ax.hist2d(genModel.detach().cpu().numpy()[:, d1], genModel.detach().cpu().numpy()[:, d2], 172 | range=extent, bins=nBins) 173 | h3 = h3/(nSamples) 174 | im3 = ax.imshow(h3) 175 | im3.set_clim(vmin, vmax) 176 | ax.axis('off') 177 | sSaveLoc = os.path.join(args.save, sPath[:-12] + '_genSamples.png') 178 | plt.savefig(sSaveLoc,bbox_inches='tight') 179 | plt.close(fig) 180 | print("finished plotting to folder", args.save) 181 | 182 | print("testing complete") 183 | 184 | 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /experiments/cnf/large/pretrained/pretrained_interp_mnist_checkpt.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/OT-Flow/4d66618e2a2f4d8e8ce080cf1b3c769c78b2590d/experiments/cnf/large/pretrained/pretrained_interp_mnist_checkpt.pth -------------------------------------------------------------------------------- /experiments/cnf/large/pretrained/pretrained_miniboone_checkpt.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/OT-Flow/4d66618e2a2f4d8e8ce080cf1b3c769c78b2590d/experiments/cnf/large/pretrained/pretrained_miniboone_checkpt.pth -------------------------------------------------------------------------------- /experiments/cnf/toy/pretrained/pretrained_swissroll_alph30_15_m32_checkpt.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/OT-Flow/4d66618e2a2f4d8e8ce080cf1b3c769c78b2590d/experiments/cnf/toy/pretrained/pretrained_swissroll_alph30_15_m32_checkpt.pth -------------------------------------------------------------------------------- /interpMnist.py: -------------------------------------------------------------------------------- 1 | # interpMnist.py 2 | # 3 | # grab two images, encode them and flow them to rho_1, interpolate between and flow back and decode 4 | # plot many of these interpolations in the latent space 5 | import matplotlib 6 | matplotlib.use('TkAgg') 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as patches 9 | 10 | import argparse 11 | import os 12 | from src.OTFlowProblem import * 13 | import config 14 | import datasets 15 | from datasets.mnist import getLoader 16 | from src.Autoencoder import * 17 | 18 | cf = config.getconfig() 19 | def_resume = 'experiments/cnf/large/pretrained/pretrained_interp_mnist_checkpt.pth' 20 | 21 | parser = argparse.ArgumentParser('Continuous Normalizing Flow') 22 | parser.add_argument( 23 | '--data', type=str, default='mnist' 24 | ) 25 | 26 | parser.add_argument("--nt" , type=int, default=16, help="number of time steps") 27 | parser.add_argument('--batch_size', type=int, default=200) 28 | parser.add_argument('--resume' , type=str, default=def_resume) 29 | parser.add_argument('--save' , type=str, default='image/') 30 | parser.add_argument('--gpu' , type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 34 | 35 | 36 | if __name__ == '__main__': 37 | 38 | if args.resume is None: 39 | print("have to provide path to saved model via --resume commandline argument") 40 | exit(1) 41 | 42 | _ , _ , test_loader = getLoader(args.data, args.batch_size, args.batch_size, augment=False, hasGPU=cf.gpu) 43 | 44 | nt = args.nt 45 | # --------------------------LOADING------------------------------------ 46 | # reload model 47 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 48 | m = checkpt['args'].m 49 | alph = checkpt['args'].alph 50 | d = checkpt['state_dict']['A'].size(1) - 1 51 | eps = checkpt['args'].eps 52 | net = Phi(nTh=2, m=m, d=d, alph=alph) # the phi aka the value function 53 | net.load_state_dict(checkpt["state_dict"]) 54 | 55 | # get expected type 56 | prec = net.A.dtype 57 | cvt = lambda x: x.type(prec).to(device, non_blocking=True) 58 | net = net.to(prec).to(device) 59 | 60 | # load the trained autoencoder 61 | autoEnc = Autoencoder(d) 62 | autoEnc.mu = checkpt['autoencoder']["mu"] 63 | autoEnc.std = checkpt['autoencoder']["std"] 64 | autoEnc.load_state_dict(checkpt['autoencoder'], strict=False) # doesnt load the buffers 65 | autoEnc = autoEnc.to(prec).to(device) 66 | # --------------------------------------------------------------------- 67 | 68 | nInterp = 5 69 | 70 | net.eval() 71 | with torch.no_grad(): 72 | 73 | torch.manual_seed(0) # for reproducibility 74 | 75 | images, labels = next(iter(test_loader)) 76 | 77 | # vectorize each image 78 | images = cvt(images.view(images.size(0), -1)) 79 | 80 | # grab a few of the class 9 81 | idx9 = labels == 9 82 | x9 = images[idx9,:] 83 | cosIdx = [3,5,6,7,8] 84 | x0 = x9[cosIdx] 85 | nSamples = 4 86 | x0orig = x0[0:nSamples,:] 87 | 88 | # grab one image of an mnist 1 and use it 89 | idx1 = labels == 1 90 | x1 = images[idx1,:] 91 | x0orig[nSamples-1,:] = x1[4,:] 92 | x0 = autoEnc.encode(x0orig) # encode 93 | x0 = (x0 - autoEnc.mu) / (autoEnc.std + eps) # normalize 94 | z1 = integrate(x0[:, 0:d], net, [0.0, 1.0], nt, stepper="rk4", alph=net.alph)[0:d] # flow to rho_1 95 | z1 = z1[:,0:d] 96 | 97 | recastZ = cvt(torch.zeros((nInterp+1)**2, z1.shape[1])) 98 | 99 | # will make a nInterp+1-by-nInterp+1 image with the four corners as the original images 100 | # upper left, upper right, lower left, lower right 101 | ul = z1[0, :] 102 | ur = z1[1, :] 103 | ll = z1[2, :] 104 | lr = z1[3, :] 105 | 106 | # assume nInterp = 5 107 | # hard coded 108 | # first row 109 | recastZ[0, :] = ul 110 | recastZ[1, :] = ul + 0.2 * (ur - ul) 111 | recastZ[2, :] = ul + 0.4 * (ur - ul) 112 | recastZ[3, :] = ul + 0.6 * (ur - ul) 113 | recastZ[4, :] = ul + 0.8 * (ur - ul) 114 | recastZ[nInterp, :] = ur 115 | # last row 116 | recastZ[nInterp*(nInterp+1) , :] = ll 117 | recastZ[nInterp*(nInterp+1)+1 , :] = ll + 0.2 * (lr - ll) 118 | recastZ[nInterp*(nInterp+1)+2 , :] = ll + 0.4 * (lr - ll) 119 | recastZ[nInterp*(nInterp+1)+3 , :] = ll + 0.6 * (lr - ll) 120 | recastZ[nInterp*(nInterp+1)+4 , :] = ll + 0.8 * (lr - ll) 121 | recastZ[(nInterp+1)**2 - 1 , :] = lr 122 | 123 | # for each column, interpolate between the top image and the bottom 124 | for col in range(nInterp+1): 125 | top = recastZ[ col , :] 126 | bot = recastZ[nInterp*(nInterp+1) + col , :] 127 | for row in range(1,nInterp): 128 | recastZ[row*(nInterp+1)+col , :] = top + 1.0/nInterp * row * (bot-top) 129 | 130 | gen = integrate(recastZ[:, 0:d], net, [1.0, 0.0], nt, stepper="rk4", alph=net.alph)[:,0:d] 131 | gen = autoEnc.decode(gen * (autoEnc.std + eps) + autoEnc.mu) 132 | 133 | # place originals in the corner spots 134 | gen[0, :] = x0orig[0, :] 135 | gen[nInterp, :] = x0orig[1, :] 136 | gen[nInterp*(1+nInterp), :] = x0orig[2, :] 137 | gen[(nInterp+1)**2 - 1, :] = x0orig[3, :] 138 | 139 | # plot them 140 | nex = 48 141 | fig, axs = plt.subplots(nInterp+1, nInterp+1) 142 | fig.set_size_inches(6, 6.1) 143 | fig.suptitle("red boxed values are original; others are interpolated in rho_1 space") 144 | gen = gen.detach().cpu().numpy() 145 | 146 | k = 0 147 | for i in range(nInterp+1): 148 | for j in range(nInterp+1): 149 | axs[i, j].imshow(gen[k,:].reshape(28,28), cmap='gray') 150 | # box the originals 151 | if (i==0 and j==0) or (i==nInterp and j==0) or (i==0 and j==nInterp) or (i==nInterp and j==nInterp): 152 | # Create a Rectangle patch 153 | rect = patches.Rectangle((0, 0), 27, 27, linewidth=2, edgecolor='r', facecolor='none') 154 | # Add the patch to the Axes 155 | axs[i, j].add_patch(rect) 156 | k+=1 157 | 158 | for i in range(axs.shape[0]): 159 | for j in range(axs.shape[1]): 160 | axs[i, j].get_yaxis().set_visible(False) 161 | axs[i, j].get_xaxis().set_visible(False) 162 | axs[i ,j].set_aspect('equal') 163 | 164 | plt.subplots_adjust(wspace=0.0, hspace=0.0) 165 | 166 | # save figure 167 | sPath = args.save + 'interpMNISTGen.pdf' 168 | if not os.path.exists(os.path.dirname(sPath)): 169 | os.makedirs(os.path.dirname(sPath)) 170 | plt.savefig(sPath, dpi=300) 171 | plt.close() 172 | print('figure saved to ', sPath) 173 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from lib.transform import AddUniformNoise, ToTensor, HorizontalFlip, Transpose, Resize 4 | 5 | 6 | dataFolder = './data/' 7 | 8 | def add_noise(x): 9 | """ 10 | [0, 1] -> [0, 255] -> add noise -> [0, 1] 11 | """ 12 | noise = x.new().resize_as_(x).uniform_() 13 | x = x * 255 + noise 14 | x = x / 256 15 | return x 16 | 17 | def dataloader(dataset, batch_size, cuda, conditional=-1, im_size=64): 18 | 19 | if dataset == 'mnist': 20 | data = datasets.MNIST(dataFolder+'MNIST', train=True, download=True, 21 | transform=transforms.Compose([ 22 | AddUniformNoise(), 23 | ToTensor() 24 | ])) 25 | 26 | test_data = datasets.MNIST(dataFolder+'MNIST', train=False, download=True, 27 | transform=transforms.Compose([ 28 | AddUniformNoise(), 29 | ToTensor() 30 | ])) 31 | 32 | if conditional >= 0 and conditional <= 9: 33 | idx = data.targets == conditional 34 | data.data = data.data[idx, :] 35 | data.targets = data.targets[idx] 36 | nTot = torch.sum(idx).item() 37 | nTrain = int((5.0 / 6.0) * nTot) 38 | nVal = nTot - nTrain 39 | train_data, valid_data = torch.utils.data.random_split(data, [nTrain, nVal]) 40 | 41 | idx = test_data.targets == conditional 42 | test_data.data = test_data.data[idx,:] 43 | test_data.targets = test_data.targets[idx] 44 | else: 45 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 46 | 47 | else: 48 | print ('what network ?', dataset) 49 | sys.exit(1) 50 | 51 | #load data 52 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 53 | 54 | train_loader = torch.utils.data.DataLoader( 55 | train_data, 56 | batch_size=batch_size, shuffle=True, **kwargs) 57 | 58 | valid_loader = torch.utils.data.DataLoader( 59 | valid_data, 60 | batch_size=batch_size, shuffle=True, **kwargs) 61 | 62 | test_loader = torch.utils.data.DataLoader(test_data, 63 | batch_size=batch_size, shuffle=True, **kwargs) 64 | 65 | return train_loader, valid_loader, test_loader 66 | 67 | 68 | 69 | 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | argPrec = torch.float32 75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 76 | cvt = lambda x: x.type(argPrec).to(device, non_blocking=True) 77 | 78 | train_loader, val_loader, test_loader = dataloader('mnist', 10000, cuda=-1) 79 | 80 | d = 784 81 | # compute mean and std for MNIST dataloader 82 | mu = torch.zeros((1, d), dtype=argPrec, device=device) 83 | musqrd = torch.zeros((1, d), dtype=argPrec, device=device) 84 | totImages = 0 85 | 86 | i = 0 87 | 88 | for data in train_loader: 89 | # _ stands in for labels, here 90 | images, _ = data 91 | images = images.view(images.size(0), -1) 92 | images = cvt(images) 93 | nImages = images.shape[0] 94 | totImages += nImages 95 | mu += torch.mean(images, dim=0, keepdims=True) # *nImages 96 | musqrd += torch.mean(images ** 2, dim=0, keepdims=True) # *nImages 97 | i += 1 98 | 99 | mu = mu / i 100 | musqrd = musqrd / i 101 | std = torch.sqrt(torch.abs(mu ** 2 - musqrd)) 102 | 103 | 104 | print('mu: ', mu) 105 | print('std: ', std) 106 | 107 | -------------------------------------------------------------------------------- /lib/toy_data.py: -------------------------------------------------------------------------------- 1 | # From FFJORD 2 | import numpy as np 3 | import sklearn 4 | import sklearn.datasets 5 | from sklearn.utils import shuffle as util_shuffle 6 | 7 | 8 | # Dataset iterator 9 | def inf_train_gen(data, rng=None, batch_size=200): 10 | if rng is None: 11 | rng = np.random.RandomState() 12 | 13 | if data == "swissroll": 14 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 15 | data = data.astype("float32")[:, [0, 2]] 16 | data /= 5 17 | return data 18 | 19 | elif data == "circles": 20 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 21 | data = data.astype("float32") 22 | data *= 3 23 | return data 24 | 25 | elif data == "rings": 26 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 27 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 28 | 29 | # so as not to have the first point = last point, we set endpoint=False 30 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 31 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 32 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 33 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 34 | 35 | circ4_x = np.cos(linspace4) 36 | circ4_y = np.sin(linspace4) 37 | circ3_x = np.cos(linspace4) * 0.75 38 | circ3_y = np.sin(linspace3) * 0.75 39 | circ2_x = np.cos(linspace2) * 0.5 40 | circ2_y = np.sin(linspace2) * 0.5 41 | circ1_x = np.cos(linspace1) * 0.25 42 | circ1_y = np.sin(linspace1) * 0.25 43 | 44 | X = np.vstack([ 45 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 46 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 47 | ]).T * 3.0 48 | X = util_shuffle(X, random_state=rng) 49 | 50 | # Add noise 51 | X = X + rng.normal(scale=0.08, size=X.shape) 52 | 53 | return X.astype("float32") 54 | 55 | elif data == "moons": 56 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 57 | data = data.astype("float32") 58 | data = data * 2 + np.array([-1, -0.2]) 59 | return data 60 | 61 | elif data == "8gaussians": 62 | scale = 4. 63 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 64 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 65 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 66 | centers = [(scale * x, scale * y) for x, y in centers] 67 | 68 | dataset = [] 69 | for i in range(batch_size): 70 | point = rng.randn(2) * 0.5 71 | idx = rng.randint(8) 72 | center = centers[idx] 73 | point[0] += center[0] 74 | point[1] += center[1] 75 | dataset.append(point) 76 | dataset = np.array(dataset, dtype="float32") 77 | dataset /= 1.414 78 | return dataset 79 | 80 | elif data == "pinwheel": 81 | radial_std = 0.3 82 | tangential_std = 0.1 83 | num_classes = 5 84 | num_per_class = batch_size // 5 85 | rate = 0.25 86 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 87 | 88 | features = rng.randn(num_classes*num_per_class, 2) \ 89 | * np.array([radial_std, tangential_std]) 90 | features[:, 0] += 1. 91 | labels = np.repeat(np.arange(num_classes), num_per_class) 92 | 93 | angles = rads[labels] + rate * np.exp(features[:, 0]) 94 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 95 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 96 | 97 | return 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations)) 98 | 99 | elif data == "2spirals": 100 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 101 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 102 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 103 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 104 | x += np.random.randn(*x.shape) * 0.1 105 | return x 106 | 107 | elif data == "checkerboard": 108 | x1 = np.random.rand(batch_size) * 4 - 2 109 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 110 | x2 = x2_ + (np.floor(x1) % 2) 111 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 112 | 113 | elif data == "line": 114 | x = rng.rand(batch_size) * 5 - 2.5 115 | y = x 116 | return np.stack((x, y), 1) 117 | elif data == "cos": 118 | x = rng.rand(batch_size) * 5 - 2.5 119 | y = np.sin(x) * 2.5 120 | return np.stack((x, y), 1) 121 | else: 122 | return inf_train_gen("8gaussians", rng, batch_size) 123 | -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | 5 | def logit(x, alpha=1E-6): 6 | y = alpha + (1.-2*alpha)*x 7 | return np.log(y) - np.log(1. - y) 8 | 9 | def logit_back(x, alpha=1E-6): 10 | y = torch.sigmoid(x) 11 | return (y - alpha)/(1.-2*alpha) 12 | 13 | class AddUniformNoise(object): 14 | def __init__(self, alpha=1E-6): 15 | self.alpha = alpha 16 | def __call__(self,samples): 17 | samples = np.array(samples,dtype = np.float32) 18 | samples += np.random.uniform(size = samples.shape) 19 | samples = logit(samples/256., self.alpha) 20 | return samples 21 | 22 | class ToTensor(object): 23 | def __init__(self): 24 | pass 25 | def __call__(self,samples): 26 | samples = torch.from_numpy(samples).float() 27 | return samples 28 | 29 | class ZeroPadding(object): 30 | def __init__(self,num): 31 | self.num = num 32 | def __call__(self,samples): 33 | samples = np.array(samples,dtype = np.float32) 34 | tmp = np.zeros((32,32)) 35 | tmp[self.num:samples.shape[0]+self.num,self.num:samples.shape[1]+self.num] = samples 36 | return tmp 37 | 38 | class Crop(object): 39 | def __init__(self,num): 40 | self.num = num 41 | def __call__(self,samples): 42 | samples = np.array(samples,dtype = np.float32) 43 | return samples[self.num:-self.num,self.num:-self.num] 44 | 45 | class HorizontalFlip(object): 46 | def __init__(self): 47 | pass 48 | def __call__(self,samples): 49 | return torchvision.transforms.functional.hflip(samples) 50 | 51 | class Transpose(object): 52 | def __init__(self): 53 | pass 54 | def __call__(self,samples): 55 | return np.transpose(samples, (2, 0, 1)) 56 | 57 | class Resize(object): 58 | def __init__(self): 59 | pass 60 | def __call__(self, samples): 61 | return torchvision.transforms.functional.resize(samples, [32, 32]) 62 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | # 3 | # some of the original utilities used by FFJORD 4 | import os 5 | import math 6 | from numbers import Number 7 | import logging 8 | import torch 9 | 10 | 11 | def makedirs(dirname): 12 | if not os.path.exists(dirname): 13 | os.makedirs(dirname) 14 | 15 | 16 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 17 | logger = logging.getLogger() 18 | if debug: 19 | level = logging.DEBUG 20 | else: 21 | level = logging.INFO 22 | logger.setLevel(level) 23 | if saving: 24 | info_file_handler = logging.FileHandler(logpath, mode="a") 25 | info_file_handler.setLevel(level) 26 | logger.addHandler(info_file_handler) 27 | if displaying: 28 | console_handler = logging.StreamHandler() 29 | console_handler.setLevel(level) 30 | logger.addHandler(console_handler) 31 | logger.info(filepath) 32 | with open(filepath, "r") as f: 33 | logger.info(f.read()) 34 | 35 | for f in package_files: 36 | logger.info(f) 37 | with open(f, "r") as package_f: 38 | logger.info(package_f.read()) 39 | 40 | return logger 41 | 42 | 43 | class AverageMeter(object): 44 | """Computes and stores the average and current value""" 45 | 46 | def __init__(self): 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1): 56 | self.val = val 57 | self.sum += val * n 58 | self.count += n 59 | self.avg = self.sum / self.count 60 | 61 | 62 | class RunningAverageMeter(object): 63 | """Computes and stores the average and current value""" 64 | 65 | def __init__(self, momentum=0.99): 66 | self.momentum = momentum 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = None 71 | self.avg = 0 72 | self.sum = 0 # 73 | 74 | def update(self, val): 75 | if self.val is None: 76 | self.avg = val 77 | else: 78 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 79 | self.sum += val 80 | self.val = val 81 | 82 | def count_parameters(model): 83 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 84 | 85 | # 86 | # def inf_generator(iterable): 87 | # """Allows training with DataLoaders in a single infinite loop: 88 | # for i, (x, y) in enumerate(inf_generator(train_loader)): 89 | # """ 90 | # iterator = iterable.__iter__() 91 | # while True: 92 | # try: 93 | # yield iterator.__next__() 94 | # except StopIteration: 95 | # iterator = iterable.__iter__() 96 | # 97 | # 98 | # def save_checkpoint(state, save, epoch): 99 | # if not os.path.exists(save): 100 | # os.makedirs(save) 101 | # filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 102 | # torch.save(state, filename) 103 | # 104 | # 105 | # def isnan(tensor): 106 | # return (tensor != tensor) 107 | # 108 | # 109 | # def logsumexp(value, dim=None, keepdim=False): 110 | # """Numerically stable implementation of the operation 111 | # value.exp().sum(dim, keepdim).log() 112 | # """ 113 | # if dim is not None: 114 | # m, _ = torch.max(value, dim=dim, keepdim=True) 115 | # value0 = value - m 116 | # if keepdim is False: 117 | # m = m.squeeze(dim) 118 | # return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 119 | # else: 120 | # m = torch.max(value) 121 | # sum_exp = torch.sum(torch.exp(value - m)) 122 | # if isinstance(sum_exp, Number): 123 | # return m + math.log(sum_exp) 124 | # else: 125 | # return m + torch.log(sum_exp) 126 | # 127 | # 128 | # 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.10.0 2 | matplotlib==3.0.3 3 | scikit-learn==0.22.2.post1 4 | torch==1.4.0 5 | torchvision==0.5.0 6 | pandas==0.24.2 7 | -------------------------------------------------------------------------------- /src/Autoencoder.py: -------------------------------------------------------------------------------- 1 | # Autoencoder.py 2 | # encoder-decoder used for MNIST experiments 3 | # 4 | # from: https://medium.com/analytics-vidhya/dimension-manipulation-using-autoencoder-in-pytorch-on-mnist-dataset-7454578b018 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import lib.utils as utils 10 | import os 11 | from src.plotter import * 12 | 13 | 14 | # define the encoder-decoder architecture 15 | class Autoencoder(nn.Module): 16 | def __init__(self, encoding_dim): 17 | super(Autoencoder, self).__init__() 18 | ## encoder ## 19 | 20 | self.d = encoding_dim 21 | 22 | # linear layer (784 -> encoding_dim) 23 | self.layer1 = nn.Linear(28 * 28, encoding_dim) 24 | 25 | ## decoder ## 26 | # linear layer (encoding_dim -> input size) 27 | self.layer2 = nn.Linear(encoding_dim, 28 * 28) 28 | 29 | # register these as buffers 30 | self.register_buffer('mu', None) 31 | self.register_buffer('std', None) 32 | 33 | def forward(self, x): 34 | 35 | x = self.encode(x) 36 | x = self.decode(x) 37 | 38 | return x 39 | 40 | def encode(self,x): 41 | # add layer, with relu activation function 42 | return F.relu(self.layer1(x)) 43 | 44 | def decode(self,x): 45 | # output layer (sigmoid for scaling from 0 to 1) 46 | return torch.sigmoid(self.layer2(x)) 47 | 48 | 49 | 50 | def trainAE(net, train_loader, val_loader, saveDir, sStartTime, argType=torch.float32, device=torch.device('cpu')): 51 | """ 52 | 53 | :param net: AutoEncoder 54 | :param train_loader: MNIST loader of training data 55 | :param val_loader: MNIST loader of validation data 56 | :param saveDir: string, path 57 | :param sStartTime: string, start time 58 | :param argType: torch type 59 | :param device: torch device 60 | :return: 61 | """ 62 | print("training auto_encoder") 63 | 64 | cvt = lambda x: x.type(argType).to(device, non_blocking=True) 65 | utils.makedirs(saveDir) 66 | 67 | # specify loss function 68 | criterion = nn.MSELoss() 69 | 70 | # specify loss function 71 | optimizer = torch.optim.Adam(net.parameters(), lr=0.001) 72 | 73 | best_loss = float('inf') 74 | bestParams = None 75 | 76 | # number of epochs to train the model 77 | n_epochs = 600 78 | 79 | for epoch in range(1, n_epochs + 1): 80 | 81 | # train the encoder-decoder 82 | net.train() 83 | train_loss = 0.0 84 | for data in train_loader: 85 | # _ stands in for labels, here 86 | images, _ = data 87 | # flatten images 88 | images = images.view(images.size(0), -1) 89 | images = cvt(images) 90 | 91 | optimizer.zero_grad() 92 | outputs = net(images) 93 | loss = criterion(outputs, images) 94 | loss.backward() 95 | optimizer.step() 96 | train_loss += loss.item() * images.size(0) 97 | 98 | # validate the encoder-decoder 99 | net.eval() 100 | val_loss = 0.0 101 | for data in val_loader: 102 | images, _ = data 103 | images = images.view(images.size(0), -1) 104 | images = cvt(images) 105 | 106 | outputs = net(images) 107 | loss = criterion(outputs, images) 108 | loss.backward() 109 | optimizer.step() 110 | val_loss += loss.item() * images.size(0) 111 | 112 | # print avg training statistics...different batch_sizes will scale these differnetly 113 | train_loss = train_loss / len(train_loader) 114 | val_loss = val_loss / len(val_loader) 115 | print('Epoch: {} \tTraining Loss: {:.6f} \t Validation Loss: {:.6f}'.format( 116 | epoch, 117 | train_loss, 118 | val_loss 119 | )) 120 | 121 | # save best set of parameters 122 | if val_loss < best_loss: 123 | best_loss = val_loss 124 | bestParams = net.state_dict() 125 | 126 | # plot 127 | if epoch % 20 == 0: 128 | net.eval() 129 | sSavePath = os.path.join(saveDir, 'figs', sStartTime + '_autoencoder{:d}.png'.format(epoch)) 130 | xRecreate = net(images) 131 | plotAutoEnc(images, xRecreate, sSavePath) 132 | 133 | # shrink step size 134 | if epoch % 150 == 0: 135 | for p in optimizer.param_groups: 136 | p['lr'] /= 10.0 137 | print("lr: ", p['lr']) 138 | 139 | d = net.d 140 | 141 | # compute mean and std for normalization 142 | mu = torch.zeros((1, d), dtype=argType, device=device) 143 | musqrd = torch.zeros((1, d), dtype=argType, device=device) 144 | totImages = 0 145 | 146 | net.load_state_dict(bestParams) 147 | 148 | i = 0 149 | net.eval() 150 | with torch.no_grad(): 151 | for data in train_loader: 152 | # _ stands in for labels, here 153 | images, _ = data 154 | images = images.view(images.size(0), -1) 155 | images = cvt(images) 156 | outputs = net.encode(images) 157 | nImages = outputs.shape[0] 158 | totImages += nImages 159 | mu += torch.mean(outputs, dim=0, keepdims=True) # *nImages 160 | musqrd += torch.mean(outputs ** 2, dim=0, keepdims=True) # *nImages 161 | 162 | # check quality 163 | if i == 0: 164 | sSavePath = os.path.join(saveDir, 'figs', sStartTime + '_autoencoder.png') 165 | outputs = (net.encode(images) - 2.34) / 0.005 166 | xRecreate = net.decode(outputs * 0.005 + 2.34) 167 | plotAutoEnc(images, xRecreate, sSavePath) 168 | 169 | sSavePath = os.path.join(saveDir, 'figs', sStartTime + '_noise_autoencoder.png') 170 | xRecreate = net.decode(outputs + 1.0 * torch.randn_like(outputs)) 171 | plotAutoEnc(images, xRecreate, sSavePath) 172 | 173 | i += 1 174 | 175 | mu = mu / i 176 | musqrd = musqrd / i 177 | std = torch.sqrt(torch.abs(mu ** 2 - musqrd)) 178 | 179 | mu.requires_grad = False 180 | std.requires_grad = False 181 | net.mu = mu 182 | net.std = std 183 | 184 | torch.save({ 185 | 'state_dict': net.state_dict(), 186 | }, os.path.join(saveDir, sStartTime + '_autoenc_checkpt.pth')) 187 | 188 | return net 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /src/OTFlowProblem.py: -------------------------------------------------------------------------------- 1 | # OTFlowProblem.py 2 | import math 3 | import torch 4 | from torch.nn.functional import pad 5 | from src.Phi import * 6 | 7 | 8 | def vec(x): 9 | """vectorize torch tensor x""" 10 | return x.view(-1,1) 11 | 12 | def OTFlowProblem(x, Phi, tspan , nt, stepper="rk4", alph =[1.0,1.0,1.0] ): 13 | """ 14 | 15 | Evaluate objective function of OT Flow problem; see Eq. (8) in the paper. 16 | 17 | :param x: input data tensor nex-by-d 18 | :param Phi: neural network 19 | :param tspan: time range to integrate over, ex. [0.0 , 1.0] 20 | :param nt: number of time steps 21 | :param stepper: string "rk1" or "rk4" Runge-Kutta schemes 22 | :param alph: list of length 3, the alpha value multipliers 23 | :return: 24 | Jc - float, objective function value dot(alph,cs) 25 | cs - list length 5, the five computed costs 26 | """ 27 | h = (tspan[1]-tspan[0]) / nt 28 | 29 | # initialize "hidden" vector to propogate with all the additional dimensions for all the ODEs 30 | z = pad(x, (0, 3, 0, 0), value=0) 31 | 32 | tk = tspan[0] 33 | 34 | if stepper=='rk4': 35 | for k in range(nt): 36 | z = stepRK4(odefun, z, Phi, alph, tk, tk + h) 37 | tk += h 38 | elif stepper=='rk1': 39 | for k in range(nt): 40 | z = stepRK1(odefun, z, Phi, alph, tk, tk + h) 41 | tk += h 42 | 43 | # ASSUME all examples are equally weighted 44 | costL = torch.mean(z[:,-2]) 45 | costC = torch.mean(C(z)) 46 | costR = torch.mean(z[:,-1]) 47 | 48 | cs = [costL, costC, costR] 49 | 50 | # return dot(cs, alph) , cs 51 | return sum(i[0] * i[1] for i in zip(cs, alph)) , cs 52 | 53 | 54 | 55 | def stepRK4(odefun, z, Phi, alph, t0, t1): 56 | """ 57 | Runge-Kutta 4 integration scheme 58 | :param odefun: function to apply at every time step 59 | :param z: tensor nex-by-d+4, inputs 60 | :param Phi: Module, the Phi potential function 61 | :param alph: list, the 3 alpha values for the OT-Flow Problem 62 | :param t0: float, starting time 63 | :param t1: float, end time 64 | :return: tensor nex-by-d+4, features at time t1 65 | """ 66 | 67 | h = t1 - t0 # step size 68 | z0 = z 69 | 70 | K = h * odefun(z0, t0, Phi, alph=alph) 71 | z = z0 + (1.0/6.0) * K 72 | 73 | K = h * odefun( z0 + 0.5*K , t0+(h/2) , Phi, alph=alph) 74 | z += (2.0/6.0) * K 75 | 76 | K = h * odefun( z0 + 0.5*K , t0+(h/2) , Phi, alph=alph) 77 | z += (2.0/6.0) * K 78 | 79 | K = h * odefun( z0 + K , t0+h , Phi, alph=alph) 80 | z += (1.0/6.0) * K 81 | 82 | return z 83 | 84 | def stepRK1(odefun, z, Phi, alph, t0, t1): 85 | """ 86 | Runge-Kutta 1 / Forward Euler integration scheme. Added for comparison, but we recommend stepRK4. 87 | :param odefun: function to apply at every time step 88 | :param z: tensor nex-by-d+4, inputs 89 | :param Phi: Module, the Phi potential function 90 | :param alph: list, the 3 alpha values for the mean field game problem 91 | :param t0: float, starting time 92 | :param t1: float, end time 93 | :return: tensor nex-by-d+4, features at time t1 94 | """ 95 | z += (t1 - t0) * odefun(z, t0, Phi, alph=alph) 96 | return z 97 | 98 | 99 | def integrate(x, net, tspan , nt, stepper="rk4", alph =[1.0,1.0,1.0], intermediates=False ): 100 | """ 101 | perform the time integration in the d-dimensional space 102 | :param x: input data tensor nex-by-d 103 | :param net: neural network Phi 104 | :param tspan: time range to integrate over, ex. [0.0 , 1.0] 105 | :param nt: number of time steps 106 | :param stepper: string "rk1" or "rk4" Runge-Kutta schemes 107 | :param alph: list of length 3, the alpha value multipliers 108 | :param intermediates: bool, True means save all intermediate time points along trajectories 109 | :return: 110 | z - tensor nex-by-d+4, features at time t1 111 | OR zFull - tensor nex-by-d+3-by-nt+1 , trajectories from time t0 to t1 (when intermediates=True) 112 | """ 113 | 114 | h = (tspan[1]-tspan[0]) / nt 115 | 116 | # initialize "hidden" vector to propagate with all the additional dimensions for all the ODEs 117 | z = pad(x, (0, 3, 0, 0), value=tspan[0]) 118 | 119 | tk = tspan[0] 120 | 121 | if intermediates: # save the intermediate values as well 122 | zFull = torch.zeros( *z.shape , nt+1, device=x.device, dtype=x.dtype) # make tensor of size z.shape[0], z.shape[1], nt 123 | zFull[:,:,0] = z 124 | 125 | if stepper == 'rk4': 126 | for k in range(nt): 127 | zFull[:,:,k+1] = stepRK4(odefun, zFull[:,:,k] , net, alph, tk, tk+h) 128 | tk += h 129 | elif stepper == 'rk1': 130 | for k in range(nt): 131 | zFull[:,:,k+1] = stepRK1(odefun, zFull[:,:,k] , net, alph, tk, tk+h) 132 | tk += h 133 | 134 | return zFull 135 | 136 | else: 137 | if stepper == 'rk4': 138 | for k in range(nt): 139 | z = stepRK4(odefun,z,net, alph,tk,tk+h) 140 | tk += h 141 | elif stepper == 'rk1': 142 | for k in range(nt): 143 | z = stepRK1(odefun,z,net, alph,tk,tk+h) 144 | tk += h 145 | 146 | return z 147 | 148 | # return in case of error 149 | return -1 150 | 151 | 152 | 153 | def C(z): 154 | """Expected negative log-likelihood; see Eq.(3) in the paper""" 155 | d = z.shape[1]-3 156 | l = z[:,d] # log-det 157 | 158 | return -( torch.sum( -0.5 * math.log(2*math.pi) - torch.pow(z[:,0:d],2) / 2 , 1 , keepdims=True ) + l.unsqueeze(1) ) 159 | 160 | 161 | def odefun(x, t, net, alph=[1.0,1.0,1.0]): 162 | """ 163 | neural ODE combining the characteristics and log-determinant (see Eq. (2)), the transport costs (see Eq. (5)), and 164 | the HJB regularizer (see Eq. (7)). 165 | 166 | d_t [x ; l ; v ; r] = odefun( [x ; l ; v ; r] , t ) 167 | 168 | x - particle position 169 | l - log determinant 170 | v - accumulated transport costs (Lagrangian) 171 | r - accumulates violation of HJB condition along trajectory 172 | """ 173 | nex, d_extra = x.shape 174 | d = d_extra - 3 175 | 176 | z = pad(x[:, :d], (0, 1, 0, 0), value=t) # concatenate with the time t 177 | 178 | gradPhi, trH = net.trHess(z) 179 | 180 | dx = -(1.0/alph[0]) * gradPhi[:,0:d] 181 | dl = -(1.0/alph[0]) * trH.unsqueeze(1) 182 | dv = 0.5 * torch.sum(torch.pow(dx, 2) , 1 ,keepdims=True) 183 | dr = torch.abs( -gradPhi[:,-1].unsqueeze(1) + alph[0] * dv ) 184 | 185 | return torch.cat( (dx,dl,dv,dr) , 1 ) 186 | 187 | 188 | -------------------------------------------------------------------------------- /src/Phi.py: -------------------------------------------------------------------------------- 1 | # Phi.py 2 | # neural network to model the potential function 3 | import torch 4 | import torch.nn as nn 5 | import copy 6 | import math 7 | 8 | def antiderivTanh(x): # activation function aka the antiderivative of tanh 9 | return torch.abs(x) + torch.log(1+torch.exp(-2.0*torch.abs(x))) 10 | 11 | def derivTanh(x): # act'' aka the second derivative of the activation function antiderivTanh 12 | return 1 - torch.pow( torch.tanh(x) , 2 ) 13 | 14 | class ResNN(nn.Module): 15 | def __init__(self, d, m, nTh=2): 16 | """ 17 | ResNet N portion of Phi 18 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 19 | :param m: int, hidden dimension 20 | :param nTh: int, number of resNet layers , (number of theta layers) 21 | """ 22 | super().__init__() 23 | 24 | if nTh < 2: 25 | print("nTh must be an integer >= 2") 26 | exit(1) 27 | 28 | self.d = d 29 | self.m = m 30 | self.nTh = nTh 31 | self.layers = nn.ModuleList([]) 32 | self.layers.append(nn.Linear(d + 1, m, bias=True)) # opening layer 33 | self.layers.append(nn.Linear(m,m, bias=True)) # resnet layers 34 | for i in range(nTh-2): 35 | self.layers.append(copy.deepcopy(self.layers[1])) 36 | self.act = antiderivTanh 37 | self.h = 1.0 / (self.nTh-1) # step size for the ResNet 38 | 39 | def forward(self, x): 40 | """ 41 | N(s;theta). the forward propogation of the ResNet 42 | :param x: tensor nex-by-d+1, inputs 43 | :return: tensor nex-by-m, outputs 44 | """ 45 | 46 | x = self.act(self.layers[0].forward(x)) 47 | 48 | for i in range(1,self.nTh): 49 | x = x + self.h * self.act(self.layers[i](x)) 50 | 51 | return x 52 | 53 | 54 | 55 | class Phi(nn.Module): 56 | def __init__(self, nTh, m, d, r=10, alph=[1.0] * 5): 57 | """ 58 | neural network approximating Phi (see Eq. (9) in our paper) 59 | 60 | Phi( x,t ) = w'*ResNet( [x;t]) + 0.5*[x' t] * A'A * [x;t] + b'*[x;t] + c 61 | 62 | :param nTh: int, number of resNet layers , (number of theta layers) 63 | :param m: int, hidden dimension 64 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 65 | :param r: int, rank r for the A matrix 66 | :param alph: list, alpha values / weighted multipliers for the optimization problem 67 | """ 68 | super().__init__() 69 | 70 | self.m = m 71 | self.nTh = nTh 72 | self.d = d 73 | self.alph = alph 74 | 75 | r = min(r,d+1) # if number of dimensions is smaller than default r, use that 76 | 77 | self.A = nn.Parameter(torch.zeros(r, d+1) , requires_grad=True) 78 | self.A = nn.init.xavier_uniform_(self.A) 79 | self.c = nn.Linear( d+1 , 1 , bias=True) # b'*[x;t] + c 80 | self.w = nn.Linear( m , 1 , bias=False) 81 | 82 | self.N = ResNN(d, m, nTh=nTh) 83 | 84 | # set initial values 85 | self.w.weight.data = torch.ones(self.w.weight.data.shape) 86 | self.c.weight.data = torch.zeros(self.c.weight.data.shape) 87 | self.c.bias.data = torch.zeros(self.c.bias.data.shape) 88 | 89 | 90 | 91 | def forward(self, x): 92 | """ calculating Phi(s, theta)...not used in OT-Flow """ 93 | 94 | # force A to be symmetric 95 | symA = torch.matmul(torch.t(self.A), self.A) # A'A 96 | 97 | return self.w( self.N(x)) + 0.5 * torch.sum( torch.matmul(x , symA) * x , dim=1, keepdims=True) + self.c(x) 98 | 99 | 100 | def trHess(self,x, justGrad=False ): 101 | """ 102 | compute gradient of Phi wrt x and trace(Hessian of Phi); see Eq. (11) and Eq. (13), respectively 103 | recomputes the forward propogation portions of Phi 104 | 105 | :param x: input data, torch Tensor nex-by-d 106 | :param justGrad: boolean, if True only return gradient, if False return (grad, trHess) 107 | :return: gradient , trace(hessian) OR just gradient 108 | """ 109 | 110 | # code in E = eye(d+1,d) as index slicing instead of matrix multiplication 111 | # assumes specific N.act as the antiderivative of tanh 112 | 113 | N = self.N 114 | m = N.layers[0].weight.shape[0] 115 | nex = x.shape[0] # number of examples in the batch 116 | d = x.shape[1]-1 117 | symA = torch.matmul(self.A.t(), self.A) 118 | 119 | u = [] # hold the u_0,u_1,...,u_M for the forward pass 120 | z = N.nTh*[None] # hold the z_0,z_1,...,z_M for the backward pass 121 | # preallocate z because we will store in the backward pass and we want the indices to match the paper 122 | 123 | # Forward of ResNet N and fill u 124 | opening = N.layers[0].forward(x) # K_0 * S + b_0 125 | u.append(N.act(opening)) # u0 126 | feat = u[0] 127 | 128 | for i in range(1,N.nTh): 129 | feat = feat + N.h * N.act(N.layers[i](feat)) 130 | u.append(feat) 131 | 132 | # going to be used more than once 133 | tanhopen = torch.tanh(opening) # act'( K_0 * S + b_0 ) 134 | 135 | # compute gradient and fill z 136 | for i in range(N.nTh-1,0,-1): # work backwards, placing z_i in appropriate spot 137 | if i == N.nTh-1: 138 | term = self.w.weight.t() 139 | else: 140 | term = z[i+1] 141 | 142 | # z_i = z_{i+1} + h K_i' diag(...) z_{i+1} 143 | z[i] = term + N.h * torch.mm( N.layers[i].weight.t() , torch.tanh( N.layers[i].forward(u[i-1]) ).t() * term) 144 | 145 | # z_0 = K_0' diag(...) z_1 146 | z[0] = torch.mm( N.layers[0].weight.t() , tanhopen.t() * z[1] ) 147 | grad = z[0] + torch.mm(symA, x.t() ) + self.c.weight.t() 148 | 149 | if justGrad: 150 | return grad.t() 151 | 152 | # ----------------- 153 | # trace of Hessian 154 | #----------------- 155 | 156 | # t_0, the trace of the opening layer 157 | Kopen = N.layers[0].weight[:,0:d] # indexed version of Kopen = torch.mm( N.layers[0].weight, E ) 158 | temp = derivTanh(opening.t()) * z[1] 159 | trH = torch.sum(temp.reshape(m, -1, nex) * torch.pow(Kopen.unsqueeze(2), 2), dim=(0, 1)) # trH = t_0 160 | 161 | # grad_s u_0 ^ T 162 | temp = tanhopen.t() # act'( K_0 * S + b_0 ) 163 | Jac = Kopen.unsqueeze(2) * temp.unsqueeze(1) # K_0' * act'( K_0 * S + b_0 ) 164 | # Jac is shape m by d by nex 165 | 166 | # t_i, trace of the resNet layers 167 | # KJ is the K_i^T * grad_s u_{i-1}^T 168 | for i in range(1,N.nTh): 169 | KJ = torch.mm(N.layers[i].weight , Jac.reshape(m,-1) ) 170 | KJ = KJ.reshape(m,-1,nex) 171 | if i == N.nTh-1: 172 | term = self.w.weight.t() 173 | else: 174 | term = z[i+1] 175 | 176 | temp = N.layers[i].forward(u[i-1]).t() # (K_i * u_{i-1} + b_i) 177 | t_i = torch.sum( ( derivTanh(temp) * term ).reshape(m,-1,nex) * torch.pow(KJ,2) , dim=(0, 1) ) 178 | trH = trH + N.h * t_i # add t_i to the accumulate trace 179 | Jac = Jac + N.h * torch.tanh(temp).reshape(m, -1, nex) * KJ # update Jacobian 180 | 181 | return grad.t(), trH + torch.trace(symA[0:d,0:d]) 182 | # indexed version of: return grad.t() , trH + torch.trace( torch.mm( E.t() , torch.mm( symA , E) ) ) 183 | 184 | 185 | 186 | if __name__ == "__main__": 187 | 188 | import time 189 | import math 190 | 191 | # test case 192 | d = 2 193 | m = 5 194 | 195 | net = Phi(nTh=2, m=m, d=d) 196 | net.N.layers[0].weight.data = 0.1 + 0.0 * net.N.layers[0].weight.data 197 | net.N.layers[0].bias.data = 0.2 + 0.0 * net.N.layers[0].bias.data 198 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data 199 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data 200 | 201 | # number of samples-by-(d+1) 202 | x = torch.Tensor([[1.0 ,4.0 , 0.5],[2.0,5.0,0.6],[3.0,6.0,0.7],[0.0,0.0,0.0]]) 203 | y = net(x) 204 | print(y) 205 | 206 | # test timings 207 | d = 400 208 | m = 32 209 | nex = 1000 210 | 211 | net = Phi(nTh=5, m=m, d=d) 212 | net.eval() 213 | x = torch.randn(nex,d+1) 214 | y = net(x) 215 | 216 | end = time.time() 217 | g,h = net.trHess(x) 218 | print('traceHess takes ', time.time()-end) 219 | 220 | end = time.time() 221 | g = net.trHess(x, justGrad=True) 222 | print('JustGrad takes ', time.time()-end) 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /src/PhiHC.py: -------------------------------------------------------------------------------- 1 | # PhiHC.py 2 | # Phi Hardcoded version 3 | # hard coded nTh = 2 4 | import torch 5 | import torch.nn as nn 6 | import copy 7 | import math 8 | 9 | def antiderivTanh(x): # activation function aka the antiderivative of tanh 10 | return torch.log( torch.exp(x) + torch.exp(-x) ) 11 | 12 | def derivTanh(x): # act'' aka the second derivative of the activation function antiderivTanh 13 | return 1 - torch.pow( torch.tanh(x) , 2 ) 14 | 15 | class ResNN(nn.Module): 16 | def __init__(self, d, m, nTh=2): 17 | """ 18 | ResNet N portion of Phi 19 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 20 | :param m: int, hidden dimension 21 | :param nTh: 2, hard-coded number of ResNet layers 22 | """ 23 | super().__init__() 24 | 25 | nTh = 2 26 | self.opening = nn.Linear(d+1 , m , bias=True) 27 | self.layer1 = nn.Linear(m,m, bias=True) 28 | self.act = antiderivTanh 29 | self.h = 1.0 30 | self.d = d 31 | self.m = m 32 | 33 | def forward(self, x): 34 | """ 35 | N(s;theta). the forward propogation of the ResNet 36 | :param x: tensor nex-by-d+1, inputs 37 | :return: tensor nex-by-m, outputs 38 | """ 39 | 40 | x = self.act(self.opening(x)) 41 | x = x + self.h * self.act(self.layer1(x)) 42 | 43 | return x 44 | 45 | 46 | 47 | class PhiHC(nn.Module): 48 | def __init__(self, nTh, m, d, r=10, alph=[1.0]*5): 49 | """ 50 | neural network approximating Phi 51 | Phi( x,t ) = w'*ResNet( [x;t]) + 0.5*[x' t] * A'A * [x;t] + b'*[x;t] + c 52 | 53 | :param nTh: int, number of resNet layers, hardcoded as 2 54 | :param m: int, hidden dimension 55 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 56 | :param r: int, rank r for the A matrix 57 | :param alph: list, alpha values / weighted multipliers for the optimization problem 58 | """ 59 | super().__init__() 60 | 61 | self.m = m 62 | self.nTh = nTh 63 | self.d = d 64 | self.alph = alph 65 | 66 | r = min(r,d+1) # if number of dimensions is smaller than default r, use that 67 | 68 | 69 | self.A = nn.Parameter(torch.zeros(r, d+1) , requires_grad=True) 70 | self.A = nn.init.xavier_uniform_(self.A) 71 | self.c = nn.Linear( d+1 , 1 , bias=True) # b'*[x;t] + c 72 | self.w = nn.Linear( m , 1 , bias=False) 73 | 74 | self.N = ResNN(d,m, nTh=nTh) 75 | 76 | # set start values 77 | self.w.weight.data = torch.ones(self.w.weight.data.shape) 78 | self.c.weight.data = torch.zeros(self.c.weight.data.shape) 79 | self.c.bias.data = torch.zeros(self.c.bias.data.shape) 80 | 81 | def forward(self, x): 82 | 83 | # force A to be symmetric 84 | symA = torch.matmul(torch.t(self.A), self.A) 85 | 86 | return self.w( self.N(x)) + 0.5 * torch.sum( torch.matmul(x , symA) * x , dim=1, keepdims=True) + self.c(x) 87 | 88 | def trHess(self, x , justGrad=False): 89 | """ 90 | compute gradient of Phi wrt x and trace(Hessian of Phi) 91 | recomputes the forward propogation portions of Phi 92 | 93 | :param x: input data, torch Tensor nex-by-d 94 | :param justGrad: boolean, if True only return gradient, if False return (grad, trHess) 95 | :return: gradient , trace(hessian) OR just gradient 96 | """ 97 | 98 | # code in E = eye(d+1,d) as index slicing instead of matrix multiplication 99 | # assumes specific N.act as the antiderivative of tanh 100 | N = self.N 101 | m = N.opening.weight.shape[0] 102 | nex = x.shape[0] # number of examples in the batch 103 | d = x.shape[1] - 1 104 | symA = torch.matmul(self.A.t(), self.A) 105 | 106 | # Forward of ResNet N 107 | opening = N.opening(x) # K_0 * S + b_0 108 | u0 = N.act(opening) 109 | tanhopen = torch.tanh(opening) 110 | out1 = N.layer1(u0).t() 111 | 112 | # compute gradient 113 | z1 = self.w.weight.t() + N.h * torch.mm( N.layer1.weight.t() , torch.tanh(out1) ) 114 | z0 = torch.mm( N.opening.weight.t() , tanhopen.t() * z1 ) 115 | grad = z0 + torch.mm(symA, x.t() ) + self.c.weight.t() 116 | if justGrad: 117 | return grad.t() 118 | 119 | Kopen = N.opening.weight[:,0:d] # Kopen = torch.mm( N.opening.weight, E ) 120 | trH1 = torch.sum((derivTanh(opening.t())*z1).view(m, -1, nex) * torch.pow(Kopen.unsqueeze(2), 2), dim=(0, 1)) 121 | 122 | Jac = Kopen.unsqueeze(2) * tanhopen.t().unsqueeze(1) 123 | # Jac is shape m by d by nex 124 | 125 | Jac = torch.mm(N.layer1.weight , Jac.view(m,-1) ).view(m,-1,nex) 126 | trH2 = torch.sum( (derivTanh(out1) * self.w.weight.t()).view(m,-1,nex) * torch.pow(Jac,2) , dim=(0, 1) ) 127 | 128 | return grad.t(), trH1 + trH2 + torch.trace(symA[0:d,0:d]) 129 | 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | import time 135 | 136 | # test case 137 | d = 2 138 | m = 16 139 | 140 | net = PhiHC(nTh=2, m=m, d=d) 141 | net.N.opening.weight.data = 0.1 + 0.0 * net.N.opening.weight.data 142 | net.N.opening.bias.data = 0.2 + 0.0 * net.N.opening.bias.data 143 | net.N.layer1.weight.data = 0.3 + 0.0 * net.N.layer1.weight.data 144 | net.N.layer1.bias.data = 0.4 + 0.0 * net.N.layer1.bias.data 145 | 146 | # number of samples-by-(d+1) 147 | x = torch.Tensor([[1.0 ,4.0 , 0.5],[2.0,5.0,0.6],[3.0,6.0,0.7],[0.0,0.0,0.0]]) 148 | y = net(x) 149 | print(y) 150 | 151 | 152 | # test timings 153 | d = 400 154 | m = 32 155 | nex = 1000 156 | 157 | net = PhiHC(nTh=2, m=m, d=d) 158 | x = torch.randn(nex,d+1) 159 | y = net(x) 160 | 161 | end = time.time() 162 | g,h = net.trHess(x) 163 | print('traceHess takes ', time.time()-end) 164 | 165 | end = time.time() 166 | g = net.trHess(x, justGrad=True) 167 | print('JustGrad takes ', time.time()-end) 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /src/mmd.py: -------------------------------------------------------------------------------- 1 | # mmd.py 2 | # Maximum Mean Discrepancy 3 | 4 | import torch 5 | import numpy as np 6 | 7 | # from https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py 8 | def pdist(sample_1, sample_2, norm=2, eps=1e-5): 9 | r"""Compute the matrix of all squared pairwise distances. 10 | Arguments 11 | --------- 12 | sample_1 : torch.Tensor or Variable 13 | The first sample, should be of shape ``(n_1, d)``. 14 | sample_2 : torch.Tensor or Variable 15 | The second sample, should be of shape ``(n_2, d)``. 16 | norm : float 17 | The l_p norm to be used. 18 | Returns 19 | ------- 20 | torch.Tensor or Variable 21 | Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to 22 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" 23 | n_1, n_2 = sample_1.size(0), sample_2.size(0) 24 | norm = float(norm) 25 | if norm == 2.: 26 | norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True) 27 | norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True) 28 | norms = (norms_1.expand(n_1, n_2) + 29 | norms_2.transpose(0, 1).expand(n_1, n_2)) 30 | distances_squared = norms - 2 * sample_1.mm(sample_2.t()) 31 | return torch.sqrt(eps + torch.abs(distances_squared)) 32 | else: 33 | dim = sample_1.size(1) 34 | expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim) 35 | expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim) 36 | differences = torch.abs(expanded_1 - expanded_2) ** norm 37 | inner = torch.sum(differences, dim=2, keepdim=False) 38 | return (eps + inner) ** (1. / norm) 39 | 40 | class MMDStatistic: 41 | r"""The *unbiased* MMD test of :cite:`gretton2012kernel`. 42 | The kernel used is equal to: 43 | .. math :: 44 | k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2}, 45 | for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`. 46 | Arguments 47 | --------- 48 | n_1: int 49 | The number of points in the first sample. 50 | n_2: int 51 | The number of points in the second sample.""" 52 | 53 | def __init__(self, n_1, n_2): 54 | self.n_1 = n_1 55 | self.n_2 = n_2 56 | 57 | # The three constants used in the test. 58 | self.a00 = 1. / (n_1 * (n_1 - 1)) 59 | self.a11 = 1. / (n_2 * (n_2 - 1)) 60 | self.a01 = - 1. / (n_1 * n_2) 61 | 62 | def __call__(self, sample_1, sample_2, alphas, ret_matrix=False): 63 | r"""Evaluate the statistic. 64 | The kernel used is 65 | .. math:: 66 | k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2}, 67 | for the provided ``alphas``. 68 | Arguments 69 | --------- 70 | sample_1: :class:`torch:torch.autograd.Variable` 71 | The first sample, of size ``(n_1, d)``. 72 | sample_2: variable of shape (n_2, d) 73 | The second sample, of size ``(n_2, d)``. 74 | alphas : list of :class:`float` 75 | The kernel parameters. 76 | ret_matrix: bool 77 | If set, the call with also return a second variable. 78 | This variable can be then used to compute a p-value using 79 | :py:meth:`~.MMDStatistic.pval`. 80 | Returns 81 | ------- 82 | :class:`float` 83 | The test statistic. 84 | :class:`torch:torch.autograd.Variable` 85 | Returned only if ``ret_matrix`` was set to true.""" 86 | sample_12 = torch.cat((sample_1, sample_2), 0) 87 | distances = pdist(sample_12, sample_12, norm=2) 88 | 89 | kernels = None 90 | for alpha in alphas: 91 | kernels_a = torch.exp(- alpha * distances ** 2) 92 | if kernels is None: 93 | kernels = kernels_a 94 | else: 95 | kernels = kernels + kernels_a 96 | 97 | k_1 = kernels[:self.n_1, :self.n_1] 98 | k_2 = kernels[self.n_1:, self.n_1:] 99 | k_12 = kernels[:self.n_1, self.n_1:] 100 | 101 | mmd = (2 * self.a01 * k_12.sum() + 102 | self.a00 * (k_1.sum() - torch.trace(k_1)) + 103 | self.a11 * (k_2.sum() - torch.trace(k_2))) 104 | if ret_matrix: 105 | return mmd, kernels 106 | else: 107 | return mmd 108 | 109 | 110 | def mmd(x,y, indepth=False, alph=1.0): 111 | """ 112 | from Li et al. Generative Moment Matching Networks 2015 113 | 114 | Gaussian kernel 115 | 116 | :param x: numpy matrix of size (nex, :) 117 | :param y: numpy matrix of size (nex,:) 118 | :return: MMD(x,y) 119 | """ 120 | 121 | # convert to numpy 122 | if type(x) is torch.Tensor: 123 | x = x.numpy() 124 | if type(y) is torch.Tensor: 125 | y = y.numpy() 126 | 127 | if max(x.size,y.size) > 20000: 128 | indepth = True 129 | 130 | 131 | 132 | # there's a quick method, that uses a lot of memory, that can be run on pointclouds of a few thousand samples 133 | # and there's a long and slow way that can be run on pointclouds with 10^5 samples 134 | if not indepth: 135 | # make torch tensor 136 | if type(x) is np.ndarray: 137 | x = torch.from_numpy(x).to(torch.float32) 138 | if type(y) is np.ndarray: 139 | y = torch.from_numpy(y).to(torch.float32) 140 | mmdObj = MMDStatistic(x.shape[0],y.shape[0]) 141 | return mmdObj( x , y , [alph] ).item() # just use alpha = 1.0 142 | 143 | else: 144 | # lots of examples, do a long approach 145 | # very slow 146 | # kernel = exp( 1/(2*sig) * || x - xj ||^2 ) 147 | 148 | 149 | # sig = 0.5 150 | # alpha = -1.0 / (2*sig) 151 | alpha = - alph 152 | 153 | xx = 0.0 154 | yy = 0.0 155 | xy = 0.0 156 | N = x.shape[0] 157 | M = y.shape[0] 158 | 159 | NsqrTerm = 1/N**2 160 | MsqrTerm = 1/M**2 161 | crossTerm = -2/(N*M) 162 | 163 | for i in range(N): 164 | xi = x[i,:] 165 | diff = xi - x 166 | power = alpha * np.linalg.norm(diff, ord=2, axis=1, keepdims=True)**2 # nex-by-1 167 | xx += np.exp(power).sum() 168 | 169 | diff = xi - y 170 | power = alpha * np.linalg.norm(diff, ord=2, axis=1, keepdims=True) ** 2 # nex-by-1 171 | xy += np.exp(power).sum() 172 | 173 | for i in range(M): 174 | yi = y[i,:] 175 | diff = yi - y 176 | power = alpha * np.linalg.norm(diff, ord=2, axis=1, keepdims=True)**2 # nex-by-1 177 | yy += np.exp(power).sum() 178 | 179 | return NsqrTerm*xx + crossTerm*xy + MsqrTerm*yy 180 | 181 | 182 | 183 | 184 | if __name__ == "__main__": 185 | 186 | for N in [2000,20000]: 187 | M = N-9 188 | d = 10 189 | x = 50.0 + np.random.rand(N,d) 190 | y = np.random.randn(M,d) 191 | ret = mmd(x,y) 192 | print('mmd: {:.3e}'.format(ret)) 193 | -------------------------------------------------------------------------------- /src/plotTraceComparison.py: -------------------------------------------------------------------------------- 1 | # plotTraceComparison.py 2 | 3 | try: 4 | import matplotlib 5 | matplotlib.use('TkAgg') 6 | import matplotlib.pyplot as plt 7 | except: 8 | import matplotlib 9 | matplotlib.use('agg') # for linux server with no tkinter 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | import os 14 | import matplotlib.gridspec as gridspec 15 | import sklearn 16 | 17 | 18 | 19 | def bootstrap(values, nIter, alpha): 20 | """ 21 | bootstrapping to create error bounds, uses resmapling with replacement of sample size n-4 22 | :param values: n-by-m matrix, n = number of runs to resmaple from, m = observations per run 23 | :param nIter: int, number of resamples 24 | :param alpha: float, percentile bounds 25 | :return: lower bounds, mean, upper bounds 26 | """ 27 | p1 = ((1.0 - alpha) / 2.0) * 100 28 | p2 = (alpha + ((1.0 - alpha) / 2.0)) * 100 29 | print('%.1f confidence interval %.1f%% and %.1f%%' % (alpha * 100, p1, p2)) 30 | 31 | stats = list() 32 | nSize = values.shape[0] - 4 33 | for i in range(nIter): 34 | sample = sklearn.utils.resample(values, n_samples=nSize) 35 | stats.append(np.mean(sample, 0)) 36 | stats = np.array(stats) 37 | lower = np.percentile(stats, p1, axis=0) 38 | upper = np.percentile(stats, p2, axis=0) 39 | avg = np.mean(values, axis=0) 40 | print('lower :', lower) 41 | print('mean :', avg) 42 | print('upper :', upper) 43 | return lower, avg, upper 44 | return lower, avg, upper 45 | 46 | 47 | 48 | 49 | def plotTraceCompare(domainMiniboone,domainBSDS,domainMNIST, 50 | approxTimingMiniboone, approxTimingBDS, approxTimingMNIST, 51 | traceErrorMiniboone, traceErrorBDS, traceErrorMNIST, 52 | lTimeExact,sPath = "../image/traceComparison/", bErrBar=False): 53 | """ 54 | 55 | :param domainMiniboone: list, number of hutchinson vectors Miniboone 56 | :param domainBSDS: " BSDS300 57 | :param domainMNIST: " MNIST 58 | :param approxTimingMiniboone: list, same length of domain, timings Miniboone 59 | :param approxTimingBDS: " BSDS300 60 | :param approxTimingMNIST: " MNIST 61 | :param traceErrorMiniboone: list, same length of domain, trace estimation rel. errors Miniboone 62 | :param traceErrorBDS: " BSDS300 63 | :param traceErrorMNIST: " MNIST 64 | :param lTimeExact: list of 3 timings of exact trace for Miniboone, BSDS, MNIST 65 | :param bErrBar: boolean, True means to plot the error bounds 66 | :return: void. plot the figures 67 | """ 68 | 69 | shade=0.3 70 | 71 | exactTimingMiniboone = lTimeExact[0] * torch.ones(len(domainMiniboone)) 72 | exactTimingBDS = lTimeExact[1] * torch.ones(len(domainBSDS)) 73 | exactTimingMNIST = lTimeExact[2] * torch.ones(len(domainMNIST)) 74 | 75 | exactTimingMiniboone = exactTimingMiniboone.cpu().detach().numpy() 76 | approxTimingMiniboone= approxTimingMiniboone.cpu().detach().numpy() 77 | traceErrorMiniboone = traceErrorMiniboone.cpu().detach().numpy() 78 | exactTimingBDS = exactTimingBDS.cpu().detach().numpy() 79 | approxTimingBDS = approxTimingBDS.cpu().detach().numpy() 80 | traceErrorBDS = traceErrorBDS.cpu().detach().numpy() 81 | exactTimingMNIST = exactTimingMNIST.cpu().detach().numpy() 82 | approxTimingMNIST = approxTimingMNIST.cpu().detach().numpy() 83 | traceErrorMNIST = traceErrorMNIST.cpu().detach().numpy() 84 | 85 | 86 | if bErrBar: 87 | # calculate error bars by bootstrapping 88 | # compute mean of 'nIter' samples with replacement of size n-4 89 | nIter = 4000 90 | alpha = 0.99 91 | # miniboone 92 | exactMiniLower , exactTimingMiniboone , exactMiniUpper = bootstrap(exactTimingMiniboone, nIter, alpha) 93 | approxMiniLower, approxTimingMiniboone, approxMiniUpper = bootstrap(approxTimingMiniboone, nIter, alpha) 94 | errMiniLower , traceErrorMiniboone , errMiniUpper = bootstrap(traceErrorMiniboone, nIter, alpha) 95 | # BSDS 96 | exactBSDSLower , exactTimingBDS , exactBSDSUpper = bootstrap(exactTimingBDS, nIter, alpha) 97 | approxBSDSLower, approxTimingBDS, approxBSDSUpper = bootstrap(approxTimingBDS, nIter, alpha) 98 | errBSDSLower , traceErrorBDS , errBSDSUpper = bootstrap(traceErrorBDS, nIter, alpha) 99 | # MNIST 100 | exactMNISTLower , exactTimingMNIST , exactMNISTUpper = bootstrap(exactTimingMNIST, nIter, alpha) 101 | approxMNISTLower, approxTimingMNIST, approxMNISTUpper = bootstrap(approxTimingMNIST, nIter, alpha) 102 | errMNISTLower , traceErrorMNIST , errMNISTUpper = bootstrap(traceErrorMNIST, nIter, alpha) 103 | 104 | else: 105 | # just calculate the mean 106 | exactTimingMiniboone = np.mean( exactTimingMiniboone , axis=0) 107 | approxTimingMiniboone = np.mean(approxTimingMiniboone, axis=0) 108 | traceErrorMiniboone = np.mean(traceErrorMiniboone, axis=0) 109 | exactTimingBDS = np.mean(exactTimingBDS, axis=0) 110 | approxTimingBDS = np.mean(approxTimingBDS, axis=0) 111 | traceErrorBDS = np.mean(traceErrorBDS, axis=0) 112 | exactTimingMNIST = np.mean(exactTimingMNIST, axis=0) 113 | approxTimingMNIST = np.mean(approxTimingMNIST, axis=0) 114 | traceErrorMNIST = np.mean(traceErrorMNIST, axis=0) 115 | 116 | 117 | ylim_min = torch.ones(1) * 8e-4 118 | ylim_max = torch.max(torch.FloatTensor(approxTimingMNIST)) 119 | print("ylim_max = ", ylim_max) 120 | 121 | ylim_min_err = 0.8*torch.min(torch.FloatTensor(traceErrorMNIST)) 122 | ylim_max_err = 1.2*torch.max(torch.FloatTensor(traceErrorMNIST)) 123 | 124 | # path to save figure 125 | if not os.path.exists(os.path.dirname(sPath)): 126 | os.makedirs(os.path.dirname(sPath)) 127 | 128 | # Plots 129 | fontsize = 20 130 | title_fontsize = 22 131 | 132 | # we do four plots. the first three share an axis and the last is the relative errors 133 | fig = plt.figure() 134 | plt.clf() 135 | fig.set_size_inches(20, 4.1) 136 | outer = gridspec.GridSpec(1, 2, wspace=0.20, width_ratios= [2.7,1.0]) 137 | # the first three that I want to share an axis 138 | inner = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=outer[0], wspace=0.08) 139 | 140 | # yticks values 141 | rangeAll = [10**(-3),10**(-2),10**(-1),1] 142 | 143 | 144 | # fig1 Miniboone timings 145 | ax = plt.Subplot(fig, inner[0]) 146 | if bErrBar: 147 | ax.fill_between(domainMiniboone, exactMiniLower, exactMiniUpper, alpha=shade, color='black') # for exact 148 | ax.fill_between(domainMiniboone, approxMiniLower, approxMiniUpper, alpha=shade, color='tab:blue') # for hutch 149 | ax.semilogy(domainMiniboone, exactTimingMiniboone, linewidth=3, linestyle='dashed', color='black') 150 | ax.semilogy(domainMiniboone, approxTimingMiniboone, marker="o", markersize=12, linestyle=':', color='tab:blue') 151 | ax.set_xticks(domainMiniboone) 152 | ax.set_yticks(rangeAll) 153 | ax.set_ylabel("Runtime (s)", fontsize=title_fontsize) 154 | ax.set_ylim((ylim_min, ylim_max)) 155 | # ax.set_xlabel("# of Hutchinson vectors", fontsize=title_fontsize) 156 | ax.set_title("(a) MINIBOONE, d=43", fontsize=title_fontsize) 157 | # try to force tick font o be large 158 | ax.tick_params(labelsize=fontsize, which='both', direction='in') 159 | fig.add_subplot(ax) 160 | 161 | 162 | # fig2 BSDS300 timings 163 | ax = plt.Subplot(fig, inner[1]) 164 | if bErrBar: 165 | ax.fill_between(domainBSDS, exactBSDSLower, exactBSDSUpper, alpha=shade, color='black') # for exact 166 | ax.fill_between(domainBSDS, approxBSDSLower, approxBSDSUpper, alpha=shade, color='tab:green') # for hutch 167 | ax.semilogy(domainBSDS, exactTimingBDS, linewidth=3, linestyle='dashed', color='black') 168 | ax.semilogy(domainBSDS, approxTimingBDS, marker=">", markersize=12, linestyle=':', color='tab:green') 169 | ax.set_xticks(domainBSDS) 170 | ax.set_yticks(rangeAll) 171 | ax.set_ylim((ylim_min, ylim_max)) 172 | ax.set_xlabel("Number of Hutchinson Vectors", fontsize=title_fontsize) 173 | ax.set_title("(b) BSDS300, d=63", fontsize=title_fontsize) 174 | # try to force tick's font size to be large 175 | ax.tick_params(labelsize=fontsize, which='both', direction='in') 176 | ax.tick_params(left=True, labelleft=False) 177 | fig.add_subplot(ax) 178 | 179 | 180 | 181 | # fig3 MNIST timings 182 | ax = plt.Subplot(fig, inner[2]) 183 | if bErrBar: 184 | ax.fill_between(domainMNIST, exactMNISTLower, exactMNISTUpper, alpha=shade, color='black') # for exact 185 | ax.fill_between(domainMNIST, approxMNISTLower, approxMNISTUpper, alpha=shade, color='tab:red') # for hutch 186 | ax.semilogy(domainMNIST, exactTimingMNIST, linewidth=3, linestyle='dashed', color='black') 187 | ax.semilogy(domainMNIST, approxTimingMNIST, marker="x", markersize=12, linestyle=':', color='tab:red') 188 | ax.set_xticks([1, 200, 400, 600, 784]) 189 | ax.set_yticks(rangeAll) 190 | ax.set_ylim((ylim_min, ylim_max)) 191 | # ax.set_xlabel("# of Hutchinson vectors", fontsize=title_fontsize) 192 | ax.set_title("(c) MNIST, d=784", fontsize=title_fontsize) 193 | # try to force tick's font size to be large 194 | ax.tick_params(labelsize=fontsize, which='both', direction='in') 195 | ax.tick_params(left=True, labelleft=False) 196 | fig.add_subplot(ax) 197 | 198 | # fig 4 relative errors 199 | ax = plt.Subplot(fig, outer[1]) 200 | if bErrBar: 201 | ax.fill_between(domainMiniboone, errMiniLower , errMiniUpper, color='tab:blue', alpha=shade) 202 | ax.fill_between(domainBSDS , errBSDSLower , errBSDSUpper, color='tab:green', alpha=shade) 203 | ax.fill_between(domainMNIST , errMNISTLower, errMNISTUpper, color='tab:red', alpha=shade) 204 | ax.semilogy(domainMiniboone, traceErrorMiniboone, marker="o", markersize=12, linestyle=':',color='tab:blue') 205 | ax.semilogy(domainBSDS, traceErrorBDS, marker=">", markersize=12, linestyle=':', color='tab:green') 206 | ax.semilogy(domainMNIST, traceErrorMNIST, marker="x", markersize=12, linestyle=':', color='tab:red') 207 | # fake line to add to the legend 208 | ax.plot(domainMNIST, 1e-16*torch.ones(len(domainMNIST)), linewidth=3, linestyle='dashed', color='black') 209 | ax.set_ylim((ylim_min_err, ylim_max_err)) 210 | ax.set_ylabel("Relative Error", fontsize=title_fontsize) 211 | ax.set_xlabel("Number of Hutchinson Vectors", fontsize=title_fontsize) 212 | ax.legend(['Hutchinson d=43', 'Hutchinson d=63', 'Hutchinson d=784', 'Exact'],fontsize=fontsize,bbox_to_anchor=(0.45,0.3)) 213 | ax.set_title("(d) Accuracy of Estimators", fontsize=title_fontsize) 214 | ax.tick_params(labelsize=fontsize, which='both', direction='in') 215 | fig.add_subplot(ax) 216 | 217 | plt.subplots_adjust(top=0.9,bottom=0.2) 218 | 219 | 220 | # plt.show() 221 | fig.savefig(sPath+'all4Trace.pdf', dpi=600) 222 | print("figure saved to ", sPath+'all4Trace.pdf') 223 | 224 | 225 | 226 | if __name__ == '__main__': 227 | 228 | # results from our runs. Values are an average from 20 runs. 229 | 230 | # Miniboone: 231 | domainMiniboone = [1, 10, 20, 30, 43] 232 | 233 | approxTimingMiniboone = torch.tensor([[0.00204698, 0.02248192, 0.05925171, 0.09269555, 0.13181952], 234 | [0.00197325, 0.02188800, 0.05658624, 0.08444519, 0.13243289], 235 | [0.00198451, 0.02560000, 0.05685863, 0.09898496, 0.13945447], 236 | [0.00199782, 0.02196685, 0.05347942, 0.09409127, 0.13435186], 237 | [0.00207565, 0.02258125, 0.05532877, 0.09219276, 0.13214721], 238 | [0.00200294, 0.02304615, 0.06069043, 0.09288090, 0.13366579], 239 | [0.00198656, 0.02010317, 0.05643674, 0.09239347, 0.13222298], 240 | [0.00201728, 0.02313114, 0.05497242, 0.09327411, 0.13496320], 241 | [0.00195686, 0.02510848, 0.05944423, 0.09257779, 0.12938856], 242 | [0.00200090, 0.01983590, 0.05545882, 0.09132032, 0.13388595], 243 | [0.00230605, 0.02441216, 0.06083686, 0.09413837, 0.13533390], 244 | [0.00202854, 0.02355405, 0.05499494, 0.09336935, 0.13517620], 245 | [0.00253235, 0.02976358, 0.05309030, 0.09212723, 0.13202228], 246 | [0.00200090, 0.02030694, 0.05293158, 0.09495040, 0.13553973], 247 | [0.00194970, 0.02337178, 0.05714330, 0.09082573, 0.13207449], 248 | [0.00200499, 0.01968742, 0.05493453, 0.09481011, 0.14179942], 249 | [0.00207258, 0.02444391, 0.05693235, 0.09386291, 0.13639885], 250 | [0.00216883, 0.02145178, 0.05614080, 0.09291776, 0.13407233], 251 | [0.00296346, 0.02311270, 0.05964186, 0.09435239, 0.13308007], 252 | [0.00210944, 0.02242970, 0.06016410, 0.09175961, 0.13343847]]) 253 | 254 | traceErrorMiniboone = torch.tensor([[0.22399592, 0.07008030, 0.05251038, 0.04112926, 0.03358838], 255 | [0.21175259, 0.06972103, 0.04838064, 0.04045253, 0.03306658], 256 | [0.21496193, 0.06821524, 0.04756567, 0.03840676, 0.03136125], 257 | [0.19982451, 0.06403704, 0.04720526, 0.03827408, 0.03041435], 258 | [0.20888096, 0.06603654, 0.04901867, 0.03706970, 0.03287060], 259 | [0.21129309, 0.06779046, 0.04861399, 0.03830324, 0.03135169], 260 | [0.20985791, 0.06680115, 0.04792466, 0.03696685, 0.03245528], 261 | [0.21677414, 0.06837571, 0.04581433, 0.04125229, 0.03183835], 262 | [0.21320291, 0.07022005, 0.04878217, 0.04066725, 0.03123112], 263 | [0.21459921, 0.07032831, 0.04619167, 0.04150819, 0.03356071], 264 | [0.22348940, 0.06709059, 0.04981212, 0.04132729, 0.03208493], 265 | [0.21790127, 0.06918178, 0.04933133, 0.03818259, 0.03425195], 266 | [0.22167198, 0.06969723, 0.04703050, 0.03707674, 0.03329653], 267 | [0.24082674, 0.07491565, 0.04908932, 0.04273553, 0.03606861], 268 | [0.19499324, 0.06544592, 0.04834570, 0.04035131, 0.03164196], 269 | [0.23307209, 0.07598016, 0.05314977, 0.04511942, 0.03460030], 270 | [0.19148827, 0.06681481, 0.04595243, 0.03640851, 0.03197488], 271 | [0.19533448, 0.06416275, 0.04717834, 0.03703568, 0.03247492], 272 | [0.22210748, 0.06772004, 0.04818948, 0.04124676, 0.03337538], 273 | [0.20792861, 0.06343383, 0.04265359, 0.03540466, 0.03049848]]) 274 | 275 | exactTimingMiniboone = torch.tensor([[0.00189805], [0.00188099], [0.00189171], [0.00211219], [0.00200957], 276 | [0.00199389], [0.00214022], [0.00192339], [0.00192688], [0.00191530], [0.00189725], [0.00189731], 277 | [0.00195821], [0.00210774], [0.00200566], [0.00228646], [0.00196464], [0.00193050], [0.00244557], 278 | [0.00233907]]) 279 | 280 | 281 | # BSDS 282 | domainBSDS = [1, 10, 20, 30, 40, 50, 63] 283 | 284 | approxTimingBDS = torch.tensor([[0.00196096, 0.01992499, 0.05914829, 0.09389056, 0.12557927, 0.15675801, 0.19541503], 285 | [0.00237568, 0.02242867, 0.05595239, 0.08900916, 0.12594381, 0.15561421, 0.19873485], 286 | [0.00209510, 0.02482176, 0.05906535, 0.09127732, 0.12411802, 0.15319961, 0.19669504], 287 | [0.00195174, 0.02213274, 0.05481574, 0.09024819, 0.12026573, 0.15681741, 0.19851059], 288 | [0.00200602, 0.02337587, 0.05639168, 0.09198797, 0.12295783, 0.15485133, 0.19691212], 289 | [0.00244122, 0.02244096, 0.06035251, 0.09647001, 0.12460339, 0.15235789, 0.19470440], 290 | [0.00245248, 0.02241946, 0.05773619, 0.09608295, 0.12370022, 0.15242343, 0.19164263], 291 | [0.00202035, 0.02782413, 0.06032794, 0.09088410, 0.12083405, 0.15557325, 0.19412991], 292 | [0.00268902, 0.02401485, 0.06104678, 0.09072845, 0.11465625, 0.15696485, 0.19905843], 293 | [0.00197837, 0.02306765, 0.05657805, 0.09488384, 0.11945062, 0.15574835, 0.19231643], 294 | [0.00201626, 0.02254848, 0.05652992, 0.08919347, 0.12453479, 0.15620403, 0.19640626], 295 | [0.00205107, 0.02238976, 0.06089932, 0.09200948, 0.12544614, 0.15656449, 0.19646567], 296 | [0.00204902, 0.02313216, 0.05572813, 0.09355161, 0.12560281, 0.15204352, 0.19627623], 297 | [0.00199168, 0.02330726, 0.05735526, 0.09469235, 0.12275814, 0.15653887, 0.19439821], 298 | [0.00218010, 0.02524467, 0.05821849, 0.09218150, 0.12721151, 0.15646823, 0.20083918], 299 | [0.00196403, 0.02051174, 0.05868442, 0.09549414, 0.13213082, 0.15206195, 0.20078899], 300 | [0.00204288, 0.02245222, 0.05068083, 0.09364992, 0.12617932, 0.15401575, 0.19684762], 301 | [0.00255693, 0.02551603, 0.05791744, 0.09331609, 0.12521063, 0.15728435, 0.19802830], 302 | [0.00206643, 0.02171597, 0.05827277, 0.08563917, 0.12383948, 0.15793253, 0.19470029], 303 | [0.00250368, 0.02256077, 0.05704192, 0.09220915, 0.13220045, 0.15089153, 0.19225907]]) 304 | 305 | traceErrorBDS = torch.tensor([[0.22224732, 0.07398839, 0.04687803, 0.04231799, 0.03719367, 0.03312020, 0.02866898], 306 | [0.22899885, 0.06637849, 0.05120244, 0.03810278, 0.03313917, 0.02972283, 0.02768398], 307 | [0.22437161, 0.06693737, 0.04891766, 0.04001211, 0.03618139, 0.03179351, 0.02886512], 308 | [0.23279119, 0.06872935, 0.04922526, 0.03858438, 0.03428061, 0.03056928, 0.02815522], 309 | [0.22483462, 0.06861812, 0.04778925, 0.04283748, 0.03553044, 0.03134139, 0.02784217], 310 | [0.24206249, 0.07523844, 0.05557442, 0.04315356, 0.03542253, 0.03241287, 0.02874050], 311 | [0.21107937, 0.06687794, 0.04701358, 0.04142660, 0.03364031, 0.03075318, 0.02824883], 312 | [0.22436637, 0.06825909, 0.04680301, 0.03896724, 0.03722268, 0.03046616, 0.02864970], 313 | [0.22186267, 0.06819139, 0.04920222, 0.03770087, 0.03415415, 0.03086751, 0.02698998], 314 | [0.22473182, 0.06998120, 0.04632138, 0.04015036, 0.03232786, 0.03110034, 0.02825256], 315 | [0.20607713, 0.06801096, 0.04883099, 0.03788529, 0.03505293, 0.03012344, 0.02555213], 316 | [0.21236718, 0.07406570, 0.04891364, 0.03851217, 0.03553124, 0.03287421, 0.02734543], 317 | [0.21273059, 0.06926261, 0.04765013, 0.04096797, 0.03521183, 0.03151200, 0.02731446], 318 | [0.20699681, 0.06645167, 0.05195054, 0.03938620, 0.03577903, 0.03134636, 0.02641797], 319 | [0.21683121, 0.06875283, 0.04659092, 0.03796173, 0.03430184, 0.03016153, 0.02521137], 320 | [0.23120171, 0.07039053, 0.05147018, 0.03970526, 0.03499187, 0.03093808, 0.02673643], 321 | [0.22178663, 0.07050845, 0.04779807, 0.03904643, 0.03483230, 0.03055800, 0.02810326], 322 | [0.23234105, 0.07264428, 0.04754635, 0.04069406, 0.03513086, 0.03031974, 0.02696989], 323 | [0.23352592, 0.07050001, 0.05279808, 0.04384740, 0.03565584, 0.03113800, 0.02937748], 324 | [0.22431205, 0.06739795, 0.04708927, 0.04052063, 0.03416919, 0.03165910, 0.02772815]]) 325 | 326 | exactTimingBDS = torch.tensor([[0.00243334], [0.00240138], [0.00223862], [0.00191088], [0.00186822], [0.00238966], 327 | [0.00229366], [0.00193875], [0.00300925], [0.00189766], [0.00211056], [0.00193008], 328 | [0.00190336], [0.00192000], [0.00198621], [0.00190147], [0.00194582], [0.00254944], 329 | [0.00196250], [0.00239798]]) 330 | 331 | # MNIST 332 | domainMNIST = [1, 100, 200, 300, 400, 500, 600, 700, 784] 333 | 334 | approxTimingMNIST = torch.tensor([[2.08383985e-03, 3.03277105e-01, 6.34796023e-01, 9.73454297e-01, 335 | 1.26875138e+00, 1.58261967e+00, 1.90680981e+00, 2.20069671e+00, 2.48040366e+00], 336 | [2.18419195e-03, 3.02417904e-01, 6.34218514e-01, 9.57537293e-01, 1.27308381e+00, 337 | 1.60399556e+00, 1.89956093e+00, 2.21912169e+00, 2.48672962e+00], 338 | [2.16268795e-03, 3.08395028e-01, 6.40161812e-01, 9.49754894e-01, 1.25998187e+00, 339 | 1.59414876e+00, 1.91681635e+00, 2.22077632e+00, 2.48915648e+00], 340 | [2.31219199e-03, 2.99216896e-01, 6.27590120e-01, 9.52667117e-01, 1.26950192e+00, 341 | 1.59411609e+00, 1.88158262e+00, 2.22522688e+00, 2.48260713e+00], 342 | [2.14118417e-03, 3.02231610e-01, 6.31080925e-01, 9.50872004e-01, 1.27247977e+00, 343 | 1.58668697e+00, 1.88744187e+00, 2.23091602e+00, 2.47214985e+00], 344 | [2.04185606e-03, 3.02155763e-01, 6.25741839e-01, 9.47705865e-01, 1.27614772e+00, 345 | 1.58635414e+00, 1.90940773e+00, 2.22117901e+00, 2.51428246e+00], 346 | [2.44326401e-03, 3.10180873e-01, 6.38756812e-01, 9.42061603e-01, 1.27602589e+00, 347 | 1.57963061e+00, 1.91455340e+00, 2.22590876e+00, 2.46318984e+00], 348 | [2.09100801e-03, 3.21886212e-01, 6.28022254e-01, 9.50401008e-01, 1.27358270e+00, 349 | 1.59494138e+00, 1.91619575e+00, 2.22641444e+00, 2.50357032e+00], 350 | [2.26099207e-03, 3.01608980e-01, 6.25809371e-01, 9.56489742e-01, 1.26890695e+00, 351 | 1.56973970e+00, 1.90321040e+00, 2.20976448e+00, 2.48652816e+00], 352 | [2.08998402e-03, 3.05045515e-01, 6.38055444e-01, 9.51280653e-01, 1.27333987e+00, 353 | 1.57056820e+00, 1.92144990e+00, 2.23278165e+00, 2.49454784e+00], 354 | [2.05004821e-03, 3.10762554e-01, 6.33757710e-01, 9.59020019e-01, 1.25874281e+00, 355 | 1.59699655e+00, 1.90024710e+00, 2.21750259e+00, 2.47988105e+00], 356 | [2.14630389e-03, 3.16698641e-01, 6.24750614e-01, 9.55690980e-01, 1.27454925e+00, 357 | 1.59527731e+00, 1.91894329e+00, 2.22634172e+00, 2.50078821e+00], 358 | [2.00806395e-03, 3.14042360e-01, 6.37971461e-01, 9.40902412e-01, 1.28086519e+00, 359 | 1.59389281e+00, 1.91719842e+00, 2.23712349e+00, 2.49246931e+00], 360 | [1.99577608e-03, 3.02376956e-01, 6.46013975e-01, 9.51083004e-01, 1.28144896e+00, 361 | 1.60936344e+00, 1.91367579e+00, 2.23405361e+00, 2.50937223e+00], 362 | [3.06073599e-03, 3.16940308e-01, 6.41007602e-01, 9.51192558e-01, 1.27902007e+00, 363 | 1.58712018e+00, 1.92633653e+00, 2.21752119e+00, 2.50338197e+00], 364 | [2.09817593e-03, 3.12108040e-01, 6.56091154e-01, 9.85183239e-01, 1.28402126e+00, 365 | 1.61969662e+00, 1.90231442e+00, 2.23895645e+00, 2.51221514e+00], 366 | [2.06336007e-03, 3.09849083e-01, 6.33301973e-01, 9.46757615e-01, 1.27894533e+00, 367 | 1.57923329e+00, 1.91604125e+00, 2.21328068e+00, 2.50347829e+00], 368 | [2.06336007e-03, 3.05488884e-01, 6.38731241e-01, 9.60516095e-01, 1.25757539e+00, 369 | 1.60113358e+00, 1.90383101e+00, 2.24165583e+00, 2.48367000e+00], 370 | [2.08076811e-03, 3.10771763e-01, 6.35176957e-01, 9.41579223e-01, 1.27896261e+00, 371 | 1.59053111e+00, 1.90248144e+00, 2.23310137e+00, 2.49199414e+00], 372 | [2.05414393e-03, 3.12963068e-01, 6.46444023e-01, 9.58285809e-01, 1.27928627e+00, 373 | 1.59574628e+00, 1.89732456e+00, 2.22769356e+00, 2.50136900e+00]]) 374 | 375 | traceErrorMNIST = torch.tensor([[0.24018253, 0.02369082, 0.01655071, 0.01396034, 0.01144823, 0.01047849, 376 | 0.00952990, 0.00854061, 0.00797484], 377 | [0.23493132, 0.02441182, 0.01750353, 0.01347727, 0.01155000, 0.01005868, 378 | 0.00924403, 0.00861205, 0.00908603], 379 | [0.23430178, 0.02245663, 0.01697846, 0.01301396, 0.01177753, 0.01026242, 380 | 0.00972797, 0.00851057, 0.00796853], 381 | [0.23591031, 0.02403234, 0.01629571, 0.01434168, 0.01231583, 0.01095246, 382 | 0.00969691, 0.00893384, 0.00820101], 383 | [0.22689532, 0.02337767, 0.01722954, 0.01429517, 0.01214339, 0.01030727, 384 | 0.00986281, 0.00866913, 0.00863269], 385 | [0.23081490, 0.02363457, 0.01650987, 0.01375823, 0.01188187, 0.01073340, 386 | 0.00933890, 0.00959375, 0.00921414], 387 | [0.25124684, 0.02325045, 0.01749201, 0.01358495, 0.01248278, 0.01140859, 388 | 0.00950113, 0.00907979, 0.00792553], 389 | [0.23507045, 0.02369924, 0.01681624, 0.01345407, 0.01110588, 0.01080476, 390 | 0.00938544, 0.00899983, 0.00792988], 391 | [0.23164381, 0.02406703, 0.01658302, 0.01367255, 0.01114006, 0.00976226, 392 | 0.00966519, 0.00825947, 0.00789442], 393 | [0.23012693, 0.02313585, 0.01634406, 0.01420588, 0.01146220, 0.01042613, 394 | 0.00921410, 0.00794939, 0.00808467], 395 | [0.22919241, 0.02389409, 0.01679990, 0.01348217, 0.01199206, 0.01094741, 396 | 0.00984883, 0.00837392, 0.00826933], 397 | [0.23617810, 0.02385402, 0.01777303, 0.01367427, 0.01193456, 0.01086325, 398 | 0.00972210, 0.00953787, 0.00868386], 399 | [0.22457570, 0.02384206, 0.01601511, 0.01331227, 0.01195829, 0.00981491, 400 | 0.00935723, 0.00853830, 0.00810188], 401 | [0.22569421, 0.02354341, 0.01674058, 0.01408734, 0.01217772, 0.01048605, 402 | 0.00958908, 0.00875280, 0.00847875], 403 | [0.23752135, 0.02474259, 0.01603253, 0.01287877, 0.01180999, 0.01017651, 404 | 0.00992901, 0.00850735, 0.00869310], 405 | [0.23726499, 0.02293676, 0.01636312, 0.01349324, 0.01208540, 0.01093696, 406 | 0.00915058, 0.00888785, 0.00856469], 407 | [0.22124900, 0.02183123, 0.01629734, 0.01310080, 0.01180154, 0.01033192, 408 | 0.00984676, 0.00878596, 0.00819684], 409 | [0.23405622, 0.02429814, 0.01630274, 0.01293633, 0.01139560, 0.01045285, 410 | 0.00971966, 0.00889976, 0.00813898], 411 | [0.24091603, 0.02339127, 0.01599100, 0.01363607, 0.01187595, 0.01049466, 412 | 0.00934176, 0.00942123, 0.00768393], 413 | [0.23984380, 0.02154738, 0.01641268, 0.01327604, 0.01174884, 0.01027578, 414 | 0.00978789, 0.00868630, 0.00847579]]) 415 | 416 | exactTimingMNIST = torch.tensor([[0.00443920], [0.00433254], [0.00421174], [0.00446614], [0.00457475], [0.00422867], 417 | [0.00455546], [0.00507485], [0.00423210], [0.00486810], [0.00440400], [0.00489318], 418 | [0.00428192], [0.00420458], [0.00457610], [0.00430054], [0.00442006], [0.00423517], 419 | [0.00457062], [0.00422797]]) 420 | 421 | 422 | lTimeExact = [exactTimingMiniboone, exactTimingBDS, exactTimingMNIST] 423 | 424 | plotTraceCompare(domainMiniboone,domainBSDS,domainMNIST, 425 | approxTimingMiniboone, approxTimingBDS, approxTimingMNIST, 426 | traceErrorMiniboone, traceErrorBDS, traceErrorMNIST, 427 | lTimeExact, bErrBar=True) 428 | 429 | 430 | 431 | 432 | -------------------------------------------------------------------------------- /src/plotter.py: -------------------------------------------------------------------------------- 1 | # plotter.py 2 | try: 3 | import matplotlib 4 | matplotlib.use('TkAgg') 5 | import matplotlib.pyplot as plt 6 | except: 7 | import matplotlib 8 | matplotlib.use('agg') # for linux server with no tkinter 9 | import matplotlib.pyplot as plt 10 | plt.rcParams['image.cmap'] = 'inferno' 11 | from src.OTFlowProblem import * 12 | import numpy as np 13 | import os 14 | import h5py 15 | import datasets 16 | from torch.nn.functional import pad 17 | from matplotlib import colors # for evaluateLarge 18 | 19 | 20 | 21 | def plot4(net, x, y, nt_val, sPath, sTitle="", doPaths=False): 22 | """ 23 | x - samples from rho_0 24 | y - samples from rho_1 25 | nt_val - number of time steps 26 | """ 27 | 28 | d = net.d 29 | nSamples = x.shape[0] 30 | 31 | 32 | fx = integrate(x[:, 0:d], net, [0.0, 1.0], nt_val, stepper="rk4", alph=net.alph) 33 | finvfx = integrate(fx[:, 0:d], net, [1.0, 0.0], nt_val, stepper="rk4", alph=net.alph) 34 | genModel = integrate(y[:, 0:d], net, [1.0, 0.0], nt_val, stepper="rk4", alph=net.alph) 35 | 36 | invErr = torch.norm(x[:,0:d] - finvfx[:,0:d]) / x.shape[0] 37 | 38 | nBins = 33 39 | LOWX = -4 40 | HIGHX = 4 41 | LOWY = -4 42 | HIGHY = 4 43 | 44 | if d > 50: # assuming bsds 45 | # plot dimensions d1 vs d2 46 | d1=0 47 | d2=1 48 | LOWX = -0.15 # note: there's a hard coded 4 and -4 in axs 2 49 | HIGHX = 0.15 50 | LOWY = -0.15 51 | HIGHY = 0.15 52 | if d > 700: # assuming MNIST 53 | d1=0 54 | d2=1 55 | LOWX = -10 # note: there's a hard coded 4 and -4 in axs 2 56 | HIGHX = 10 57 | LOWY = -10 58 | HIGHY = 10 59 | elif d==8: # assuming gas 60 | LOWX = -2 # note: there's a hard coded 4 and -4 in axs 2 61 | HIGHX = 2 62 | LOWY = -2 63 | HIGHY = 2 64 | d1=2 65 | d2=3 66 | nBins = 100 67 | else: 68 | d1=0 69 | d2=1 70 | 71 | fig, axs = plt.subplots(2, 2) 72 | fig.set_size_inches(12, 10) 73 | fig.suptitle(sTitle + ', inv err {:.2e}'.format(invErr)) 74 | 75 | # hist, xbins, ybins, im = axs[0, 0].hist2d(x.numpy()[:,0],x.numpy()[:,1], range=[[LOW, HIGH], [LOW, HIGH]], bins = nBins) 76 | im1 , _, _, map1 = axs[0, 0].hist2d(x.detach().cpu().numpy()[:, d1], x.detach().cpu().numpy()[:, d2], range=[[LOWX, HIGHX], [LOWY, HIGHY]], bins=nBins) 77 | axs[0, 0].set_title('x from rho_0') 78 | im2 , _, _, map2 = axs[0, 1].hist2d(fx.detach().cpu().numpy()[:, d1], fx.detach().cpu().numpy()[:, d2], range=[[-4, 4], [-4, 4]], bins = nBins) 79 | axs[0, 1].set_title('f(x)') 80 | im3 , _, _, map3 = axs[1, 0].hist2d(finvfx.detach().cpu().numpy()[: ,d1] ,finvfx.detach().cpu().numpy()[: ,d2], range=[[LOWX, HIGHX], [LOWY, HIGHY]], bins = nBins) 81 | axs[1, 0].set_title('finv( f(x) )') 82 | im4 , _, _, map4 = axs[1, 1].hist2d(genModel.detach().cpu().numpy()[:, d1], genModel.detach().cpu().numpy()[:, d2], range=[[LOWX, HIGHX], [LOWY, HIGHY]], bins = nBins) 83 | axs[1, 1].set_title('finv( y from rho1 )') 84 | 85 | fig.colorbar(map1, cax=fig.add_axes([0.47, 0.53, 0.02, 0.35]) ) 86 | fig.colorbar(map2, cax=fig.add_axes([0.89, 0.53, 0.02, 0.35]) ) 87 | fig.colorbar(map3, cax=fig.add_axes([0.47, 0.11, 0.02, 0.35]) ) 88 | fig.colorbar(map4, cax=fig.add_axes([0.89, 0.11, 0.02, 0.35]) ) 89 | 90 | 91 | # plot paths 92 | if doPaths: 93 | forwPath = integrate(x[:, 0:d], net, [0.0, 1.0], nt_val, stepper="rk4", alph=net.alph, intermediates=True) 94 | backPath = integrate(fx[:, 0:d], net, [1.0, 0.0], nt_val, stepper="rk4", alph=net.alph, intermediates=True) 95 | 96 | # plot the forward and inverse trajectories of several points; white is forward, red is inverse 97 | nPts = 10 98 | pts = np.unique(np.random.randint(nSamples, size=nPts)) 99 | for pt in pts: 100 | axs[0, 0].plot(forwPath[pt, 0, :].detach().cpu().numpy(), forwPath[pt, 1, :].detach().cpu().numpy(), color='white', linewidth=4) 101 | axs[0, 0].plot(backPath[pt, 0, :].detach().cpu().numpy(), backPath[pt, 1, :].detach().cpu().numpy(), color='red', linewidth=2) 102 | 103 | for i in range(axs.shape[0]): 104 | for j in range(axs.shape[1]): 105 | # axs[i, j].get_yaxis().set_visible(False) 106 | # axs[i, j].get_xaxis().set_visible(False) 107 | axs[i ,j].set_aspect('equal') 108 | 109 | # sPath = os.path.join(args.save, 'figs', sStartTime + '_{:04d}.png'.format(itr)) 110 | if not os.path.exists(os.path.dirname(sPath)): 111 | os.makedirs(os.path.dirname(sPath)) 112 | plt.savefig(sPath, dpi=300) 113 | plt.close() 114 | 115 | 116 | 117 | def plotAutoEnc(x, xRecreate, sPath): 118 | 119 | # assume square image 120 | s = int(math.sqrt(x.shape[1])) 121 | 122 | 123 | nex = 8 124 | 125 | fig, axs = plt.subplots(4, nex//2) 126 | fig.set_size_inches(9, 9) 127 | fig.suptitle("first 2 rows originals. Rows 3 and 4 are generations.") 128 | 129 | for i in range(nex//2): 130 | axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy()) 131 | axs[1, i].imshow(x[ nex//2 + i , : ].reshape(s,s).detach().cpu().numpy()) 132 | axs[2, i].imshow(xRecreate[i,:].reshape(s,s).detach().cpu().numpy()) 133 | axs[3, i].imshow(xRecreate[ nex//2 + i , : ].reshape(s, s).detach().cpu().numpy()) 134 | 135 | 136 | for i in range(axs.shape[0]): 137 | for j in range(axs.shape[1]): 138 | axs[i, j].get_yaxis().set_visible(False) 139 | axs[i, j].get_xaxis().set_visible(False) 140 | axs[i ,j].set_aspect('equal') 141 | 142 | plt.subplots_adjust(wspace=0.0, hspace=0.0) 143 | 144 | if not os.path.exists(os.path.dirname(sPath)): 145 | os.makedirs(os.path.dirname(sPath)) 146 | plt.savefig(sPath, dpi=300) 147 | plt.close() 148 | 149 | 150 | def plotAutoEnc3D(x, xRecreate, sPath): 151 | 152 | nex = 8 153 | 154 | fig, axs = plt.subplots(4, nex//2) 155 | fig.set_size_inches(9, 9) 156 | fig.suptitle("first 2 rows originals. Rows 3 and 4 are generations.") 157 | 158 | for i in range(nex//2): 159 | axs[0, i].imshow(x[i,:].permute(1,2,0).detach().cpu().numpy()) 160 | axs[1, i].imshow(x[ nex//2 + i , : ].permute(1,2,0).detach().cpu().numpy()) 161 | axs[2, i].imshow(xRecreate[i,:].permute(1,2,0).detach().cpu().numpy()) 162 | axs[3, i].imshow(xRecreate[ nex//2 + i , : ].permute(1,2,0).detach().cpu().numpy()) 163 | 164 | 165 | for i in range(axs.shape[0]): 166 | for j in range(axs.shape[1]): 167 | axs[i, j].get_yaxis().set_visible(False) 168 | axs[i, j].get_xaxis().set_visible(False) 169 | axs[i ,j].set_aspect('equal') 170 | 171 | plt.subplots_adjust(wspace=0.0, hspace=0.0) 172 | 173 | if not os.path.exists(os.path.dirname(sPath)): 174 | os.makedirs(os.path.dirname(sPath)) 175 | plt.savefig(sPath, dpi=300) 176 | plt.close() 177 | 178 | 179 | 180 | def plotImageGen(x, xRecreate, sPath): 181 | 182 | # assume square image 183 | s = int(math.sqrt(x.shape[1])) 184 | 185 | nex = 80 186 | nCols = nex//5 187 | 188 | 189 | fig, axs = plt.subplots(7, nCols) 190 | fig.set_size_inches(16, 7) 191 | fig.suptitle("first 2 rows originals. Rows 3 and 4 are generations.") 192 | 193 | for i in range(nCols): 194 | axs[0, i].imshow(x[i,:].reshape(s,s).detach().cpu().numpy()) 195 | # axs[1, i].imshow(x[ nex//3 + i , : ].reshape(s,s).detach().cpu().numpy()) 196 | # axs[2, i].imshow(x[ 2*nex//3 + i , : ].reshape(s,s).detach().cpu().numpy()) 197 | axs[2, i].imshow(xRecreate[i,:].reshape(s,s).detach().cpu().numpy()) 198 | axs[3, i].imshow(xRecreate[nCols + i,:].reshape(s,s).detach().cpu().numpy()) 199 | 200 | axs[4, i].imshow(xRecreate[2*nCols + i,:].reshape(s,s).detach().cpu().numpy()) 201 | axs[5, i].imshow(xRecreate[3*nCols + i , : ].reshape(s, s).detach().cpu().numpy()) 202 | axs[6, i].imshow(xRecreate[4*nCols + i , : ].reshape(s, s).detach().cpu().numpy()) 203 | 204 | for i in range(axs.shape[0]): 205 | for j in range(axs.shape[1]): 206 | axs[i, j].get_yaxis().set_visible(False) 207 | axs[i, j].get_xaxis().set_visible(False) 208 | axs[i ,j].set_aspect('equal') 209 | 210 | plt.subplots_adjust(wspace=0.0, hspace=0.0) 211 | 212 | if not os.path.exists(os.path.dirname(sPath)): 213 | os.makedirs(os.path.dirname(sPath)) 214 | plt.savefig(sPath, dpi=300) 215 | plt.close() 216 | 217 | 218 | def plot4mnist(x, sPath, sTitle=""): 219 | """ 220 | x - tensor (>4, 28,28) 221 | """ 222 | fig, axs = plt.subplots(2, 2) 223 | fig.set_size_inches(12, 10) 224 | fig.suptitle(sTitle) 225 | 226 | im1 = axs[0, 0].imshow(x[0,:,:].detach().cpu().numpy()) 227 | im2 = axs[0, 1].imshow(x[1,:,:].detach().cpu().numpy()) 228 | im3 = axs[1, 0].imshow(x[2,:,:].detach().cpu().numpy()) 229 | im4 = axs[1, 1].imshow(x[3,:,:].detach().cpu().numpy()) 230 | 231 | fig.colorbar(im1, cax=fig.add_axes([0.47, 0.53, 0.02, 0.35]) ) 232 | fig.colorbar(im2, cax=fig.add_axes([0.89, 0.53, 0.02, 0.35]) ) 233 | fig.colorbar(im3, cax=fig.add_axes([0.47, 0.11, 0.02, 0.35]) ) 234 | fig.colorbar(im4, cax=fig.add_axes([0.89, 0.11, 0.02, 0.35]) ) 235 | 236 | for i in range(axs.shape[0]): 237 | for j in range(axs.shape[1]): 238 | axs[i, j].get_yaxis().set_visible(False) 239 | axs[i, j].get_xaxis().set_visible(False) 240 | axs[i ,j].set_aspect('equal') 241 | 242 | # sPath = os.path.join(args.save, 'figs', sStartTime + '_{:04d}.png'.format(itr)) 243 | if not os.path.exists(os.path.dirname(sPath)): 244 | os.makedirs(os.path.dirname(sPath)) 245 | plt.savefig(sPath, dpi=300) 246 | plt.close() 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /test/gradTestOTFlowProblem.py: -------------------------------------------------------------------------------- 1 | # testOTFlowProblem.py 2 | # 3 | 4 | # gradient check of OTFlowProblem 5 | import matplotlib 6 | matplotlib.use('TkAgg') 7 | import matplotlib.pyplot as plt 8 | 9 | from src.Phi import * 10 | from src.OTFlowProblem import * 11 | import torch.nn.utils 12 | 13 | 14 | doPlots = True 15 | 16 | d = 5 17 | m = 16 18 | 19 | net = Phi(nTh=2, m=m, d=d) 20 | net.double() 21 | 22 | # vecParams = nn.utils.convert_parameters.parameters_to_vector(net.parameters()) 23 | x = torch.randn(1,d+1).type(torch.double) 24 | # net(x) 25 | 26 | v = torch.randn(x.shape).type(torch.double) 27 | # ------------------------------------------------ 28 | # f is the full OTFlowProblem 29 | # OTFlowProblem(x, Phi, tspan , nt, stepper="rk1", alph =[1.0,1.0,1.0,1.0,1.0] ) 30 | 31 | input = torch.randn(1,d+1).type(torch.double) 32 | 33 | fx = torch.sum(OTFlowProblem(input[:,0:d], net, [0.0, 1.0] , nt=2, stepper="rk4", alph =[1.0,1.0,1.0,1.0,1.0] )[0]) 34 | vecX = copy.copy(nn.utils.convert_parameters.parameters_to_vector(net.parameters())) 35 | v = torch.randn(vecX.shape).type(torch.double) 36 | 37 | netV = Phi(nTh=2, m=m, d=d).double() # make another copy for shape info 38 | 39 | # jacobian of fx wrt x 40 | # g = torch.autograd.grad(fx,net.w.weight, retain_graph=True, create_graph=True, allow_unused=True) 41 | g = torch.autograd.grad(fx,net.parameters(), retain_graph=True, create_graph=True, allow_unused=True) 42 | 43 | nn.utils.convert_parameters.vector_to_parameters(v, netV.parameters()) # structure v into the tensors 44 | gv = 0.0 45 | for gi, vi in zip(g, netV.parameters()): 46 | if gi is not None: # if gi is None, then that means the gradient there is 0 47 | gv += torch.matmul(gi.view(1, -1), vi.view(-1, 1)) 48 | 49 | 50 | 51 | niter = 20 52 | h0 = 0.1 53 | E0 = [] 54 | E1 = [] 55 | hlist = [] 56 | 57 | 58 | for i in range(1,niter): 59 | h = h0**i 60 | hlist.append(h) 61 | 62 | newVec = vecX + h*v 63 | 64 | nn.utils.convert_parameters.vector_to_parameters(newVec, net.parameters()) # set parameters 65 | fxhv = torch.sum(OTFlowProblem(input[:,0:d], net, [0.0, 1.0] , nt=2, stepper="rk4", alph =[1.0,1.0,1.0,1.0,1.0] )[0]) 66 | 67 | # print(newVec[0:3]) 68 | # print(vecX[0:3]) 69 | # print("{:.6f} {:.6f} {:.6e}".format(fxhv.item(), fx.item() , torch.norm(net.w.weight - torch.ones(5).type(torch.double)).item())) 70 | 71 | fdiff = fxhv - fx 72 | 73 | res0 = torch.norm(fdiff) 74 | E0.append( res0 ) 75 | 76 | res1 = torch.norm(fdiff - h * gv) 77 | E1.append( res1 ) 78 | 79 | print(" ") 80 | for i in range(niter-1): 81 | print("{:e} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 82 | 83 | 84 | if doPlots: 85 | plt.plot(hlist,E0, label='E0') 86 | plt.plot(hlist,E1, label='E1') 87 | plt.yscale('log') 88 | plt.xscale('log') 89 | plt.legend() 90 | plt.show() 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | # CHECK JUST ONE PARAMETER TENSOR 106 | 107 | # d = 5 108 | # m = 16 109 | # 110 | # net = Phi(nTh=2, m=m, d=d) 111 | # net.double() 112 | 113 | 114 | # ------------------------------------------------ 115 | # f is the full OTFlowProblem 116 | # OTFlowProblem(x, Phi, tspan , nt, stepper="rk1", alph =[1.0,1.0,1.0,1.0,1.0] ) 117 | 118 | input = torch.randn(1,d+1).type(torch.double) 119 | 120 | fx = torch.sum(OTFlowProblem(input[:,0:d], net, [0.0, 1.0] , nt=2, stepper="rk4", alph =[1.0,1.0,1.0,1.0,1.0] )[0]) 121 | x = net.N.layers[0].weight.data 122 | v = torch.randn(net.N.layers[0].weight.data.shape).type(torch.double) 123 | 124 | # netV = Phi(nTh=2, m=m, d=d).double() # make another copy for shape info 125 | 126 | # jacobian of fx wrt x 127 | # g = torch.autograd.grad(fx,net.w.weight, retain_graph=True, create_graph=True, allow_unused=True) 128 | g = torch.autograd.grad(fx,net.N.layers[0].weight, retain_graph=True, create_graph=True, allow_unused=True)[0] 129 | 130 | # nn.utils.convert_parameters.vector_to_parameters(v, netV.parameters()) # structure v into the tensors 131 | # gv = 0.0 132 | # for gi, vi in zip(g, netV.parameters()): 133 | # if gi is not None: # if gi is None, then that means the gradient there is 0 134 | # gv += torch.matmul(gi.view(1, -1), vi.view(-1, 1)) 135 | 136 | gv = torch.matmul(g.view(1, -1), v.view(-1, 1)) 137 | 138 | 139 | niter = 20 140 | h0 = 0.1 141 | E0 = [] 142 | E1 = [] 143 | hlist = [] 144 | 145 | 146 | for i in range(1,niter): 147 | h = h0**i 148 | hlist.append(h) 149 | 150 | net.N.layers[0].weight.data = x + h*v 151 | 152 | fxhv = torch.sum(OTFlowProblem(input[:,0:d], net, [0.0, 1.0] , nt=2, stepper="rk4", alph =[1.0,1.0,1.0,1.0,1.0] )[0]) 153 | 154 | # print(newVec[0:3]) 155 | # print(vecX[0:3]) 156 | # print("{:.6f} {:.6f} {:.6e}".format(fxhv.item(), fx.item() , torch.norm(net.w.weight - torch.ones(5).type(torch.double)).item())) 157 | 158 | fdiff = fxhv - fx 159 | 160 | res0 = torch.norm(fdiff) 161 | E0.append( res0 ) 162 | 163 | res1 = torch.norm(fdiff - h * gv) 164 | E1.append( res1 ) 165 | 166 | print(" ") 167 | for i in range(niter-1): 168 | print("{:e} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 169 | 170 | 171 | if doPlots: 172 | plt.plot(hlist,E0, label='E0') 173 | plt.plot(hlist,E1, label='E1') 174 | plt.yscale('log') 175 | plt.xscale('log') 176 | plt.legend() 177 | plt.show() 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /test/gradTestTrHess.py: -------------------------------------------------------------------------------- 1 | # gradTestTrHess.py 2 | 3 | import matplotlib 4 | matplotlib.use('TkAgg') 5 | import matplotlib.pyplot as plt 6 | from src.Phi import Phi 7 | import torch.nn.utils 8 | import copy 9 | import torch.nn as nn 10 | 11 | 12 | 13 | doPlots = True 14 | 15 | 16 | d = 2 17 | m = 5 18 | nTh = 3 19 | 20 | net = Phi(nTh=nTh, m=m, d=d) 21 | net.double() 22 | 23 | # vecParams = nn.utils.convert_parameters.parameters_to_vector(net.parameters()) 24 | x = torch.randn(1,3).type(torch.double) 25 | x.requires_grad = True 26 | y = net(x) 27 | v = torch.randn(x.shape).type(torch.double) 28 | 29 | # ------------------------------------------------ 30 | # f 31 | # nablaPhi = net.trHess(x)[0] 32 | 33 | g = net.trHess(x)[0] 34 | 35 | niter = 20 36 | h0 = 0.1 37 | E0 = [] 38 | E1 = [] 39 | hlist = [] 40 | 41 | 42 | for i in range(niter): 43 | h = h0**i 44 | hlist.append(h) 45 | E0.append( torch.norm(net( x + h * v ) - net(x)) ) 46 | E1.append( torch.norm(net( x + h * v ) - net(x) - h * torch.matmul(g , v.t())) ) 47 | 48 | for i in range(niter): 49 | print("{:f} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 50 | 51 | if doPlots: 52 | plt.plot(hlist,E0, label='E0') 53 | plt.plot(hlist,E1, label='E1') 54 | plt.yscale('log') 55 | plt.xscale('log') 56 | plt.legend() 57 | plt.show() 58 | 59 | print("\n") 60 | # ------------------------------------------------ 61 | # f is the gradient wrt x computation 62 | # trH = net.trHess(x)[1] 63 | 64 | input = torch.randn(1,3).type(torch.double) 65 | 66 | fx = net.trHess(input)[1] 67 | vecX = copy.copy(nn.utils.convert_parameters.parameters_to_vector(net.parameters())) 68 | v = torch.randn(vecX.shape).type(torch.double) 69 | 70 | netV = Phi(nTh=nTh, m=m, d=d).double() # make another copy for shape info 71 | 72 | # jacobian of fx wrt x 73 | # g = torch.autograd.grad(fx,net.w.weight, retain_graph=True, create_graph=True, allow_unused=True) 74 | g = torch.autograd.grad(fx,net.parameters(), retain_graph=True, create_graph=True, allow_unused=True) 75 | 76 | nn.utils.convert_parameters.vector_to_parameters(v, netV.parameters()) # structure v into the tensors 77 | gv = 0.0 78 | for gi, vi in zip(g, netV.parameters()): 79 | if gi is not None: # if gi is None, then that means the gradient there is 0 80 | gv += torch.matmul(gi.view(1, -1), vi.view(-1, 1)) 81 | 82 | niter = 20 83 | h0 = 0.1 84 | E0 = [] 85 | E1 = [] 86 | hlist = [] 87 | 88 | for i in range(1,niter): 89 | h = h0**i 90 | hlist.append(h) 91 | 92 | newVec = vecX + h*v 93 | 94 | nn.utils.convert_parameters.vector_to_parameters(newVec, net.parameters()) # set parameters 95 | fxhv = net.trHess(input)[1] 96 | 97 | # print(newVec[0:3]) 98 | # print(vecX[0:3]) 99 | # print("{:.6f} {:.6f} {:.6e}".format(fxhv.item(), fx.item() , torch.norm(net.w.weight - torch.ones(5).type(torch.double)).item())) 100 | 101 | fdiff = fxhv - fx 102 | 103 | res0 = torch.norm(fdiff) 104 | E0.append( res0 ) 105 | 106 | res1 = torch.norm(fdiff - h * gv) 107 | E1.append( res1 ) 108 | 109 | print(" ") 110 | for i in range(niter-1): 111 | print("{:e} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 112 | 113 | 114 | if doPlots: 115 | plt.plot(hlist,E0, label='E0') 116 | plt.plot(hlist,E1, label='E1') 117 | plt.yscale('log') 118 | plt.xscale('log') 119 | plt.legend() 120 | plt.show() 121 | 122 | 123 | print("\n") 124 | # ------------------------------------------------ 125 | # f is the trace of the Hessian computation 126 | # trH = net.trHess(x)[0] 127 | 128 | input = torch.randn(1,3).type(torch.double) 129 | 130 | fx = torch.sum(net.trHess(input)[0]) 131 | vecX = copy.copy(nn.utils.convert_parameters.parameters_to_vector(net.parameters())) 132 | v = torch.randn(vecX.shape).type(torch.double) 133 | 134 | netV = Phi(nTh=nTh, m=m, d=d).double() # make another copy for shape info 135 | 136 | # jacobian of fx wrt x 137 | # g = torch.autograd.grad(fx,net.w.weight, retain_graph=True, create_graph=True, allow_unused=True) 138 | g = torch.autograd.grad(fx,net.parameters(), retain_graph=True, create_graph=True, allow_unused=True) 139 | 140 | nn.utils.convert_parameters.vector_to_parameters(v, netV.parameters()) # structure v into the tensors 141 | gv = 0.0 142 | for gi, vi in zip(g, netV.parameters()): 143 | if gi is not None: # if gi is None, then that means the gradient there is 0 144 | gv += torch.matmul(gi.view(1, -1), vi.view(-1, 1)) 145 | 146 | 147 | 148 | niter = 20 149 | h0 = 0.1 150 | E0 = [] 151 | E1 = [] 152 | hlist = [] 153 | 154 | 155 | for i in range(1,niter): 156 | h = h0**i 157 | hlist.append(h) 158 | 159 | newVec = vecX + h*v 160 | 161 | nn.utils.convert_parameters.vector_to_parameters(newVec, net.parameters()) # set parameters 162 | fxhv = torch.sum(net.trHess(input)[0]) 163 | 164 | # print(newVec[0:3]) 165 | # print(vecX[0:3]) 166 | # print("{:.6f} {:.6f} {:.6e}".format(fxhv.item(), fx.item() , torch.norm(net.w.weight - torch.ones(5).type(torch.double)).item())) 167 | 168 | fdiff = fxhv - fx 169 | 170 | res0 = torch.norm(fdiff) 171 | E0.append( res0 ) 172 | 173 | res1 = torch.norm(fdiff - h * gv) 174 | E1.append( res1 ) 175 | 176 | print(" ") 177 | for i in range(niter-1): 178 | print("{:e} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 179 | 180 | 181 | if doPlots: 182 | plt.plot(hlist,E0, label='E0') 183 | plt.plot(hlist,E1, label='E1') 184 | plt.yscale('log') 185 | plt.xscale('log') 186 | plt.legend() 187 | plt.show() 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /test/testPhiGradx.py: -------------------------------------------------------------------------------- 1 | # testPhiGradx.py 2 | # test the grad wrt x returned by trHess when nTh > 2 3 | 4 | import matplotlib 5 | matplotlib.use('TkAgg') 6 | import matplotlib.pyplot as plt 7 | from src.Phi import * 8 | import torch.nn.utils 9 | 10 | doPlots = True 11 | 12 | d = 2 13 | m = 5 14 | nTh = 4 15 | 16 | net = Phi(nTh=nTh, m=m, d=d) 17 | net.double() 18 | 19 | # vecParams = nn.utils.convert_parameters.parameters_to_vector(net.parameters()) 20 | x = torch.randn(1,3).type(torch.double) 21 | # AD grad 22 | x.requires_grad = True 23 | y = net(x) 24 | 25 | v = torch.randn(x.shape).type(torch.double) 26 | 27 | # ------------------------------------------------ 28 | # f 29 | # nablaPhi = net.trHess(x)[0] 30 | 31 | g = net.trHess(x)[0] 32 | 33 | 34 | niter = 20 35 | h0 = 0.5 36 | E0 = [] 37 | E1 = [] 38 | hlist = [] 39 | 40 | 41 | for i in range(niter): 42 | h = h0**i 43 | hlist.append(h) 44 | E0.append( torch.norm(net( x + h * v ) - net(x)) ) 45 | E1.append( torch.norm(net( x + h * v ) - net(x) - h * torch.matmul(g , v.t())) ) 46 | 47 | for i in range(niter): 48 | print("{:f} {:.6e} {:.6e}".format( hlist[i] , E0[i].item() , E1[i].item() )) 49 | 50 | if doPlots: 51 | plt.plot(hlist,E0, label='E0') 52 | plt.plot(hlist,E1, label='E1') 53 | plt.yscale('log') 54 | plt.xscale('log') 55 | plt.legend() 56 | plt.show() 57 | 58 | 59 | 60 | print("\n") -------------------------------------------------------------------------------- /test/testPhiOpt.py: -------------------------------------------------------------------------------- 1 | # testPhiOpt.py 2 | 3 | from src.Phi import Phi 4 | from src.PhiHC import PhiHC 5 | from src.OTFlowProblem import * 6 | import torch 7 | import time 8 | 9 | # we know PhiHardCoded is accurate, nTh is hardcoded in there 10 | # we want to generalize it so nTh > 2 11 | 12 | if __name__ == '__main__': 13 | 14 | d = 8 15 | m = 16 16 | nTh = 2 17 | alph = [1.0,1.0,3.0,5.0,1.0] 18 | 19 | torch.manual_seed(0) 20 | net = Phi(nTh, m, d, alph=alph) 21 | net = net.to(torch.double) 22 | 23 | torch.manual_seed(0) 24 | netLoop = PhiHC(nTh, m,d,alph=alph) 25 | netLoop = netLoop.to(torch.double) 26 | 27 | nex = 10000 28 | x = torch.randn(nex,d+1).to(torch.double) 29 | 30 | end = time.time() 31 | y = net(x) 32 | print("time: ", time.time()-end) 33 | 34 | end = time.time() 35 | yLoop = netLoop(x) 36 | print("time: ", time.time()-end) 37 | 38 | 39 | print("Phi err: ", torch.norm(y-yLoop).item()) 40 | 41 | end = time.time() 42 | y1,y2 = net.trHess(x) 43 | print("time: ", time.time()-end) 44 | end = time.time() 45 | yLoop1,yLoop2 = netLoop.trHess(x) 46 | print("time: ", time.time()-end) 47 | print("grad err: ", torch.norm(y1 - yLoop1).item()) 48 | print("traceHess err: ", torch.norm(y2 - yLoop2).item()) 49 | 50 | 51 | -------------------------------------------------------------------------------- /trainLargeOTflow.py: -------------------------------------------------------------------------------- 1 | # trainLargeOTflow.py 2 | # train OT-Flow for the large density estimation data sets 3 | import argparse 4 | import os 5 | import time 6 | import datetime 7 | import torch.optim as optim 8 | import numpy as np 9 | import math 10 | import lib.toy_data as toy_data 11 | import lib.utils as utils 12 | from lib.utils import count_parameters 13 | 14 | from src.plotter import plot4 15 | from src.OTFlowProblem import * 16 | from src.Phi import * 17 | import config 18 | import datasets 19 | 20 | cf = config.getconfig() 21 | 22 | if cf.gpu: 23 | def_viz_freq = 200 24 | def_batch = 2000 25 | def_niter = 8000 26 | def_m = 256 27 | def_val_freq = 0 28 | else: # if no gpu on platform, assume debugging on a local cpu 29 | def_viz_freq = 20 30 | def_val_freq = 20 31 | def_batch = 200 32 | def_niter = 2000 33 | def_m = 16 34 | 35 | parser = argparse.ArgumentParser('OT-Flow') 36 | parser.add_argument( 37 | '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300','mnist'], type=str, default='miniboone' 38 | ) 39 | 40 | parser.add_argument("--nt" , type=int, default=6, help="number of time steps") 41 | parser.add_argument("--nt_val", type=int, default=10, help="number of time steps for validation") 42 | parser.add_argument('--alph' , type=str, default='1.0,100.0,15.0') 43 | parser.add_argument('--m' , type=int, default=def_m) 44 | parser.add_argument('--nTh' , type=int, default=2) 45 | 46 | parser.add_argument('--lr' , type=float, default=0.01) 47 | parser.add_argument("--drop_freq", type=int , default=0, help="how often to decrease learning rate; 0 lets the mdoel choose") 48 | parser.add_argument("--lr_drop" , type=float, default=10.0, help="how much to decrease learning rate (divide by)") 49 | parser.add_argument('--weight_decay', type=float, default=0.0) 50 | 51 | parser.add_argument('--prec' , type=str, default='single', choices=['single','double'], help="single or double precision") 52 | parser.add_argument('--niters' , type=int, default=def_niter) 53 | parser.add_argument('--batch_size', type=int, default=def_batch) 54 | parser.add_argument('--test_batch_size', type=int, default=def_batch) 55 | 56 | parser.add_argument('--resume', type=str, default=None) 57 | parser.add_argument('--evaluate', action='store_true') 58 | parser.add_argument('--early_stopping', type=int, default=20) 59 | 60 | parser.add_argument('--save', type=str, default='experiments/cnf/large') 61 | parser.add_argument('--viz_freq', type=int, default=def_viz_freq) 62 | parser.add_argument('--val_freq', type=int, default=def_val_freq) # validation frequency needs to be less than viz_freq or equal to viz_freq 63 | parser.add_argument('--log_freq', type=int, default=10) 64 | parser.add_argument('--gpu', type=int, default=0) 65 | args = parser.parse_args() 66 | 67 | args.alph = [float(item) for item in args.alph.split(',')] 68 | 69 | # add timestamp to save path 70 | start_time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 71 | 72 | # logger 73 | utils.makedirs(args.save) 74 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 75 | logger.info("start time: " + start_time) 76 | logger.info(args) 77 | 78 | test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size 79 | 80 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 81 | 82 | if args.prec =='double': 83 | prec = torch.float64 84 | else: 85 | prec = torch.float32 86 | 87 | 88 | 89 | def batch_iter(X, batch_size=args.batch_size, shuffle=False): 90 | """ 91 | X: feature tensor (shape: num_instances x num_features) 92 | """ 93 | if shuffle: 94 | idxs = torch.randperm(X.shape[0]) 95 | else: 96 | idxs = torch.arange(X.shape[0]) 97 | if X.is_cuda: 98 | idxs = idxs.cuda() 99 | for batch_idxs in idxs.split(batch_size): 100 | yield X[batch_idxs] 101 | 102 | 103 | # decrease the learning rate based on validation 104 | ndecs = 0 105 | n_vals_wo_improve=0 106 | def update_lr(optimizer, n_vals_without_improvement): 107 | global ndecs 108 | if ndecs == 0 and n_vals_without_improvement > args.early_stopping: 109 | for param_group in optimizer.param_groups: 110 | param_group["lr"] = args.lr / args.lr_drop 111 | ndecs = 1 112 | elif ndecs == 1 and n_vals_without_improvement > args.early_stopping: 113 | for param_group in optimizer.param_groups: 114 | param_group["lr"] = args.lr / args.lr_drop**2 115 | ndecs = 2 116 | else: 117 | ndecs += 1 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = args.lr / args.lr_drop**ndecs 120 | 121 | 122 | def load_data(name): 123 | 124 | if name == 'bsds300': 125 | return datasets.BSDS300() 126 | 127 | elif name == 'power': 128 | return datasets.POWER() 129 | 130 | elif name == 'gas': 131 | return datasets.GAS() 132 | 133 | elif name == 'hepmass': 134 | return datasets.HEPMASS() 135 | 136 | elif name == 'miniboone': 137 | return datasets.MINIBOONE() 138 | 139 | else: 140 | raise ValueError('Unknown dataset') 141 | 142 | 143 | def compute_loss(net, x, nt): 144 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 145 | return Jc, cs 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | 151 | cvt = lambda x: x.type(prec).to(device, non_blocking=True) 152 | 153 | data = load_data(args.data) 154 | data.trn.x = torch.from_numpy(data.trn.x) 155 | print(data.trn.x.shape) 156 | data.val.x = torch.from_numpy(data.val.x) 157 | 158 | # hyperparameters of model 159 | d = data.trn.x.shape[1] 160 | nt = args.nt 161 | nt_val = args.nt_val 162 | nTh = args.nTh 163 | m = args.m 164 | 165 | # set up neural network to model potential function Phi 166 | net = Phi(nTh=nTh, m=m, d=d, alph=args.alph) 167 | net = net.to(prec).to(device) 168 | 169 | 170 | # resume training on a model that's already had some training 171 | if args.resume is not None: 172 | # reload model 173 | checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 174 | m = checkpt['args'].m 175 | alph = args.alph # overwrite saved alpha 176 | nTh = checkpt['args'].nTh 177 | args.hutch = checkpt['args'].hutch 178 | net = Phi(nTh=nTh, m=m, d=d, alph=alph) # the phi aka the value function 179 | prec = checkpt['state_dict']['A'].dtype 180 | net = net.to(prec) 181 | net.load_state_dict(checkpt["state_dict"]) 182 | net = net.to(device) 183 | 184 | if args.val_freq == 0: 185 | # if val_freq set to 0, then validate after every epoch 186 | args.val_freq = math.ceil(data.trn.x.shape[0]/args.batch_size) 187 | 188 | # ADAM optimizer 189 | optim = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 190 | 191 | logger.info(net) 192 | logger.info("-------------------------") 193 | logger.info("DIMENSION={:} m={:} nTh={:} alpha={:}".format(d,m,nTh,net.alph)) 194 | logger.info("nt={:} nt_val={:}".format(nt,nt_val)) 195 | logger.info("Number of trainable parameters: {}".format(count_parameters(net))) 196 | logger.info("-------------------------") 197 | logger.info(str(optim)) # optimizer info 198 | logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu)) 199 | logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(args.niters, args.val_freq, args.viz_freq)) 200 | logger.info("saveLocation = {:}".format(args.save)) 201 | logger.info("-------------------------\n") 202 | 203 | begin = time.time() 204 | end = begin 205 | best_loss = float('inf') 206 | best_cs = [0.0]*3 207 | bestParams = None 208 | 209 | log_msg = ( 210 | '{:5s} {:6s} {:7s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} '.format( 211 | 'iter', ' time','lr','loss', 'L (L2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR', 212 | ) 213 | ) 214 | logger.info(log_msg) 215 | 216 | timeMeter = utils.AverageMeter() 217 | 218 | # box constraints / acceptable range for parameter values 219 | clampMax = 1.5 220 | clampMin = -1.5 221 | 222 | net.train() 223 | itr = 1 224 | while itr < args.niters: 225 | # train 226 | for x0 in batch_iter(data.trn.x, shuffle=True): 227 | x0 = cvt(x0) 228 | optim.zero_grad() 229 | 230 | # clip parameters 231 | for p in net.parameters(): 232 | p.data = torch.clamp(p.data, clampMin, clampMax) 233 | 234 | currParams = net.state_dict() 235 | loss,cs = compute_loss(net, x0, nt=nt) 236 | loss.backward() 237 | 238 | optim.step() 239 | timeMeter.update(time.time() - end) 240 | 241 | log_message = ( 242 | '{:05d} {:6.3f} {:7.1e} {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 243 | itr, timeMeter.val, optim.param_groups[0]['lr'], loss, cs[0], cs[1], cs[2] 244 | ) 245 | ) 246 | 247 | if torch.isnan(loss): # catch NaNs when hyperparameters are poorly chosen 248 | logger.info(log_message) 249 | logger.info("NaN encountered....exiting prematurely") 250 | logger.info("Training Time: {:} seconds".format(timeMeter.sum)) 251 | logger.info('File: ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format( 252 | args.data, int(net.alph[1]), int(net.alph[2]), m) 253 | ) 254 | exit(1) 255 | 256 | # validation 257 | if itr % args.val_freq == 0 or itr == args.niters: 258 | net.eval() 259 | with torch.no_grad(): 260 | 261 | valLossMeter = utils.AverageMeter() 262 | valAlphMeterL = utils.AverageMeter() 263 | valAlphMeterC = utils.AverageMeter() 264 | valAlphMeterR = utils.AverageMeter() 265 | 266 | for x0 in batch_iter(data.val.x, batch_size=test_batch_size): 267 | x0 = cvt(x0) 268 | nex = x0.shape[0] 269 | val_loss, val_cs = compute_loss(net, x0, nt=nt_val) 270 | valLossMeter.update(val_loss.item(), nex) 271 | valAlphMeterL.update(val_cs[0].item(), nex) 272 | valAlphMeterC.update(val_cs[1].item(), nex) 273 | valAlphMeterR.update(val_cs[2].item(), nex) 274 | 275 | 276 | # add to print message 277 | log_message += ' {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 278 | valLossMeter.avg, valAlphMeterL.avg, valAlphMeterC.avg, valAlphMeterR.avg 279 | ) 280 | 281 | # save best set of parameters 282 | if valLossMeter.avg < best_loss: 283 | n_vals_wo_improve = 0 284 | best_loss = valLossMeter.avg 285 | best_cs = [ valAlphMeterL.avg, valAlphMeterC.avg, valAlphMeterR.avg ] 286 | utils.makedirs(args.save) 287 | bestParams = net.state_dict() 288 | torch.save({ 289 | 'args': args, 290 | 'state_dict': bestParams, 291 | }, os.path.join(args.save, start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(net.alph[1]),int(net.alph[2]),m))) 292 | else: 293 | n_vals_wo_improve+=1 294 | 295 | net.train() 296 | log_message += ' no improve: {:d}/{:d}'.format(n_vals_wo_improve, args.early_stopping) 297 | logger.info(log_message) # print iteration 298 | 299 | # create plots for assessment mid-training 300 | if itr % args.viz_freq == 0: 301 | with torch.no_grad(): 302 | net.eval() 303 | currState = net.state_dict() 304 | net.load_state_dict(bestParams) 305 | 306 | # plot one batch 307 | p_samples = cvt(data.val.x[0:test_batch_size,:]) 308 | nSamples = p_samples.shape[0] 309 | y = cvt(torch.randn(nSamples,d)) # sampling from rho_1 / standard normal 310 | 311 | sPath = os.path.join(args.save, 'figs', start_time + '_{:04d}.png'.format(itr)) 312 | plot4(net, p_samples, y, nt_val, sPath, sTitle='loss {:.2f} , C {:.2f}'.format(best_loss, best_cs[1] )) 313 | 314 | net.load_state_dict(currState) 315 | net.train() 316 | 317 | if args.drop_freq == 0: # if set to the code setting 0 , the lr drops based on validation 318 | if n_vals_wo_improve > args.early_stopping: 319 | if ndecs>2: 320 | logger.info("early stopping engaged") 321 | logger.info("Training Time: {:} seconds".format(timeMeter.sum)) 322 | logger.info('File: ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format( 323 | args.data, int(net.alph[1]), int(net.alph[2]), m) 324 | ) 325 | exit(0) 326 | else: 327 | update_lr(optim, n_vals_wo_improve) 328 | n_vals_wo_improve = 0 329 | else: 330 | # shrink step size 331 | if itr % args.drop_freq == 0: 332 | for p in optim.param_groups: 333 | p['lr'] /= args.lr_drop 334 | print("lr: ", p['lr']) 335 | 336 | itr += 1 337 | end = time.time() 338 | # end batch_iter 339 | 340 | logger.info("Training Time: {:} seconds".format(timeMeter.sum)) 341 | logger.info('Training has finished. ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(net.alph[1]),int(net.alph[2]),m)) 342 | 343 | 344 | 345 | 346 | 347 | 348 | -------------------------------------------------------------------------------- /trainMnistOTflow.py: -------------------------------------------------------------------------------- 1 | # trainMnistOTflow.py 2 | # train the MNIST model with the encoder-decoder structure 3 | import argparse 4 | import os 5 | import time 6 | import datetime 7 | import torch.optim as optim 8 | import math 9 | from lib import dataloader as dl 10 | import lib.utils as utils 11 | from lib.utils import count_parameters 12 | import datasets 13 | from datasets.mnist import getLoader 14 | from src.plotter import * 15 | from src.OTFlowProblem import * 16 | from src.Autoencoder import * 17 | import config 18 | 19 | cf = config.getconfig() 20 | 21 | if cf.gpu: 22 | def_viz_freq = 100 23 | def_batch = 800 24 | def_niters = 50000 25 | def_m = 128 26 | def_val_freq = 20 27 | else: # if no gpu on platform, assume debugging on a local cpu 28 | def_viz_freq = 4 29 | def_batch = 20 30 | def_niters = 40 31 | def_val_freq = 1 32 | def_m = 16 33 | 34 | parser = argparse.ArgumentParser('OT-Flow') 35 | parser.add_argument( 36 | '--data', choices=['mnist'], type=str, default='mnist' 37 | ) 38 | parser.add_argument("--nt" , type=int, default=8, help="number of time steps") 39 | parser.add_argument("--nt_val", type=int, default=16, help="number of time steps for validation") 40 | parser.add_argument('--alph' , type=str, default='1.0,80.0,500.0') 41 | parser.add_argument('--m' , type=int, default=def_m) 42 | parser.add_argument('--d' , type=int, default=128) # encoded dimension 43 | 44 | parser.add_argument('--weight_decay', type=float, default=0.0) 45 | parser.add_argument('--lr' , type=float, default=0.008) 46 | parser.add_argument('--drop_freq' , type=int, default=5000, help="how often to decrease learning rate") 47 | parser.add_argument('--lr_drop' , type=float, default=10.0**(0.5), help="how much to decrease learning rate (divide by)") 48 | parser.add_argument('--eps' , type=float, default=10**-6) 49 | 50 | parser.add_argument('--niters' , type=int, default=def_niters) 51 | parser.add_argument('--batch_size' , type=int, default=def_batch) 52 | parser.add_argument('--val_batch_size', type=int, default=def_batch) 53 | parser.add_argument('--resume' , type=str, default=None) 54 | parser.add_argument('--autoenc' , type=str, default=None) 55 | parser.add_argument('--save' , type=str, default='experiments/cnf/large') 56 | parser.add_argument('--viz_freq' , type=int, default=def_viz_freq) 57 | parser.add_argument('--val_freq' , type=int, default=def_val_freq) 58 | parser.add_argument('--gpu' , type=int, default=0) 59 | parser.add_argument('--conditional', type=int, default=-1) # -1 means unconditioned 60 | args = parser.parse_args() 61 | 62 | args.alph = [float(item) for item in args.alph.split(',')] 63 | 64 | # add timestamp to save path 65 | start_time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 66 | 67 | # logger 68 | utils.makedirs(args.save) 69 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 70 | logger.info("start time: " + start_time) 71 | logger.info(args) 72 | 73 | val_batch_size = args.val_batch_size if args.val_batch_size else args.batch_size 74 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 75 | 76 | def compute_loss(net, x, nt): 77 | Jc , costs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 78 | return Jc, costs 79 | 80 | if __name__ == '__main__': 81 | 82 | prec = torch.float64 83 | cvt = lambda x: x.type(prec).to(device, non_blocking=True) 84 | print("device: ", device) 85 | 86 | train_loader, val_loader, _ = getLoader(args.data, args.batch_size, args.val_batch_size, augment=False, hasGPU=cf.gpu, conditional=args.conditional) 87 | 88 | d = args.d # encoded dimensions 89 | # -----------AutoEncoder ------------------------------------------------------------ 90 | if args.autoenc is None: # if no trained encoder-decoder is provided, then train one 91 | # initialize the encoder-decoder 92 | autoEnc = Autoencoder(d) 93 | autoEnc = autoEnc.type(prec).to(device) 94 | print(autoEnc) 95 | 96 | autoEnc = trainAE(autoEnc, train_loader, val_loader, args.save, start_time, argType=prec, device=device) 97 | 98 | else: 99 | # load the trained autoencoder 100 | checkpt = torch.load(args.autoenc, map_location=lambda storage, loc: storage) 101 | autoEnc = Autoencoder(d) 102 | autoEnc.mu = checkpt["state_dict"]["mu"] # checkpt['AEmu'].to(prec) 103 | autoEnc.std = checkpt["state_dict"]["std"] #checkpt['AEstd'].to(prec) 104 | autoEnc.load_state_dict(checkpt["state_dict"], strict=False) # doesnt load the buffers 105 | autoEnc = autoEnc.to(prec).to(device) 106 | # ----------------------------------------------------------------------- 107 | 108 | nt = args.nt 109 | nt_val = args.nt_val 110 | nTh = 2 111 | m = args.m 112 | 113 | net = Phi(nTh=nTh, m=m, d=d, alph=args.alph) # the phi aka the value function 114 | net = net.to(prec).to(device) 115 | 116 | if args.val_freq == 0: 117 | # if val_freq set to 0, then validate after every epoch....assume mnist train 50000 118 | args.val_freq = math.ceil( 50000 /args.batch_size) 119 | 120 | # ADAM optimizer 121 | optim = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 122 | 123 | logger.info(net) 124 | logger.info("-------------------------") 125 | logger.info("DIMENSION={:} m={:} nTh={:} alpha={:}".format(d,m,nTh,net.alph)) 126 | logger.info("nt={:} nt_val={:}".format(nt,nt_val)) 127 | logger.info("Number of trainable parameters: {}".format(count_parameters(net))) 128 | logger.info("-------------------------") 129 | logger.info(str(optim)) # optimizer info 130 | logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu)) 131 | logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(args.niters, args.val_freq, args.viz_freq)) 132 | logger.info("saveLocation = {:}".format(args.save)) 133 | logger.info("-------------------------\n") 134 | 135 | begin = time.time() 136 | end = begin 137 | best_loss = float('inf') 138 | best_costs = [0.0]*3 139 | best_params = None 140 | 141 | log_msg = ( 142 | '{:5s} {:6s} {:7s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} '.format( 143 | 'iter', ' time','lr','loss', 'L (L_2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR' 144 | ) 145 | ) 146 | logger.info(log_msg) 147 | 148 | timeMeter = utils.AverageMeter() 149 | clampMax = 2.0 150 | clampMin = -2.0 151 | 152 | net.train() 153 | itr = 1 154 | while itr < args.niters: 155 | # train 156 | for data in train_loader: 157 | images, _ = data 158 | # flatten images 159 | x0 = images.view(images.size(0), -1) 160 | x0 = cvt(x0) 161 | x0 = autoEnc.encode(x0) # encode 162 | x0 = (x0 - autoEnc.mu) / (autoEnc.std + args.eps) # normalize 163 | 164 | optim.zero_grad() 165 | 166 | # clip parameters 167 | for p in net.parameters(): 168 | p.data = torch.clamp(p.data, clampMin, clampMax) 169 | 170 | loss,costs = compute_loss(net, x0, nt=nt) 171 | loss.backward() 172 | optim.step() 173 | timeMeter.update(time.time() - end) 174 | 175 | log_message = ( 176 | '{:05d} {:6.3f} {:7.1e} {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 177 | itr, timeMeter.val, optim.param_groups[0]['lr'], loss, costs[0], costs[1], costs[2] 178 | ) 179 | ) 180 | 181 | if torch.isnan(loss): 182 | logger.info(log_message) 183 | logger.info("NaN encountered....exiting prematurely") 184 | logger.info("Training Time: {:} seconds".format(timeMeter.sum)) 185 | logger.info('File: ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format( 186 | args.data, int(net.alph[1]), int(net.alph[2]), m) 187 | ) 188 | exit(1) 189 | 190 | # validation 191 | if itr == 1 or itr % args.val_freq == 0 or itr == args.niters: 192 | net.eval() 193 | with torch.no_grad(): 194 | 195 | valLossMeter = utils.AverageMeter() 196 | valAlphMeterL = utils.AverageMeter() 197 | valAlphMeterC = utils.AverageMeter() 198 | valAlphMeterR = utils.AverageMeter() 199 | 200 | for data in val_loader: 201 | images, _ = data 202 | # flatten images 203 | x0 = images.view(images.size(0), -1) 204 | x0 = cvt(x0) 205 | x0 = autoEnc.encode(x0) # encode 206 | x0 = (x0 - autoEnc.mu) / (autoEnc.std + args.eps ) # normalize 207 | 208 | nex = x0.shape[0] 209 | val_loss, val_costs = compute_loss(net, x0, nt=nt_val) 210 | valLossMeter.update(val_loss.item(), nex) 211 | valAlphMeterL.update(val_costs[0].item(), nex) 212 | valAlphMeterC.update(val_costs[1].item(), nex) 213 | valAlphMeterR.update(val_costs[2].item(), nex) 214 | 215 | if not cf.gpu: # for debugging 216 | break 217 | 218 | # add to print message 219 | log_message += ' {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 220 | valLossMeter.avg, valAlphMeterL.avg, valAlphMeterC.avg, valAlphMeterR.avg 221 | ) 222 | 223 | # save best set of parameters 224 | if valLossMeter.avg < best_loss: 225 | logger.info('saving new best') 226 | best_loss = valLossMeter.avg 227 | best_costs = [ valAlphMeterL.avg, valAlphMeterC.avg, valAlphMeterR.avg ] 228 | utils.makedirs(args.save) 229 | best_params = net.state_dict() 230 | torch.save({ 231 | 'args': args, 232 | 'state_dict': best_params, 233 | 'autoencoder': autoEnc.state_dict(), 234 | }, os.path.join(args.save, start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(net.alph[1]),int(net.alph[2]),m))) 235 | net.train() 236 | 237 | logger.info(log_message) # print iteration 238 | 239 | # create plots 240 | if itr % args.viz_freq == 0: 241 | with torch.no_grad(): 242 | net.eval() 243 | currState = net.state_dict() 244 | net.load_state_dict(best_params) 245 | 246 | # plot one batch in R^d space 247 | p_samples = next(iter(val_loader))[0] 248 | p_samples = p_samples.view(p_samples.size(0), -1) 249 | p_samples = cvt(p_samples) 250 | p_samples = autoEnc.encode(p_samples) # encode 251 | p_samples = (p_samples - autoEnc.mu) / (autoEnc.std + args.eps ) # normalize 252 | 253 | nSamples = p_samples.shape[0] 254 | y = cvt(torch.randn(nSamples,d)) # sampling from rho_1 255 | sPath = os.path.join(args.save, 'figs', start_time + '_{:04d}.png'.format(itr)) 256 | plot4(net, p_samples, y, nt_val, sPath, sTitle='loss {:.2f} , C {:.2f}'.format(best_loss, best_costs[1] )) 257 | 258 | # plot the Mnist images 259 | nSamples = 8 # overwrite 260 | p_samples = p_samples[0:nSamples,:] 261 | y = y[0:nSamples,:] 262 | 263 | sPath = os.path.join(args.save, 'figs', start_time + '_class{:d}_imshow{:04d}.png'.format(args.conditional, itr)) 264 | genModel = integrate(y[:, 0:d], net, [1.0, 0.0], nt_val, stepper="rk4", alph=net.alph) 265 | genModel = genModel[:, 0:d] 266 | genDecoded = autoEnc.decode( genModel * (autoEnc.std + args.eps ) + autoEnc.mu ) # de-normalize and decode 267 | pDecoded = autoEnc.decode(p_samples * (autoEnc.std + args.eps) + autoEnc.mu) # de-normalize and decode 268 | plotAutoEnc(pDecoded, genDecoded, sPath) 269 | net.load_state_dict(currState) 270 | net.train() 271 | 272 | # shrink step size 273 | if itr % args.drop_freq == 0: 274 | for p in optim.param_groups: 275 | p['lr'] /= args.lr_drop # 10.0**(0.5) 276 | print("lr: ", p['lr']) 277 | 278 | itr += 1 279 | end = time.time() 280 | # end batch_iter 281 | 282 | logger.info("Training Time: {:} seconds".format(timeMeter.sum)) 283 | logger.info('Training has finished. ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(net.alph[1]),int(net.alph[2]),m)) 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /trainToyOTflow.py: -------------------------------------------------------------------------------- 1 | # trainToyOTflow.py 2 | # training driver for the two-dimensional toy problems 3 | import argparse 4 | import os 5 | import time 6 | import datetime 7 | import torch.optim as optim 8 | import numpy as np 9 | import math 10 | import lib.toy_data as toy_data 11 | import lib.utils as utils 12 | from lib.utils import count_parameters 13 | from src.plotter import plot4 14 | from src.OTFlowProblem import * 15 | import config 16 | 17 | cf = config.getconfig() 18 | 19 | if cf.gpu: # if gpu on platform 20 | def_viz_freq = 100 21 | def_batch = 4096 22 | def_niter = 1500 23 | else: # if no gpu on platform, assume debugging on a local cpu 24 | def_viz_freq = 100 25 | def_batch = 2048 26 | def_niter = 1000 27 | 28 | parser = argparse.ArgumentParser('OT-Flow') 29 | parser.add_argument( 30 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], 31 | type=str, default='8gaussians' 32 | ) 33 | 34 | parser.add_argument("--nt" , type=int, default=8, help="number of time steps") 35 | parser.add_argument("--nt_val", type=int, default=8, help="number of time steps for validation") 36 | parser.add_argument('--alph' , type=str, default='1.0,100.0,5.0') 37 | parser.add_argument('--m' , type=int, default=32) 38 | parser.add_argument('--nTh' , type=int, default=2) 39 | 40 | parser.add_argument('--niters' , type=int , default=def_niter) 41 | parser.add_argument('--batch_size' , type=int , default=def_batch) 42 | parser.add_argument('--val_batch_size', type=int , default=def_batch) 43 | 44 | parser.add_argument('--lr' , type=float, default=0.1) 45 | parser.add_argument("--drop_freq" , type=int , default=100, help="how often to decrease learning rate") 46 | parser.add_argument('--weight_decay', type=float, default=0.0) 47 | parser.add_argument('--lr_drop' , type=float, default=2.0) 48 | parser.add_argument('--optim' , type=str , default='adam', choices=['adam']) 49 | parser.add_argument('--prec' , type=str , default='single', choices=['single','double'], help="single or double precision") 50 | 51 | parser.add_argument('--save' , type=str, default='experiments/cnf/toy') 52 | parser.add_argument('--viz_freq', type=int, default=def_viz_freq) 53 | parser.add_argument('--val_freq', type=int, default=1) 54 | parser.add_argument('--gpu' , type=int, default=0) 55 | parser.add_argument('--sample_freq', type=int, default=25) 56 | 57 | 58 | args = parser.parse_args() 59 | 60 | args.alph = [float(item) for item in args.alph.split(',')] 61 | 62 | # get precision type 63 | if args.prec =='double': 64 | prec = torch.float64 65 | else: 66 | prec = torch.float32 67 | 68 | # get timestamp for saving models 69 | start_time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 70 | 71 | # logger 72 | utils.makedirs(args.save) 73 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 74 | logger.info("start time: " + start_time) 75 | logger.info(args) 76 | 77 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 78 | 79 | 80 | def compute_loss(net, x, nt): 81 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 82 | return Jc, cs 83 | 84 | 85 | if __name__ == '__main__': 86 | 87 | torch.set_default_dtype(prec) 88 | cvt = lambda x: x.type(prec).to(device, non_blocking=True) 89 | 90 | # neural network for the potential function Phi 91 | d = 2 92 | alph = args.alph 93 | nt = args.nt 94 | nt_val = args.nt_val 95 | nTh = args.nTh 96 | m = args.m 97 | net = Phi(nTh=nTh, m=args.m, d=d, alph=alph) 98 | net = net.to(prec).to(device) 99 | 100 | optim = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay ) # lr=0.04 good 101 | 102 | logger.info(net) 103 | logger.info("-------------------------") 104 | logger.info("DIMENSION={:} m={:} nTh={:} alpha={:}".format(d,m,nTh,alph)) 105 | logger.info("nt={:} nt_val={:}".format(nt,nt_val)) 106 | logger.info("Number of trainable parameters: {}".format(count_parameters(net))) 107 | logger.info("-------------------------") 108 | logger.info(str(optim)) # optimizer info 109 | logger.info("data={:} batch_size={:} gpu={:}".format(args.data, args.batch_size, args.gpu)) 110 | logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(args.niters, args.val_freq, args.viz_freq)) 111 | logger.info("saveLocation = {:}".format(args.save)) 112 | logger.info("-------------------------\n") 113 | 114 | end = time.time() 115 | best_loss = float('inf') 116 | bestParams = None 117 | 118 | # setup data [nSamples, d] 119 | # use one batch as the entire data set 120 | x0 = toy_data.inf_train_gen(args.data, batch_size=args.batch_size) 121 | x0 = cvt(torch.from_numpy(x0)) 122 | 123 | x0val = toy_data.inf_train_gen(args.data, batch_size=args.val_batch_size) 124 | x0val = cvt(torch.from_numpy(x0val)) 125 | 126 | log_msg = ( 127 | '{:5s} {:6s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} '.format( 128 | 'iter', ' time','loss', 'L (L_2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR' 129 | ) 130 | ) 131 | logger.info(log_msg) 132 | 133 | time_meter = utils.AverageMeter() 134 | 135 | net.train() 136 | for itr in range(1, args.niters + 1): 137 | # train 138 | optim.zero_grad() 139 | loss, costs = compute_loss(net, x0, nt=nt) 140 | loss.backward() 141 | optim.step() 142 | 143 | time_meter.update(time.time() - end) 144 | 145 | log_message = ( 146 | '{:05d} {:6.3f} {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 147 | itr, time_meter.val , loss, costs[0], costs[1], costs[2] 148 | ) 149 | ) 150 | 151 | # validate 152 | if itr % args.val_freq == 0 or itr == args.niters: 153 | with torch.no_grad(): 154 | net.eval() 155 | test_loss, test_costs = compute_loss(net, x0val, nt=nt_val) 156 | 157 | # add to print message 158 | log_message += ' {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( 159 | test_loss, test_costs[0], test_costs[1], test_costs[2] 160 | ) 161 | 162 | # save best set of parameters 163 | if test_loss.item() < best_loss: 164 | best_loss = test_loss.item() 165 | best_costs = test_costs 166 | utils.makedirs(args.save) 167 | best_params = net.state_dict() 168 | torch.save({ 169 | 'args': args, 170 | 'state_dict': best_params, 171 | }, os.path.join(args.save, start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(alph[1]),int(alph[2]),m))) 172 | net.train() 173 | 174 | logger.info(log_message) # print iteration 175 | 176 | # create plots 177 | if itr % args.viz_freq == 0: 178 | with torch.no_grad(): 179 | net.eval() 180 | curr_state = net.state_dict() 181 | net.load_state_dict(best_params) 182 | 183 | nSamples = 20000 184 | p_samples = cvt(torch.Tensor( toy_data.inf_train_gen(args.data, batch_size=nSamples) )) 185 | y = cvt(torch.randn(nSamples,d)) # sampling from the standard normal (rho_1) 186 | 187 | sPath = os.path.join(args.save, 'figs', start_time + '_{:04d}.png'.format(itr)) 188 | plot4(net, p_samples, y, nt_val, sPath, doPaths=True, sTitle='{:s} - loss {:.2f} , C {:.2f} , alph {:.1f} {:.1f} ' 189 | ' nt {:d} m {:d} nTh {:d} '.format(args.data, best_loss, best_costs[1], alph[1], alph[2], nt, m, nTh)) 190 | 191 | net.load_state_dict(curr_state) 192 | net.train() 193 | 194 | # shrink step size 195 | if itr % args.drop_freq == 0: 196 | for p in optim.param_groups: 197 | p['lr'] /= args.lr_drop 198 | print("lr: ", p['lr']) 199 | 200 | # resample data 201 | if itr % args.sample_freq == 0: 202 | # resample data [nSamples, d+1] 203 | logger.info("resampling") 204 | x0 = toy_data.inf_train_gen(args.data, batch_size=args.batch_size) # load data batch 205 | x0 = cvt(torch.from_numpy(x0)) # convert to torch, type and gpu 206 | 207 | end = time.time() 208 | 209 | logger.info("Training Time: {:} seconds".format(time_meter.sum)) 210 | logger.info('Training has finished. ' + start_time + '_{:}_alph{:}_{:}_m{:}_checkpt.pth'.format(args.data,int(alph[1]),int(alph[2]),m)) 211 | 212 | 213 | 214 | 215 | 216 | --------------------------------------------------------------------------------