├── .gitignore
├── README.md
├── data
├── teaser.png
└── wheel.ply
├── example_basic.py
├── example_optimize.py
└── sinkhorn.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .polyscope.ini
3 | imgui.ini
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast, Memory-Efficient Approximate Wasserstein Distances
2 | This repository contains PyTorch code to compute fast p-Wasserstein distances between d-dimensional point clouds using the [Sinkhorn Algorithm](https://arxiv.org/abs/1306.0895).
3 |
4 | This implementation uses **linear memory overhead** and is **stable in float32, runs on the GPU, and fully differentiable**.
5 |
6 | This shows an example of the correspondences between two shapes found by computing the Sinkhorn distance on 200k input points:
7 |
8 |
9 |
10 |
11 | ## How to use:
12 | 1. Copy the `sinkhorn.py` file in this repository to your PyTorch codebase.
13 | 2. `pip install pykeops tqdm`
14 | 3. Import `from sinkhorn import sinkhorn` and use the `sinkhorn` function!
15 |
16 | ### Running the example code
17 | Look at [example_basic.py](example_basic.py) for a basic example and [example_optimize.py](example_optimize.py) for an example of how to use Sinkhorn in your optimization
18 |
19 | **NOTE:** To run the examples, you need to first run
20 | ```
21 | pip install pykeops tqdm numpy scipy polyscope point-cloud-utils
22 | ```
23 |
24 | ## `sinkhorn` function documentation
25 | ```
26 | sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
27 | w_x: Union[torch.Tensor, None] = None,
28 | w_y: Union[torch.Tensor, None] = None,
29 | eps: float = 1e-3,
30 | max_iters: int = 100, stop_thresh: float = 1e-5,
31 | verbose=False)
32 | ```
33 | Computes the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds
34 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors.
35 | Note that this algorithm can be backpropped through
36 | (though this may be slow if using many iterations).
37 |
38 | **Arguments:**
39 |
40 | * `x`: A `[n, d]` shaped tensor representing a d-dimensional point cloud with `n` points (one per row)
41 | * `y`: A `[m, d]` shaped tensor representing a d-dimensional point cloud with `m` points (one per row)
42 | * `p`: Which norm to use. Must be an integer greater than 0.
43 | * `w_x`: A `[n,]` shaped tensor of optional weights for the points `x` (`None` for uniform weights). Note that these must sum to the same value as w_y. Default is `None`.
44 | * `w_y`: A `[m,]` shaped tensor of optional weights for the points `y` (`None` for uniform weights). Note that these must sum to the same value as w_y. Default is `None`.
45 | * `eps`: The reciprocal of the Sinkhorn entropy regularization parameter.
46 | * `max_iters`: The maximum number of Sinkhorn iterations to perform.
47 | * `stop_thresh`: Stop if the maximum change in the parameters is below this amount
48 | * `verbose`: If set, print a progress bar
49 |
50 | **Returns:**
51 |
52 | A triple `(d, corrs_x_to_y, corr_y_to_x)` where:
53 | * `d` is the approximate p-wasserstein distance between point clouds `x` and `y`
54 | * `corrs_x_to_y` is a `[n,]`-shaped tensor where `corrs_x_to_y[i]` is the index of the approximate correspondence in point cloud `y` of point `x[i]` (i.e. `x[i]` and `y[corrs_x_to_y[i]]` are a corresponding pair)
55 | * `corrs_y_to_x` is a `[m,]`-shaped tensor where `corrs_y_to_x[i]` is the index of the approximate correspondence in point cloud `x` of `point y[j]` (i.e. `y[j]` and `x[corrs_y_to_x[j]]` are a corresponding pair)
56 |
57 |
--------------------------------------------------------------------------------
/data/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fwilliams/scalable-pytorch-sinkhorn/c453e660d57e6abc1e891722793d95622efdbd7e/data/teaser.png
--------------------------------------------------------------------------------
/data/wheel.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fwilliams/scalable-pytorch-sinkhorn/c453e660d57e6abc1e891722793d95622efdbd7e/data/wheel.ply
--------------------------------------------------------------------------------
/example_basic.py:
--------------------------------------------------------------------------------
1 | import point_cloud_utils as pcu
2 | import polyscope as ps
3 | from sinkhorn import sinkhorn
4 | import torch
5 | from scipy.spatial.transform import Rotation
6 | import time
7 | import numpy as np
8 |
9 |
10 | if __name__ == "__main__":
11 | torch.manual_seed(1234567) # Fixed seed
12 | np.random.seed(1234567)
13 |
14 | niters = 500
15 | device = 'cuda'
16 | dtype = torch.float32
17 | eps = 1e-3
18 | stop_error = 1e-3
19 |
20 | # Load a 3D point cloud with around 200k points and normalize it to [0, 1]^3
21 | print("Loading point cloud...")
22 | pc_load_start_time = time.time()
23 | p1 = torch.from_numpy(pcu.load_mesh_v("./data/wheel.ply")).to(device=device, dtype=dtype)
24 | p1 -= p1.min(dim=0)[0]
25 | p1 /= p1.max(dim=0)[0].max()
26 | print(f"Done in {time.time() - pc_load_start_time}s")
27 |
28 | # Create a second point cloud by applying a random rotation and translation to the first one
29 | R = torch.from_numpy(Rotation.random().as_matrix()).to(p1)
30 | p2 = p1 @ R.T + 1.5
31 |
32 | print("Running Sinkhorn...")
33 | sinkhorn_start_time = time.time()
34 | loss, corrs_1_to_2, corrs_2_to_1 = \
35 | sinkhorn(p1, p2, p=2, eps=eps, max_iters=niters, stop_thresh=stop_error, verbose=True)
36 | torch.cuda.synchronize()
37 | print(f"Done in {time.time() - sinkhorn_start_time}s")
38 |
39 | print(f"Sinkhorn loss is {loss.item()}")
40 |
41 | ps.init()
42 | edges = torch.stack([torch.arange(p1.shape[0]).to(corrs_1_to_2),
43 | corrs_1_to_2 + p1.shape[0]], dim=-1).cpu().numpy()
44 | verts = torch.cat([p1, p2], dim=0).cpu().numpy()
45 | p1 = p1.cpu().numpy()
46 | p2 = p2.cpu().numpy()
47 | ps.register_point_cloud("p1", p1)
48 | ps.register_point_cloud("p2", p2)
49 | ps.register_curve_network("corr1", verts, edges[::200]) # Only plot 100 for easier viz
50 | ps.show()
51 |
52 |
--------------------------------------------------------------------------------
/example_optimize.py:
--------------------------------------------------------------------------------
1 | import point_cloud_utils as pcu
2 | import polyscope as ps
3 | from sinkhorn import sinkhorn
4 | import torch
5 | from scipy.spatial.transform import Rotation
6 | import time
7 | import numpy as np
8 |
9 |
10 | if __name__ == "__main__":
11 | torch.manual_seed(1234567) # Fixed seed
12 | np.random.seed(1234567)
13 |
14 | niters = 500
15 | device = 'cuda'
16 | dtype = torch.float32
17 | eps = 1e-3
18 | stop_error = 1e-5
19 | fast_loss = False
20 |
21 | # Load a 3D point cloud with around 200k points and normalize it to [0, 1]^3
22 | print("Loading point cloud...")
23 | pc_load_start_time = time.time()
24 | p1 = torch.from_numpy(pcu.load_mesh_v("./data/wheel.ply")).to(device=device, dtype=dtype)
25 | p1 -= p1.min(dim=0)[0]
26 | p1 /= p1.max(dim=0)[0].max()
27 | p1 = p1[::200].contiguous()
28 | print(f"Done in {time.time() - pc_load_start_time}s")
29 |
30 | # Create a second point cloud by applying a random rotation and translation to the first one
31 | R = torch.from_numpy(Rotation.random().as_matrix()).to(p1)
32 | p2 = p1 @ R.T + 1.5
33 | p2.requires_grad = True
34 |
35 | timeframes = [p2.detach().cpu().numpy()]
36 |
37 | # Optimize the transformed point cloud to match the original point cloud p1 by minimizing the
38 | # Sinkhorn loss
39 | optimizer = torch.optim.Adam([p2], lr=1e-2)
40 | for epoch in range(101):
41 | optimizer.zero_grad()
42 | loss, corrs_1_to_2, corrs_2_to_1 = \
43 | sinkhorn(p1, p2, p=2, eps=eps, max_iters=niters, stop_thresh=stop_error, verbose=False)
44 | print(f"Sinkhorn loss is {loss.item()}")
45 |
46 | # For a faster approximate loss that doesn't need to backprop through Sinkhorn
47 | # I've found this is basically the same as using the sinkhorn loss directly
48 | if fast_loss:
49 | loss = ((p1 - p2[corrs_1_to_2]) ** 2.0).sum(-1).mean() + \
50 | ((p2 - p1[corrs_2_to_1]) ** 2.0).sum(-1).mean()
51 |
52 | loss.backward()
53 | optimizer.step()
54 |
55 | if epoch % 50 == 0 and epoch != 0:
56 | timeframes.append(p2.detach().cpu().numpy())
57 |
58 | ps.init()
59 | p1 = p1.cpu().numpy()
60 | p2 = p2.detach().cpu().numpy()
61 | ps.register_point_cloud("p1", p1)
62 | for i in range(len(timeframes)):
63 | pi = timeframes[i]
64 | ps.register_point_cloud(f"traj_p_{i}", pi)
65 | ps.show()
66 |
67 |
--------------------------------------------------------------------------------
/sinkhorn.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import pykeops.torch as keops
4 | import torch
5 |
6 | import tqdm
7 |
8 | def sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
9 | w_x: Union[torch.Tensor, None] = None,
10 | w_y: Union[torch.Tensor, None] = None,
11 | eps: float = 1e-3,
12 | max_iters: int = 100, stop_thresh: float = 1e-5,
13 | verbose=False):
14 | """
15 | Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds
16 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors.
17 | Note that this algorithm can be backpropped through
18 | (though this may be slow if using many iterations).
19 |
20 | :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row)
21 | :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row)
22 | :param p: Which norm to use. Must be an integer greater than 0.
23 | :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
24 | :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
25 | :param eps: The reciprocal of the sinkhorn entropy regularization parameter.
26 | :param max_iters: The maximum number of Sinkhorn iterations to perform.
27 | :param stop_thresh: Stop if the maximum change in the parameters is below this amount
28 | :param verbose: Print iterations
29 | :return: a triple (d, corrs_x_to_y, corr_y_to_x) where:
30 | * d is the approximate p-wasserstein distance between point clouds x and y
31 | * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair)
32 | * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair)
33 | """
34 |
35 | if not isinstance(p, int):
36 | raise TypeError(f"p must be an integer greater than 0, got {p}")
37 | if p <= 0:
38 | raise ValueError(f"p must be an integer greater than 0, got {p}")
39 |
40 | if eps <= 0:
41 | raise ValueError("Entropy regularization term eps must be > 0")
42 |
43 | if not isinstance(p, int):
44 | raise TypeError(f"max_iters must be an integer > 0, got {max_iters}")
45 | if max_iters <= 0:
46 | raise ValueError(f"max_iters must be an integer > 0, got {max_iters}")
47 |
48 | if not isinstance(stop_thresh, float):
49 | raise TypeError(f"stop_thresh must be a float, got {stop_thresh}")
50 |
51 | if len(x.shape) != 2:
52 | raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}")
53 | if len(y.shape) != 2:
54 | raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}")
55 | if x.shape[1] != y.shape[1]:
56 | raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], "
57 | f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}")
58 |
59 | if w_x is not None:
60 | if w_y is None:
61 | raise ValueError("If w_x is not None, w_y must also be not None")
62 | if len(w_x.shape) > 1:
63 | w_x = w_x.squeeze()
64 | if len(w_x.shape) != 1:
65 | raise ValueError(f"w_x must have shape [n,] or [n, 1] "
66 | f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}")
67 | if w_x.shape[0] != x.shape[0]:
68 | raise ValueError(f"w_x must match the shape of x in dimension 0 but got "
69 | f"x.shape = {x.shape} and w_x.shape = {w_x.shape}")
70 | if w_y is not None:
71 | if w_x is None:
72 | raise ValueError("If w_y is not None, w_x must also be not None")
73 | if len(w_y.shape) > 1:
74 | w_y = w_y.squeeze()
75 | if len(w_y.shape) != 1:
76 | raise ValueError(f"w_y must have shape [n,] or [n, 1] "
77 | f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}")
78 | if w_x.shape[0] != x.shape[0]:
79 | raise ValueError(f"w_y must match the shape of y in dimension 0 but got "
80 | f"y.shape = {y.shape} and w_y.shape = {w_y.shape}")
81 |
82 |
83 | # Distance matrix [n, m]
84 | x_i = keops.Vi(x) # [n, 1, d]
85 | y_j = keops.Vj(y) # [i, m, d]
86 | if p == 1:
87 | M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m]
88 | else:
89 | M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p) # [n, m]
90 |
91 | # Weights [n,] and [m,]
92 | if w_x is None and w_y is None:
93 | w_x = torch.ones(x.shape[0]).to(x) / x.shape[0]
94 | w_y = torch.ones(y.shape[0]).to(x) / y.shape[0]
95 | w_y *= (w_x.shape[0] / w_y.shape[0])
96 |
97 | sum_w_x = w_x.sum().item()
98 | sum_w_y = w_y.sum().item()
99 | if abs(sum_w_x - sum_w_y) > 1e-5:
100 | raise ValueError(f"Weights w_x and w_y do not sum to the same value, "
101 | f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} "
102 | f"(absolute difference = {abs(sum_w_x - sum_w_y)}")
103 |
104 | log_a = torch.log(w_x) # [n]
105 | log_b = torch.log(w_y) # [m]
106 |
107 | # Initialize the iteration with the change of variable
108 | u = torch.zeros_like(w_x)
109 | v = eps * torch.log(w_y)
110 |
111 | u_i = keops.Vi(u.unsqueeze(-1))
112 | v_j = keops.Vj(v.unsqueeze(-1))
113 |
114 | if verbose:
115 | pbar = tqdm.trange(max_iters)
116 | else:
117 | pbar = range(max_iters)
118 |
119 | for _ in pbar:
120 | u_prev = u
121 | v_prev = v
122 |
123 | summand_u = (-M_ij + v_j) / eps
124 | u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze())
125 | u_i = keops.Vi(u.unsqueeze(-1))
126 |
127 | summand_v = (-M_ij + u_i) / eps
128 | v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze())
129 | v_j = keops.Vj(v.unsqueeze(-1))
130 |
131 | max_err_u = torch.max(torch.abs(u_prev-u))
132 | max_err_v = torch.max(torch.abs(v_prev-v))
133 | if verbose:
134 | pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()})
135 | if max_err_u < stop_thresh and max_err_v < stop_thresh:
136 | break
137 |
138 | P_ij = ((-M_ij + u_i + v_j) / eps).exp()
139 |
140 | approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1)
141 | approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1)
142 |
143 | if u.shape[0] > v.shape[0]:
144 | distance = (P_ij * M_ij).sum(dim=1).sum()
145 | else:
146 | distance = (P_ij * M_ij).sum(dim=0).sum()
147 | return distance, approx_corr_1, approx_corr_2
148 |
149 |
--------------------------------------------------------------------------------