├── requirements.txt ├── EXP_GAN ├── .DS_Store ├── data_loading.py ├── load_models.py ├── download_celebA.py ├── torch_lin_sinkhorn.py ├── models.py ├── train_models_celebA.py └── train_models_cifar.py ├── results ├── celebA_samples.png ├── cifar10_samples.png └── plot_accuracy_ROT_sphere.jpg ├── .gitignore ├── README.md └── FastSinkhorn.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | scipy==1.5.0 3 | PIL==7.0.0 4 | -------------------------------------------------------------------------------- /EXP_GAN/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/EXP_GAN/.DS_Store -------------------------------------------------------------------------------- /results/celebA_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/celebA_samples.png -------------------------------------------------------------------------------- /results/cifar10_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/cifar10_samples.png -------------------------------------------------------------------------------- /results/plot_accuracy_ROT_sphere.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/plot_accuracy_ROT_sphere.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Compiled source # 7 | ################### 8 | *.com 9 | *.class 10 | *.dll 11 | *.exe 12 | *.o 13 | *.so 14 | 15 | # Packages # 16 | ############ 17 | # it's better to unpack these files and commit the raw source 18 | # git has its own built in compression methods 19 | *.7z 20 | *.dmg 21 | *.gz 22 | *.iso 23 | *.jar 24 | *.rar 25 | *.tar 26 | *.zip 27 | 28 | # Logs and databases # 29 | ###################### 30 | *.log 31 | *.sql 32 | *.sqlite 33 | 34 | # OS generated files # 35 | ###################### 36 | .DS_Store 37 | .DS_Store? 38 | ._* 39 | .Spotlight-V100 40 | .Trashes 41 | ehthumbs.db 42 | Thumbs.db 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear Time Sinkhorn Divergences using Positive Features 2 | Code of the paper by Meyer Scetbon and Marco Cuturi 3 | 4 | ## Approximation of the Regularized Optimal Transport in Linear Time 5 | In this work, we show that one can approximate the regularized optimal transport in linear time with respect to the number of samples for some usual cost functions, e.g. the square Euclidean distance. We present the time-accuracy tradeoff between different methods to compute the regularized OT when the samples live on the unit sphere. 6 | ![figure](results/plot_accuracy_ROT_sphere.jpg) 7 | 8 | The implementation of the recursive Nystrom is adapted from the MATLAB implementation (https://github.com/cnmusco/recursive-nystrom). 9 | 10 | ## Generative Adversarial Network 11 | We also show that our method offers a constructive way to build a kernel and then a cost function adapted to the problem in order to compare distributions using optimal transport. We show some visual results of the generative models learned using our method on CIFAR10 (left) and CelebA (right). 12 | 13 |

14 | 15 | 16 |

