├── README.md ├── matlab └── w2.c └── python ├── example.ipynb ├── example.py ├── pyproject.toml ├── setup.py └── src └── main.cpp /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # The back-and-forth method in optimal transport 4 | 5 | This repository contains the source code used in the paper [A fast approach to optimal transport: The back-and-forth method](https://arxiv.org/pdf/1905.12154.pdf) [1]. The original code was written in C and we provide here a Python and a MATLAB wrapper to the C code. 6 | 7 | 8 | 9 | # Documentation 10 | Available here: . 11 | 12 | 13 | # Python 14 | 15 | ## Installation 16 | 17 | The simplest way to use the Python code is to [run this notebook on Google Colab](https://colab.research.google.com/drive/1Uml2n4MIVDZnviEHEMFrJIdMwYDOPHax?usp=sharing). 18 | 19 | The notebook is also available here as `example.ipynb`. 20 | 21 | Alternatively, to install the Python bindings on your machine, first clone the the GitHub repository and then install the Python bindings by running 22 | ``` 23 | pip install ./bfm/python 24 | ``` 25 | 26 | ## Usage 27 | See the Jupyter notebook `example.ipynb` or directly run `example.py`. 28 | 29 | 30 | 31 | 32 | # MATLAB 33 | 34 | ## Installation 35 | 36 | Requirements: FFTW ([download here](http://www.fftw.org/)), MATLAB. 37 | 38 | Download the C MEX file `w2.c` [here](https://raw.githubusercontent.com/Math-Jacobs/bfm/main/matlab/w2.c) or clone the GitHub repository and navigate to the `matlab/` folder. 39 | 40 | Compilation: in a MATLAB session run 41 | ```matlab 42 | mex -O CFLAGS="\$CFLAGS -std=c99" -lfftw3 -lm w2.c 43 | ``` 44 | This will produce a MEX function `w2` that you can use in MATLAB. You may need to use flags `-I` and `-L` to link to the FFTW3 library, e.g. `mex -O CFLAGS="\$CFLAGS -std=c++11" w2.c -lfftw3 -I/usr/local/include`. See [this page](https://www.mathworks.com/help/matlab/matlab_external/build-an-executable-mex-file.html) for more information on how to compile MEX files. 45 | 46 | 47 | 48 | ## Usage 49 | 50 | In a MATLAB session, run the command 51 | ```matlab 52 | [phi, psi] = w2(mu, nu, numIters, sigma); 53 | ``` 54 | 55 | Input: 56 | 57 | * `mu` and `nu` are two arrays of nonnegative values which sum up to the same value. 58 | * `numIters` is the total number of iterations. 59 | * `sigma` is the initial step size of the gradient ascent iterations. 60 | 61 | Output: 62 | 63 | * `phi` and `psi` are arrays corresponding to the Kantorovich potentials. 64 | 65 | 66 | 67 | 68 | 69 | # References 70 | 71 | 72 | [1] Matt Jacobs and Flavien Léger. [A fast approach to optimal transport: The back-and-forth method](https://arxiv.org/pdf/1905.12154.pdf). *Numerische Mathematik* (2020): 1-32. 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /matlab/w2.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "mex.h" 8 | 9 | 10 | 11 | 12 | 13 | double compute_l2_ot(double *mu, double *nu, double *phi, double *dual, double totalMass, double sigma, int maxIters, int n1, int n2); 14 | 15 | 16 | 17 | void mexFunction( int nlhs, mxArray *plhs[], 18 | int nrhs, const mxArray *prhs[]){ 19 | 20 | 21 | double *mu=mxGetPr(prhs[0]); 22 | double *nu=mxGetPr(prhs[1]); 23 | int maxIters=(int) mxGetScalar(prhs[2]); 24 | double sigma =(double) mxGetScalar(prhs[3]); 25 | 26 | int n1=mxGetM(prhs[0]); 27 | int n2=mxGetN(prhs[0]); 28 | 29 | int pcount=n1*n2; 30 | 31 | plhs[0] = mxCreateDoubleMatrix(n1,n2,mxREAL); 32 | plhs[1] = mxCreateDoubleMatrix(n1,n2,mxREAL); 33 | 34 | double *phi=mxGetPr(plhs[0]); 35 | double *psi=mxGetPr(plhs[1]); 36 | 37 | 38 | double sum=0; 39 | 40 | for(int i=0;i0)-(x<0); 148 | return truth; 149 | 150 | } 151 | 152 | 153 | void init_hull(convex_hull *hull, int n){ 154 | hull->indices=calloc(n,sizeof(double)); 155 | hull->hullCount=0; 156 | 157 | } 158 | 159 | void destroy_hull(convex_hull *hull){ 160 | free(hull->indices); 161 | } 162 | 163 | void transpose_doubles(double *transpose, double *data, int n1, int n2){ 164 | 165 | for(int i=0;ihullCount<2){ 206 | hull->indices[1]=i; 207 | hull->hullCount++; 208 | }else{ 209 | int hc=hull->hullCount; 210 | int ic1=hull->indices[hc-1]; 211 | int ic2=hull->indices[hc-2]; 212 | 213 | double oldSlope=(u[ic1]-u[ic2])/(ic1-ic2); 214 | double slope=(u[i]-u[ic1])/(i-ic1); 215 | 216 | if(slope>=oldSlope){ 217 | int hc=hull->hullCount; 218 | hull->indices[hc]=i; 219 | hull->hullCount++; 220 | }else{ 221 | hull->hullCount--; 222 | add_point(u, hull, i); 223 | } 224 | } 225 | } 226 | 227 | 228 | void get_convex_hull(double *u, convex_hull *hull, int n){ 229 | 230 | hull->indices[0]=0; 231 | hull->indices[1]=1; 232 | hull->hullCount=2; 233 | 234 | for(int i=2;ihullCount; 247 | 248 | for(int i=0;iindices[counter]; 252 | int ic2=hull->indices[counter-1]; 253 | 254 | double slope=n*(u[ic1]-u[ic2])/(ic1-ic2); 255 | while(s>slope&&counterindices[counter]; 258 | ic2=hull->indices[counter-1]; 259 | slope=n*(u[ic1]-u[ic2])/(ic1-ic2); 260 | } 261 | dualIndicies[i]=hull->indices[counter-1]; 262 | 263 | } 264 | } 265 | 266 | 267 | void compute_dual(double *dual, double *u, int *dualIndicies, convex_hull *hull, int n){ 268 | 269 | get_convex_hull(u, hull, n); 270 | 271 | 272 | compute_dual_indices(dualIndicies, u, hull, n); 273 | 274 | for(int i=0;iv2){ 281 | dual[i]=v1; 282 | }else{ 283 | dualIndicies[i]=n-1; 284 | dual[i]=v2; 285 | } 286 | 287 | } 288 | 289 | } 290 | 291 | 292 | 293 | 294 | void compute_2d_dual(double *dual, double *u, convex_hull *hull, int n1, int n2){ 295 | 296 | int pcount=n1*n2; 297 | 298 | int n=fmax(n1,n2); 299 | 300 | int *argmin=calloc(n,sizeof(int)); 301 | 302 | double *temp=calloc(pcount,sizeof(double)); 303 | 304 | memcpy(temp, u, pcount*sizeof(double)); 305 | 306 | 307 | for(int i=0;i0){ 387 | 388 | double xStretch0=fabs(xMap[i*(n1+1)+j+1]-xMap[i*(n1+1)+j]); 389 | double xStretch1=fabs(xMap[(i+1)*(n1+1)+j+1]-xMap[(i+1)*(n1+1)+j]); 390 | 391 | double yStretch0=fabs(yMap[(i+1)*(n1+1)+j]-yMap[i*(n1+1)+j]); 392 | double yStretch1=fabs(yMap[(i+1)*(n1+1)+j+1]-yMap[i*(n1+1)+j+1]); 393 | 394 | double xStretch=fmax(xStretch0, xStretch1); 395 | double yStretch=fmax(yStretch0, yStretch1); 396 | 397 | int xSamples=2*fmax(n1*xStretch,1); 398 | int ySamples=2*fmax(n2*yStretch,1); 399 | 400 | if(xStretchgradSq*sigma*upper){ 530 | return sigma*scaleUp; 531 | }else if(diff gradSq * sigma * upper: 66 | return sigma * scaleUp 67 | elif diff < gradSq * sigma * lower: 68 | return sigma * scaleDown 69 | return sigma 70 | 71 | 72 | # Back-and-forth solver 73 | def compute_ot(phi, psi, bf, sigma): 74 | 75 | kernel = initialize_kernel(n1, n2) 76 | rho = np.copy(mu) 77 | 78 | oldValue = compute_w2(phi, psi, mu, nu, x, y) 79 | 80 | for k in range(numIters+1): 81 | 82 | gradSq = update_potential(phi, rho, nu, kernel, sigma) 83 | 84 | bf.ctransform(psi, phi) 85 | bf.ctransform(phi, psi) 86 | 87 | value = compute_w2(phi, psi, mu, nu, x, y) 88 | sigma = stepsize_update(sigma, value, oldValue, gradSq) 89 | oldValue = value 90 | 91 | bf.pushforward(rho, phi, nu) 92 | 93 | gradSq = update_potential(psi, rho, mu, kernel, sigma) 94 | 95 | bf.ctransform(phi, psi) 96 | bf.ctransform(psi, phi) 97 | 98 | bf.pushforward(rho, psi, mu) 99 | 100 | value = compute_w2(phi, psi, mu, nu, x, y) 101 | sigma = stepsize_update(sigma, value, oldValue, gradSq) 102 | oldValue = value 103 | 104 | if k % 5 == 0: 105 | print(f'iter {k:4d}, W2 value: {value:.6e}, H1 err: {gradSq:.2e}') 106 | 107 | # %% Example: Caffarelli's counterexample 108 | 109 | # Caffarelli's counterexample illustrates that the optimal map can be discontinous when the target domain is nonconvex. 110 | # Reference: Luis A. Caffarelli. The regularity of mappings with a convex potential. J. Amer. Math. Soc. 5, 1 (1992), 99–104. 111 | 112 | 113 | # Define the problem data and initial values 114 | 115 | # Grid of size n1 x n2 116 | n1 = 1024 # x axis 117 | n2 = 1024 # y axis 118 | 119 | x, y = np.meshgrid(np.linspace(0.5/n1,1-0.5/n1,n1), np.linspace(0.5/n2,1-0.5/n1,n2)) 120 | 121 | phi = 0.5 * (x*x + y*y) 122 | psi = 0.5 * (x*x + y*y) 123 | 124 | # Initialize densities 125 | mu = np.zeros((n2, n1)) 126 | r = 0.125 127 | mu[(x-0.5)**2 + (y-0.5)**2 < r**2] = 1 128 | nu = np.zeros((n2, n1)) 129 | idx = (((x-0.25)**2 + (y-0.5)**2 < r**2) & (x < 0.25) ) 130 | idx = idx | (((x-0.75)**2 + (y-0.5)**2 < r**2) & (x > 0.75)) 131 | idx = idx | ((x < 0.751) & (x > 0.249) & (y < 0.51) & (y > 0.49)) 132 | nu[idx] = 1 133 | 134 | # Normalize 135 | mu *= n1*n2 / np.sum(mu) 136 | nu *= n1*n2 / np.sum(nu) 137 | 138 | 139 | # Plot mu and nu 140 | fig, ax = plt.subplots(1, 2) 141 | ax[0].imshow(mu, origin='lower', extent=(0,1,0,1)) 142 | ax[0].set_title("$\\mu$") 143 | ax[1].imshow(nu, origin='lower', extent=(0,1,0,1)) 144 | ax[1].set_title("$\\nu$"); 145 | 146 | 147 | # %% Run the back-and-forth solver 148 | 149 | # Number of iterations for BFM 150 | numIters = 50 151 | 152 | # Initial step size 153 | sigma = 4/np.maximum(mu.max(), nu.max()) 154 | 155 | tic = time() 156 | 157 | # Initialize BFM method 158 | bf = BFM(n1, n2, mu) 159 | compute_ot(phi, psi, bf, sigma) 160 | 161 | toc = time() 162 | print(f'\nElapsed time: {toc-tic:.2f}s') 163 | 164 | 165 | 166 | # %% Visualizations 167 | 168 | 169 | my, mx = ma.masked_array(np.gradient(psi-0.5*(x*x+y*y), 1/n2, 1/n1), mask=((mu==0), (mu==0))) 170 | 171 | fig, ax = plt.subplots() 172 | ax.contourf(x, y, mu+nu) 173 | ax.set_aspect('equal') 174 | skip = (slice(None,None,n1//50), slice(None,None,n2//50)) 175 | ax.quiver(x[skip], y[skip], mx[skip], my[skip], color='yellow', angles='xy', scale_units='xy', scale=1); 176 | 177 | # %% 178 | # The discontinuity of the optimal map is hard to see as a quiver plot. So let's instead display only the x-component of the map. 179 | 180 | fig, ax = plt.subplots(1, 2) 181 | ax[0].imshow(x + mx, origin='lower', extent=(0,1,0,1), cmap='plasma') 182 | 183 | x_masked = ma.masked_array(x, mask=(nu==0)) 184 | ax[1].imshow(x_masked, origin='lower', extent=(0,1,0,1), cmap='plasma') 185 | 186 | 187 | 188 | # %% Displacement interpolation 189 | 190 | # Plotting interpolation 191 | def plot_interpolation(mu, nu, phi, psi, n_fig=6): 192 | fig, ax = plt.subplots(1, n_fig, figsize=(20,8)) 193 | [axi.axis('off') for axi in ax.ravel()] 194 | vmax = mu.max() 195 | ax[0].imshow(mu, vmax=vmax) 196 | ax[0].set_title("$t=0$") 197 | ax[n_fig-1].imshow(nu, vmax=vmax) 198 | ax[n_fig-1].set_title("$t=1$") 199 | 200 | interpolate = np.zeros_like(mu) 201 | rho_fwd = np.zeros_like(mu) 202 | rho_bwd = np.zeros_like(mu) 203 | 204 | for i in range(1,n_fig-1): 205 | t = i / (n_fig - 1) 206 | psi_t = (1-t) * 0.5 * (x*x + y*y) + t * psi 207 | phi_t = t * 0.5 * (x*x + y*y) + (1-t) * phi 208 | 209 | bf.pushforward(rho_fwd, psi_t, mu) 210 | bf.pushforward(rho_bwd, phi_t, nu) 211 | interpolate = (1-t) * rho_fwd + t * rho_bwd 212 | ax[i].imshow(interpolate, vmax=vmax) 213 | ax[i].set_title(f"$t={i}/{n_fig-1}$") 214 | 215 | plot_interpolation(mu, nu, phi, psi) 216 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "pybind11>=2.6.0", 6 | ] 7 | 8 | build-backend = "setuptools.build_meta" 9 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | from pybind11.setup_helpers import Pybind11Extension, build_ext 4 | from pybind11 import get_cmake_dir 5 | 6 | import sys 7 | 8 | __version__ = "0.0.1" 9 | 10 | ext_modules = [ 11 | Pybind11Extension("w2", 12 | ["src/main.cpp"], 13 | define_macros = [('VERSION_INFO', __version__)], 14 | ), 15 | ] 16 | 17 | setup( 18 | name="w2", 19 | version=__version__, 20 | author="Wonjun Lee", 21 | author_email="wlee@math.ucla.edu", 22 | description="Python wrapper for the back-and-forth method for optimal transport", 23 | long_description=""" 24 | The code is based on C code of the back-and-forth method https://github.com/Math-Jacobs/bfm. 25 | Link to the paper: https://arxiv.org/pdf/1905.12154.pdf 26 | """, 27 | long_description_content_type="text/markdown", 28 | ext_modules=ext_modules, 29 | cmdclass={"build_ext": build_ext}, 30 | zip_safe=False, 31 | ) -------------------------------------------------------------------------------- /python/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace py = pybind11; 5 | 6 | 7 | 8 | class convex_hull{ 9 | public: 10 | int* indices; 11 | int hullCount; 12 | 13 | convex_hull(int n){ 14 | indices=new int[n]; 15 | hullCount=0; 16 | } 17 | 18 | ~convex_hull(){ 19 | delete[] indices; 20 | } 21 | }; 22 | 23 | 24 | 25 | // --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- 26 | 27 | 28 | 29 | class BFM{ 30 | public: 31 | int n1; 32 | int n2; 33 | 34 | double totalMass; 35 | 36 | double *xMap; 37 | double *yMap; 38 | 39 | double *rho; 40 | 41 | int *argmin; 42 | double *temp; 43 | 44 | convex_hull* hull; 45 | 46 | BFM(int n1, int n2, py::array_t & mu_np){ 47 | 48 | py::buffer_info mu_buf = mu_np.request(); 49 | double *mu = static_cast(mu_buf.ptr); 50 | 51 | this->n1 = n1; 52 | this->n2 = n2; 53 | 54 | int n=fmax(n1,n2); 55 | hull = new convex_hull(n); 56 | argmin = new int[n]; 57 | temp = new double[n1*n2]; 58 | 59 | xMap=new double[(n1+1)*(n2+1)]; 60 | yMap=new double[(n1+1)*(n2+1)]; 61 | 62 | for(int i=0;i & dual_np, py::array_t & phi_np){ 93 | 94 | py::buffer_info phi_buf = phi_np.request(); 95 | py::buffer_info dual_buf = dual_np.request(); 96 | 97 | double *phi = static_cast (phi_buf.ptr); 98 | double *dual = static_cast (dual_buf.ptr); 99 | 100 | compute_2d_dual_inside(dual, phi, hull, n1, n2); 101 | } 102 | 103 | void pushforward(py::array_t & rho_np, py::array_t & phi_np, py::array_t & nu_np){ 104 | 105 | py::buffer_info phi_buf = phi_np.request(); 106 | py::buffer_info nu_buf = nu_np.request(); 107 | 108 | double *phi = static_cast (phi_buf.ptr); 109 | double *nu = static_cast (nu_buf.ptr); 110 | 111 | calc_pushforward_map(phi); 112 | sampling_pushforward(nu); 113 | 114 | py::buffer_info rho_buf = rho_np.request(); 115 | memcpy(static_cast (rho_buf.ptr), rho, n1*n2*sizeof(double)); 116 | } 117 | 118 | 119 | double compute_w2(py::array_t & phi_np, py::array_t & dual_np, py::array_t & mu_np, py::array_t & nu_np){ 120 | 121 | py::buffer_info phi_buf = phi_np.request(); 122 | py::buffer_info dual_buf = dual_np.request(); 123 | py::buffer_info mu_buf = mu_np.request(); 124 | py::buffer_info nu_buf = nu_np.request(); 125 | 126 | double *phi = static_cast (phi_buf.ptr); 127 | double *dual = static_cast (dual_buf.ptr); 128 | double *mu = static_cast (mu_buf.ptr); 129 | double *nu = static_cast (nu_buf.ptr); 130 | 131 | int pcount=n1*n2; 132 | 133 | double value=0; 134 | 135 | for(int i=0;iv2){ 190 | dual[i]=v1; 191 | }else{ 192 | dualIndicies[i]=n-1; 193 | dual[i]=v2; 194 | } 195 | 196 | } 197 | 198 | } 199 | 200 | 201 | int sgn(double x){ 202 | 203 | int truth=(x>0)-(x<0); 204 | return truth; 205 | 206 | } 207 | 208 | 209 | void transpose_doubles(double *transpose, double *data, int n1, int n2){ 210 | 211 | for(int i=0;iindices[0]=0; 223 | hull->indices[1]=1; 224 | hull->hullCount=2; 225 | 226 | for(int i=2;ihullCount<2){ 239 | hull->indices[1]=i; 240 | hull->hullCount++; 241 | }else{ 242 | int hc=hull->hullCount; 243 | int ic1=hull->indices[hc-1]; 244 | int ic2=hull->indices[hc-2]; 245 | 246 | double oldSlope=(u[ic1]-u[ic2])/(ic1-ic2); 247 | double slope=(u[i]-u[ic1])/(i-ic1); 248 | 249 | if(slope>=oldSlope){ 250 | int hc=hull->hullCount; 251 | hull->indices[hc]=i; 252 | hull->hullCount++; 253 | }else{ 254 | hull->hullCount--; 255 | add_point(u, hull, i); 256 | } 257 | } 258 | } 259 | 260 | 261 | 262 | 263 | double interpolate_function(double *function, double x, double y, int n1, int n2){ 264 | 265 | int xIndex=fmin(fmax(x*n1-.5 ,0),n1-1); 266 | int yIndex=fmin(fmax(y*n2-.5 ,0),n2-1); 267 | 268 | double xfrac=x*n1-xIndex-.5; 269 | double yfrac=y*n2-yIndex-.5; 270 | 271 | int xOther=xIndex+sgn(xfrac); 272 | int yOther=yIndex+sgn(yfrac); 273 | 274 | xOther=fmax(fmin(xOther, n1-1),0); 275 | yOther=fmax(fmin(yOther, n2-1),0); 276 | 277 | double v1=(1-fabs(xfrac))*(1-fabs(yfrac))*function[yIndex*n1+xIndex]; 278 | double v2=fabs(xfrac)*(1-fabs(yfrac))*function[yIndex*n1+xOther]; 279 | double v3=(1-fabs(xfrac))*fabs(yfrac)*function[yOther*n1+xIndex]; 280 | double v4=fabs(xfrac)*fabs(yfrac)*function[yOther*n1+xOther]; 281 | 282 | double v=v1+v2+v3+v4; 283 | 284 | return v; 285 | 286 | } 287 | 288 | 289 | 290 | 291 | void compute_dual_indices(int *dualIndicies, double *u, convex_hull *hull, int n){ 292 | 293 | int counter=1; 294 | int hc=hull->hullCount; 295 | 296 | for(int i=0;iindices[counter]; 300 | int ic2=hull->indices[counter-1]; 301 | 302 | double slope=n*(u[ic1]-u[ic2])/(ic1-ic2); 303 | while(s>slope&&counterindices[counter]; 306 | ic2=hull->indices[counter-1]; 307 | slope=n*(u[ic1]-u[ic2])/(ic1-ic2); 308 | } 309 | dualIndicies[i]=hull->indices[counter-1]; 310 | 311 | } 312 | } 313 | 314 | 315 | 316 | void calc_pushforward_map(double *dual){ 317 | 318 | 319 | double xStep=1.0/n1; 320 | double yStep=1.0/n2; 321 | 322 | 323 | for(int i=0;i0){ 358 | 359 | double xStretch0=fabs(xMap[i*(n1+1)+j+1]-xMap[i*(n1+1)+j]); 360 | double xStretch1=fabs(xMap[(i+1)*(n1+1)+j+1]-xMap[(i+1)*(n1+1)+j]); 361 | 362 | double yStretch0=fabs(yMap[(i+1)*(n1+1)+j]-yMap[i*(n1+1)+j]); 363 | double yStretch1=fabs(yMap[(i+1)*(n1+1)+j+1]-yMap[i*(n1+1)+j+1]); 364 | 365 | double xStretch=fmax(xStretch0, xStretch1); 366 | double yStretch=fmax(yStretch0, yStretch1); 367 | 368 | int xSamples=fmax(n1*xStretch,1); 369 | int ySamples=fmax(n2*yStretch,1); 370 | 371 | double factor=1/(xSamples*ySamples*1.0); 372 | 373 | for(int l=0;l(m, "BFM") 434 | .def(py::init &>()) 435 | .def("ctransform", &BFM::ctransform) 436 | .def("pushforward", &BFM::pushforward); 437 | } 438 | --------------------------------------------------------------------------------