├── requirements.txt
├── EXP_GAN
├── .DS_Store
├── data_loading.py
├── load_models.py
├── download_celebA.py
├── torch_lin_sinkhorn.py
├── models.py
├── train_models_celebA.py
└── train_models_cifar.py
├── results
├── celebA_samples.png
├── cifar10_samples.png
└── plot_accuracy_ROT_sphere.jpg
├── .gitignore
├── README.md
└── FastSinkhorn.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.5.0
2 | scipy==1.5.0
3 | PIL==7.0.0
4 |
--------------------------------------------------------------------------------
/EXP_GAN/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/EXP_GAN/.DS_Store
--------------------------------------------------------------------------------
/results/celebA_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/celebA_samples.png
--------------------------------------------------------------------------------
/results/cifar10_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/cifar10_samples.png
--------------------------------------------------------------------------------
/results/plot_accuracy_ROT_sphere.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meyerscetbon/LinearSinkhorn/HEAD/results/plot_accuracy_ROT_sphere.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Compiled source #
7 | ###################
8 | *.com
9 | *.class
10 | *.dll
11 | *.exe
12 | *.o
13 | *.so
14 |
15 | # Packages #
16 | ############
17 | # it's better to unpack these files and commit the raw source
18 | # git has its own built in compression methods
19 | *.7z
20 | *.dmg
21 | *.gz
22 | *.iso
23 | *.jar
24 | *.rar
25 | *.tar
26 | *.zip
27 |
28 | # Logs and databases #
29 | ######################
30 | *.log
31 | *.sql
32 | *.sqlite
33 |
34 | # OS generated files #
35 | ######################
36 | .DS_Store
37 | .DS_Store?
38 | ._*
39 | .Spotlight-V100
40 | .Trashes
41 | ehthumbs.db
42 | Thumbs.db
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Linear Time Sinkhorn Divergences using Positive Features
2 | Code of the paper by Meyer Scetbon and Marco Cuturi
3 |
4 | ## Approximation of the Regularized Optimal Transport in Linear Time
5 | In this work, we show that one can approximate the regularized optimal transport in linear time with respect to the number of samples for some usual cost functions, e.g. the square Euclidean distance. We present the time-accuracy tradeoff between different methods to compute the regularized OT when the samples live on the unit sphere.
6 | 
7 |
8 | The implementation of the recursive Nystrom is adapted from the MATLAB implementation (https://github.com/cnmusco/recursive-nystrom).
9 |
10 | ## Generative Adversarial Network
11 | We also show that our method offers a constructive way to build a kernel and then a cost function adapted to the problem in order to compare distributions using optimal transport. We show some visual results of the generative models learned using our method on CIFAR10 (left) and CelebA (right).
12 |
13 |
14 |
15 |
16 |
17 |
18 | The implementation of the WGAN is a code adapted from the MMD-GAN implementation (https://github.com/OctoberChang/MMD-GAN).
19 |
20 |
21 |
22 | This repository contains a Python implementation of the algorithms presented in the [paper](https://arxiv.org/pdf/2006.07057.pdf).
23 |
--------------------------------------------------------------------------------
/EXP_GAN/data_loading.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | from PIL import Image
4 | from os import listdir
5 | from os.path import join
6 |
7 | import torchvision.transforms as transforms
8 | import torchvision.datasets as dset
9 |
10 |
11 | ### Get Data ###
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
14 |
15 |
16 | def load_img(filepath):
17 | img = Image.open(filepath).convert("RGB")
18 | return img
19 |
20 |
21 | class FolderWithImages(data.Dataset):
22 | def __init__(self, root, input_transform=None, target_transform=None):
23 | super(FolderWithImages, self).__init__()
24 | self.image_filenames = [
25 | join(root, x) for x in listdir(root) if is_image_file(x.lower())
26 | ]
27 |
28 | self.input_transform = input_transform
29 | self.target_transform = target_transform
30 |
31 | def __getitem__(self, index):
32 | input = load_img(self.image_filenames[index])
33 | target = input.copy()
34 | if self.input_transform:
35 | input = self.input_transform(input)
36 | if self.target_transform:
37 | target = self.target_transform(target)
38 |
39 | return input, target
40 |
41 | def __len__(self):
42 | return len(self.image_filenames)
43 |
44 |
45 | class ALICropAndScale(object):
46 | def __call__(self, img):
47 | return img.resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7))
48 |
49 |
50 | def get_data(image_size, dataset_name, data_root, train_flag=True):
51 | if dataset_name == "cifar10":
52 |
53 | transform = transforms.Compose(
54 | [
55 | transforms.Resize(image_size),
56 | transforms.CenterCrop(image_size),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
59 | ]
60 | )
61 |
62 | dataset = dset.CIFAR10(
63 | root=data_root, download=True, train=train_flag, transform=transform
64 | )
65 |
66 | elif dataset_name == "mnist":
67 | transform = transforms.Compose(
68 | [
69 | transforms.Resize(image_size),
70 | transforms.CenterCrop(image_size),
71 | transforms.ToTensor(),
72 | transforms.Normalize((0.5,), (0.5,)),
73 | ]
74 | )
75 |
76 | dataset = dset.MNIST(
77 | root=data_root, download=True, train=train_flag, transform=transform
78 | )
79 |
80 | elif dataset_name == "celeba":
81 | imdir = "CelebA/splits/train" if train_flag else "CelebA/splits/val"
82 | dataroot = os.path.join(data_root, imdir)
83 | if image_size != 64:
84 | raise ValueError("the image size for CelebA dataset need to be 64!")
85 |
86 | dataset = FolderWithImages(
87 | root=dataroot,
88 | input_transform=transforms.Compose(
89 | [
90 | ALICropAndScale(),
91 | transforms.ToTensor(),
92 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
93 | ]
94 | ),
95 | target_transform=transforms.ToTensor(),
96 | )
97 |
98 | return dataset
99 |
--------------------------------------------------------------------------------
/EXP_GAN/load_models.py:
--------------------------------------------------------------------------------
1 | import random
2 | from scipy.special import lambertw
3 | import numpy as np
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torch.utils.data
7 | import torchvision.utils as vutils
8 |
9 | import models
10 |
11 |
12 | # CIFAR10
13 | # image_size = 64
14 | # nc = 3
15 | # nz = 128
16 | # dataset_name = 'cifar10'
17 |
18 | # CELEBA
19 | dataset_name = "celeba"
20 | image_size = 64
21 | nc = 3
22 | nz = 128
23 |
24 |
25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 | print(torch.cuda.is_available(), device)
27 |
28 | manual_seed = 49
29 | random_ = False
30 | R = 1
31 | batch_size = 8000
32 |
33 |
34 | num_random_samples = 600
35 | reg = 1
36 |
37 | epsilon = reg
38 | hidden_dim = nz
39 |
40 | # Fix the seed
41 | np.random.seed(seed=manual_seed)
42 | random.seed(manual_seed)
43 | torch.manual_seed(manual_seed)
44 | torch.cuda.manual_seed(manual_seed)
45 | cudnn.benchmark = True
46 |
47 |
48 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49):
49 | q = (1 / 2) + (R ** 2) / reg
50 | y = R ** 2 / (reg * nz)
51 | q = np.real((1 / 2) * np.exp(lambertw(y)))
52 |
53 | C = (2 * q) ** (nz / 4)
54 |
55 | np.random.seed(seed)
56 | var = (q * reg) / 4
57 | U = np.random.multivariate_normal(
58 | np.zeros(nz), var * np.eye(nz), num_random_samples
59 | )
60 | U = torch.from_numpy(U)
61 |
62 | U_init = U.to(device)
63 | C_init = torch.DoubleTensor([C]).to(device)
64 | q_init = torch.DoubleTensor([q]).to(device)
65 |
66 | return q_init, C_init, U_init
67 |
68 |
69 | q, C, U_init = compute_constants(
70 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed
71 | )
72 | q, C, U_init = q.to(device), C.to(device), U_init.to(device)
73 |
74 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64)
75 | D_embedding = models.Embedding(
76 | image_size,
77 | nc,
78 | reg,
79 | device,
80 | q,
81 | C,
82 | U_init,
83 | k=hidden_dim,
84 | num_random_samples=num_random_samples,
85 | R=R,
86 | seed=manual_seed,
87 | ndf=64,
88 | random=random_,
89 | )
90 |
91 | netG = models.NetG(G_generator)
92 | path_model_G = "netG_celebA_600_1.pth"
93 | netG.load_state_dict(
94 | torch.load(path_model_G, map_location="cpu")
95 | ) # If on cluster comment map_location
96 | netG.to(device)
97 |
98 | netE = models.NetE(D_embedding)
99 | path_model_E = "netE_cifar_max_600_1.pth"
100 | netE.load_state_dict(
101 | torch.load(path_model_E, map_location="cpu")
102 | ) # If on cluster comment map_location
103 | netE.to(device)
104 |
105 |
106 | # Choose a random seed to sample a random image from the generator
107 | manual_seed = 123
108 | np.random.seed(seed=manual_seed)
109 | random.seed(manual_seed)
110 | torch.manual_seed(manual_seed)
111 | torch.cuda.manual_seed(manual_seed)
112 | cudnn.benchmark = True
113 |
114 |
115 | batch_size_noise = 32
116 | fixed_noise = torch.DoubleTensor(batch_size_noise, nz, 1, 1).normal_(0, 1).to(device)
117 | fixed_noise = fixed_noise.float()
118 | y_fixed = netG(fixed_noise) # between -1 and 1
119 |
120 | # A sample from the trained model
121 | y_fixed = y_fixed.mul(0.5).add(0.5)
122 | vutils.save_image(y_fixed, "celebA_image_vf_600.png")
123 |
--------------------------------------------------------------------------------
/EXP_GAN/download_celebA.py:
--------------------------------------------------------------------------------
1 | """
2 | Modification of
3 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py
4 | - http://stackoverflow.com/a/39225039
5 | """
6 | from __future__ import print_function
7 | import os
8 | import zipfile
9 | import requests
10 |
11 |
12 | def download_file_from_google_drive(id, destination):
13 | URL = "https://docs.google.com/uc?export=download"
14 | session = requests.Session()
15 |
16 | response = session.get(URL, params={"id": id}, stream=True)
17 | token = get_confirm_token(response)
18 |
19 | if token:
20 | params = {"id": id, "confirm": token}
21 | response = session.get(URL, params=params, stream=True)
22 |
23 | save_response_content(response, destination)
24 |
25 |
26 | def get_confirm_token(response):
27 | for key, value in response.cookies.items():
28 | if key.startswith("download_warning"):
29 | return value
30 | return None
31 |
32 |
33 | def save_response_content(response, destination, chunk_size=32 * 1024):
34 | total_size = int(response.headers.get("content-length", 0))
35 | with open(destination, "wb") as f:
36 | for chunk in tqdm(
37 | response.iter_content(chunk_size),
38 | total=total_size,
39 | unit="B",
40 | unit_scale=True,
41 | desc=destination,
42 | ):
43 | if chunk: # filter out keep-alive new chunks
44 | f.write(chunk)
45 |
46 |
47 | def unzip(filepath):
48 | print("Extracting: " + filepath)
49 | base_path = os.path.dirname(filepath)
50 | with zipfile.ZipFile(filepath) as zf:
51 | zf.extractall(base_path)
52 | os.remove(filepath)
53 |
54 |
55 | def download_celeb_a(base_path):
56 | data_path = os.path.join(base_path, "CelebA")
57 | images_path = os.path.join(data_path, "images")
58 | if os.path.exists(data_path):
59 | print("[!] Found Celeb-A - skip")
60 | return
61 |
62 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
63 | save_path = os.path.join(base_path, filename)
64 |
65 | if os.path.exists(save_path):
66 | print("[*] {} already exists".format(save_path))
67 | else:
68 | download_file_from_google_drive(drive_id, save_path)
69 |
70 | # zip_dir = ''
71 | with zipfile.ZipFile(save_path) as zf:
72 | # zip_dir = zf.namelist()[0]
73 | zf.extractall(base_path)
74 | if not os.path.exists(data_path):
75 | os.mkdir(data_path)
76 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path)
77 | os.remove(save_path)
78 |
79 |
80 | def prepare_data_dir(path="./data"):
81 | if not os.path.exists(path):
82 | os.mkdir(path)
83 |
84 |
85 | # check, if file exists, make link
86 | def check_link(in_dir, basename, out_dir):
87 | in_file = os.path.join(in_dir, basename)
88 | if os.path.exists(in_file):
89 | link_file = os.path.join(out_dir, basename)
90 | rel_link = os.path.relpath(in_file, out_dir)
91 | os.symlink(rel_link, link_file)
92 |
93 |
94 | def add_splits(base_path):
95 | data_path = os.path.join(base_path, "CelebA")
96 | images_path = os.path.join(data_path, "images")
97 | train_dir = os.path.join(data_path, "splits", "train")
98 | valid_dir = os.path.join(data_path, "splits", "valid")
99 | test_dir = os.path.join(data_path, "splits", "test")
100 | if not os.path.exists(train_dir):
101 | os.makedirs(train_dir)
102 | if not os.path.exists(valid_dir):
103 | os.makedirs(valid_dir)
104 | if not os.path.exists(test_dir):
105 | os.makedirs(test_dir)
106 |
107 | # these constants based on the standard CelebA splits
108 | NUM_EXAMPLES = 202599
109 | TRAIN_STOP = 162770
110 | VALID_STOP = 182637
111 |
112 | for i in range(0, TRAIN_STOP):
113 | basename = "{:06d}.jpg".format(i + 1)
114 | check_link(images_path, basename, train_dir)
115 | for i in range(TRAIN_STOP, VALID_STOP):
116 | basename = "{:06d}.jpg".format(i + 1)
117 | check_link(images_path, basename, valid_dir)
118 | for i in range(VALID_STOP, NUM_EXAMPLES):
119 | basename = "{:06d}.jpg".format(i + 1)
120 | check_link(images_path, basename, test_dir)
121 |
122 |
123 | if __name__ == "__main__":
124 | base_path = "./"
125 | prepare_data_dir()
126 | download_celeb_a(base_path)
127 | add_splits(base_path)
128 |
--------------------------------------------------------------------------------
/EXP_GAN/torch_lin_sinkhorn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Lin_Sinkhorn_AD(torch.autograd.Function):
5 | @staticmethod
6 | def forward(ctx, x_emb, y_emb, reg, niter_sin, lam=1e-6, tau=1e-9):
7 | phi_x = x_emb.squeeze().type(torch.DoubleTensor)
8 | phi_y = y_emb.squeeze().type(torch.DoubleTensor)
9 |
10 | n = phi_x.size()[0]
11 | m = phi_y.size()[0]
12 |
13 | a = (1.0 / n) * torch.ones(n)
14 | a = a.type(torch.DoubleTensor)
15 |
16 | b = (1.0 / m) * torch.ones(m)
17 | b = b.type(torch.DoubleTensor)
18 |
19 | actual_nits = 0
20 |
21 | u = 1.0 * torch.ones(n).type(torch.DoubleTensor)
22 | v = 1.0 * torch.ones(m).type(torch.DoubleTensor)
23 | err = 0.0
24 |
25 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam
26 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam
27 |
28 | for k in range(niter_sin):
29 | u = a / u_trans
30 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam
31 |
32 | v = b / v_trans
33 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam
34 |
35 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum(
36 | torch.abs(v * v_trans - b)
37 | )
38 |
39 | actual_nits += 1
40 | if err < tau:
41 | break
42 |
43 | if k % 10 == 0:
44 | ### Stpping Criteria ###s
45 | with torch.no_grad():
46 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum(
47 | torch.abs(v * v_trans - b)
48 | )
49 | if err < tau:
50 | break
51 |
52 | ctx.u = u
53 | ctx.v = v
54 | ctx.reg = reg
55 | ctx.phi_x = phi_x
56 | ctx.phi_y = phi_y
57 |
58 | cost = reg * (torch.sum(a * torch.log(u)) + torch.sum(b * torch.log(v)) - 1)
59 | return cost
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | u = ctx.u
64 | v = ctx.v
65 | reg = ctx.reg
66 | phi_x = ctx.phi_x
67 | phi_y = ctx.phi_y
68 |
69 | grad_input = grad_output.clone()
70 | grad_phi_x = (
71 | grad_input
72 | * torch.matmul(u.view(-1, 1), torch.matmul(phi_y.t(), v).view(1, -1))
73 | * (-reg)
74 | )
75 | grad_phi_y = (
76 | grad_input
77 | * torch.matmul(v.view(-1, 1), torch.matmul(phi_x.t(), u).view(1, -1))
78 | * (-reg)
79 | )
80 |
81 | return grad_phi_x, grad_phi_y, None, None, None, None, None
82 |
83 |
84 | def Lin_Sinkhorn(
85 | phi_x, phi_y, reg, niter_sin, device, lam=1e-6, tau=1e-9, stabilize=False
86 | ):
87 | phi_x = phi_x.squeeze().type(torch.DoubleTensor).to(device)
88 | phi_y = phi_y.squeeze().type(torch.DoubleTensor).to(device)
89 |
90 | n = phi_x.size()[0]
91 | m = phi_y.size()[0]
92 |
93 | a = (1.0 / n) * torch.ones(n)
94 | a = a.type(torch.DoubleTensor).to(device)
95 |
96 | b = (1.0 / m) * torch.ones(m)
97 | b = b.type(torch.DoubleTensor).to(device)
98 |
99 | actual_nits = 0
100 | if stabilize == True:
101 | alpha, beta, err = torch.zeros(n).to(device), torch.zeros(m).to(device), 0.0
102 | for i in range(niter_sin):
103 | alpha_res = alpha
104 | beta_res = beta
105 |
106 | lin_M = torch.exp(alpha / reg) * torch.matmul(
107 | phi_x, torch.matmul(phi_y.t(), torch.exp(beta / reg))
108 | )
109 | lin_M = lin_M + lam
110 | alpha = reg * (torch.log(a) - torch.log(lin_M)) + alpha
111 |
112 | lin_M_t = torch.exp(beta / reg) * torch.matmul(
113 | phi_y, torch.matmul(phi_x.t(), torch.exp(alpha / reg))
114 | )
115 | lin_M_t = lin_M + lam
116 | beta = reg * (torch.log(b) - torch.log(lin_M_t)) + beta
117 |
118 | err = (alpha - alpha_res).abs().sum() + (beta - beta_res).abs().sum()
119 |
120 | actual_nits += 1
121 | if err < tau:
122 | break
123 | cost = torch.sum(a * alpha) + torch.sum(b * beta)
124 | print(cost)
125 |
126 | else:
127 | u = 1.0 * torch.ones(n).type(torch.DoubleTensor).to(device)
128 | v = 1.0 * torch.ones(m).type(torch.DoubleTensor).to(device)
129 | err = 0.0
130 |
131 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam
132 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam
133 |
134 | for k in range(niter_sin):
135 | u = a / u_trans
136 | v_trans = torch.matmul(phi_y, torch.matmul(phi_x.t(), u)) + lam
137 |
138 | v = b / v_trans
139 | u_trans = torch.matmul(phi_x, torch.matmul(phi_y.t(), v)) + lam
140 |
141 | err = torch.sum(torch.abs(u * u_trans - a)) + torch.sum(
142 | torch.abs(v * v_trans - b)
143 | )
144 |
145 | actual_nits += 1
146 | if err < tau:
147 | break
148 | cost = reg * (torch.sum(a * torch.log(u)) + torch.sum(b * torch.log(v)) - 1)
149 |
150 | return cost
151 |
--------------------------------------------------------------------------------
/EXP_GAN/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | # input: batch_size * nc * image_size * image_size
6 | # output: batch_size * num_random_samples * 1 * 1
7 | class Embedding(nn.Module):
8 | def __init__(
9 | self,
10 | isize,
11 | nc,
12 | reg,
13 | device,
14 | q,
15 | C,
16 | U_init,
17 | k=100,
18 | num_random_samples=100,
19 | R=1,
20 | ndf=64,
21 | seed=49,
22 | random=False,
23 | ):
24 | super(Embedding, self).__init__()
25 | assert isize % 16 == 0, "isize has to be a multiple of 16"
26 |
27 | # input is nc x isize x isize
28 | main = nn.Sequential()
29 | main.add_module(
30 | "initial_conv_{0}-{1}".format(nc, ndf),
31 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
32 | )
33 | main.add_module("initial_relu_{0}".format(ndf), nn.LeakyReLU(0.2, inplace=True))
34 | csize, cndf = isize / 2, ndf
35 |
36 | while csize > 4:
37 | in_feat = cndf
38 | out_feat = cndf * 2
39 | main.add_module(
40 | "pyramid_{0}-{1}_conv".format(in_feat, out_feat),
41 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False),
42 | )
43 | main.add_module(
44 | "pyramid_{0}_batchnorm".format(out_feat), nn.BatchNorm2d(out_feat)
45 | )
46 | main.add_module(
47 | "pyramid_{0}_relu".format(out_feat), nn.LeakyReLU(0.2, inplace=True)
48 | )
49 | cndf = cndf * 2
50 | csize = csize / 2
51 |
52 | main.add_module(
53 | "final_{0}-{1}_conv".format(cndf, 1),
54 | nn.Conv2d(cndf, k, 4, 1, 0, bias=False),
55 | )
56 |
57 | self.main = main
58 |
59 | if random == False:
60 | U = torch.nn.Parameter(U_init)
61 |
62 | else:
63 | U = U_init
64 |
65 | self.U = U.type(torch.DoubleTensor)
66 | self.q = q.type(torch.DoubleTensor)
67 | self.C = C.type(torch.DoubleTensor)
68 | self.reg = reg
69 | self.num_random_samples = num_random_samples
70 |
71 | # X and Y are 2D tensors
72 | def Square_Euclidean_Distance(self, X, Y):
73 | X_col = X.unsqueeze(1).type(torch.DoubleTensor)
74 | Y_lin = Y.unsqueeze(0).type(torch.DoubleTensor)
75 | C = torch.sum((X_col - Y_lin) ** 2, 2)
76 | return C
77 |
78 | # input: batch_size * k * 1 * 1
79 | # output: batch_size * num_random_samples * 1 * 1
80 | def Feature_Map_Gaussian(self, X):
81 | X = X.squeeze()
82 | batch_size, dim = X.size()
83 |
84 | SED = self.Square_Euclidean_Distance(X, self.U)
85 | W = -(2 * SED) / self.reg
86 | Z = self.U ** 2
87 | A = torch.sum(Z, 1)
88 | a = self.reg * self.q
89 | V = A / a
90 |
91 | res_trans = V + W
92 | res_trans = self.C * torch.exp(res_trans)
93 |
94 | res = (
95 | 1 / torch.sqrt(torch.DoubleTensor([self.num_random_samples]))
96 | ) * res_trans
97 | res = res.view(batch_size, self.num_random_samples, 1, 1)
98 |
99 | return res
100 |
101 | def forward(self, input):
102 | output = self.main(input)
103 | output = self.Feature_Map_Gaussian(output)
104 |
105 | return output
106 |
107 |
108 | # input: batch_size * k * 1 * 1
109 | # output: batch_size * nc * image_size * image_size
110 | class Generator(nn.Module):
111 | def __init__(self, isize, nc, k=100, ngf=64):
112 | super(Generator, self).__init__()
113 | assert isize % 16 == 0, "isize has to be a multiple of 16"
114 |
115 | cngf, tisize = ngf // 2, 4
116 | while tisize != isize:
117 | cngf = cngf * 2
118 | tisize = tisize * 2
119 |
120 | main = nn.Sequential()
121 | main.add_module(
122 | "initial_{0}-{1}_convt".format(k, cngf),
123 | nn.ConvTranspose2d(k, cngf, 4, 1, 0, bias=False),
124 | )
125 | main.add_module("initial_{0}_batchnorm".format(cngf), nn.BatchNorm2d(cngf))
126 | main.add_module("initial_{0}_relu".format(cngf), nn.ReLU(True))
127 |
128 | csize = 4
129 | while csize < isize // 2:
130 | main.add_module(
131 | "pyramid_{0}-{1}_convt".format(cngf, cngf // 2),
132 | nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False),
133 | )
134 | main.add_module(
135 | "pyramid_{0}_batchnorm".format(cngf // 2), nn.BatchNorm2d(cngf // 2)
136 | )
137 | main.add_module("pyramid_{0}_relu".format(cngf // 2), nn.ReLU(True))
138 | cngf = cngf // 2
139 | csize = csize * 2
140 |
141 | main.add_module(
142 | "final_{0}-{1}_convt".format(cngf, nc),
143 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False),
144 | )
145 | main.add_module("final_{0}_tanh".format(nc), nn.Tanh())
146 |
147 | self.main = main
148 |
149 | def forward(self, input):
150 | output = self.main(input)
151 | return output
152 |
153 |
154 | # input: batch_size * nz * 1 * 1
155 | # output: batch_size * nc * image_size * image_size
156 | class NetG(nn.Module):
157 | def __init__(self, decoder):
158 | super(NetG, self).__init__()
159 | self.decoder = decoder
160 |
161 | def forward(self, input):
162 | output = self.decoder(input)
163 | return output
164 |
165 |
166 | # input: batch_size * nc * image_size * image_size
167 | # f_emb: batch_size * k * 1 * 1
168 | class NetE(nn.Module):
169 | def __init__(self, embedding):
170 | super(NetE, self).__init__()
171 | self.embedding = embedding
172 |
173 | def forward(self, input):
174 | f_emb = self.embedding(input)
175 | f_emb = f_emb.view(input.size(0), -1)
176 |
177 | return f_emb
178 |
--------------------------------------------------------------------------------
/EXP_GAN/train_models_celebA.py:
--------------------------------------------------------------------------------
1 | import random
2 | from scipy.special import lambertw
3 | import numpy as np
4 | import torch
5 | import timeit
6 | import os
7 |
8 | import torch.backends.cudnn as cudnn
9 | import torch.utils.data
10 | import torchvision.utils as vutils
11 |
12 |
13 | import models
14 | import data_loading
15 | import torch_lin_sinkhorn
16 |
17 |
18 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49):
19 | q = (1 / 2) + (R ** 2) / reg
20 | y = R ** 2 / (reg * nz)
21 | q = np.real((1 / 2) * np.exp(lambertw(y)))
22 |
23 | C = (2 * q) ** (nz / 4)
24 |
25 | np.random.seed(seed)
26 | var = (q * reg) / 4
27 | U = np.random.multivariate_normal(
28 | np.zeros(nz), var * np.eye(nz), num_random_samples
29 | )
30 | U = torch.from_numpy(U)
31 |
32 | U_init = U.to(device)
33 | C_init = torch.DoubleTensor([C]).to(device)
34 | q_init = torch.DoubleTensor([q]).to(device)
35 |
36 | return q_init, C_init, U_init
37 |
38 |
39 | def training_func(
40 | num_random_samples,
41 | reg,
42 | batch_size,
43 | niter_sin,
44 | image_size,
45 | nc,
46 | nz,
47 | dataset_name,
48 | device,
49 | manual_seed,
50 | lr,
51 | max_iter,
52 | data_root,
53 | R,
54 | random_,
55 | ):
56 | name_dir = "sampled_images_celebA" + "_" + str(num_random_samples) + "_" + str(reg)
57 | if os.path.exists(name_dir) == 0:
58 | os.mkdir(name_dir)
59 |
60 | epsilon = reg
61 | hidden_dim = nz
62 |
63 | # Create an output file
64 | file_to_print = open(
65 | "results_training_celebA"
66 | + "_"
67 | + str(num_random_samples)
68 | + "_"
69 | + str(reg)
70 | + ".csv",
71 | "w",
72 | )
73 | file_to_print.write(str(device) + "\n")
74 | file_to_print.flush()
75 |
76 | # Fix the seed
77 | np.random.seed(seed=manual_seed)
78 | random.seed(manual_seed)
79 | torch.manual_seed(manual_seed)
80 | torch.cuda.manual_seed(manual_seed)
81 | cudnn.benchmark = True
82 |
83 | # Initialisation of weights
84 | def weights_init(m):
85 | classname = m.__class__.__name__
86 | if classname.find("Conv") != -1:
87 | m.weight.data.normal_(0.0, 0.02)
88 | elif classname.find("BatchNorm") != -1:
89 | m.weight.data.normal_(1.0, 0.02)
90 | m.bias.data.fill_(0)
91 | elif classname.find("Linear") != -1:
92 | m.weight.data.normal_(0.0, 0.1)
93 | m.bias.data.fill_(0)
94 |
95 | trn_dataset = data_loading.get_data(
96 | image_size, dataset_name, data_root, train_flag=True
97 | )
98 | trn_loader = torch.utils.data.DataLoader(
99 | trn_dataset, batch_size=batch_size, shuffle=True, num_workers=1
100 | )
101 |
102 | # construct Generator and Embedding
103 | q, C, U_init = compute_constants(
104 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed
105 | )
106 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64)
107 | D_embedding = models.Embedding(
108 | image_size,
109 | nc,
110 | reg,
111 | device,
112 | q,
113 | C,
114 | U_init,
115 | k=hidden_dim,
116 | num_random_samples=num_random_samples,
117 | R=R,
118 | seed=manual_seed,
119 | ndf=64,
120 | random=random_,
121 | )
122 |
123 | netG = models.NetG(G_generator)
124 | netE = models.NetE(D_embedding)
125 |
126 | netG.apply(weights_init)
127 | netE.apply(weights_init)
128 |
129 | netG.to(device)
130 | netE.to(device)
131 |
132 | lin_Sinkhorn_AD = torch_lin_sinkhorn.Lin_Sinkhorn_AD.apply
133 | fixed_noise = torch.DoubleTensor(64, nz, 1, 1).normal_(0, 1).to(device)
134 | one = torch.tensor(1, dtype=torch.float).double()
135 | mone = one * -1
136 |
137 | # setup optimizer
138 | optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr)
139 | optimizerE = torch.optim.RMSprop(netE.parameters(), lr=lr)
140 |
141 | time = timeit.default_timer()
142 | gen_iterations = 0
143 |
144 | for t in range(max_iter):
145 | data_iter = iter(trn_loader)
146 | i = 0
147 | while i < len(trn_loader):
148 | # ---------------------------
149 | # Optimize over NetE
150 | # ---------------------------
151 | for p in netE.parameters():
152 | p.requires_grad = True
153 |
154 | if gen_iterations < 25 or gen_iterations % 500 == 0:
155 | Diters = 10 # 10
156 | Giters = 1
157 | else:
158 | Diters = 1 # 5
159 | Giters = 1
160 |
161 | for j in range(Diters):
162 | if i == len(trn_loader):
163 | break
164 |
165 | for p in netE.parameters():
166 | p.data.clamp_(-0.01, 0.01) # clamp parameters of NetE to a cube
167 |
168 | data = data_iter.next()
169 | i += 1
170 | netE.zero_grad()
171 |
172 | x_cpu, _ = data
173 | x = x_cpu.to(device)
174 | x_emb = netE(x)
175 |
176 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device)
177 | with torch.no_grad():
178 | y = netG(noise)
179 |
180 | y_emb = netE(y)
181 |
182 | # Compute the loss
183 | sink_E = (
184 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin)
185 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin)
186 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin)
187 | )
188 |
189 | sink_E.backward(mone)
190 | optimizerE.step()
191 |
192 | # ---------------------------
193 | # Optimize over NetG
194 | # ---------------------------
195 | for p in netE.parameters():
196 | p.requires_grad = False
197 |
198 | for j in range(Giters):
199 | if i == len(trn_loader):
200 | break
201 |
202 | data = data_iter.next()
203 | i += 1
204 | netG.zero_grad()
205 |
206 | x_cpu, _ = data
207 | x = x_cpu.to(device)
208 | x_emb = netE(x)
209 |
210 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device)
211 | y = netG(noise)
212 | y_emb = netE(y)
213 |
214 | # Compute the loss
215 | sink_G = (
216 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin)
217 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin)
218 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin)
219 | )
220 |
221 | sink_G.backward(one)
222 | optimizerG.step()
223 |
224 | gen_iterations += 1
225 |
226 | run_time = (timeit.default_timer() - time) / 60.0
227 |
228 | s = "[%3d / %3d] [%3d / %3d] [%5d] (%.2f m) loss_E: %.6f loss_G: %.6f" % (
229 | t,
230 | max_iter,
231 | i * batch_size,
232 | batch_size * len(trn_loader),
233 | gen_iterations,
234 | run_time,
235 | sink_E,
236 | sink_G,
237 | )
238 |
239 | s = s + "\n"
240 | file_to_print.write(s)
241 | file_to_print.flush()
242 |
243 | if gen_iterations % 100 == 0:
244 | with torch.no_grad():
245 | fixed_noise = fixed_noise.float()
246 | y_fixed = netG(fixed_noise)
247 | y_fixed = y_fixed.mul(0.5).add(0.5)
248 | vutils.save_image(
249 | y_fixed,
250 | "{0}/fake_samples_{1}.png".format(name_dir, gen_iterations),
251 | )
252 |
253 | if t % 10 == 0:
254 | torch.save(
255 | netG.state_dict(),
256 | "netG_celebA" + "_" + str(num_random_samples) + "_" + str(reg) + ".pth",
257 | )
258 | torch.save(
259 | netE.state_dict(),
260 | "netE_celebA" + "_" + str(num_random_samples) + "_" + str(reg) + ".pth",
261 | )
262 |
263 |
264 | # Dataset
265 | image_size = 64
266 | nc = 3
267 | nz = 128
268 | dataset_name = "celeba"
269 | data_root = "./data"
270 |
271 | # Parameters
272 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
273 | random_ = False
274 | manual_seed = 49
275 | lr = 5 * 1e-5
276 | R = 1
277 | max_iter = 10000
278 | niter_sin = 1000
279 | batch_size = 8000
280 |
281 |
282 | num_random_samples_list = [10, 100, 300, 600]
283 | reg_list = [1e-1, 1, 10]
284 |
285 |
286 | if __name__ == "__main__":
287 | for num_random_samples in num_random_samples_list:
288 | for reg in reg_list:
289 | training_func(
290 | num_random_samples,
291 | reg,
292 | batch_size,
293 | niter_sin,
294 | image_size,
295 | nc,
296 | nz,
297 | dataset_name,
298 | device,
299 | manual_seed,
300 | lr,
301 | max_iter,
302 | data_root,
303 | R,
304 | random_,
305 | )
306 |
--------------------------------------------------------------------------------
/EXP_GAN/train_models_cifar.py:
--------------------------------------------------------------------------------
1 | import random
2 | from scipy.special import lambertw
3 | import numpy as np
4 | import torch
5 | import timeit
6 |
7 | import torch.backends.cudnn as cudnn
8 | import torch.utils.data
9 | import torchvision.utils as vutils
10 |
11 |
12 | import models
13 | import data_loading
14 | import torch_lin_sinkhorn
15 |
16 | import os
17 |
18 |
19 | def compute_constants(reg, device, nz, R=1, num_random_samples=100, seed=49):
20 | q = (1 / 2) + (R ** 2) / reg
21 | y = R ** 2 / (reg * nz)
22 | q = np.real((1 / 2) * np.exp(lambertw(y)))
23 |
24 | C = (2 * q) ** (nz / 4)
25 |
26 | np.random.seed(seed)
27 | var = (q * reg) / 4
28 | U = np.random.multivariate_normal(
29 | np.zeros(nz), var * np.eye(nz), num_random_samples
30 | )
31 | U = torch.from_numpy(U)
32 |
33 | U_init = U.to(device)
34 | C_init = torch.DoubleTensor([C]).to(device)
35 | q_init = torch.DoubleTensor([q]).to(device)
36 |
37 | return q_init, C_init, U_init
38 |
39 |
40 | def training_func(
41 | num_random_samples,
42 | reg,
43 | batch_size,
44 | niter_sin,
45 | image_size,
46 | nc,
47 | nz,
48 | dataset_name,
49 | device,
50 | manual_seed,
51 | lr,
52 | max_iter,
53 | data_root,
54 | R,
55 | random_,
56 | ):
57 |
58 | name_dir = (
59 | "sampled_images_cifar_max" + "_" + str(num_random_samples) + "_" + str(reg)
60 | )
61 | if os.path.exists(name_dir) == 0:
62 | os.mkdir(name_dir)
63 |
64 | epsilon = reg
65 | hidden_dim = nz
66 |
67 | # Create an output file
68 | file_to_print = open(
69 | "results_training_cifar_max"
70 | + "_"
71 | + str(num_random_samples)
72 | + "_"
73 | + str(reg)
74 | + ".csv",
75 | "w",
76 | )
77 | file_to_print.write(str(device) + "\n")
78 | file_to_print.flush()
79 |
80 | # Fix the seed
81 | np.random.seed(seed=manual_seed)
82 | random.seed(manual_seed)
83 | torch.manual_seed(manual_seed)
84 | torch.cuda.manual_seed(manual_seed)
85 | cudnn.benchmark = True
86 |
87 | # Initialisation of weights
88 | def weights_init(m):
89 | classname = m.__class__.__name__
90 | if classname.find("Conv") != -1:
91 | m.weight.data.normal_(0.0, 0.02)
92 | elif classname.find("BatchNorm") != -1:
93 | m.weight.data.normal_(1.0, 0.02)
94 | m.bias.data.fill_(0)
95 | elif classname.find("Linear") != -1:
96 | m.weight.data.normal_(0.0, 0.1)
97 | m.bias.data.fill_(0)
98 |
99 | trn_dataset = data_loading.get_data(
100 | image_size, dataset_name, data_root, train_flag=True
101 | )
102 | trn_loader = torch.utils.data.DataLoader(
103 | trn_dataset, batch_size=batch_size, shuffle=True, num_workers=1
104 | )
105 |
106 | # construct Generator and Embedding:
107 | q, C, U_init = compute_constants(
108 | reg, device, nz, R=R, num_random_samples=num_random_samples, seed=manual_seed
109 | )
110 |
111 | G_generator = models.Generator(image_size, nc, k=nz, ngf=64)
112 | D_embedding = models.Embedding(
113 | image_size,
114 | nc,
115 | reg,
116 | device,
117 | q,
118 | C,
119 | U_init,
120 | k=hidden_dim,
121 | num_random_samples=num_random_samples,
122 | R=R,
123 | seed=manual_seed,
124 | ndf=64,
125 | random=random_,
126 | )
127 |
128 | netG = models.NetG(G_generator)
129 | netE = models.NetE(D_embedding)
130 |
131 | netG.apply(weights_init)
132 | netE.apply(weights_init)
133 |
134 | netG.to(device)
135 | netE.to(device)
136 |
137 | lin_Sinkhorn_AD = torch_lin_sinkhorn.Lin_Sinkhorn_AD.apply
138 | fixed_noise = torch.DoubleTensor(64, nz, 1, 1).normal_(0, 1).to(device)
139 | one = torch.tensor(1, dtype=torch.float).double()
140 | mone = one * -1
141 |
142 | # setup optimizer
143 | optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr)
144 | optimizerE = torch.optim.RMSprop(netE.parameters(), lr=lr)
145 |
146 | time = timeit.default_timer()
147 | gen_iterations = 0
148 |
149 | for t in range(max_iter):
150 | data_iter = iter(trn_loader)
151 | i = 0
152 | while i < len(trn_loader):
153 | # ---------------------------
154 | # Optimize over NetE
155 | # ---------------------------
156 | for p in netE.parameters():
157 | p.requires_grad = True
158 |
159 | if gen_iterations < 25 or gen_iterations % 500 == 0:
160 | Diters = 10 # 10
161 | Giters = 1
162 | else:
163 | Diters = 1 # 5
164 | Giters = 1
165 |
166 | for j in range(Diters):
167 | if i == len(trn_loader):
168 | break
169 |
170 | for p in netE.parameters():
171 | p.data.clamp_(-0.01, 0.01) # clamp parameters of NetE to a cube
172 |
173 | data = data_iter.next()
174 | i += 1
175 | netE.zero_grad()
176 |
177 | x_cpu, _ = data
178 | x = x_cpu.to(device)
179 | x_emb = netE(x)
180 |
181 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device)
182 | with torch.no_grad():
183 | y = netG(noise)
184 |
185 | y_emb = netE(y)
186 |
187 | ### Compute the loss ###
188 | sink_E = (
189 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin)
190 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin)
191 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin)
192 | )
193 |
194 | sink_E.backward(mone)
195 | optimizerE.step()
196 |
197 | # ---------------------------
198 | # Optimize over NetG
199 | # ---------------------------
200 | for p in netE.parameters():
201 | p.requires_grad = False
202 |
203 | for j in range(Giters):
204 | if i == len(trn_loader):
205 | break
206 |
207 | data = data_iter.next()
208 | i += 1
209 | netG.zero_grad()
210 |
211 | x_cpu, _ = data
212 | x = x_cpu.to(device)
213 | x_emb = netE(x)
214 |
215 | noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device)
216 | y = netG(noise)
217 | y_emb = netE(y)
218 |
219 | # Compute the loss
220 | sink_G = (
221 | 2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin)
222 | - lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin)
223 | - lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin)
224 | )
225 |
226 | sink_G.backward(one)
227 | optimizerG.step()
228 |
229 | gen_iterations += 1
230 |
231 | run_time = (timeit.default_timer() - time) / 60.0
232 |
233 | s = "[%3d / %3d] [%3d / %3d] [%5d] (%.2f m) loss_E: %.6f loss_G: %.6f" % (
234 | t,
235 | max_iter,
236 | i * batch_size,
237 | batch_size * len(trn_loader),
238 | gen_iterations,
239 | run_time,
240 | sink_E,
241 | sink_G,
242 | )
243 |
244 | s = s + "\n"
245 | file_to_print.write(s)
246 | file_to_print.flush()
247 |
248 | if gen_iterations % 100 == 0:
249 | with torch.no_grad():
250 | fixed_noise = fixed_noise.float()
251 | y_fixed = netG(fixed_noise)
252 | y_fixed = y_fixed.mul(0.5).add(0.5)
253 | vutils.save_image(
254 | y_fixed,
255 | "{0}/fake_samples_{1}.png".format(name_dir, gen_iterations),
256 | )
257 |
258 | if t % 10 == 0:
259 | torch.save(
260 | netG.state_dict(),
261 | "netG_cifar_max"
262 | + "_"
263 | + str(num_random_samples)
264 | + "_"
265 | + str(reg)
266 | + ".pth",
267 | )
268 | torch.save(
269 | netE.state_dict(),
270 | "netE_cifar_max"
271 | + "_"
272 | + str(num_random_samples)
273 | + "_"
274 | + str(reg)
275 | + ".pth",
276 | )
277 |
278 |
279 | # Dataset
280 | image_size = 64
281 | nc = 3
282 | nz = 128
283 | dataset_name = "cifar10"
284 | data_root = "./data"
285 |
286 | ### Fixed parameters ###
287 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
288 | manual_seed = 49
289 | random_ = False
290 | lr = 5 * 1e-5
291 | R = 1
292 | max_iter = 10000
293 | niter_sin = 1000
294 | batch_size = 8000
295 |
296 | num_random_samples_list = [10, 100, 300, 600]
297 | reg_list = [1e-1, 1, 10]
298 |
299 | if __name__ == "__main__":
300 | for num_random_samples in num_random_samples_list:
301 | for reg in reg_list:
302 | training_func(
303 | num_random_samples,
304 | reg,
305 | batch_size,
306 | niter_sin,
307 | image_size,
308 | nc,
309 | nz,
310 | dataset_name,
311 | device,
312 | manual_seed,
313 | lr,
314 | max_iter,
315 | data_root,
316 | R,
317 | random_,
318 | )
319 |
--------------------------------------------------------------------------------
/FastSinkhorn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import time
4 | from scipy import special
5 |
6 | # Here the Regularized version goes from -\infinty to the true OT
7 | def compute_ROT(u, v, a, b, reg):
8 | res = reg * (np.dot(a, np.log(u)) + np.dot(b, np.log(v)))
9 | return res
10 |
11 |
12 | ################ Classical Sinkhorn Algorithm ####################
13 |
14 | def Sinkhorn(C, reg, a, b, delta=1e-9, lam=1e-6):
15 |
16 | K = np.exp(-C / reg)
17 | u = np.ones(np.shape(a)[0])
18 | v = np.ones(np.shape(b)[0])
19 |
20 | u_trans = np.dot(K, v) + lam # add regularization to avoid divide 0
21 | v_trans = np.dot(K.T, u) + lam # add regularization to avoid divide 0
22 |
23 | err_1 = np.sum(np.abs(u * u_trans - a))
24 | err_2 = np.sum(np.abs(v * v_trans - b))
25 |
26 | while True:
27 | if err_1 + err_2 > delta:
28 | u = a / u_trans
29 | v_trans = np.dot(K.T, u) + lam
30 |
31 | v = b / v_trans
32 | u_trans = np.dot(K, v) + lam
33 |
34 | err_1 = np.sum(np.abs(u * u_trans - a))
35 | err_2 = np.sum(np.abs(v * v_trans - b))
36 | else:
37 | return u, v
38 |
39 |
40 | # Classical Sinkhorn algorithm: Square Euclidean Distance
41 | def Sinkhorn_RBF(X, Y, reg, a, b, delta=1e-9, num_iter=50, lam=1e-100):
42 | start = time.time()
43 |
44 | acc = []
45 | times = []
46 |
47 | C = Square_Euclidean_Distance(X, Y)
48 | K = np.exp(-C / reg)
49 | u = np.ones(np.shape(a)[0])
50 | v = np.ones(np.shape(b)[0])
51 |
52 | u_trans = np.dot(K, v) + lam
53 | v_trans = np.dot(K.T, u) + lam
54 |
55 | for k in range(num_iter):
56 |
57 | u = a / u_trans
58 | v_trans = np.dot(K.T, u) + lam
59 |
60 | v = b / v_trans
61 | u_trans = np.dot(K, v) + lam
62 |
63 | ROT_trans = compute_ROT(u, v, a, b, reg)
64 | if np.isnan(ROT_trans) == True:
65 | return "Error"
66 | else:
67 | acc.append(compute_ROT(u, v, a, b, reg))
68 | end = time.time()
69 | times.append(end - start)
70 |
71 | return acc[-1], np.array(acc), np.array(times)
72 |
73 |
74 | ################ Positive Random Features ####################
75 |
76 | # Positive Random Features Sinkhorn: K = AB
77 | def Lin_Sinkhorn(A, B, a, b, delta=1e-9, max_iter=1e5, lam=1e-100):
78 | u = np.ones(np.shape(a)[0])
79 | v = np.ones(np.shape(b)[0])
80 | u_trans = np.dot(A, np.dot(B, v)) + lam
81 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam
82 |
83 | err_1 = np.sum(np.abs(u * u_trans - a))
84 | err_2 = np.sum(np.abs(v * v_trans - b))
85 | k = 0
86 | while True and k < max_iter:
87 | if err_1 + err_2 > delta:
88 | u = a / u_trans
89 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam
90 |
91 | v = b / v_trans
92 | u_trans = np.dot(A, np.dot(B, v)) + lam
93 |
94 | err_1 = np.sum(np.abs(u * u_trans - a))
95 | err_2 = np.sum(np.abs(v * v_trans - b))
96 | k = k + 1
97 | else:
98 | return u, v
99 | return u, v
100 |
101 |
102 | # Positive Random Features Sinkhorn: Square Euclidean Distance
103 | def Lin_Sinkhorn_RBF(
104 | X, Y, reg, a, b, num_samples, seed=49, delta=1e-9, num_iter=50, lam=1e-100
105 | ):
106 | start = time.time()
107 |
108 | acc = []
109 | times = []
110 |
111 | R = theoritical_R(X, Y)
112 | A = Feature_Map_Gaussian(X, reg, R=R, num_samples=num_samples, seed=seed)
113 | B = Feature_Map_Gaussian(Y, reg, R=R, num_samples=num_samples, seed=seed).T
114 |
115 | u = np.ones(np.shape(a)[0])
116 | v = np.ones(np.shape(b)[0])
117 | u_trans = np.dot(A, np.dot(B, v)) + lam
118 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam
119 |
120 | for k in range(num_iter):
121 | u = a / u_trans
122 | v_trans = np.dot(B.T, np.dot(A.T, u)) + lam
123 |
124 | v = b / v_trans
125 | u_trans = np.dot(A, np.dot(B, v)) + lam
126 |
127 | ROT_trans = compute_ROT(u, v, a, b, reg)
128 | if np.isnan(ROT_trans) == True:
129 | return "Error"
130 | else:
131 | acc.append(ROT_trans)
132 | end = time.time()
133 | times.append(end - start)
134 |
135 | return acc[-1], np.array(acc), np.array(times)
136 |
137 |
138 | # Random Feature Map: Square Euclidean Distance
139 | def Feature_Map_Gaussian(X, reg, R=1, num_samples=100, seed=49):
140 | n, d = np.shape(X)
141 |
142 | # q = (1/2) + (R**2) / reg
143 | y = R ** 2 / (reg * d)
144 | q = np.real((1 / 2) * np.exp(special.lambertw(y)))
145 | C = (2 * q) ** (d / 4)
146 |
147 | var = (q * reg) / 4
148 |
149 | np.random.seed(seed)
150 | U = np.random.multivariate_normal(np.zeros(d), var * np.eye(d), num_samples)
151 |
152 | SED = Square_Euclidean_Distance(X, U)
153 | W = -(2 * SED) / reg
154 | V = np.sum(U ** 2, axis=1) / (reg * q)
155 |
156 | res_trans = V + W
157 | res_trans = C * np.exp(res_trans)
158 |
159 | res = (1 / np.sqrt(num_samples)) * res_trans
160 |
161 | return res
162 |
163 |
164 | def theoritical_R(X, Y):
165 | norm_X = np.linalg.norm(X, axis=1)
166 | norm_Y = np.linalg.norm(Y, axis=1)
167 | norm_max = np.maximum(np.max(norm_X), np.max(norm_Y))
168 |
169 | return norm_max
170 |
171 |
172 | # Random Feature Map: Arccos Kernel
173 | def Feature_Map_Arccos(X, s=1, sig=1.5, num_samples=100, kappa=1e-6, seed=49):
174 | n, d = np.shape(X)
175 | C = (sig ** (d / 2)) * np.sqrt(2)
176 |
177 | np.random.seed(seed)
178 | U = np.random.multivariate_normal(np.zeros(d), (sig ** 2) * np.eye(d), num_samples)
179 |
180 | IP = Inner_Product(X, U)
181 | res_trans = C * (np.maximum(IP, 0) ** s)
182 |
183 | V = ((sig ** 2) - 1) / (sig ** 2)
184 | V = -(1 / 4) * V * np.sum(U ** 2, axis=1)
185 | V = np.exp(V)
186 |
187 | res = np.zeros((n, num_samples + 1))
188 | res[:, :num_samples] = (1 / np.sqrt(num_samples)) * res_trans * V
189 | res[:, -1] = kappa
190 |
191 | return res
192 |
193 |
194 | ######################## Nystrom Method #######################
195 |
196 | # Nystrom Sinkhorn: K =VA^{-1}V
197 | def Nys_Sinkhorn(A, V, a, b, delta=1e-9, max_iter=1e3, lam=1e-100):
198 | u = np.ones(np.shape(a)[0])
199 | v = np.ones(np.shape(b)[0])
200 |
201 | u_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, v))) + lam
202 | v_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, u))) + lam
203 |
204 | err_1 = np.sum(np.abs(u * u_trans - a))
205 | err_2 = np.sum(np.abs(v * v_trans - b))
206 | k = 0
207 | while True and k < max_iter:
208 | if err_1 + err_2 > delta:
209 | u = a / u_trans
210 | v_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, u))) + lam
211 |
212 | v = b / v_trans
213 | u_trans = np.dot(V, np.linalg.solve(A, np.dot(V.T, v))) + lam
214 |
215 | err_1 = np.sum(np.abs(u * u_trans - a))
216 | err_2 = np.sum(np.abs(v * v_trans - b))
217 | k = k + 1
218 | else:
219 | return u, v
220 | return u, v
221 |
222 |
223 | # Nystrom Sinkhorn: Square Euclidean Distance
224 | def Nys_Sinkhorn_RBF(
225 | X, Y, reg, a, b, rank, seed=49, delta=1e-9, num_iter=50, lam=1e-100
226 | ):
227 | start = time.time()
228 |
229 | acc = []
230 | times = []
231 |
232 | n = np.shape(X)[0]
233 | m = np.shape(Y)[0]
234 |
235 | a_nys = np.zeros(n + m)
236 | a_nys[:n] = a
237 |
238 | b_nys = np.zeros(n + m)
239 | b_nys[n:] = b
240 |
241 | A, V = Nystrom_RBF(X, Y, reg, rank, seed=seed, stable=1e-10)
242 | A_inv = np.linalg.inv(A)
243 |
244 | u = np.ones(np.shape(a_nys)[0])
245 | v = np.ones(np.shape(b_nys)[0])
246 |
247 | u_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, v))) + lam
248 | v_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, u))) + lam
249 |
250 | for k in range(num_iter):
251 |
252 | u = a_nys / u_trans
253 | v_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, u))) + lam
254 |
255 | v = b_nys / v_trans
256 | u_trans = np.dot(V, np.dot(A_inv, np.dot(V.T, v))) + lam
257 |
258 | u_rot, v_rot = u[:n], v[n:]
259 |
260 | ROT_trans = compute_ROT(u_rot, v_rot, a, b, reg)
261 | if np.isnan(ROT_trans) == True:
262 | return "Error"
263 | else:
264 | acc.append(ROT_trans)
265 | end = time.time()
266 | times.append(end - start)
267 |
268 | return acc[-1], np.array(acc), np.array(times)
269 |
270 |
271 | # Uniform Nyström: Square Euclidean Distance
272 | def Nystrom_RBF(X, Y, reg, rank, seed=49, stable=1e-100):
273 | n, d = np.shape(X)
274 | m, d = np.shape(Y)
275 | n_tot = n + m
276 | Z = np.concatenate((X, Y), axis=0)
277 |
278 | rank_trans = int(np.minimum(rank, n_tot))
279 |
280 | np.random.seed(seed)
281 | ind = np.random.choice(n_tot, rank_trans, replace=False)
282 | ind = np.sort(ind)
283 |
284 | Z_1 = Z[ind, :]
285 | A = np.exp(-Square_Euclidean_Distance(Z_1, Z_1) / reg)
286 | A = A + stable * np.eye(rank_trans)
287 | V = np.exp(-Square_Euclidean_Distance(Z, Z_1) / reg)
288 |
289 | return A, V
290 |
291 |
292 | # Recursive Nyström Sampling: Square Euclidean Distance
293 | def recursive_Nystrom_RBF(X, Y, rank, reg, seed=49, stable=1e-100):
294 | Z = np.concatenate((X, Y), axis=0)
295 | n, d = np.shape(Z)
296 |
297 | ## Start of algorithm
298 | sLevel = rank
299 | oversamp = np.log(sLevel)
300 | k = int(sLevel / (4 * oversamp)) + 1
301 | nLevels = int(np.log(n / sLevel) / np.log(2)) + 1
302 |
303 | np.random(seed)
304 | perm = np.random.permutation(n)
305 |
306 | # set up sizes for recursive levels
307 | lSize = np.zeros(nLevels)
308 | lSize[0] = n
309 | for i in range(1, nLevels):
310 | lSize[i] = int(lSize[i - 1] / 2) + 1
311 |
312 | # rInd: indices of points selected at previous level of recursion
313 | # at the base level it's just a uniform sample of ~sLevel points
314 | samp = np.arange(lSize[-1]).astype(int)
315 | rInd = perm[samp]
316 | weights = np.ones((np.shape(rInd)[0], 1))
317 |
318 | # we need the diagonal of the whole kernel matrix
319 | kDiag = np.zeros(n)
320 | for i in range(n):
321 | kDiag[i] = np.exp(-Square_Euclidean_Distance(Z[i, :], Z[i, :]) / reg)
322 |
323 | # Main recursion, unrolled for efficiency
324 | for l in range(nLevels - 1, -1, -1):
325 | np.random(seed + l)
326 | # indices of current uniform sample
327 | rIndCurr = perm[: int(lSize[l])]
328 | # build sampled kernel
329 | SED = Square_Euclidean_Distance(Z[rIndCurr, :], Z[rInd, :])
330 | KS = np.exp(-SED / reg)
331 | SKS = KS[samp, :]
332 | SKSn = np.shape(SKS)[0]
333 |
334 | # optimal lambda for taking O(klogk) samples
335 | if k >= SKSn:
336 | # for the rare chance we take less than k samples in a round
337 | lam = 10e-6
338 | # don't set to exactly 0 to avoid stability issues
339 | else:
340 | ######
341 | Q = np.diag(weights.reshape(SKSn))
342 | Q = np.dot(Q, SKS)
343 | Oper = Q * weights.reshape(SKSn, 1)
344 | eigen = np.sort(np.linalg.eig(Oper)[1])[-k:]
345 | lam = (
346 | np.sum(np.diag(SKS) * (weights ** 2)) - np.sum(np.abs(np.real(eigen)))
347 | ) / k
348 |
349 | # compute and sample by lambda ridge leverage scores
350 | if l != 0:
351 | # on intermediate levels, we independently sample each column
352 | # by its leverage score. the sample size is sLevel in expectation
353 | R = np.linalg.inv(SKS + np.diag(np.dot(lam, weights ** (-2))))
354 | # max(0,.) helps avoid numerical issues, unnecessary in theory
355 | z = np.sum(np.dot(KS, R) * KS, 1)
356 | z = kDiag[rIndCurr] - z
357 | z = np.maximum(0, z)
358 | z = oversamp * (1 / lam) * z
359 | levs = np.minimum(1, z)
360 |
361 | M = np.random.rand(1, int(lSize[l])) - levs
362 | ind_matrix = M < 0
363 | ind_matrix = ind_matrix.reshape(int(lSize[l]))
364 | samp = np.where(ind_matrix == 1)[0]
365 | # with very low probability, we could accidentally sample no
366 | # columns. In this case, just take a fixed size uniform sample.
367 | samp_list = np.ndarray.tolist(samp)
368 | if len(samp_list) == 0:
369 | levs[:] = sLevel / lSize[l]
370 | samp = np.random.choice(int(lSize[l]), sLevel, replace=False)
371 |
372 | weights = np.sqrt(1.0 / (levs[samp]))
373 |
374 | else:
375 | # on the top level, we sample exactly s landmark points without replacement
376 | R = np.linalg.inv(SKS + np.diag(np.dot(lam, weights ** (-2))))
377 | z = np.sum(np.dot(KS, R) * KS, 1)
378 | z = kDiag[rIndCurr] - z
379 | z = np.maximum(0, z)
380 | levs = np.minimum(1, (1 / lam) * z)
381 | ########
382 | total_sum = np.sum(levs)
383 | levs_norm = levs / total_sum
384 | samp = np.random.choice(
385 | np.shape(levs)[0], rank, replace=False, p=levs_norm.reshape(-1)
386 | )
387 |
388 | rInd = perm[samp]
389 |
390 | # build final Nystrom approximation
391 | # pinv or inversion with slight regularization helps stability
392 | V = np.exp(-Square_Euclidean_Distance(Z, Z[rInd, :]) / reg)
393 | A = V[rInd, :]
394 | A = A + stable * np.eye(rank)
395 | # A_inv = np.linalg.inv(A)
396 |
397 | return A, V
398 |
399 |
400 | # Adaptative Rank Nystrom: Square Euclidean Distance
401 | def Adaptive_Nystrom_RBF(X, Y, reg, tau=1e-1, seed=49):
402 | err = 1e30
403 | r = 1
404 | while err > tau:
405 | r = 2 * r
406 | A, V = Nystrom_RBF(X, Y, reg, r)
407 |
408 | diag = np.zeros(np.shape(A)[0])
409 | for i in range(np.shape(A)[0]):
410 | M = np.dot(V, A)
411 | diag[i] = np.dot(M[i, :], V.T[:, i])
412 |
413 | err = 1 - np.min(diag)
414 |
415 | return A, V
416 |
417 |
418 | # Square Euclidean Distance
419 | def Square_Euclidean_Distance(X, Y):
420 | X_col = X[:, np.newaxis]
421 | Y_lin = Y[np.newaxis, :]
422 | C = np.sum((X_col - Y_lin) ** 2, 2)
423 | return C
424 |
425 |
426 | # Arccos Cost
427 | def Arccos_Cost(X, Y, s=1, kappa=1e-6):
428 | if len(np.shape(X)) == 1:
429 | X = X.reshape(1, -1)
430 |
431 | if len(np.shape(Y)) == 1:
432 | Y = Y.reshape(1, -1)
433 |
434 | n, d = np.shape(X)
435 | m, d = np.shape(Y)
436 | M = np.zeros((n, m))
437 | for i in range(n):
438 | for j in range(m):
439 | norm = np.linalg.norm(X[i, :]) * np.linalg.norm(Y[j, :])
440 | theta = np.arccos(Inner_Product(X[i, :], Y[j, :]) / norm)
441 | if s == 0:
442 | M[i, j] = (1 / np.pi) * (np.pi - theta)
443 | if s == 1:
444 | J = np.sin(theta) + (np.pi - theta) * np.cos(theta)
445 | M[i, j] = (1 / np.pi) * norm * J
446 | if s == 2:
447 | J = 3 * np.sin(theta) * np.cos(theta) + (np.pi - theta) * (
448 | 1 + 2 * np.cos(theta) ** 2
449 | )
450 | M[i, j] = (1 / np.pi) * (norm ** 2) * J
451 |
452 | M = M + kappa
453 | M = -np.log(M)
454 | return M
455 |
456 |
457 | # Inner Product Cost
458 | def Inner_Product(X, Y):
459 | if len(np.shape(X)) == 1:
460 | X = X.reshape(1, -1)
461 |
462 | if len(np.shape(Y)) == 1:
463 | Y = Y.reshape(1, -1)
464 |
465 | n, d = np.shape(X)
466 | m, d = np.shape(Y)
467 | M = np.zeros((n, m))
468 | for i in range(n):
469 | for j in range(m):
470 | M[i, j] = np.sum(X[i, :] * Y[j, :])
471 | return M
472 |
--------------------------------------------------------------------------------