├── .gitignore ├── README.md ├── data ├── teaser.png └── wheel.ply ├── example_basic.py ├── example_optimize.py └── sinkhorn.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .polyscope.ini 3 | imgui.ini -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast, Memory-Efficient Approximate Wasserstein Distances 2 | This repository contains PyTorch code to compute fast p-Wasserstein distances between d-dimensional point clouds using the [Sinkhorn Algorithm](https://arxiv.org/abs/1306.0895). 3 | 4 | This implementation uses **linear memory overhead** and is **stable in float32, runs on the GPU, and fully differentiable**. 5 | 6 | This shows an example of the correspondences between two shapes found by computing the Sinkhorn distance on 200k input points: 7 |

8 | 9 |

10 | 11 | ## How to use: 12 | 1. Copy the `sinkhorn.py` file in this repository to your PyTorch codebase. 13 | 2. `pip install pykeops tqdm` 14 | 3. Import `from sinkhorn import sinkhorn` and use the `sinkhorn` function! 15 | 16 | ### Running the example code 17 | Look at [example_basic.py](example_basic.py) for a basic example and [example_optimize.py](example_optimize.py) for an example of how to use Sinkhorn in your optimization 18 | 19 | **NOTE:** To run the examples, you need to first run 20 | ``` 21 | pip install pykeops tqdm numpy scipy polyscope point-cloud-utils 22 | ``` 23 | 24 | ## `sinkhorn` function documentation 25 | ``` 26 | sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2, 27 | w_x: Union[torch.Tensor, None] = None, 28 | w_y: Union[torch.Tensor, None] = None, 29 | eps: float = 1e-3, 30 | max_iters: int = 100, stop_thresh: float = 1e-5, 31 | verbose=False) 32 | ``` 33 | Computes the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds 34 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. 35 | Note that this algorithm can be backpropped through 36 | (though this may be slow if using many iterations). 37 | 38 | **Arguments:** 39 | 40 | * `x`: A `[n, d]` shaped tensor representing a d-dimensional point cloud with `n` points (one per row) 41 | * `y`: A `[m, d]` shaped tensor representing a d-dimensional point cloud with `m` points (one per row) 42 | * `p`: Which norm to use. Must be an integer greater than 0. 43 | * `w_x`: A `[n,]` shaped tensor of optional weights for the points `x` (`None` for uniform weights). Note that these must sum to the same value as w_y. Default is `None`. 44 | * `w_y`: A `[m,]` shaped tensor of optional weights for the points `y` (`None` for uniform weights). Note that these must sum to the same value as w_y. Default is `None`. 45 | * `eps`: The reciprocal of the Sinkhorn entropy regularization parameter. 46 | * `max_iters`: The maximum number of Sinkhorn iterations to perform. 47 | * `stop_thresh`: Stop if the maximum change in the parameters is below this amount 48 | * `verbose`: If set, print a progress bar 49 | 50 | **Returns:** 51 | 52 | A triple `(d, corrs_x_to_y, corr_y_to_x)` where: 53 | * `d` is the approximate p-wasserstein distance between point clouds `x` and `y` 54 | * `corrs_x_to_y` is a `[n,]`-shaped tensor where `corrs_x_to_y[i]` is the index of the approximate correspondence in point cloud `y` of point `x[i]` (i.e. `x[i]` and `y[corrs_x_to_y[i]]` are a corresponding pair) 55 | * `corrs_y_to_x` is a `[m,]`-shaped tensor where `corrs_y_to_x[i]` is the index of the approximate correspondence in point cloud `x` of `point y[j]` (i.e. `y[j]` and `x[corrs_y_to_x[j]]` are a corresponding pair) 56 | 57 | -------------------------------------------------------------------------------- /data/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fwilliams/scalable-pytorch-sinkhorn/c453e660d57e6abc1e891722793d95622efdbd7e/data/teaser.png -------------------------------------------------------------------------------- /data/wheel.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fwilliams/scalable-pytorch-sinkhorn/c453e660d57e6abc1e891722793d95622efdbd7e/data/wheel.ply -------------------------------------------------------------------------------- /example_basic.py: -------------------------------------------------------------------------------- 1 | import point_cloud_utils as pcu 2 | import polyscope as ps 3 | from sinkhorn import sinkhorn 4 | import torch 5 | from scipy.spatial.transform import Rotation 6 | import time 7 | import numpy as np 8 | 9 | 10 | if __name__ == "__main__": 11 | torch.manual_seed(1234567) # Fixed seed 12 | np.random.seed(1234567) 13 | 14 | niters = 500 15 | device = 'cuda' 16 | dtype = torch.float32 17 | eps = 1e-3 18 | stop_error = 1e-3 19 | 20 | # Load a 3D point cloud with around 200k points and normalize it to [0, 1]^3 21 | print("Loading point cloud...") 22 | pc_load_start_time = time.time() 23 | p1 = torch.from_numpy(pcu.load_mesh_v("./data/wheel.ply")).to(device=device, dtype=dtype) 24 | p1 -= p1.min(dim=0)[0] 25 | p1 /= p1.max(dim=0)[0].max() 26 | print(f"Done in {time.time() - pc_load_start_time}s") 27 | 28 | # Create a second point cloud by applying a random rotation and translation to the first one 29 | R = torch.from_numpy(Rotation.random().as_matrix()).to(p1) 30 | p2 = p1 @ R.T + 1.5 31 | 32 | print("Running Sinkhorn...") 33 | sinkhorn_start_time = time.time() 34 | loss, corrs_1_to_2, corrs_2_to_1 = \ 35 | sinkhorn(p1, p2, p=2, eps=eps, max_iters=niters, stop_thresh=stop_error, verbose=True) 36 | torch.cuda.synchronize() 37 | print(f"Done in {time.time() - sinkhorn_start_time}s") 38 | 39 | print(f"Sinkhorn loss is {loss.item()}") 40 | 41 | ps.init() 42 | edges = torch.stack([torch.arange(p1.shape[0]).to(corrs_1_to_2), 43 | corrs_1_to_2 + p1.shape[0]], dim=-1).cpu().numpy() 44 | verts = torch.cat([p1, p2], dim=0).cpu().numpy() 45 | p1 = p1.cpu().numpy() 46 | p2 = p2.cpu().numpy() 47 | ps.register_point_cloud("p1", p1) 48 | ps.register_point_cloud("p2", p2) 49 | ps.register_curve_network("corr1", verts, edges[::200]) # Only plot 100 for easier viz 50 | ps.show() 51 | 52 | -------------------------------------------------------------------------------- /example_optimize.py: -------------------------------------------------------------------------------- 1 | import point_cloud_utils as pcu 2 | import polyscope as ps 3 | from sinkhorn import sinkhorn 4 | import torch 5 | from scipy.spatial.transform import Rotation 6 | import time 7 | import numpy as np 8 | 9 | 10 | if __name__ == "__main__": 11 | torch.manual_seed(1234567) # Fixed seed 12 | np.random.seed(1234567) 13 | 14 | niters = 500 15 | device = 'cuda' 16 | dtype = torch.float32 17 | eps = 1e-3 18 | stop_error = 1e-5 19 | fast_loss = False 20 | 21 | # Load a 3D point cloud with around 200k points and normalize it to [0, 1]^3 22 | print("Loading point cloud...") 23 | pc_load_start_time = time.time() 24 | p1 = torch.from_numpy(pcu.load_mesh_v("./data/wheel.ply")).to(device=device, dtype=dtype) 25 | p1 -= p1.min(dim=0)[0] 26 | p1 /= p1.max(dim=0)[0].max() 27 | p1 = p1[::200].contiguous() 28 | print(f"Done in {time.time() - pc_load_start_time}s") 29 | 30 | # Create a second point cloud by applying a random rotation and translation to the first one 31 | R = torch.from_numpy(Rotation.random().as_matrix()).to(p1) 32 | p2 = p1 @ R.T + 1.5 33 | p2.requires_grad = True 34 | 35 | timeframes = [p2.detach().cpu().numpy()] 36 | 37 | # Optimize the transformed point cloud to match the original point cloud p1 by minimizing the 38 | # Sinkhorn loss 39 | optimizer = torch.optim.Adam([p2], lr=1e-2) 40 | for epoch in range(101): 41 | optimizer.zero_grad() 42 | loss, corrs_1_to_2, corrs_2_to_1 = \ 43 | sinkhorn(p1, p2, p=2, eps=eps, max_iters=niters, stop_thresh=stop_error, verbose=False) 44 | print(f"Sinkhorn loss is {loss.item()}") 45 | 46 | # For a faster approximate loss that doesn't need to backprop through Sinkhorn 47 | # I've found this is basically the same as using the sinkhorn loss directly 48 | if fast_loss: 49 | loss = ((p1 - p2[corrs_1_to_2]) ** 2.0).sum(-1).mean() + \ 50 | ((p2 - p1[corrs_2_to_1]) ** 2.0).sum(-1).mean() 51 | 52 | loss.backward() 53 | optimizer.step() 54 | 55 | if epoch % 50 == 0 and epoch != 0: 56 | timeframes.append(p2.detach().cpu().numpy()) 57 | 58 | ps.init() 59 | p1 = p1.cpu().numpy() 60 | p2 = p2.detach().cpu().numpy() 61 | ps.register_point_cloud("p1", p1) 62 | for i in range(len(timeframes)): 63 | pi = timeframes[i] 64 | ps.register_point_cloud(f"traj_p_{i}", pi) 65 | ps.show() 66 | 67 | -------------------------------------------------------------------------------- /sinkhorn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pykeops.torch as keops 4 | import torch 5 | 6 | import tqdm 7 | 8 | def sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2, 9 | w_x: Union[torch.Tensor, None] = None, 10 | w_y: Union[torch.Tensor, None] = None, 11 | eps: float = 1e-3, 12 | max_iters: int = 100, stop_thresh: float = 1e-5, 13 | verbose=False): 14 | """ 15 | Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds 16 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. 17 | Note that this algorithm can be backpropped through 18 | (though this may be slow if using many iterations). 19 | 20 | :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row) 21 | :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row) 22 | :param p: Which norm to use. Must be an integer greater than 0. 23 | :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 24 | :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 25 | :param eps: The reciprocal of the sinkhorn entropy regularization parameter. 26 | :param max_iters: The maximum number of Sinkhorn iterations to perform. 27 | :param stop_thresh: Stop if the maximum change in the parameters is below this amount 28 | :param verbose: Print iterations 29 | :return: a triple (d, corrs_x_to_y, corr_y_to_x) where: 30 | * d is the approximate p-wasserstein distance between point clouds x and y 31 | * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair) 32 | * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair) 33 | """ 34 | 35 | if not isinstance(p, int): 36 | raise TypeError(f"p must be an integer greater than 0, got {p}") 37 | if p <= 0: 38 | raise ValueError(f"p must be an integer greater than 0, got {p}") 39 | 40 | if eps <= 0: 41 | raise ValueError("Entropy regularization term eps must be > 0") 42 | 43 | if not isinstance(p, int): 44 | raise TypeError(f"max_iters must be an integer > 0, got {max_iters}") 45 | if max_iters <= 0: 46 | raise ValueError(f"max_iters must be an integer > 0, got {max_iters}") 47 | 48 | if not isinstance(stop_thresh, float): 49 | raise TypeError(f"stop_thresh must be a float, got {stop_thresh}") 50 | 51 | if len(x.shape) != 2: 52 | raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}") 53 | if len(y.shape) != 2: 54 | raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}") 55 | if x.shape[1] != y.shape[1]: 56 | raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], " 57 | f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}") 58 | 59 | if w_x is not None: 60 | if w_y is None: 61 | raise ValueError("If w_x is not None, w_y must also be not None") 62 | if len(w_x.shape) > 1: 63 | w_x = w_x.squeeze() 64 | if len(w_x.shape) != 1: 65 | raise ValueError(f"w_x must have shape [n,] or [n, 1] " 66 | f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}") 67 | if w_x.shape[0] != x.shape[0]: 68 | raise ValueError(f"w_x must match the shape of x in dimension 0 but got " 69 | f"x.shape = {x.shape} and w_x.shape = {w_x.shape}") 70 | if w_y is not None: 71 | if w_x is None: 72 | raise ValueError("If w_y is not None, w_x must also be not None") 73 | if len(w_y.shape) > 1: 74 | w_y = w_y.squeeze() 75 | if len(w_y.shape) != 1: 76 | raise ValueError(f"w_y must have shape [n,] or [n, 1] " 77 | f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}") 78 | if w_x.shape[0] != x.shape[0]: 79 | raise ValueError(f"w_y must match the shape of y in dimension 0 but got " 80 | f"y.shape = {y.shape} and w_y.shape = {w_y.shape}") 81 | 82 | 83 | # Distance matrix [n, m] 84 | x_i = keops.Vi(x) # [n, 1, d] 85 | y_j = keops.Vj(y) # [i, m, d] 86 | if p == 1: 87 | M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m] 88 | else: 89 | M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p) # [n, m] 90 | 91 | # Weights [n,] and [m,] 92 | if w_x is None and w_y is None: 93 | w_x = torch.ones(x.shape[0]).to(x) / x.shape[0] 94 | w_y = torch.ones(y.shape[0]).to(x) / y.shape[0] 95 | w_y *= (w_x.shape[0] / w_y.shape[0]) 96 | 97 | sum_w_x = w_x.sum().item() 98 | sum_w_y = w_y.sum().item() 99 | if abs(sum_w_x - sum_w_y) > 1e-5: 100 | raise ValueError(f"Weights w_x and w_y do not sum to the same value, " 101 | f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} " 102 | f"(absolute difference = {abs(sum_w_x - sum_w_y)}") 103 | 104 | log_a = torch.log(w_x) # [n] 105 | log_b = torch.log(w_y) # [m] 106 | 107 | # Initialize the iteration with the change of variable 108 | u = torch.zeros_like(w_x) 109 | v = eps * torch.log(w_y) 110 | 111 | u_i = keops.Vi(u.unsqueeze(-1)) 112 | v_j = keops.Vj(v.unsqueeze(-1)) 113 | 114 | if verbose: 115 | pbar = tqdm.trange(max_iters) 116 | else: 117 | pbar = range(max_iters) 118 | 119 | for _ in pbar: 120 | u_prev = u 121 | v_prev = v 122 | 123 | summand_u = (-M_ij + v_j) / eps 124 | u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze()) 125 | u_i = keops.Vi(u.unsqueeze(-1)) 126 | 127 | summand_v = (-M_ij + u_i) / eps 128 | v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze()) 129 | v_j = keops.Vj(v.unsqueeze(-1)) 130 | 131 | max_err_u = torch.max(torch.abs(u_prev-u)) 132 | max_err_v = torch.max(torch.abs(v_prev-v)) 133 | if verbose: 134 | pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()}) 135 | if max_err_u < stop_thresh and max_err_v < stop_thresh: 136 | break 137 | 138 | P_ij = ((-M_ij + u_i + v_j) / eps).exp() 139 | 140 | approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1) 141 | approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1) 142 | 143 | if u.shape[0] > v.shape[0]: 144 | distance = (P_ij * M_ij).sum(dim=1).sum() 145 | else: 146 | distance = (P_ij * M_ij).sum(dim=0).sum() 147 | return distance, approx_corr_1, approx_corr_2 148 | 149 | --------------------------------------------------------------------------------