17 | 18 | The implementation of the WGAN is a code adapted from the MMD-GAN implementation (https://github.com/OctoberChang/MMD-GAN). 19 | 20 | 21 | 22 | This repository contains a Python implementation of the algorithms presented in the [paper](https://arxiv.org/pdf/2006.07057.pdf). 23 | -------------------------------------------------------------------------------- /EXP_GAN/data_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from os import listdir 5 | from os.path import join 6 | 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as dset 9 | 10 | 11 | ### Get Data ### 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 14 | 15 | 16 | def load_img(filepath): 17 | img = Image.open(filepath).convert("RGB") 18 | return img 19 | 20 | 21 | class FolderWithImages(data.Dataset): 22 | def __init__(self, root, input_transform=None, target_transform=None): 23 | super(FolderWithImages, self).__init__() 24 | self.image_filenames = [ 25 | join(root, x) for x in listdir(root) if is_image_file(x.lower()) 26 | ] 27 | 28 | self.input_transform = input_transform 29 | self.target_transform = target_transform 30 | 31 | def __getitem__(self, index): 32 | input = load_img(self.image_filenames[index]) 33 | target = input.copy() 34 | if self.input_transform: 35 | input = self.input_transform(input) 36 | if self.target_transform: 37 | target = self.target_transform(target) 38 | 39 | return input, target 40 | 41 | def __len__(self): 42 | return len(self.image_filenames) 43 | 44 | 45 | class ALICropAndScale(object): 46 | def __call__(self, img): 47 | return img.resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7)) 48 | 49 | 50 | def get_data(image_size, dataset_name, data_root, train_flag=True): 51 | if dataset_name == "cifar10": 52 | 53 | transform = transforms.Compose( 54 | [ 55 | transforms.Resize(image_size), 56 | transforms.CenterCrop(image_size), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 59 | ] 60 | ) 61 | 62 | dataset = dset.CIFAR10( 63 | root=data_root, download=True, train=train_flag, transform=transform 64 | ) 65 | 66 | elif dataset_name == "mnist": 67 | transform = transforms.Compose( 68 | [ 69 | transforms.Resize(image_size), 70 | transforms.CenterCrop(image_size), 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.5,), (0.5,)), 73 | ] 74 | ) 75 | 76 | dataset = dset.MNIST( 77 | root=data_root, download=True, train=train_flag, transform=transform 78 | ) 79 | 80 | elif dataset_name == "celeba": 81 | imdir = "CelebA/splits/train" if train_flag else "CelebA/splits/val" 82 | dataroot = os.path.join(data_root, imdir) 83 | if image_size != 64: 84 | raise ValueError("the image size for CelebA dataset need to be 64!") 85 | 86 | dataset = FolderWithImages( 87 | root=dataroot, 88 | input_transform=transforms.Compose( 89 | [ 90 | ALICropAndScale(), 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 93 | ] 94 | ), 95 | target_transform=transforms.ToTensor(), 96 | ) 97 | 98 | return dataset 99 | -------------------------------------------------------------------------------- /EXP_GAN/load_models.py: -------------------------------------------------------------------------------- 1 | import random 2 | from scipy.special import lambertw 3 | import numpy as np 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.utils.data 7 | import torchvision.utils as vutils 8 | 9 | import models 10 | 11 | 12 | # CIFAR10 13 | # image_size = 64 14 | # nc = 3 15 | # nz = 128 16 | # dataset_name = 'cifar10' 17 | 18 | # CELEBA 19 | dataset_name = "celeba" 20 | image_size = 64 21 | nc = 3 22 | nz = 128 23 | 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | print(torch.cuda.is_available(), device) 27 | 28 | manual_seed = 49 29 | random_ = False 30 | R = 1 31 | batch_size = 8000 32 | 33 | 34 | num_random_samples = 600 35 | reg = 1 36 | 37 | epsilon = reg 38 | hidden_dim = nz 39 | 40 | # Fix the seed 41 | np.random.seed(seed=manual_seed) 42 | random.seed(manual_seed) 43 | torch.manual_seed(manual_seed) 44 | torch.cuda.manual_seed(manual_seed) 45 | cudnn.benchmark = True 46 | 47 | 48 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49): 49 | q = (1 / 2) + (R ** 2) / reg 50 | y = R ** 2 / (reg * nz) 51 | q = np.real((1 / 2) * np.exp(lambertw(y))) 52 | 53 | C = (2 * q) ** (nz / 4) 54 | 55 | np.random.seed(seed) 56 | var = (q * reg) / 4 57 | U = np.random.multivariate_normal( 58 | np.zeros(nz), var * np.eye(nz), num_random_samples 59 | ) 60 | U = torch.from_numpy(U) 61 | 62 | U_init = U.to(device) 63 | C_init = torch.DoubleTensor([C]).to(device) 64 | q_init = torch.DoubleTensor([q]).to(device) 65 | 66 | return q_init, C_init, U_init 67 | 68 | 69 | q, C, U_init = compute_constants( 70 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed 71 | ) 72 | q, C, U_init = q.to(device), C.to(device), U_init.to(device) 73 | 74 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64) 75 | D_embedding = models.Embedding( 76 | image_size, 77 | nc, 78 | reg, 79 | device, 80 | q, 81 | C, 82 | U_init, 83 | k=hidden_dim, 84 | num_random_samples=num_random_samples, 85 | R=R, 86 | seed=manual_seed, 87 | ndf=64, 88 | random=random_, 89 | ) 90 | 91 | netG = models.NetG(G_generator) 92 | path_model_G = "netG_celebA_600_1.pth" 93 | netG.load_state_dict( 94 | torch.load(path_model_G, map_location="cpu") 95 | ) # If on cluster comment map_location 96 | netG.to(device) 97 | 98 | netE = models.NetE(D_embedding) 99 | path_model_E = "netE_cifar_max_600_1.pth" 100 | netE.load_state_dict( 101 | torch.load(path_model_E, map_location="cpu") 102 | ) # If on cluster comment map_location 103 | netE.to(device) 104 | 105 | 106 | # Choose a random seed to sample a random image from the generator 107 | manual_seed = 123 108 | np.random.seed(seed=manual_seed) 109 | random.seed(manual_seed) 110 | torch.manual_seed(manual_seed) 111 | torch.cuda.manual_seed(manual_seed) 112 | cudnn.benchmark = True 113 | 114 | 115 | batch_size_noise = 32 116 | fixed_noise = torch.DoubleTensor(batch_size_noise, nz, 1, 1).normal_(0, 1).to(device) 117 | fixed_noise = fixed_noise.float() 118 | y_fixed = netG(fixed_noise) # between -1 and 1 119 | 120 | # A sample from the trained model 121 | y_fixed = y_fixed.mul(0.5).add(0.5) 122 | vutils.save_image(y_fixed, "celebA_image_vf_600.png") 123 | -------------------------------------------------------------------------------- /EXP_GAN/download_celebA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of 3 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 4 | - http://stackoverflow.com/a/39225039 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import zipfile 9 | import requests 10 | 11 | 12 | def download_file_from_google_drive(id, destination): 13 | URL = "https://docs.google.com/uc?export=download" 14 | session = requests.Session() 15 | 16 | response = session.get(URL, params={"id": id}, stream=True) 17 | token = get_confirm_token(response) 18 | 19 | if token: 20 | params = {"id": id, "confirm": token} 21 | response = session.get(URL, params=params, stream=True) 22 | 23 | save_response_content(response, destination) 24 | 25 | 26 | def get_confirm_token(response): 27 | for key, value in response.cookies.items(): 28 | if key.startswith("download_warning"): 29 | return value 30 | return None 31 | 32 | 33 | def save_response_content(response, destination, chunk_size=32 * 1024): 34 | total_size = int(response.headers.get("content-length", 0)) 35 | with open(destination, "wb") as f: 36 | for chunk in tqdm( 37 | response.iter_content(chunk_size), 38 | total=total_size, 39 | unit="B", 40 | unit_scale=True, 41 | desc=destination, 42 | ): 43 | if chunk: # filter out keep-alive new chunks 44 | f.write(chunk) 45 | 46 | 47 | def unzip(filepath): 48 | print("Extracting: " + filepath) 49 | base_path = os.path.dirname(filepath) 50 | with zipfile.ZipFile(filepath) as zf: 51 | zf.extractall(base_path) 52 | os.remove(filepath) 53 | 54 | 55 | def download_celeb_a(base_path): 56 | data_path = os.path.join(base_path, "CelebA") 57 | images_path = os.path.join(data_path, "images") 58 | if os.path.exists(data_path): 59 | print("[!] Found Celeb-A - skip") 60 | return 61 | 62 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 63 | save_path = os.path.join(base_path, filename) 64 | 65 | if os.path.exists(save_path): 66 | print("[*] {} already exists".format(save_path)) 67 | else: 68 | download_file_from_google_drive(drive_id, save_path) 69 | 70 | # zip_dir = '' 71 | with zipfile.ZipFile(save_path) as zf: 72 | # zip_dir = zf.namelist()[0] 73 | zf.extractall(base_path) 74 | if not os.path.exists(data_path): 75 | os.mkdir(data_path) 76 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 77 | os.remove(save_path) 78 | 79 | 80 | def prepare_data_dir(path="./data"): 81 | if not os.path.exists(path): 82 | os.mkdir(path) 83 | 84 | 85 | # check, if file exists, make link 86 | def check_link(in_dir, basename, out_dir): 87 | in_file = os.path.join(in_dir, basename) 88 | if os.path.exists(in_file): 89 | link_file = os.path.join(out_dir, basename) 90 | rel_link = os.path.relpath(in_file, out_dir) 91 | os.symlink(rel_link, link_file) 92 | 93 | 94 | def add_splits(base_path): 95 | data_path = os.path.join(base_path, "CelebA") 96 | images_path = os.path.join(data_path, "images") 97 | train_dir = os.path.join(data_path, "splits", "train") 98 | valid_dir = os.path.join(data_path, "splits", "valid") 99 | test_dir = os.path.join(data_path, "splits", "test") 100 | if not os.path.exists(train_dir): 101 | os.makedirs(train_dir) 102 | if not os.path.exists(valid_dir): 103 | os.makedirs(valid_dir) 104 | if not os.path.exists(test_dir): 105 | os.makedirs(test_dir) 106 | 107 | # these constants based on the standard CelebA splits 108 | NUM_EXAMPLES = 202599 109 | TRAIN_STOP = 162770 110 | VALID_STOP = 182637 111 | 112 | for i in range(0, TRAIN_STOP): 113 | basename = "{:06d}.jpg".format(i + 1) 114 | check_link(images_path, basename, train_dir) 115 | for i in range(TRAIN_STOP, VALID_STOP): 116 | basename = "{:06d}.jpg".format(i + 1) 117 | check_link(images_path, basename, valid_dir) 118 | for i in range(VALID_STOP, NUM_EXAMPLES): 119 | basename = "{:06d}.jpg".format(i + 1) 120 | check_link(images_path, basename, test_dir) 121 | 122 | 123 | if __name__ == "__main__": 124 | base_path = "./" 125 | prepare_data_dir() 126 | download_celeb_a(base_path) 127 | add_splits(base_path) 128 | -------------------------------------------------------------------------------- /EXP_GAN/torch_lin_sinkhorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Lin_Sinkhorn_AD(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x_emb, y_emb, reg, niter_sin, lam=1e-6, tau=1e-9): 7 | phi_x = x_emb.squeeze().type(torch.DoubleTensor) 8 | phi_y = y_emb.squeeze().type(torch.DoubleTensor) 9 | 10 | n = phi_x.size()[0] 11 | m = phi_y.size()[0] 12 | 13 | a = (1.0 / n) * torch.ones(n) 14 | a = a.type(torch.DoubleTensor) 15 | 16 | b = (1.0 / m) * torch.ones(m) 17 | b = b.type(torch.DoubleTensor) 18 | 19 | actual_nits = 0 20 | 21 | u = 1.0 * torch.ones(n).type(torch.DoubleTensor) 22 | v = 1.0 * torch.ones(m).type(torch.DoubleTensor) 23 | err = 0.0 24 | 25 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam 26 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam 27 | 28 | for k in range(niter_sin): 29 | u = a / u_trans 30 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam 31 | 32 | v = b / v_trans 33 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam 34 | 35 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum( 36 | torch.abs(v * v_trans - b) 37 | ) 38 | 39 | actual_nits += 1 40 | if err < tau: 41 | break 42 | 43 | if k % 10 == 0: 44 | ### Stpping Criteria ###s 45 | with torch.no_grad(): 46 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum( 47 | torch.abs(v * v_trans - b) 48 | ) 49 | if err < tau: 50 | break 51 | 52 | ctx.u = u 53 | ctx.v = v 54 | ctx.reg = reg 55 | ctx.phi_x = phi_x 56 | ctx.phi_y = phi_y 57 | 58 | cost = reg * (torch.sum(a * torch.log(u)) + torch.sum(b * torch.log(v)) - 1) 59 | return cost 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | u = ctx.u 64 | v = ctx.v 65 | reg = ctx.reg 66 | phi_x = ctx.phi_x 67 | phi_y = ctx.phi_y 68 | 69 | grad_input = grad_output.clone() 70 | grad_phi_x = ( 71 | grad_input 72 | * torch.matmul(u.view(-1, 1), torch.matmul(phi_y.t(), v).view(1, -1)) 73 | * (-reg) 74 | ) 75 | grad_phi_y = ( 76 | grad_input 77 | * torch.matmul(v.view(-1, 1), torch.matmul(phi_x.t(), u).view(1, -1)) 78 | * (-reg) 79 | ) 80 | 81 | return grad_phi_x, grad_phi_y, None, None, None, None, None 82 | 83 | 84 | def Lin_Sinkhorn( 85 | phi_x, phi_y, reg, niter_sin, device, lam=1e-6, tau=1e-9, stabilize=False 86 | ): 87 | phi_x = phi_x.squeeze().type(torch.DoubleTensor).to(device) 88 | phi_y = phi_y.squeeze().type(torch.DoubleTensor).to(device) 89 | 90 | n = phi_x.size()[0] 91 | m = phi_y.size()[0] 92 | 93 | a = (1.0 / n) * torch.ones(n) 94 | a = a.type(torch.DoubleTensor).to(device) 95 | 96 | b = (1.0 / m) * torch.ones(m) 97 | b = b.type(torch.DoubleTensor).to(device) 98 | 99 | actual_nits = 0 100 | if stabilize == True: 101 | alpha, beta, err = torch.zeros(n).to(device), torch.zeros(m).to(device), 0.0 102 | for i in range(niter_sin): 103 | alpha_res = alpha 104 | beta_res = beta 105 | 106 | lin_M = torch.exp(alpha / reg) * torch.matmul( 107 | phi_x, torch.matmul(phi_y.t(), torch.exp(beta / reg)) 108 | ) 109 | lin_M = lin_M + lam 110 | alpha = reg * (torch.log(a) - torch.log(lin_M)) + alpha 111 | 112 | lin_M_t = torch.exp(beta / reg) * torch.matmul( 113 | phi_y, torch.matmul(phi_x.t(), torch.exp(alpha / reg)) 114 | ) 115 | lin_M_t = lin_M + lam 116 | beta = reg * (torch.log(b) - torch.log(lin_M_t)) + beta 117 | 118 | err = (alpha - alpha_res).abs().sum() + (beta - beta_res).abs().sum() 119 | 120 | actual_nits += 1 121 | if err < tau: 122 | break 123 | cost = torch.sum(a * alpha) + torch.sum(b * beta) 124 | print(cost) 125 | 126 | else: 127 | u = 1.0 * torch.ones(n).type(torch.DoubleTensor).to(device) 128 | v = 1.0 * torch.ones(m).type(torch.DoubleTensor).to(device) 129 | err = 0.0 130 | 131 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam 132 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam 133 | 134 | for k in range(niter_sin): 135 | u = a / u_trans 136 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam 137 | 138 | v = b / v_trans 139 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam 140 | 141 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum( 142 | torch.abs(v * v_trans - b) 143 | ) 144 | 145 | actual_nits += 1 146 | if err < tau: 147 | break 148 | cost = reg * (torch.sum(a * torch.log(u)) + torch.sum(b * torch.log(v)) - 1) 149 | 150 | return cost 151 | -------------------------------------------------------------------------------- /EXP_GAN/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | # input: batch_size * nc * image_size * image_size 6 | # output: batch_size * num_random_samples * 1 * 1 7 | class Embedding(nn.Module): 8 | def __init__( 9 | self, 10 | isize, 11 | nc, 12 | reg, 13 | device, 14 | q, 15 | C, 16 | U_init, 17 | k=100, 18 | num_random_samples=100, 19 | R=1, 20 | ndf=64, 21 | seed=49, 22 | random=False, 23 | ): 24 | super(Embedding, self).__init__() 25 | assert isize % 16 == 0, "isize has to be a multiple of 16" 26 | 27 | # input is nc x isize x isize 28 | main = nn.Sequential() 29 | main.add_module( 30 | "initial_conv_{0}-{1}".format(nc, ndf), 31 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 32 | ) 33 | main.add_module("initial_relu_{0}".format(ndf), nn.LeakyReLU(0.2, inplace=True)) 34 | csize, cndf = isize / 2, ndf 35 | 36 | while csize > 4: 37 | in_feat = cndf 38 | out_feat = cndf * 2 39 | main.add_module( 40 | "pyramid_{0}-{1}_conv".format(in_feat, out_feat), 41 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False), 42 | ) 43 | main.add_module( 44 | "pyramid_{0}_batchnorm".format(out_feat), nn.BatchNorm2d(out_feat) 45 | ) 46 | main.add_module( 47 | "pyramid_{0}_relu".format(out_feat), nn.LeakyReLU(0.2, inplace=True) 48 | ) 49 | cndf = cndf * 2 50 | csize = csize / 2 51 | 52 | main.add_module( 53 | "final_{0}-{1}_conv".format(cndf, 1), 54 | nn.Conv2d(cndf, k, 4, 1, 0, bias=False), 55 | ) 56 | 57 | self.main = main 58 | 59 | if random == False: 60 | U = torch.nn.Parameter(U_init) 61 | 62 | else: 63 | U = U_init 64 | 65 | self.U = U.type(torch.DoubleTensor) 66 | self.q = q.type(torch.DoubleTensor) 67 | self.C = C.type(torch.DoubleTensor) 68 | self.reg = reg 69 | self.num_random_samples = num_random_samples 70 | 71 | # X and Y are 2D tensors 72 | def Square_Euclidean_Distance(self, X, Y): 73 | X_col = X.unsqueeze(1).type(torch.DoubleTensor) 74 | Y_lin = Y.unsqueeze(0).type(torch.DoubleTensor) 75 | C = torch.sum((X_col - Y_lin) ** 2, 2) 76 | return C 77 | 78 | # input: batch_size * k * 1 * 1 79 | # output: batch_size * num_random_samples * 1 * 1 80 | def Feature_Map_Gaussian(self, X): 81 | X = X.squeeze() 82 | batch_size, dim = X.size() 83 | 84 | SED = self.Square_Euclidean_Distance(X, self.U) 85 | W = -(2 * SED) / self.reg 86 | Z = self.U ** 2 87 | A = torch.sum(Z, 1) 88 | a = self.reg * self.q 89 | V = A / a 90 | 91 | res_trans = V + W 92 | res_trans = self.C * torch.exp(res_trans) 93 | 94 | res = ( 95 | 1 / torch.sqrt(torch.DoubleTensor([self.num_random_samples])) 96 | ) * res_trans 97 | res = res.view(batch_size, self.num_random_samples, 1, 1) 98 | 99 | return res 100 | 101 | def forward(self, input): 102 | output = self.main(input) 103 | output = self.Feature_Map_Gaussian(output) 104 | 105 | return output 106 | 107 | 108 | # input: batch_size * k * 1 * 1 109 | # output: batch_size * nc * image_size * image_size 110 | class Generator(nn.Module): 111 | def __init__(self, isize, nc, k=100, ngf=64): 112 | super(Generator, self).__init__() 113 | assert isize % 16 == 0, "isize has to be a multiple of 16" 114 | 115 | cngf, tisize = ngf // 2, 4 116 | while tisize != isize: 117 | cngf = cngf * 2 118 | tisize = tisize * 2 119 | 120 | main = nn.Sequential() 121 | main.add_module( 122 | "initial_{0}-{1}_convt".format(k, cngf), 123 | nn.ConvTranspose2d(k, cngf, 4, 1, 0, bias=False), 124 | ) 125 | main.add_module("initial_{0}_batchnorm".format(cngf), nn.BatchNorm2d(cngf)) 126 | main.add_module("initial_{0}_relu".format(cngf), nn.ReLU(True)) 127 | 128 | csize = 4 129 | while csize < isize // 2: 130 | main.add_module( 131 | "pyramid_{0}-{1}_convt".format(cngf, cngf // 2), 132 | nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False), 133 | ) 134 | main.add_module( 135 | "pyramid_{0}_batchnorm".format(cngf // 2), nn.BatchNorm2d(cngf // 2) 136 | ) 137 | main.add_module("pyramid_{0}_relu".format(cngf // 2), nn.ReLU(True)) 138 | cngf = cngf // 2 139 | csize = csize * 2 140 | 141 | main.add_module( 142 | "final_{0}-{1}_convt".format(cngf, nc), 143 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False), 144 | ) 145 | main.add_module("final_{0}_tanh".format(nc), nn.Tanh()) 146 | 147 | self.main = main 148 | 149 | def forward(self, input): 150 | output = self.main(input) 151 | return output 152 | 153 | 154 | # input: batch_size * nz * 1 * 1 155 | # output: batch_size * nc * image_size * image_size 156 | class NetG(nn.Module): 157 | def __init__(self, decoder): 158 | super(NetG, self).__init__() 159 | self.decoder = decoder 160 | 161 | def forward(self, input): 162 | output = self.decoder(input) 163 | return output 164 | 165 | 166 | # input: batch_size * nc * image_size * image_size 167 | # f_emb: batch_size * k * 1 * 1 168 | class NetE(nn.Module): 169 | def __init__(self, embedding): 170 | super(NetE, self).__init__() 171 | self.embedding = embedding 172 | 173 | def forward(self, input): 174 | f_emb = self.embedding(input) 175 | f_emb = f_emb.view(input.size(0), -1) 176 | 177 | return f_emb 178 | -------------------------------------------------------------------------------- /EXP_GAN/train_models_celebA.py: -------------------------------------------------------------------------------- 1 | import random 2 | from scipy.special import lambertw 3 | import numpy as np 4 | import torch 5 | import timeit 6 | import os 7 | 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | import torchvision.utils as vutils 11 | 12 | 13 | import models 14 | import data_loading 15 | import torch_lin_sinkhorn 16 | 17 | 18 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49): 19 | q = (1 / 2) + (R ** 2) / reg 20 | y = R ** 2 / (reg * nz) 21 | q = np.real((1 / 2) * np.exp(lambertw(y))) 22 | 23 | C = (2 * q) ** (nz / 4) 24 | 25 | np.random.seed(seed) 26 | var = (q * reg) / 4 27 | U = np.random.multivariate_normal( 28 | np.zeros(nz), var * np.eye(nz), num_random_samples 29 | ) 30 | U = torch.from_numpy(U) 31 | 32 | U_init = U.to(device) 33 | C_init = torch.DoubleTensor([C]).to(device) 34 | q_init = torch.DoubleTensor([q]).to(device) 35 | 36 | return q_init, C_init, U_init 37 | 38 | 39 | def training_func( 40 | num_random_samples, 41 | reg, 42 | batch_size, 43 | niter_sin, 44 | image_size, 45 | nc, 46 | nz, 47 | dataset_name, 48 | device, 49 | manual_seed, 50 | lr, 51 | max_iter, 52 | data_root, 53 | R, 54 | random_, 55 | ): 56 | name_dir = "sampled_images_celebA" + "_" + str(num_random_samples) + "_" + str(reg) 57 | if os.path.exists(name_dir) == 0: 58 | os.mkdir(name_dir) 59 | 60 | epsilon = reg 61 | hidden_dim = nz 62 | 63 | # Create an output file 64 | file_to_print = open( 65 | "results_training_celebA" 66 | + "_" 67 | + str(num_random_samples) 68 | + "_" 69 | + str(reg) 70 | + ".csv", 71 | "w", 72 | ) 73 | file_to_print.write(str(device) + "\n") 74 | file_to_print.flush() 75 | 76 | # Fix the seed 77 | np.random.seed(seed=manual_seed) 78 | random.seed(manual_seed) 79 | torch.manual_seed(manual_seed) 80 | torch.cuda.manual_seed(manual_seed) 81 | cudnn.benchmark = True 82 | 83 | # Initialisation of weights 84 | def weights_init(m): 85 | classname = m.__class__.__name__ 86 | if classname.find("Conv") != -1: 87 | m.weight.data.normal_(0.0, 0.02) 88 | elif classname.find("BatchNorm") != -1: 89 | m.weight.data.normal_(1.0, 0.02) 90 | m.bias.data.fill_(0) 91 | elif classname.find("Linear") != -1: 92 | m.weight.data.normal_(0.0, 0.1) 93 | m.bias.data.fill_(0) 94 | 95 | trn_dataset = data_loading.get_data( 96 | image_size, dataset_name, data_root, train_flag=True 97 | ) 98 | trn_loader = torch.utils.data.DataLoader( 99 | trn_dataset, batch_size=batch_size, shuffle=True, num_workers=1 100 | ) 101 | 102 | # construct Generator and Embedding 103 | q, C, U_init = compute_constants( 104 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed 105 | ) 106 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64) 107 | D_embedding = models.Embedding( 108 | image_size, 109 | nc, 110 | reg, 111 | device, 112 | q, 113 | C, 114 | U_init, 115 | k=hidden_dim, 116 | num_random_samples=num_random_samples, 117 | R=R, 118 | seed=manual_seed, 119 | ndf=64, 120 | random=random_, 121 | ) 122 | 123 | netG = models.NetG(G_generator) 124 | netE = models.NetE(D_embedding) 125 | 126 | netG.apply(weights_init) 127 | netE.apply(weights_init) 128 | 129 | netG.to(device) 130 | netE.to(device) 131 | 132 | lin_Sinkhorn_AD = torch_lin_sinkhorn.Lin_Sinkhorn_AD.apply 133 | fixed_noise = torch.DoubleTensor(64, nz, 1, 1).normal_(0, 1).to(device) 134 | one = torch.tensor(1, dtype=torch.float).double() 135 | mone = one * -1 136 | 137 | # setup optimizer 138 | optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr) 139 | optimizerE = torch.optim.RMSprop(netE.parameters(), lr=lr) 140 | 141 | time = timeit.default_timer() 142 | gen_iterations = 0 143 | 144 | for t in range(max_iter): 145 | data_iter = iter(trn_loader) 146 | i = 0 147 | while i < len(trn_loader): 148 | # --------------------------- 149 | # Optimize over NetE 150 | # --------------------------- 151 | for p in netE.parameters(): 152 | p.requires_grad = True 153 | 154 | if gen_iterations < 25 or gen_iterations % 500 == 0: 155 | Diters = 10 # 10 156 | Giters = 1 157 | else: 158 | Diters = 1 # 5 159 | Giters = 1 160 | 161 | for j in range(Diters): 162 | if i == len(trn_loader): 163 | break 164 | 165 | for p in netE.parameters(): 166 | p.data.clamp_(-0.01, 0.01) # clamp parameters of NetE to a cube 167 | 168 | data = data_iter.next() 169 | i += 1 170 | netE.zero_grad() 171 | 172 | x_cpu, _ = data 173 | x = x_cpu.to(device) 174 | x_emb = netE(x) 175 | 176 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device) 177 | with torch.no_grad(): 178 | y = netG(noise) 179 | 180 | y_emb = netE(y) 181 | 182 | # Compute the loss 183 | sink_E = ( 184 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) 185 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) 186 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin) 187 | ) 188 | 189 | sink_E.backward(mone) 190 | optimizerE.step() 191 | 192 | # --------------------------- 193 | # Optimize over NetG 194 | # --------------------------- 195 | for p in netE.parameters(): 196 | p.requires_grad = False 197 | 198 | for j in range(Giters): 199 | if i == len(trn_loader): 200 | break 201 | 202 | data = data_iter.next() 203 | i += 1 204 | netG.zero_grad() 205 | 206 | x_cpu, _ = data 207 | x = x_cpu.to(device) 208 | x_emb = netE(x) 209 | 210 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device) 211 | y = netG(noise) 212 | y_emb = netE(y) 213 | 214 | # Compute the loss 215 | sink_G = ( 216 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) 217 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) 218 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin) 219 | ) 220 | 221 | sink_G.backward(one) 222 | optimizerG.step() 223 | 224 | gen_iterations += 1 225 | 226 | run_time = (timeit.default_timer() - time) / 60.0 227 | 228 | s = "[%3d / %3d] [%3d / %3d] [%5d] (%.2f m) loss_E: %.6f loss_G: %.6f" % ( 229 | t, 230 | max_iter, 231 | i * batch_size, 232 | batch_size * len(trn_loader), 233 | gen_iterations, 234 | run_time, 235 | sink_E, 236 | sink_G, 237 | ) 238 | 239 | s = s + "\n" 240 | file_to_print.write(s) 241 | file_to_print.flush() 242 | 243 | if gen_iterations % 100 == 0: 244 | with torch.no_grad(): 245 | fixed_noise = fixed_noise.float() 246 | y_fixed = netG(fixed_noise) 247 | y_fixed = y_fixed.mul(0.5).add(0.5) 248 | vutils.save_image( 249 | y_fixed, 250 | "{0}/fake_samples_{1}.png".format(name_dir, gen_iterations), 251 | ) 252 | 253 | if t % 10 == 0: 254 | torch.save( 255 | netG.state_dict(), 256 | "netG_celebA" + "_" + str(num_random_samples) + "_" + str(reg) + ".pth", 257 | ) 258 | torch.save( 259 | netE.state_dict(), 260 | "netE_celebA" + "_" + str(num_random_samples) + "_" + str(reg) + ".pth", 261 | ) 262 | 263 | 264 | # Dataset 265 | image_size = 64 266 | nc = 3 267 | nz = 128 268 | dataset_name = "celeba" 269 | data_root = "./data" 270 | 271 | # Parameters 272 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 273 | random_ = False 274 | manual_seed = 49 275 | lr = 5 * 1e-5 276 | R = 1 277 | max_iter = 10000 278 | niter_sin = 1000 279 | batch_size = 8000 280 | 281 | 282 | num_random_samples_list = [10, 100, 300, 600] 283 | reg_list = [1e-1, 1, 10] 284 | 285 | 286 | if __name__ == "__main__": 287 | for num_random_samples in num_random_samples_list: 288 | for reg in reg_list: 289 | training_func( 290 | num_random_samples, 291 | reg, 292 | batch_size, 293 | niter_sin, 294 | image_size, 295 | nc, 296 | nz, 297 | dataset_name, 298 | device, 299 | manual_seed, 300 | lr, 301 | max_iter, 302 | data_root, 303 | R, 304 | random_, 305 | ) 306 | -------------------------------------------------------------------------------- /EXP_GAN/train_models_cifar.py: -------------------------------------------------------------------------------- 1 | import random 2 | from scipy.special import lambertw 3 | import numpy as np 4 | import torch 5 | import timeit 6 | 7 | import torch.backends.cudnn as cudnn 8 | import torch.utils.data 9 | import torchvision.utils as vutils 10 | 11 | 12 | import models 13 | import data_loading 14 | import torch_lin_sinkhorn 15 | 16 | import os 17 | 18 | 19 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49): 20 | q = (1 / 2) + (R ** 2) / reg 21 | y = R ** 2 / (reg * nz) 22 | q = np.real((1 / 2) * np.exp(lambertw(y))) 23 | 24 | C = (2 * q) ** (nz / 4) 25 | 26 | np.random.seed(seed) 27 | var = (q * reg) / 4 28 | U = np.random.multivariate_normal( 29 | np.zeros(nz), var * np.eye(nz), num_random_samples 30 | ) 31 | U = torch.from_numpy(U) 32 | 33 | U_init = U.to(device) 34 | C_init = torch.DoubleTensor([C]).to(device) 35 | q_init = torch.DoubleTensor([q]).to(device) 36 | 37 | return q_init, C_init, U_init 38 | 39 | 40 | def training_func( 41 | num_random_samples, 42 | reg, 43 | batch_size, 44 | niter_sin, 45 | image_size, 46 | nc, 47 | nz, 48 | dataset_name, 49 | device, 50 | manual_seed, 51 | lr, 52 | max_iter, 53 | data_root, 54 | R, 55 | random_, 56 | ): 57 | 58 | name_dir = ( 59 | "sampled_images_cifar_max" + "_" + str(num_random_samples) + "_" + str(reg) 60 | ) 61 | if os.path.exists(name_dir) == 0: 62 | os.mkdir(name_dir) 63 | 64 | epsilon = reg 65 | hidden_dim = nz 66 | 67 | # Create an output file 68 | file_to_print = open( 69 | "results_training_cifar_max" 70 | + "_" 71 | + str(num_random_samples) 72 | + "_" 73 | + str(reg) 74 | + ".csv", 75 | "w", 76 | ) 77 | file_to_print.write(str(device) + "\n") 78 | file_to_print.flush() 79 | 80 | # Fix the seed 81 | np.random.seed(seed=manual_seed) 82 | random.seed(manual_seed) 83 | torch.manual_seed(manual_seed) 84 | torch.cuda.manual_seed(manual_seed) 85 | cudnn.benchmark = True 86 | 87 | # Initialisation of weights 88 | def weights_init(m): 89 | classname = m.__class__.__name__ 90 | if classname.find("Conv") != -1: 91 | m.weight.data.normal_(0.0, 0.02) 92 | elif classname.find("BatchNorm") != -1: 93 | m.weight.data.normal_(1.0, 0.02) 94 | m.bias.data.fill_(0) 95 | elif classname.find("Linear") != -1: 96 | m.weight.data.normal_(0.0, 0.1) 97 | m.bias.data.fill_(0) 98 | 99 | trn_dataset = data_loading.get_data( 100 | image_size, dataset_name, data_root, train_flag=True 101 | ) 102 | trn_loader = torch.utils.data.DataLoader( 103 | trn_dataset, batch_size=batch_size, shuffle=True, num_workers=1 104 | ) 105 | 106 | # construct Generator and Embedding: 107 | q, C, U_init = compute_constants( 108 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed 109 | ) 110 | 111 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64) 112 | D_embedding = models.Embedding( 113 | image_size, 114 | nc, 115 | reg, 116 | device, 117 | q, 118 | C, 119 | U_init, 120 | k=hidden_dim, 121 | num_random_samples=num_random_samples, 122 | R=R, 123 | seed=manual_seed, 124 | ndf=64, 125 | random=random_, 126 | ) 127 | 128 | netG = models.NetG(G_generator) 129 | netE = models.NetE(D_embedding) 130 | 131 | netG.apply(weights_init) 132 | netE.apply(weights_init) 133 | 134 | netG.to(device) 135 | netE.to(device) 136 | 137 | lin_Sinkhorn_AD = torch_lin_sinkhorn.Lin_Sinkhorn_AD.apply 138 | fixed_noise = torch.DoubleTensor(64, nz, 1, 1).normal_(0, 1).to(device) 139 | one = torch.tensor(1, dtype=torch.float).double() 140 | mone = one * -1 141 | 142 | # setup optimizer 143 | optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr) 144 | optimizerE = torch.optim.RMSprop(netE.parameters(), lr=lr) 145 | 146 | time = timeit.default_timer() 147 | gen_iterations = 0 148 | 149 | for t in range(max_iter): 150 | data_iter = iter(trn_loader) 151 | i = 0 152 | while i < len(trn_loader): 153 | # --------------------------- 154 | # Optimize over NetE 155 | # --------------------------- 156 | for p in netE.parameters(): 157 | p.requires_grad = True 158 | 159 | if gen_iterations < 25 or gen_iterations % 500 == 0: 160 | Diters = 10 # 10 161 | Giters = 1 162 | else: 163 | Diters = 1 # 5 164 | Giters = 1 165 | 166 | for j in range(Diters): 167 | if i == len(trn_loader): 168 | break 169 | 170 | for p in netE.parameters(): 171 | p.data.clamp_(-0.01, 0.01) # clamp parameters of NetE to a cube 172 | 173 | data = data_iter.next() 174 | i += 1 175 | netE.zero_grad() 176 | 177 | x_cpu, _ = data 178 | x = x_cpu.to(device) 179 | x_emb = netE(x) 180 | 181 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device) 182 | with torch.no_grad(): 183 | y = netG(noise) 184 | 185 | y_emb = netE(y) 186 | 187 | ### Compute the loss ### 188 | sink_E = ( 189 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) 190 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) 191 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin) 192 | ) 193 | 194 | sink_E.backward(mone) 195 | optimizerE.step() 196 | 197 | # --------------------------- 198 | # Optimize over NetG 199 | # --------------------------- 200 | for p in netE.parameters(): 201 | p.requires_grad = False 202 | 203 | for j in range(Giters): 204 | if i == len(trn_loader): 205 | break 206 | 207 | data = data_iter.next() 208 | i += 1 209 | netG.zero_grad() 210 | 211 | x_cpu, _ = data 212 | x = x_cpu.to(device) 213 | x_emb = netE(x) 214 | 215 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device) 216 | y = netG(noise) 217 | y_emb = netE(y) 218 | 219 | # Compute the loss 220 | sink_G = ( 221 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) 222 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) 223 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin) 224 | ) 225 | 226 | sink_G.backward(one) 227 | optimizerG.step() 228 | 229 | gen_iterations += 1 230 | 231 | run_time = (timeit.default_timer() - time) / 60.0 232 | 233 | s = "[%3d / %3d] [%3d / %3d] [%5d] (%.2f m) loss_E: %.6f loss_G: %.6f" % ( 234 | t, 235 | max_iter, 236 | i * batch_size, 237 | batch_size * len(trn_loader), 238 | gen_iterations, 239 | run_time, 240 | sink_E, 241 | sink_G, 242 | ) 243 | 244 | s = s + "\n" 245 | file_to_print.write(s) 246 | file_to_print.flush() 247 | 248 | if gen_iterations % 100 == 0: 249 | with torch.no_grad(): 250 | fixed_noise = fixed_noise.float() 251 | y_fixed = netG(fixed_noise) 252 | y_fixed = y_fixed.mul(0.5).add(0.5) 253 | vutils.save_image( 254 | y_fixed, 255 | "{0}/fake_samples_{1}.png".format(name_dir, gen_iterations), 256 | ) 257 | 258 | if t % 10 == 0: 259 | torch.save( 260 | netG.state_dict(), 261 | "netG_cifar_max" 262 | + "_" 263 | + str(num_random_samples) 264 | + "_" 265 | + str(reg) 266 | + ".pth", 267 | ) 268 | torch.save( 269 | netE.state_dict(), 270 | "netE_cifar_max" 271 | + "_" 272 | + str(num_random_samples) 273 | + "_" 274 | + str(reg) 275 | + ".pth", 276 | ) 277 | 278 | 279 | # Dataset 280 | image_size = 64 281 | nc = 3 282 | nz = 128 283 | dataset_name = "cifar10" 284 | data_root = "./data" 285 | 286 | ### Fixed parameters ### 287 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 288 | manual_seed = 49 289 | random_ = False 290 | lr = 5 * 1e-5 291 | R = 1 292 | max_iter = 10000 293 | niter_sin = 1000 294 | batch_size = 8000 295 | 296 | num_random_samples_list = [10, 100, 300, 600] 297 | reg_list = [1e-1, 1, 10] 298 | 299 | if __name__ == "__main__": 300 | for num_random_samples in num_random_samples_list: 301 | for reg in reg_list: 302 | training_func( 303 | num_random_samples, 304 | reg, 305 | batch_size, 306 | niter_sin, 307 | image_size, 308 | nc, 309 | nz, 310 | dataset_name, 311 | device, 312 | manual_seed, 313 | lr, 314 | max_iter, 315 | data_root, 316 | R, 317 | random_, 318 | ) 319 | -------------------------------------------------------------------------------- /FastSinkhorn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import time 4 | from scipy import special 5 | 6 | # Here the Regularized version goes from -\infinty to the true OT 7 | def compute_ROT(u, v, a, b, reg): 8 | res = reg * (np.dot(a, np.log(u)) + np.dot(b, np.log(v))) 9 | return res 10 | 11 | 12 | ################ Classical Sinkhorn Algorithm #################### 13 | 14 | def Sinkhorn(C, reg, a, b, delta=1e-9, lam=1e-6): 15 | 16 | K = np.exp(-C / reg) 17 | u = np.ones(np.shape(a)[0]) 18 | v = np.ones(np.shape(b)[0]) 19 | 20 | u_trans = np.dot(K, v) + lam # add regularization to avoid divide 0 21 | v_trans = np.dot(K.T, u) + lam # add regularization to avoid divide 0 22 | 23 | err_1 = np.sum(np.abs(u * u_trans - a)) 24 | err_2 = np.sum(np.abs(v * v_trans - b)) 25 | 26 | while True: 27 | if err_1 + err_2 > delta: 28 | u = a / u_trans 29 | v_trans = np.dot(K.T, u) + lam 30 | 31 | v = b / v_trans 32 | u_trans = np.dot(K, v) + lam 33 | 34 | err_1 = np.sum(np.abs(u * u_trans - a)) 35 | err_2 = np.sum(np.abs(v * v_trans - b)) 36 | else: 37 | return u, v 38 | 39 | 40 | # Classical Sinkhorn algorithm: Square Euclidean Distance 41 | def Sinkhorn_RBF(X, Y, reg, a, b, delta=1e-9, num_iter=50, lam=1e-100): 42 | start = time.time() 43 | 44 | acc = [] 45 | times = [] 46 | 47 | C = Square_Euclidean_Distance(X, Y) 48 | K = np.exp(-C / reg) 49 | u = np.ones(np.shape(a)[0]) 50 | v = np.ones(np.shape(b)[0]) 51 | 52 | u_trans = np.dot(K, v) + lam 53 | v_trans = np.dot(K.T, u) + lam 54 | 55 | for k in range(num_iter): 56 | 57 | u = a / u_trans 58 | v_trans = np.dot(K.T, u) + lam 59 | 60 | v = b / v_trans 61 | u_trans = np.dot(K, v) + lam 62 | 63 | ROT_trans = compute_ROT(u, v, a, b, reg) 64 | if np.isnan(ROT_trans) == True: 65 | return "Error" 66 | else: 67 | acc.append(compute_ROT(u, v, a, b, reg)) 68 | end = time.time() 69 | times.append(end - start) 70 | 71 | return acc[-1], np.array(acc), np.array(times) 72 | 73 | 74 | ################ Positive Random Features #################### 75 | 76 | # Positive Random Features Sinkhorn: K = AB 77 | def Lin_Sinkhorn(A, B, a, b, delta=1e-9, max_iter=1e5, lam=1e-100): 78 | u = np.ones(np.shape(a)[0]) 79 | v = np.ones(np.shape(b)[0]) 80 | u_trans = np.dot(A, np.dot(B, v)) + lam 81 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam 82 | 83 | err_1 = np.sum(np.abs(u * u_trans - a)) 84 | err_2 = np.sum(np.abs(v * v_trans - b)) 85 | k = 0 86 | while True and k < max_iter: 87 | if err_1 + err_2 > delta: 88 | u = a / u_trans 89 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam 90 | 91 | v = b / v_trans 92 | u_trans = np.dot(A, np.dot(B, v)) + lam 93 | 94 | err_1 = np.sum(np.abs(u * u_trans - a)) 95 | err_2 = np.sum(np.abs(v * v_trans - b)) 96 | k = k + 1 97 | else: 98 | return u, v 99 | return u, v 100 | 101 | 102 | # Positive Random Features Sinkhorn: Square Euclidean Distance 103 | def Lin_Sinkhorn_RBF( 104 | X, Y, reg, a, b, num_samples, seed=49, delta=1e-9, num_iter=50, lam=1e-100 105 | ): 106 | start = time.time() 107 | 108 | acc = [] 109 | times = [] 110 | 111 | R = theoritical_R(X, Y) 112 | A = Feature_Map_Gaussian(X, reg, R=R, num_samples=num_samples, seed=seed) 113 | B = Feature_Map_Gaussian(Y, reg, R=R, num_samples=num_samples, seed=seed).T 114 | 115 | u = np.ones(np.shape(a)[0]) 116 | v = np.ones(np.shape(b)[0]) 117 | u_trans = np.dot(A, np.dot(B, v)) + lam 118 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam 119 | 120 | for k in range(num_iter): 121 | u = a / u_trans 122 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam 123 | 124 | v = b / v_trans 125 | u_trans = np.dot(A, np.dot(B, v)) + lam 126 | 127 | ROT_trans = compute_ROT(u, v, a, b, reg) 128 | if np.isnan(ROT_trans) == True: 129 | return "Error" 130 | else: 131 | acc.append(ROT_trans) 132 | end = time.time() 133 | times.append(end - start) 134 | 135 | return acc[-1], np.array(acc), np.array(times) 136 | 137 | 138 | # Random Feature Map: Square Euclidean Distance 139 | def Feature_Map_Gaussian(X, reg, R=1, num_samples=100, seed=49): 140 | n, d = np.shape(X) 141 | 142 | # q = (1/2) + (R**2) / reg 143 | y = R ** 2 / (reg * d) 144 | q = np.real((1 / 2) * np.exp(special.lambertw(y))) 145 | C = (2 * q) ** (d / 4) 146 | 147 | var = (q * reg) / 4 148 | 149 | np.random.seed(seed) 150 | U = np.random.multivariate_normal(np.zeros(d), var * np.eye(d), num_samples) 151 | 152 | SED = Square_Euclidean_Distance(X, U) 153 | W = -(2 * SED) / reg 154 | V = np.sum(U ** 2, axis=1) / (reg * q) 155 | 156 | res_trans = V + W 157 | res_trans = C * np.exp(res_trans) 158 | 159 | res = (1 / np.sqrt(num_samples)) * res_trans 160 | 161 | return res 162 | 163 | 164 | def theoritical_R(X, Y): 165 | norm_X = np.linalg.norm(X, axis=1) 166 | norm_Y = np.linalg.norm(Y, axis=1) 167 | norm_max = np.maximum(np.max(norm_X), np.max(norm_Y)) 168 | 169 | return norm_max 170 | 171 | 172 | # Random Feature Map: Arccos Kernel 173 | def Feature_Map_Arccos(X, s=1, sig=1.5, num_samples=100, kappa=1e-6, seed=49): 174 | n, d = np.shape(X) 175 | C = (sig ** (d / 2)) * np.sqrt(2) 176 | 177 | np.random.seed(seed) 178 | U = np.random.multivariate_normal(np.zeros(d), (sig ** 2) * np.eye(d), num_samples) 179 | 180 | IP = Inner_Product(X, U) 181 | res_trans = C * (np.maximum(IP, 0) ** s) 182 | 183 | V = ((sig ** 2) - 1) / (sig ** 2) 184 | V = -(1 / 4) * V * np.sum(U ** 2, axis=1) 185 | V = np.exp(V) 186 | 187 | res = np.zeros((n, num_samples + 1)) 188 | res[:, :num_samples] = (1 / np.sqrt(num_samples)) * res_trans * V 189 | res[:, -1] = kappa 190 | 191 | return res 192 | 193 | 194 | ######################## Nystrom Method ####################### 195 | 196 | # Nystrom Sinkhorn: K =VA^{-1}V 197 | def Nys_Sinkhorn(A, V, a, b, delta=1e-9, max_iter=1e3, lam=1e-100): 198 | u = np.ones(np.shape(a)[0]) 199 | v = np.ones(np.shape(b)[0]) 200 | 201 | u_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, v))) + lam 202 | v_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, u))) + lam 203 | 204 | err_1 = np.sum(np.abs(u * u_trans - a)) 205 | err_2 = np.sum(np.abs(v * v_trans - b)) 206 | k = 0 207 | while True and k < max_iter: 208 | if err_1 + err_2 > delta: 209 | u = a / u_trans 210 | v_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, u))) + lam 211 | 212 | v = b / v_trans 213 | u_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, v))) + lam 214 | 215 | err_1 = np.sum(np.abs(u * u_trans - a)) 216 | err_2 = np.sum(np.abs(v * v_trans - b)) 217 | k = k + 1 218 | else: 219 | return u, v 220 | return u, v 221 | 222 | 223 | # Nystrom Sinkhorn: Square Euclidean Distance 224 | def Nys_Sinkhorn_RBF( 225 | X, Y, reg, a, b, rank, seed=49, delta=1e-9, num_iter=50, lam=1e-100 226 | ): 227 | start = time.time() 228 | 229 | acc = [] 230 | times = [] 231 | 232 | n = np.shape(X)[0] 233 | m = np.shape(Y)[0] 234 | 235 | a_nys = np.zeros(n + m) 236 | a_nys[:n] = a 237 | 238 | b_nys = np.zeros(n + m) 239 | b_nys[n:] = b 240 | 241 | A, V = Nystrom_RBF(X, Y, reg, rank, seed=seed, stable=1e-10) 242 | A_inv = np.linalg.inv(A) 243 | 244 | u = np.ones(np.shape(a_nys)[0]) 245 | v = np.ones(np.shape(b_nys)[0]) 246 | 247 | u_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, v))) + lam 248 | v_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, u))) + lam 249 | 250 | for k in range(num_iter): 251 | 252 | u = a_nys / u_trans 253 | v_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, u))) + lam 254 | 255 | v = b_nys / v_trans 256 | u_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, v))) + lam 257 | 258 | u_rot, v_rot = u[:n], v[n:] 259 | 260 | ROT_trans = compute_ROT(u_rot, v_rot, a, b, reg) 261 | if np.isnan(ROT_trans) == True: 262 | return "Error" 263 | else: 264 | acc.append(ROT_trans) 265 | end = time.time() 266 | times.append(end - start) 267 | 268 | return acc[-1], np.array(acc), np.array(times) 269 | 270 | 271 | # Uniform Nyström: Square Euclidean Distance 272 | def Nystrom_RBF(X, Y, reg, rank, seed=49, stable=1e-100): 273 | n, d = np.shape(X) 274 | m, d = np.shape(Y) 275 | n_tot = n + m 276 | Z = np.concatenate((X, Y), axis=0) 277 | 278 | rank_trans = int(np.minimum(rank, n_tot)) 279 | 280 | np.random.seed(seed) 281 | ind = np.random.choice(n_tot, rank_trans, replace=False) 282 | ind = np.sort(ind) 283 | 284 | Z_1 = Z[ind, :] 285 | A = np.exp(-Square_Euclidean_Distance(Z_1, Z_1) / reg) 286 | A = A + stable * np.eye(rank_trans) 287 | V = np.exp(-Square_Euclidean_Distance(Z, Z_1) / reg) 288 | 289 | return A, V 290 | 291 | 292 | # Recursive Nyström Sampling: Square Euclidean Distance 293 | def recursive_Nystrom_RBF(X, Y, rank, reg, seed=49, stable=1e-100): 294 | Z = np.concatenate((X, Y), axis=0) 295 | n, d = np.shape(Z) 296 | 297 | ## Start of algorithm 298 | sLevel = rank 299 | oversamp = np.log(sLevel) 300 | k = int(sLevel / (4 * oversamp)) + 1 301 | nLevels = int(np.log(n / sLevel) / np.log(2)) + 1 302 | 303 | np.random(seed) 304 | perm = np.random.permutation(n) 305 | 306 | # set up sizes for recursive levels 307 | lSize = np.zeros(nLevels) 308 | lSize[0] = n 309 | for i in range(1, nLevels): 310 | lSize[i] = int(lSize[i - 1] / 2) + 1 311 | 312 | # rInd: indices of points selected at previous level of recursion 313 | # at the base level it's just a uniform sample of ~sLevel points 314 | samp = np.arange(lSize[-1]).astype(int) 315 | rInd = perm[samp] 316 | weights = np.ones((np.shape(rInd)[0], 1)) 317 | 318 | # we need the diagonal of the whole kernel matrix 319 | kDiag = np.zeros(n) 320 | for i in range(n): 321 | kDiag[i] = np.exp(-Square_Euclidean_Distance(Z[i, :], Z[i, :]) / reg) 322 | 323 | # Main recursion, unrolled for efficiency 324 | for l in range(nLevels - 1, -1, -1): 325 | np.random(seed + l) 326 | # indices of current uniform sample 327 | rIndCurr = perm[: int(lSize[l])] 328 | # build sampled kernel 329 | SED = Square_Euclidean_Distance(Z[rIndCurr, :], Z[rInd, :]) 330 | KS = np.exp(-SED / reg) 331 | SKS = KS[samp, :] 332 | SKSn = np.shape(SKS)[0] 333 | 334 | # optimal lambda for taking O(klogk) samples 335 | if k >= SKSn: 336 | # for the rare chance we take less than k samples in a round 337 | lam = 10e-6 338 | # don't set to exactly 0 to avoid stability issues 339 | else: 340 | ###### 341 | Q = np.diag(weights.reshape(SKSn)) 342 | Q = np.dot(Q, SKS) 343 | Oper = Q * weights.reshape(SKSn, 1) 344 | eigen = np.sort(np.linalg.eig(Oper)[1])[-k:] 345 | lam = ( 346 | np.sum(np.diag(SKS) * (weights ** 2)) - np.sum(np.abs(np.real(eigen))) 347 | ) / k 348 | 349 | # compute and sample by lambda ridge leverage scores 350 | if l != 0: 351 | # on intermediate levels, we independently sample each column 352 | # by its leverage score. the sample size is sLevel in expectation 353 | R = np.linalg.inv(SKS + np.diag(np.dot(lam, weights ** (-2)))) 354 | # max(0,.) helps avoid numerical issues, unnecessary in theory 355 | z = np.sum(np.dot(KS, R) * KS, 1) 356 | z = kDiag[rIndCurr] - z 357 | z = np.maximum(0, z) 358 | z = oversamp * (1 / lam) * z 359 | levs = np.minimum(1, z) 360 | 361 | M = np.random.rand(1, int(lSize[l])) - levs 362 | ind_matrix = M < 0 363 | ind_matrix = ind_matrix.reshape(int(lSize[l])) 364 | samp = np.where(ind_matrix == 1)[0] 365 | # with very low probability, we could accidentally sample no 366 | # columns. In this case, just take a fixed size uniform sample. 367 | samp_list = np.ndarray.tolist(samp) 368 | if len(samp_list) == 0: 369 | levs[:] = sLevel / lSize[l] 370 | samp = np.random.choice(int(lSize[l]), sLevel, replace=False) 371 | 372 | weights = np.sqrt(1.0 / (levs[samp])) 373 | 374 | else: 375 | # on the top level, we sample exactly s landmark points without replacement 376 | R = np.linalg.inv(SKS + np.diag(np.dot(lam, weights ** (-2)))) 377 | z = np.sum(np.dot(KS, R) * KS, 1) 378 | z = kDiag[rIndCurr] - z 379 | z = np.maximum(0, z) 380 | levs = np.minimum(1, (1 / lam) * z) 381 | ######## 382 | total_sum = np.sum(levs) 383 | levs_norm = levs / total_sum 384 | samp = np.random.choice( 385 | np.shape(levs)[0], rank, replace=False, p=levs_norm.reshape(-1) 386 | ) 387 | 388 | rInd = perm[samp] 389 | 390 | # build final Nystrom approximation 391 | # pinv or inversion with slight regularization helps stability 392 | V = np.exp(-Square_Euclidean_Distance(Z, Z[rInd, :]) / reg) 393 | A = V[rInd, :] 394 | A = A + stable * np.eye(rank) 395 | # A_inv = np.linalg.inv(A) 396 | 397 | return A, V 398 | 399 | 400 | # Adaptative Rank Nystrom: Square Euclidean Distance 401 | def Adaptive_Nystrom_RBF(X, Y, reg, tau=1e-1, seed=49): 402 | err = 1e30 403 | r = 1 404 | while err > tau: 405 | r = 2 * r 406 | A, V = Nystrom_RBF(X, Y, reg, r) 407 | 408 | diag = np.zeros(np.shape(A)[0]) 409 | for i in range(np.shape(A)[0]): 410 | M = np.dot(V, A) 411 | diag[i] = np.dot(M[i, :], V.T[:, i]) 412 | 413 | err = 1 - np.min(diag) 414 | 415 | return A, V 416 | 417 | 418 | # Square Euclidean Distance 419 | def Square_Euclidean_Distance(X, Y): 420 | X_col = X[:, np.newaxis] 421 | Y_lin = Y[np.newaxis, :] 422 | C = np.sum((X_col - Y_lin) ** 2, 2) 423 | return C 424 | 425 | 426 | # Arccos Cost 427 | def Arccos_Cost(X, Y, s=1, kappa=1e-6): 428 | if len(np.shape(X)) == 1: 429 | X = X.reshape(1, -1) 430 | 431 | if len(np.shape(Y)) == 1: 432 | Y = Y.reshape(1, -1) 433 | 434 | n, d = np.shape(X) 435 | m, d = np.shape(Y) 436 | M = np.zeros((n, m)) 437 | for i in range(n): 438 | for j in range(m): 439 | norm = np.linalg.norm(X[i, :]) * np.linalg.norm(Y[j, :]) 440 | theta = np.arccos(Inner_Product(X[i, :], Y[j, :]) / norm) 441 | if s == 0: 442 | M[i, j] = (1 / np.pi) * (np.pi - theta) 443 | if s == 1: 444 | J = np.sin(theta) + (np.pi - theta) * np.cos(theta) 445 | M[i, j] = (1 / np.pi) * norm * J 446 | if s == 2: 447 | J = 3 * np.sin(theta) * np.cos(theta) + (np.pi - theta) * ( 448 | 1 + 2 * np.cos(theta) ** 2 449 | ) 450 | M[i, j] = (1 / np.pi) * (norm ** 2) * J 451 | 452 | M = M + kappa 453 | M = -np.log(M) 454 | return M 455 | 456 | 457 | # Inner Product Cost 458 | def Inner_Product(X, Y): 459 | if len(np.shape(X)) == 1: 460 | X = X.reshape(1, -1) 461 | 462 | if len(np.shape(Y)) == 1: 463 | Y = Y.reshape(1, -1) 464 | 465 | n, d = np.shape(X) 466 | m, d = np.shape(Y) 467 | M = np.zeros((n, m)) 468 | for i in range(n): 469 | for j in range(m): 470 | M[i, j] = np.sum(X[i, :] * Y[j, :]) 471 | return M 472 | --------------------------------------------------------------------------------