├── .flake8 ├── .github └── workflows │ ├── pre-commit.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE.md ├── README.rst ├── TrajectoryNet ├── __init__.py ├── dataset.py ├── eval.py ├── eval_utils.py ├── lib │ ├── __init__.py │ ├── growth_net.py │ ├── layers │ │ ├── __init__.py │ │ ├── cnf.py │ │ ├── container.py │ │ ├── coupling.py │ │ ├── diffeq_layers │ │ │ ├── __init__.py │ │ │ ├── basic.py │ │ │ ├── container.py │ │ │ ├── resnet.py │ │ │ └── wrappers.py │ │ ├── elemwise.py │ │ ├── glow.py │ │ ├── norm_flows.py │ │ ├── normalization.py │ │ ├── odefunc.py │ │ ├── resnet.py │ │ ├── squeeze.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ └── cnf_regularization.py │ ├── spectral_norm.py │ ├── utils.py │ ├── visualize_flow.py │ └── viz_scrna.py ├── main.py ├── optimal_transport │ ├── MMD.py │ ├── __init__.py │ ├── emd.py │ ├── gcs.npy │ ├── growth │ │ └── traj.mp4 │ ├── model │ ├── plot_UOT_1D.py │ ├── sinkhorn_knopp_unbalanced.py │ └── train_growth.py ├── parse.py ├── train_growth.py ├── train_misc.py └── version.py ├── data ├── eb_genes.txt └── eb_velocity_v5.npz ├── figures ├── EB-Trajectory.gif └── eb_high_quality.png ├── notebooks ├── DentateGyrus-Load.ipynb ├── EmbryoidBody_TrajectoryInference.ipynb ├── Example_Anndata_to_TrajectoryNet.ipynb └── WOT-Schiebinger-load.ipynb ├── requirements.txt ├── results └── fig8_results │ ├── backward_trajectories.npy │ ├── checkpt.pth │ ├── logs │ └── train_eval.csv ├── setup.cfg └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | on: 3 | push: 4 | branches-ignore: 5 | - 'master' 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Cancel Previous Runs 12 | uses: styfle/cancel-workflow-action@0.6.0 13 | with: 14 | access_token: ${{ github.token }} 15 | - uses: actions/checkout@v2 16 | with: 17 | fetch-depth: 0 18 | 19 | - uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.7" 22 | architecture: "x64" 23 | 24 | - uses: actions/cache@v2 25 | with: 26 | path: ~/.cache/pre-commit 27 | key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}- 28 | 29 | - uses: pre-commit/action@v2.0.0 30 | continue-on-error: true 31 | 32 | - name: Commit files 33 | run: | 34 | if [[ `git status --porcelain --untracked-files=no` ]]; then 35 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" 36 | git config --local user.name "github-actions[bot]" 37 | git commit -m "pre-commit" -a 38 | fi 39 | 40 | - name: Push changes 41 | uses: ad-m/github-push-action@master 42 | with: 43 | github_token: ${{ secrets.GITHUB_TOKEN }} 44 | branch: ${{ github.ref }} 45 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish Python 🐍 distributions 📦 to PyPI 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install build 29 | - name: Build package 30 | run: python -m build 31 | - name: Publish distribution 📦 to Test PyPI 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | skip_existing: true 35 | user: __token__ 36 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 37 | repository_url: https://test.pypi.org/legacy/ 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | with: 41 | user: __token__ 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/tmp/ 2 | *.jpg 3 | 4 | #Vim 5 | *.sw? 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # DotEnv configuration 66 | .env 67 | 68 | # Database 69 | *.db 70 | *.rdb 71 | 72 | # Pycharm 73 | .idea 74 | 75 | # VS Code 76 | .vscode/ 77 | 78 | # Spyder 79 | .spyproject/ 80 | 81 | # Jupyter NB Checkpoints 82 | .ipynb_checkpoints/ 83 | 84 | # exclude data from source control by default 85 | /data/ 86 | 87 | # exclude old folder by default 88 | /old/ 89 | 90 | # Mac OS-specific storage files 91 | .DS_Store 92 | 93 | # vim 94 | *.swp 95 | *.swo 96 | 97 | # Mypy cache 98 | .mypy_cache/ 99 | 100 | # Snakemake cache 101 | .snakemake/ 102 | *.out 103 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ---------------------------------- 2 | 3 | Non-Commercial License 4 | Yale Copyright © 2024 Yale University. 5 | 6 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | ---------------------------------- 11 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Pytorch Implementation of TrajectoryNet 2 | ======================================= 3 | 4 | This library runs code associated with the TrajectoryNet paper [1]. 5 | 6 | In brief, TrajectoryNet is a Continuous Normalizing Flow model which can 7 | perform dynamic optimal transport using energy regularization and / or a 8 | combination of velocity, density, and growth regularizations to better match 9 | cellular trajectories. 10 | 11 | Our setting is similar to that of `WaddingtonOT 12 | `_. In that we have access to a bunch of 13 | population measurements of cells over time and would like to model the dynamics 14 | of cells over that time period. TrajectoryNet is trained end-to-end and is 15 | continuous both in gene space and in time. 16 | 17 | 18 | Installation 19 | ------------ 20 | 21 | TrajectoryNet is available in `pypi`. Install by running the following 22 | 23 | .. code-block:: bash 24 | 25 | pip install TrajectoryNet 26 | 27 | This code was tested with python 3.7 and 3.8. 28 | 29 | Example 30 | ------- 31 | 32 | .. image:: figures/eb_high_quality.png 33 | :alt: EB PHATE Scatterplot 34 | :height: 300 35 | 36 | .. image:: figures/EB-Trajectory.gif 37 | :alt: Trajectory of density over time 38 | :height: 300 39 | 40 | 41 | Basic Usage 42 | ----------- 43 | 44 | Run with 45 | 46 | .. code-block:: bash 47 | 48 | python -m TrajectoryNet.main --dataset SCURVE 49 | 50 | To run TrajectoryNet on the `S Curve` example in the paper. To use a 51 | custom dataset expose the coordinates and timepoint information according 52 | to the example jupyter notebooks in the `/notebooks/` folder. 53 | 54 | If you have an `AnnData `_ object then take a look at 55 | `notebooks/Example_Anndata_to_TrajectoryNet.ipynb 56 | `_, 57 | which shows how to load one of the example `scvelo `_ anndata objects into 58 | TrajectoryNet. Alternatively you can use the custom (compressed) format for 59 | TrajectoryNet as described below. 60 | 61 | For this format TrajectoryNet requires the following: 62 | 63 | 1. An embedding matrix titled `[embedding_name]` (Cells x Dimensions) 64 | 2. A sample labels array titled `sample_labels` (Cells) 65 | 66 | To run TrajectoryNet with a custom dataset use: 67 | 68 | .. code-block:: bash 69 | 70 | python -m TrajectoryNet.main --dataset [PATH_TO_NPZ_FILE] --embedding_name [EMBEDDING_NAME] 71 | python -m TrajectoryNet.eval --dataset [PATH_TO_NPZ_FILE] --embedding_name [EMBEDDING_NAME] 72 | 73 | 74 | See `notebooks/EB-Eval.ipynb` for an example on how to use TrajectoryNet on 75 | a PCA embedding to get trajectories in the gene space. 76 | 77 | 78 | References 79 | ---------- 80 | [1] Tong, A., Huang, J., Wolf, G., van Dijk, D., and Krishnaswamy, S. TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics. In International Conference on Machine Learning, 2020. `arxiv `_ `ICML `_ 81 | 82 | --- 83 | 84 | If you found this library useful, please consider citing:: 85 | 86 | @inproceedings{tong2020trajectorynet, 87 | title = {TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics}, 88 | shorttitle = {TrajectoryNet}, 89 | booktitle = {Proceedings of the 37th International Conference on Machine Learning}, 90 | author = {Tong, Alexander and Huang, Jessie and Wolf, Guy and {van Dijk}, David and Krishnaswamy, Smita}, 91 | year = {2020} 92 | } 93 | -------------------------------------------------------------------------------- /TrajectoryNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/__init__.py -------------------------------------------------------------------------------- /TrajectoryNet/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | 7 | from TrajectoryNet import dataset, eval_utils 8 | from TrajectoryNet.parse import parser 9 | from TrajectoryNet.lib.growth_net import GrowthNet 10 | from TrajectoryNet.lib.viz_scrna import trajectory_to_video, save_vectors 11 | from TrajectoryNet.lib.viz_scrna import ( 12 | save_trajectory_density, 13 | save_2d_trajectory, 14 | save_2d_trajectory_v2, 15 | ) 16 | 17 | from TrajectoryNet.train_misc import ( 18 | set_cnf_options, 19 | count_nfe, 20 | count_parameters, 21 | count_total_time, 22 | add_spectral_norm, 23 | spectral_norm_power_iteration, 24 | create_regularization_fns, 25 | get_regularization, 26 | append_regularization_to_log, 27 | build_model_tabular, 28 | ) 29 | 30 | 31 | def makedirs(dirname): 32 | if not os.path.exists(dirname): 33 | os.makedirs(dirname) 34 | 35 | 36 | def save_trajectory( 37 | prior_logdensity, 38 | prior_sampler, 39 | model, 40 | data_samples, 41 | savedir, 42 | ntimes=101, 43 | end_times=None, 44 | memory=0.01, 45 | device="cpu", 46 | ): 47 | model.eval() 48 | 49 | # Sample from prior 50 | z_samples = prior_sampler(1000, 2).to(device) 51 | 52 | # sample from a grid 53 | npts = 100 54 | side = np.linspace(-4, 4, npts) 55 | xx, yy = np.meshgrid(side, side) 56 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 57 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 58 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 59 | 60 | with torch.no_grad(): 61 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 62 | logp_samples = prior_logdensity(z_samples) 63 | logp_grid = prior_logdensity(z_grid) 64 | t = 0 65 | for cnf in model.chain: 66 | 67 | # Construct integration_list 68 | if end_times is None: 69 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 70 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 71 | for i, et in enumerate(end_times[1:]): 72 | integration_list.append( 73 | torch.linspace(end_times[i], et, ntimes).to(device) 74 | ) 75 | full_times = torch.cat(integration_list, 0) 76 | print(full_times.shape) 77 | 78 | # Integrate over evenly spaced samples 79 | z_traj, logpz = cnf( 80 | z_samples, 81 | logp_samples, 82 | integration_times=integration_list[0], 83 | reverse=True, 84 | ) 85 | full_traj = [(z_traj, logpz)] 86 | for int_times in integration_list[1:]: 87 | prev_z, prev_logp = full_traj[-1] 88 | z_traj, logpz = cnf( 89 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True 90 | ) 91 | full_traj.append((z_traj[1:], logpz[1:])) 92 | full_zip = list(zip(*full_traj)) 93 | z_traj = torch.cat(full_zip[0], 0) 94 | # z_logp = torch.cat(full_zip[1], 0) 95 | z_traj = z_traj.cpu().numpy() 96 | 97 | grid_z_traj, grid_logpz_traj = [], [] 98 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) 99 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)): 100 | _grid_z_traj, _grid_logpz_traj = cnf( 101 | z_grid[ii], 102 | logp_grid[ii], 103 | integration_times=integration_list[0], 104 | reverse=True, 105 | ) 106 | full_traj = [(_grid_z_traj, _grid_logpz_traj)] 107 | for int_times in integration_list[1:]: 108 | prev_z, prev_logp = full_traj[-1] 109 | _grid_z_traj, _grid_logpz_traj = cnf( 110 | prev_z[-1], 111 | prev_logp[-1], 112 | integration_times=int_times, 113 | reverse=True, 114 | ) 115 | full_traj.append((_grid_z_traj, _grid_logpz_traj)) 116 | full_zip = list(zip(*full_traj)) 117 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy() 118 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy() 119 | print(_grid_z_traj.shape) 120 | grid_z_traj.append(_grid_z_traj) 121 | grid_logpz_traj.append(_grid_logpz_traj) 122 | 123 | grid_z_traj = np.concatenate(grid_z_traj, axis=1) 124 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) 125 | 126 | plt.figure(figsize=(8, 8)) 127 | for _ in range(z_traj.shape[0]): 128 | 129 | plt.clf() 130 | 131 | # plot target potential function 132 | ax = plt.subplot(1, 1, 1, aspect="equal") 133 | 134 | """ 135 | ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200) 136 | ax.invert_yaxis() 137 | ax.get_xaxis().set_ticks([]) 138 | ax.get_yaxis().set_ticks([]) 139 | ax.set_title("Target", fontsize=32) 140 | 141 | """ 142 | # plot the density 143 | # ax = plt.subplot(2, 2, 2, aspect="equal") 144 | 145 | z, logqz = grid_z_traj[t], grid_logpz_traj[t] 146 | 147 | xx = z[:, 0].reshape(npts, npts) 148 | yy = z[:, 1].reshape(npts, npts) 149 | qz = np.exp(logqz).reshape(npts, npts) 150 | rgb = plt.cm.Spectral(t / z_traj.shape[0]) 151 | print(t, rgb) 152 | background_color = "white" 153 | cvals = [0, np.percentile(qz, 0.1)] 154 | colors = [ 155 | background_color, 156 | rgb, 157 | ] 158 | norm = plt.Normalize(min(cvals), max(cvals)) 159 | tuples = list(zip(map(norm, cvals), colors)) 160 | cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples) 161 | from matplotlib.colors import LogNorm 162 | 163 | plt.pcolormesh( 164 | xx, 165 | yy, 166 | qz, 167 | # norm=LogNorm(vmin=qz.min(), vmax=qz.max()), 168 | cmap=cmap, 169 | ) 170 | ax.set_xlim(-4, 4) 171 | ax.set_ylim(-4, 4) 172 | cmap = matplotlib.cm.get_cmap(None) 173 | ax.set_facecolor(background_color) 174 | ax.invert_yaxis() 175 | ax.get_xaxis().set_ticks([]) 176 | ax.get_yaxis().set_ticks([]) 177 | ax.set_title("Density", fontsize=32) 178 | 179 | """ 180 | # plot the samples 181 | ax = plt.subplot(2, 2, 3, aspect="equal") 182 | 183 | zk = z_traj[t] 184 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200) 185 | ax.invert_yaxis() 186 | ax.get_xaxis().set_ticks([]) 187 | ax.get_yaxis().set_ticks([]) 188 | ax.set_title("Samples", fontsize=32) 189 | 190 | # plot vector field 191 | ax = plt.subplot(2, 2, 4, aspect="equal") 192 | 193 | K = 13j 194 | y, x = np.mgrid[-4:4:K, -4:4:K] 195 | K = int(K.imag) 196 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) 197 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) 198 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0] 199 | dydt = -dydt.cpu().detach().numpy() 200 | dydt = dydt.reshape(K, K, 2) 201 | 202 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) 203 | ax.quiver( 204 | x, y, dydt[:, :, 0], -dydt[:, :, 1], 205 | # x, y, dydt[:, :, 0], dydt[:, :, 1], 206 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" 207 | ) 208 | ax.set_xlim(-4, 4) 209 | ax.set_ylim(4, -4) 210 | #ax.set_ylim(-4, 4) 211 | ax.axis("off") 212 | ax.set_title("Vector Field", fontsize=32) 213 | """ 214 | 215 | makedirs(savedir) 216 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) 217 | t += 1 218 | 219 | 220 | def get_trajectory_samples(device, model, data, n=2000): 221 | ntimes = 5 222 | model.eval() 223 | z_samples = data.base_sample()(n, 2).to(device) 224 | 225 | integration_list = [torch.linspace(0, args.int_tps[0], ntimes).to(device)] 226 | for i, et in enumerate(args.int_tps[1:]): 227 | integration_list.append(torch.linspace(args.int_tps[i], et, ntimes).to(device)) 228 | print(integration_list) 229 | 230 | 231 | def plot_output(device, args, model, data): 232 | # logger.info('Plotting trajectory to {}'.format(save_traj_dir)) 233 | data_samples = data.get_data()[data.sample_index(2000, 0)] 234 | start_points = data.base_sample()(1000, 2) 235 | # start_points = data.get_data()[idx] 236 | # start_points = torch.from_numpy(start_points).type(torch.float32) 237 | """ 238 | save_vectors( 239 | data.base_density(), 240 | model, 241 | start_points, 242 | data.get_data()[data.get_times() == 1], 243 | data.get_times()[data.get_times() == 1], 244 | args.save, 245 | device=device, 246 | end_times=args.int_tps, 247 | ntimes=100, 248 | memory=1.0, 249 | lim=1.5, 250 | ) 251 | save_traj_dir = os.path.join(args.save, "trajectory_2d") 252 | save_2d_trajectory_v2( 253 | data.base_density(), 254 | data.base_sample(), 255 | model, 256 | data_samples, 257 | save_traj_dir, 258 | device=device, 259 | end_times=args.int_tps, 260 | ntimes=3, 261 | memory=1.0, 262 | limit=2.5, 263 | ) 264 | """ 265 | 266 | density_dir = os.path.join(args.save, "density2") 267 | save_trajectory_density( 268 | data.base_density(), 269 | model, 270 | data_samples, 271 | density_dir, 272 | device=device, 273 | end_times=args.int_tps, 274 | ntimes=100, 275 | memory=1, 276 | ) 277 | trajectory_to_video(density_dir) 278 | 279 | 280 | def integrate_backwards( 281 | end_samples, model, savedir, ntimes=100, memory=0.1, device="cpu" 282 | ): 283 | """Integrate some samples backwards and save the results.""" 284 | with torch.no_grad(): 285 | z = torch.from_numpy(end_samples).type(torch.float32).to(device) 286 | zero = torch.zeros(z.shape[0], 1).to(z) 287 | cnf = model.chain[0] 288 | 289 | zs = [z] 290 | deltas = [] 291 | int_tps = np.linspace(args.int_tps[0], args.int_tps[-1], ntimes) 292 | for i, itp in enumerate(int_tps[::-1][:-1]): 293 | # tp counts down from last 294 | timescale = int_tps[1] - int_tps[0] 295 | integration_times = torch.tensor([itp - timescale, itp]) 296 | # integration_times = torch.tensor([np.linspace(itp - args.time_scale, itp, ntimes)]) 297 | integration_times = integration_times.type(torch.float32).to(device) 298 | 299 | # transform to previous timepoint 300 | z, delta_logp = cnf(zs[-1], zero, integration_times=integration_times) 301 | zs.append(z) 302 | deltas.append(delta_logp) 303 | zs = torch.stack(zs, 0) 304 | zs = zs.cpu().numpy() 305 | np.save(os.path.join(savedir, "backward_trajectories.npy"), zs) 306 | 307 | 308 | def main(args): 309 | device = torch.device( 310 | "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" 311 | ) 312 | if args.use_cpu: 313 | device = torch.device("cpu") 314 | 315 | data = dataset.SCData.factory(args.dataset, args) 316 | 317 | args.timepoints = data.get_unique_times() 318 | 319 | # Use maximum timepoint to establish integration_times 320 | # as some timepoints may be left out for validation etc. 321 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale 322 | 323 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 324 | model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to( 325 | device 326 | ) 327 | if args.use_growth: 328 | growth_model_path = data.get_growth_net_path() 329 | # growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt" 330 | growth_model = torch.load(growth_model_path, map_location=device) 331 | if args.spectral_norm: 332 | add_spectral_norm(model) 333 | set_cnf_options(args, model) 334 | 335 | state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) 336 | model.load_state_dict(state_dict["state_dict"]) 337 | 338 | # plot_output(device, args, model, data) 339 | # exit() 340 | # get_trajectory_samples(device, model, data) 341 | 342 | args.data = data 343 | args.timepoints = args.data.get_unique_times() 344 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale 345 | 346 | print("integrating backwards") 347 | # end_time_data = data.data_dict[args.embedding_name] 348 | end_time_data = data.get_data()[ 349 | args.data.get_times() == np.max(args.data.get_times()) 350 | ] 351 | # np.random.permutation(end_time_data) 352 | # rand_idx = np.random.randint(end_time_data.shape[0], size=5000) 353 | # end_time_data = end_time_data[rand_idx,:] 354 | integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device) 355 | exit() 356 | losses_list = [] 357 | # for factor in np.linspace(0.05, 0.95, 19): 358 | # for factor in np.linspace(0.91, 0.99, 9): 359 | if args.dataset == "CHAFFER": # Do timepoint adjustment 360 | print("adjusting_timepoints") 361 | lt = args.leaveout_timepoint 362 | if lt == 1: 363 | factor = 0.6799872494335812 364 | factor = 0.95 365 | elif lt == 2: 366 | factor = 0.2905983814032348 367 | factor = 0.01 368 | else: 369 | raise RuntimeError("Unknown timepoint %d" % args.leaveout_timepoint) 370 | args.int_tps[lt] = (1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[ 371 | lt + 1 372 | ] 373 | losses = eval_utils.evaluate_kantorovich_v2(device, args, model) 374 | losses_list.append(losses) 375 | print(np.array(losses_list)) 376 | np.save(os.path.join(args.save, "emd_list"), np.array(losses_list)) 377 | # zs = np.load(os.path.join(args.save, 'backward_trajectories')) 378 | # losses = eval_utils.evaluate_mse(device, args, model) 379 | # losses = eval_utils.evaluate_kantorovich(device, args, model) 380 | # print(losses) 381 | # eval_utils.generate_samples(device, args, model, growth_model, timepoint=args.timepoints[-1]) 382 | # eval_utils.calculate_path_length(device, args, model, data, args.int_tps[-1]) 383 | 384 | 385 | if __name__ == "__main__": 386 | args = parser.parse_args() 387 | main(args) 388 | -------------------------------------------------------------------------------- /TrajectoryNet/eval_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import os 4 | import torch 5 | 6 | from .optimal_transport.emd import earth_mover_distance 7 | 8 | 9 | def generate_samples(device, args, model, growth_model, n=10000, timepoint=None): 10 | """generates samples using model and base density 11 | 12 | This is useful for measuring the wasserstein distance between the 13 | predicted distribution and the true distribution for evaluation 14 | purposes against other types of models. We should use 15 | negative log likelihood if possible as it is deterministic and 16 | more discriminative for this model type. 17 | 18 | TODO: Is this biased??? 19 | """ 20 | z_samples = args.data.base_sample()(n, *args.data.get_shape()).to(device) 21 | # Forward pass through the model / growth model 22 | with torch.no_grad(): 23 | int_list = [ 24 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device) 25 | for it in args.int_tps[: timepoint + 1] 26 | ] 27 | 28 | logpz = args.data.base_density()(z_samples) 29 | z = z_samples 30 | for it in int_list: 31 | z, logpz = model(z, logpz, integration_times=it, reverse=True) 32 | z = z.cpu().numpy() 33 | np.save(os.path.join(args.save, "samples_%0.2f.npy" % timepoint), z) 34 | logpz = logpz.cpu().numpy() 35 | plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5) 36 | original_data = args.data.get_data()[args.data.get_times() == timepoint] 37 | idx = np.random.randint(original_data.shape[0], size=n) 38 | samples = original_data[idx, :] 39 | plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5) 40 | plt.savefig(os.path.join(args.save, "samples%d.png" % timepoint)) 41 | plt.close() 42 | 43 | pz = np.exp(logpz) 44 | pz = pz / np.sum(pz) 45 | print(pz) 46 | 47 | print( 48 | earth_mover_distance( 49 | original_data, samples + np.random.randn(*samples.shape) * 0.1 50 | ) 51 | ) 52 | 53 | print(earth_mover_distance(z, original_data)) 54 | print(earth_mover_distance(z, samples)) 55 | # print(earth_mover_distance(z, original_data, weights1=pz.flatten())) 56 | # print( 57 | # earth_mover_distance( 58 | # args.data.get_data()[args.data.get_times() == (timepoint - 1)], 59 | # original_data, 60 | # ) 61 | # ) 62 | 63 | if args.use_growth and growth_model is not None: 64 | raise NotImplementedError( 65 | "generating samples with growth model is not yet implemented" 66 | ) 67 | 68 | 69 | def calculate_path_length(device, args, model, data, end_time, n_pts=10000): 70 | """Calculates the total length of the path from time 0 to timepoint""" 71 | # z_samples = torch.tensor(data.get_data()).type(torch.float32).to(device) 72 | z_samples = data.base_sample()(n_pts, *data.get_shape()).to(device) 73 | model.eval() 74 | n = 1001 75 | with torch.no_grad(): 76 | integration_times = ( 77 | torch.tensor(np.linspace(0, end_time, n)).type(torch.float32).to(device) 78 | ) 79 | # z, _ = model(z_samples, torch.zeros_like(z_samples), integration_times=integration_times, reverse=False) 80 | z, _ = model( 81 | z_samples, 82 | torch.zeros_like(z_samples), 83 | integration_times=integration_times, 84 | reverse=True, 85 | ) 86 | z = z.cpu().numpy() 87 | z_diff = np.diff(z, axis=0) 88 | z_lengths = np.sum(np.linalg.norm(z_diff, axis=-1), axis=0) 89 | total_length = np.mean(z_lengths) 90 | import ot as pot 91 | from scipy.spatial.distance import cdist 92 | 93 | emd = pot.emd2( 94 | np.ones(n_pts) / n_pts, 95 | np.ones(n_pts) / n_pts, 96 | cdist(z[-1, :, :], data.get_data()), 97 | ) 98 | print(total_length, emd) 99 | plt.scatter(z[-1, :, 0], z[-1, :, 1]) 100 | plt.savefig("test.png") 101 | plt.close() 102 | 103 | 104 | def evaluate_mse(device, args, model, growth_model=None): 105 | if args.use_growth or growth_model is not None: 106 | print("WARNING: Ignoring growth model and computing anyway") 107 | 108 | paths = args.data.get_paths() 109 | 110 | z_samples = torch.tensor(paths[:, 0, :]).type(torch.float32).to(device) 111 | # Forward pass through the model / growth model 112 | with torch.no_grad(): 113 | int_list = [ 114 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device) 115 | for it in args.int_tps 116 | ] 117 | 118 | logpz = args.data.base_density()(z_samples) 119 | z = z_samples 120 | zs = [] 121 | for it in int_list: 122 | z, _ = model(z, logpz, integration_times=it, reverse=True) 123 | zs.append(z.cpu().numpy()) 124 | zs = np.stack(zs) 125 | np.save(os.path.join(args.save, "path_samples.npy"), zs) 126 | 127 | # logpz = logpz.cpu().numpy() 128 | # plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5) 129 | mses = [] 130 | print(zs.shape, paths[:, 1, :].shape) 131 | for tpi in range(len(args.timepoints)): 132 | mses.append(np.mean((paths[:, tpi + 1, :] - zs[tpi]) ** 2, axis=(-2, -1))) 133 | mses = np.array(mses) 134 | print(mses.shape) 135 | np.save(os.path.join(args.save, "mses.npy"), mses) 136 | return mses 137 | 138 | 139 | def evaluate_kantorovich_v2(device, args, model, growth_model=None): 140 | """Eval the model via kantorovich distance on leftout timepoint 141 | 142 | v2 computes samples from subsequent timepoint instead of base distribution. 143 | this is arguably a fairer comparison to other methods such as WOT which are 144 | not model based this should accumulate much less numerical error in the 145 | integration procedure. However fixes to the number of samples to the number in the 146 | previous timepoint. 147 | 148 | If we have a growth model we should use this to modify the weighting of the 149 | points over time. 150 | """ 151 | if args.use_growth or growth_model is not None: 152 | # raise NotImplementedError( 153 | # "generating samples with growth model is not yet implemented" 154 | # ) 155 | print("WARNING: Ignoring growth model and computing anyway") 156 | 157 | # Backward pass through the model / growth model 158 | with torch.no_grad(): 159 | int_times = torch.tensor( 160 | [ 161 | args.int_tps[args.leaveout_timepoint], 162 | args.int_tps[args.leaveout_timepoint + 1], 163 | ] 164 | ) 165 | int_times = int_times.type(torch.float32).to(device) 166 | next_z = args.data.get_data()[ 167 | args.data.get_times() == args.leaveout_timepoint + 1 168 | ] 169 | next_z = torch.from_numpy(next_z).type(torch.float32).to(device) 170 | prev_z = args.data.get_data()[ 171 | args.data.get_times() == args.leaveout_timepoint - 1 172 | ] 173 | prev_z = torch.from_numpy(prev_z).type(torch.float32).to(device) 174 | zero = torch.zeros(next_z.shape[0], 1).to(device) 175 | z_backward, _ = model.chain[0](next_z, zero, integration_times=int_times) 176 | z_backward = z_backward.cpu().numpy() 177 | int_times = torch.tensor( 178 | [ 179 | args.int_tps[args.leaveout_timepoint - 1], 180 | args.int_tps[args.leaveout_timepoint], 181 | ] 182 | ) 183 | zero = torch.zeros(prev_z.shape[0], 1).to(device) 184 | z_forward, _ = model.chain[0]( 185 | prev_z, zero, integration_times=int_times, reverse=True 186 | ) 187 | z_forward = z_forward.cpu().numpy() 188 | 189 | emds = [] 190 | for tpi in [args.leaveout_timepoint]: 191 | original_data = args.data.get_data()[ 192 | args.data.get_times() == args.timepoints[tpi] 193 | ] 194 | emds.append(earth_mover_distance(z_backward, original_data)) 195 | emds.append(earth_mover_distance(z_forward, original_data)) 196 | 197 | emds = np.array(emds) 198 | np.save(os.path.join(args.save, "emds_v2.npy"), emds) 199 | return emds 200 | 201 | 202 | def evaluate_kantorovich(device, args, model, growth_model=None, n=10000): 203 | """Eval the model via kantorovich distance on all timepoints 204 | 205 | compute samples forward from the starting parametric distribution keeping track 206 | of growth rate to scale the final distribution. 207 | 208 | The growth model is a single model of time independent cell growth / 209 | death rate defined as a variation from uniform. 210 | 211 | If we have a growth model we should use this to modify the weighting of the 212 | points over time. 213 | """ 214 | if args.use_growth or growth_model is not None: 215 | # raise NotImplementedError( 216 | # "generating samples with growth model is not yet implemented" 217 | # ) 218 | print("WARNING: Ignoring growth model and computing anyway") 219 | 220 | z_samples = args.data.base_sample()(n, *args.data.get_shape()).to(device) 221 | # Forward pass through the model / growth model 222 | with torch.no_grad(): 223 | int_list = [] 224 | for i, it in enumerate(args.int_tps): 225 | if i == 0: 226 | prev = 0.0 227 | else: 228 | prev = args.int_tps[i - 1] 229 | int_list.append(torch.tensor([prev, it]).type(torch.float32).to(device)) 230 | 231 | # int_list = [ 232 | # torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device) 233 | # for it in args.int_tps 234 | # ] 235 | print(args.int_tps) 236 | 237 | logpz = args.data.base_density()(z_samples) 238 | z = z_samples 239 | zs = [] 240 | growthrates = [torch.ones(z_samples.shape[0], 1).to(device)] 241 | for it, tp in zip(int_list, args.timepoints): 242 | z, _ = model(z, logpz, integration_times=it, reverse=True) 243 | zs.append(z.cpu().numpy()) 244 | if args.use_growth: 245 | time_state = tp * torch.ones(z.shape[0], 1).to(device) 246 | full_state = torch.cat([z, time_state], 1) 247 | # Multiply growth rates together to get total mass along path 248 | growthrates.append( 249 | torch.clamp(growth_model(full_state), 1e-4, 1e4) * growthrates[-1] 250 | ) 251 | zs = np.stack(zs) 252 | if args.use_growth: 253 | growthrates = growthrates[1:] 254 | growthrates = torch.stack(growthrates) 255 | growthrates = growthrates.cpu().numpy() 256 | np.save(os.path.join(args.save, "sample_weights.npy"), growthrates) 257 | np.save(os.path.join(args.save, "samples.npy"), zs) 258 | 259 | # logpz = logpz.cpu().numpy() 260 | # plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5) 261 | emds = [] 262 | for tpi in range(len(args.timepoints)): 263 | original_data = args.data.get_data()[ 264 | args.data.get_times() == args.timepoints[tpi] 265 | ] 266 | if args.use_growth: 267 | emds.append( 268 | earth_mover_distance( 269 | zs[tpi], original_data, weights1=growthrates[tpi].flatten() 270 | ) 271 | ) 272 | else: 273 | emds.append(earth_mover_distance(zs[tpi], original_data)) 274 | 275 | # Add validation point kantorovich distance evaluation 276 | if args.data.has_validation_samples(): 277 | for tpi in np.unique(args.data.val_labels): 278 | original_data = args.data.val_data[ 279 | args.data.val_labels == args.timepoints[tpi] 280 | ] 281 | if args.use_growth: 282 | emds.append( 283 | earth_mover_distance( 284 | zs[tpi], original_data, weights1=growthrates[tpi].flatten() 285 | ) 286 | ) 287 | else: 288 | emds.append(earth_mover_distance(zs[tpi], original_data)) 289 | 290 | emds = np.array(emds) 291 | print(emds) 292 | np.save(os.path.join(args.save, "emds.npy"), emds) 293 | return emds 294 | 295 | 296 | def evaluate(device, args, model, growth_model=None): 297 | """Eval the model via negative log likelihood on all timepoints 298 | 299 | Compute loss by integrating backwards from the last time step 300 | At each time step integrate back one time step, and concatenate that 301 | to samples of the empirical distribution at that previous timestep 302 | repeating over and over to calculate the likelihood of samples in 303 | later timepoints iteratively, making sure that the ODE is evaluated 304 | at every time step to calculate those later points. 305 | 306 | The growth model is a single model of time independent cell growth / 307 | death rate defined as a variation from uniform. 308 | """ 309 | use_growth = args.use_growth and growth_model is not None 310 | 311 | # Backward pass accumulating losses, previous state and deltas 312 | deltas = [] 313 | zs = [] 314 | z = None 315 | for i, (itp, tp) in enumerate(zip(args.int_tps[::-1], args.timepoints[::-1])): 316 | # tp counts down from last 317 | integration_times = torch.tensor([itp - args.time_scale, itp]) 318 | integration_times = integration_times.type(torch.float32).to(device) 319 | 320 | x = args.data.get_data()[args.data.get_times() == tp] 321 | x = torch.from_numpy(x).type(torch.float32).to(device) 322 | 323 | if i > 0: 324 | x = torch.cat((z, x)) 325 | zs.append(z) 326 | zero = torch.zeros(x.shape[0], 1).to(x) 327 | 328 | # transform to previous timepoint 329 | z, delta_logp = model(x, zero, integration_times=integration_times) 330 | deltas.append(delta_logp) 331 | 332 | logpz = args.data.base_density()(z) 333 | 334 | # build growth rates 335 | if use_growth: 336 | growthrates = [torch.ones_like(logpz)] 337 | for z_state, tp in zip(zs[::-1], args.timepoints[::-1][1:]): 338 | # Full state includes time parameter to growth_model 339 | time_state = tp * torch.ones(z_state.shape[0], 1).to(z_state) 340 | full_state = torch.cat([z_state, time_state], 1) 341 | growthrates.append(growth_model(full_state)) 342 | 343 | # Accumulate losses 344 | losses = [] 345 | logps = [logpz] 346 | for i, (delta_logp, tp) in enumerate(zip(deltas[::-1], args.timepoints)): 347 | n_cells_in_tp = np.sum(args.data.get_times() == tp) 348 | logpx = logps[-1] - delta_logp 349 | if use_growth: 350 | logpx += torch.log(growthrates[i]) 351 | logps.append(logpx[:-n_cells_in_tp]) 352 | losses.append(-torch.sum(logpx[-n_cells_in_tp:])) 353 | losses = torch.stack(losses).cpu().numpy() 354 | np.save(os.path.join(args.save, "nll.npy"), losses) 355 | return losses 356 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/growth_net.py: -------------------------------------------------------------------------------- 1 | """ Implements a very simple growth network 2 | 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GrowthNet(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.fc1 = nn.Linear(3, 64) 12 | self.fc2 = nn.Linear(64, 64) 13 | self.fc3 = nn.Linear(64, 1) 14 | 15 | def forward(self, x): 16 | x = F.leaky_relu(self.fc1(x)) 17 | x = F.leaky_relu(self.fc2(x)) 18 | x = self.fc3(x) 19 | return x 20 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .elemwise import * 2 | from .container import * 3 | from .cnf import * 4 | from .odefunc import * 5 | from .squeeze import * 6 | from .normalization import * 7 | from . import diffeq_layers 8 | from .coupling import * 9 | from .glow import * 10 | from .norm_flows import * 11 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/cnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchdiffeq import odeint_adjoint as odeint 5 | 6 | from .wrappers.cnf_regularization import RegularizedODEfunc 7 | 8 | __all__ = ["CNF"] 9 | 10 | 11 | class CNF(nn.Module): 12 | def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5): 13 | super(CNF, self).__init__() 14 | if train_T: 15 | self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) 16 | else: 17 | self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T))) 18 | 19 | nreg = 0 20 | if regularization_fns is not None: 21 | odefunc = RegularizedODEfunc(odefunc, regularization_fns) 22 | nreg = len(regularization_fns) 23 | self.odefunc = odefunc 24 | self.nreg = nreg 25 | self.regularization_states = None 26 | self.solver = solver 27 | self.atol = atol 28 | self.rtol = rtol 29 | self.test_solver = solver 30 | self.test_atol = atol 31 | self.test_rtol = rtol 32 | self.solver_options = {} 33 | 34 | def forward(self, z, logpz=None, integration_times=None, reverse=False): 35 | 36 | if logpz is None: 37 | _logpz = torch.zeros(z.shape[0], 1).to(z) 38 | else: 39 | _logpz = logpz 40 | 41 | if integration_times is None: 42 | integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z) 43 | if reverse: 44 | integration_times = _flip(integration_times, 0) 45 | 46 | # Refresh the odefunc statistics. 47 | self.odefunc.before_odeint() 48 | 49 | # Add regularization states. 50 | reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg)) 51 | 52 | if self.training: 53 | state_t = odeint( 54 | self.odefunc, 55 | (z, _logpz) + reg_states, 56 | integration_times.to(z), 57 | atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol, 58 | rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol, 59 | method=self.solver, 60 | options=self.solver_options, 61 | ) 62 | else: 63 | state_t = odeint( 64 | self.odefunc, 65 | (z, _logpz), 66 | integration_times.to(z), 67 | atol=self.test_atol, 68 | rtol=self.test_rtol, 69 | method=self.test_solver, 70 | ) 71 | 72 | if len(integration_times) == 2: 73 | state_t = tuple(s[1] for s in state_t) 74 | 75 | z_t, logpz_t = state_t[:2] 76 | self.regularization_states = state_t[2:] 77 | 78 | if logpz is not None: 79 | return z_t, logpz_t 80 | else: 81 | return z_t 82 | 83 | def get_regularization_states(self): 84 | reg_states = self.regularization_states 85 | self.regularization_states = None 86 | return reg_states 87 | 88 | def num_evals(self): 89 | return self.odefunc._num_evals.item() 90 | 91 | 92 | def _flip(x, dim): 93 | indices = [slice(None)] * x.dim() 94 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) 95 | return x[tuple(indices)] 96 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequentialFlow(nn.Module): 5 | """A generalized nn.Sequential container for normalizing flows. 6 | """ 7 | 8 | def __init__(self, layersList): 9 | super(SequentialFlow, self).__init__() 10 | self.chain = nn.ModuleList(layersList) 11 | 12 | def forward(self, x, logpx=None, integration_times=None, reverse=False, inds=None): 13 | if inds is None: 14 | if reverse: 15 | inds = range(len(self.chain) - 1, -1, -1) 16 | else: 17 | inds = range(len(self.chain)) 18 | 19 | if logpx is None: 20 | for i in inds: 21 | x = self.chain[i](x, reverse=reverse) 22 | return x 23 | else: 24 | for i in inds: 25 | x, logpx = self.chain[i](x, logpx, integration_times=integration_times, reverse=reverse) 26 | return x, logpx 27 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['CouplingLayer', 'MaskedCouplingLayer'] 5 | 6 | 7 | class CouplingLayer(nn.Module): 8 | """Used in 2D experiments.""" 9 | 10 | def __init__(self, d, intermediate_dim=64, swap=False): 11 | nn.Module.__init__(self) 12 | self.d = d - (d // 2) 13 | self.swap = swap 14 | self.net_s_t = nn.Sequential( 15 | nn.Linear(self.d, intermediate_dim), 16 | nn.ReLU(inplace=True), 17 | nn.Linear(intermediate_dim, intermediate_dim), 18 | nn.ReLU(inplace=True), 19 | nn.Linear(intermediate_dim, (d - self.d) * 2), 20 | ) 21 | 22 | def forward(self, x, logpx=None, reverse=False): 23 | 24 | if self.swap: 25 | x = torch.cat([x[:, self.d:], x[:, :self.d]], 1) 26 | 27 | in_dim = self.d 28 | out_dim = x.shape[1] - self.d 29 | 30 | s_t = self.net_s_t(x[:, :in_dim]) 31 | scale = torch.sigmoid(s_t[:, :out_dim] + 2.) 32 | shift = s_t[:, out_dim:] 33 | 34 | logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True) 35 | 36 | if not reverse: 37 | y1 = x[:, self.d:] * scale + shift 38 | delta_logp = -logdetjac 39 | else: 40 | y1 = (x[:, self.d:] - shift) / scale 41 | delta_logp = logdetjac 42 | 43 | y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1) 44 | 45 | if logpx is None: 46 | return y 47 | else: 48 | return y, logpx + delta_logp 49 | 50 | 51 | class MaskedCouplingLayer(nn.Module): 52 | """Used in the tabular experiments.""" 53 | 54 | def __init__(self, d, hidden_dims, mask_type='alternate', swap=False): 55 | nn.Module.__init__(self) 56 | self.d = d 57 | self.register_buffer('mask', sample_mask(d, mask_type, swap).view(1, d)) 58 | self.net_scale = build_net(d, hidden_dims, activation="tanh") 59 | self.net_shift = build_net(d, hidden_dims, activation="relu") 60 | 61 | def forward(self, x, logpx=None, reverse=False): 62 | 63 | scale = torch.exp(self.net_scale(x * self.mask)) 64 | shift = self.net_shift(x * self.mask) 65 | 66 | masked_scale = scale * (1 - self.mask) + torch.ones_like(scale) * self.mask 67 | masked_shift = shift * (1 - self.mask) 68 | 69 | logdetjac = torch.sum(torch.log(masked_scale).view(scale.shape[0], -1), 1, keepdim=True) 70 | 71 | if not reverse: 72 | y = x * masked_scale + masked_shift 73 | delta_logp = -logdetjac 74 | else: 75 | y = (x - masked_shift) / masked_scale 76 | delta_logp = logdetjac 77 | 78 | if logpx is None: 79 | return y 80 | else: 81 | return y, logpx + delta_logp 82 | 83 | 84 | def sample_mask(dim, mask_type, swap): 85 | if mask_type == 'alternate': 86 | # Index-based masking in MAF paper. 87 | mask = torch.zeros(dim) 88 | mask[::2] = 1 89 | if swap: 90 | mask = 1 - mask 91 | return mask 92 | elif mask_type == 'channel': 93 | # Masking type used in Real NVP paper. 94 | mask = torch.zeros(dim) 95 | mask[:dim // 2] = 1 96 | if swap: 97 | mask = 1 - mask 98 | return mask 99 | else: 100 | raise ValueError('Unknown mask_type {}'.format(mask_type)) 101 | 102 | 103 | def build_net(input_dim, hidden_dims, activation="relu"): 104 | dims = (input_dim,) + tuple(hidden_dims) + (input_dim,) 105 | activation_modules = {"relu": nn.ReLU(inplace=True), "tanh": nn.Tanh()} 106 | 107 | chain = [] 108 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 109 | chain.append(nn.Linear(in_dim, out_dim)) 110 | if i < len(hidden_dims): 111 | chain.append(activation_modules[activation]) 112 | return nn.Sequential(*chain) 113 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/diffeq_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .container import * 2 | from .resnet import * 3 | from .basic import * 4 | from .wrappers import * 5 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/diffeq_layers/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Linear') != -1 or classname.find('Conv') != -1: 9 | nn.init.constant_(m.weight, 0) 10 | nn.init.normal_(m.bias, 0, 0.01) 11 | 12 | 13 | class HyperLinear(nn.Module): 14 | def __init__(self, dim_in, dim_out, hypernet_dim=8, n_hidden=1, activation=nn.Tanh): 15 | super(HyperLinear, self).__init__() 16 | self.dim_in = dim_in 17 | self.dim_out = dim_out 18 | self.params_dim = self.dim_in * self.dim_out + self.dim_out 19 | 20 | layers = [] 21 | dims = [1] + [hypernet_dim] * n_hidden + [self.params_dim] 22 | for i in range(1, len(dims)): 23 | layers.append(nn.Linear(dims[i - 1], dims[i])) 24 | if i < len(dims) - 1: 25 | layers.append(activation()) 26 | self._hypernet = nn.Sequential(*layers) 27 | self._hypernet.apply(weights_init) 28 | 29 | def forward(self, t, x): 30 | params = self._hypernet(t.view(1, 1)).view(-1) 31 | b = params[:self.dim_out].view(self.dim_out) 32 | w = params[self.dim_out:].view(self.dim_out, self.dim_in) 33 | return F.linear(x, w, b) 34 | 35 | 36 | class IgnoreLinear(nn.Module): 37 | def __init__(self, dim_in, dim_out): 38 | super(IgnoreLinear, self).__init__() 39 | self._layer = nn.Linear(dim_in, dim_out) 40 | 41 | def forward(self, t, x): 42 | return self._layer(x) 43 | 44 | 45 | class ConcatLinear(nn.Module): 46 | def __init__(self, dim_in, dim_out): 47 | super(ConcatLinear, self).__init__() 48 | self._layer = nn.Linear(dim_in + 1, dim_out) 49 | 50 | def forward(self, t, x): 51 | tt = torch.ones_like(x[:, :1]) * t 52 | ttx = torch.cat([tt, x], 1) 53 | return self._layer(ttx) 54 | 55 | 56 | class ConcatLinear_v2(nn.Module): 57 | def __init__(self, dim_in, dim_out): 58 | super(ConcatLinear, self).__init__() 59 | self._layer = nn.Linear(dim_in, dim_out) 60 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 61 | 62 | def forward(self, t, x): 63 | return self._layer(x) + self._hyper_bias(t.view(1, 1)) 64 | 65 | 66 | class SquashLinear(nn.Module): 67 | def __init__(self, dim_in, dim_out): 68 | super(SquashLinear, self).__init__() 69 | self._layer = nn.Linear(dim_in, dim_out) 70 | self._hyper = nn.Linear(1, dim_out) 71 | 72 | def forward(self, t, x): 73 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1))) 74 | 75 | 76 | class ConcatSquashLinear(nn.Module): 77 | def __init__(self, dim_in, dim_out): 78 | super(ConcatSquashLinear, self).__init__() 79 | self._layer = nn.Linear(dim_in, dim_out) 80 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 81 | self._hyper_gate = nn.Linear(1, dim_out) 82 | 83 | def forward(self, t, x): 84 | return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))) \ 85 | + self._hyper_bias(t.view(1, 1)) 86 | 87 | 88 | class HyperConv2d(nn.Module): 89 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 90 | super(HyperConv2d, self).__init__() 91 | assert dim_in % groups == 0 and dim_out % groups == 0, "dim_in and dim_out must both be divisible by groups." 92 | self.dim_in = dim_in 93 | self.dim_out = dim_out 94 | self.ksize = ksize 95 | self.stride = stride 96 | self.padding = padding 97 | self.dilation = dilation 98 | self.groups = groups 99 | self.bias = bias 100 | self.transpose = transpose 101 | 102 | self.params_dim = int(dim_in * dim_out * ksize * ksize / groups) 103 | if self.bias: 104 | self.params_dim += dim_out 105 | self._hypernet = nn.Linear(1, self.params_dim) 106 | self.conv_fn = F.conv_transpose2d if transpose else F.conv2d 107 | 108 | self._hypernet.apply(weights_init) 109 | 110 | def forward(self, t, x): 111 | params = self._hypernet(t.view(1, 1)).view(-1) 112 | weight_size = int(self.dim_in * self.dim_out * self.ksize * self.ksize / self.groups) 113 | if self.transpose: 114 | weight = params[:weight_size].view(self.dim_in, self.dim_out // self.groups, self.ksize, self.ksize) 115 | else: 116 | weight = params[:weight_size].view(self.dim_out, self.dim_in // self.groups, self.ksize, self.ksize) 117 | bias = params[:self.dim_out].view(self.dim_out) if self.bias else None 118 | return self.conv_fn( 119 | x, weight=weight, bias=bias, stride=self.stride, padding=self.padding, groups=self.groups, 120 | dilation=self.dilation 121 | ) 122 | 123 | 124 | class IgnoreConv2d(nn.Module): 125 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 126 | super(IgnoreConv2d, self).__init__() 127 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 128 | self._layer = module( 129 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 130 | bias=bias 131 | ) 132 | 133 | def forward(self, t, x): 134 | return self._layer(x) 135 | 136 | 137 | class SquashConv2d(nn.Module): 138 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 139 | super(SquashConv2d, self).__init__() 140 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 141 | self._layer = module( 142 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 143 | bias=bias 144 | ) 145 | self._hyper = nn.Linear(1, dim_out) 146 | 147 | def forward(self, t, x): 148 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1))).view(1, -1, 1, 1) 149 | 150 | 151 | class ConcatConv2d(nn.Module): 152 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 153 | super(ConcatConv2d, self).__init__() 154 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 155 | self._layer = module( 156 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 157 | bias=bias 158 | ) 159 | 160 | def forward(self, t, x): 161 | tt = torch.ones_like(x[:, :1, :, :]) * t 162 | ttx = torch.cat([tt, x], 1) 163 | return self._layer(ttx) 164 | 165 | 166 | class ConcatConv2d_v2(nn.Module): 167 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 168 | super(ConcatConv2d, self).__init__() 169 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 170 | self._layer = module( 171 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 172 | bias=bias 173 | ) 174 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 175 | 176 | def forward(self, t, x): 177 | return self._layer(x) + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1) 178 | 179 | 180 | class ConcatSquashConv2d(nn.Module): 181 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 182 | super(ConcatSquashConv2d, self).__init__() 183 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 184 | self._layer = module( 185 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 186 | bias=bias 187 | ) 188 | self._hyper_gate = nn.Linear(1, dim_out) 189 | self._hyper_bias = nn.Linear(1, dim_out, bias=False) 190 | 191 | def forward(self, t, x): 192 | return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))).view(1, -1, 1, 1) \ 193 | + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1) 194 | 195 | 196 | class ConcatCoordConv2d(nn.Module): 197 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 198 | super(ConcatCoordConv2d, self).__init__() 199 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 200 | self._layer = module( 201 | dim_in + 3, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 202 | bias=bias 203 | ) 204 | 205 | def forward(self, t, x): 206 | b, c, h, w = x.shape 207 | hh = torch.arange(h).to(x).view(1, 1, h, 1).expand(b, 1, h, w) 208 | ww = torch.arange(w).to(x).view(1, 1, 1, w).expand(b, 1, h, w) 209 | tt = t.to(x).view(1, 1, 1, 1).expand(b, 1, h, w) 210 | x_aug = torch.cat([x, tt, hh, ww], 1) 211 | return self._layer(x_aug) 212 | 213 | 214 | class GatedLinear(nn.Module): 215 | def __init__(self, in_features, out_features): 216 | super(GatedLinear, self).__init__() 217 | self.layer_f = nn.Linear(in_features, out_features) 218 | self.layer_g = nn.Linear(in_features, out_features) 219 | 220 | def forward(self, x): 221 | f = self.layer_f(x) 222 | g = torch.sigmoid(self.layer_g(x)) 223 | return f * g 224 | 225 | 226 | class GatedConv(nn.Module): 227 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1): 228 | super(GatedConv, self).__init__() 229 | self.layer_f = nn.Conv2d( 230 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=1, groups=groups 231 | ) 232 | self.layer_g = nn.Conv2d( 233 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=1, groups=groups 234 | ) 235 | 236 | def forward(self, x): 237 | f = self.layer_f(x) 238 | g = torch.sigmoid(self.layer_g(x)) 239 | return f * g 240 | 241 | 242 | class GatedConvTranspose(nn.Module): 243 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1): 244 | super(GatedConvTranspose, self).__init__() 245 | self.layer_f = nn.ConvTranspose2d( 246 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, 247 | groups=groups 248 | ) 249 | self.layer_g = nn.ConvTranspose2d( 250 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, 251 | groups=groups 252 | ) 253 | 254 | def forward(self, x): 255 | f = self.layer_f(x) 256 | g = torch.sigmoid(self.layer_g(x)) 257 | return f * g 258 | 259 | 260 | class BlendLinear(nn.Module): 261 | def __init__(self, dim_in, dim_out, layer_type=nn.Linear, **unused_kwargs): 262 | super(BlendLinear, self).__init__() 263 | self._layer0 = layer_type(dim_in, dim_out) 264 | self._layer1 = layer_type(dim_in, dim_out) 265 | 266 | def forward(self, t, x): 267 | y0 = self._layer0(x) 268 | y1 = self._layer1(x) 269 | return y0 + (y1 - y0) * t 270 | 271 | 272 | class BlendConv2d(nn.Module): 273 | def __init__( 274 | self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False, 275 | **unused_kwargs 276 | ): 277 | super(BlendConv2d, self).__init__() 278 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 279 | self._layer0 = module( 280 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 281 | bias=bias 282 | ) 283 | self._layer1 = module( 284 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 285 | bias=bias 286 | ) 287 | 288 | def forward(self, t, x): 289 | y0 = self._layer0(x) 290 | y1 = self._layer1(x) 291 | return y0 + (y1 - y0) * t 292 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/diffeq_layers/container.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .wrappers import diffeq_wrapper 5 | 6 | 7 | class SequentialDiffEq(nn.Module): 8 | """A container for a sequential chain of layers. Supports both regular and diffeq layers. 9 | """ 10 | 11 | def __init__(self, *layers): 12 | super(SequentialDiffEq, self).__init__() 13 | self.layers = nn.ModuleList([diffeq_wrapper(layer) for layer in layers]) 14 | 15 | def forward(self, t, x): 16 | for layer in self.layers: 17 | x = layer(t, x) 18 | return x 19 | 20 | 21 | class MixtureODELayer(nn.Module): 22 | """Produces a mixture of experts where output = sigma(t) * f(t, x). 23 | Time-dependent weights sigma(t) help learn to blend the experts without resorting to a highly stiff f. 24 | Supports both regular and diffeq experts. 25 | """ 26 | 27 | def __init__(self, experts): 28 | super(MixtureODELayer, self).__init__() 29 | assert len(experts) > 1 30 | wrapped_experts = [diffeq_wrapper(ex) for ex in experts] 31 | self.experts = nn.ModuleList(wrapped_experts) 32 | self.mixture_weights = nn.Linear(1, len(self.experts)) 33 | 34 | def forward(self, t, y): 35 | dys = [] 36 | for f in self.experts: 37 | dys.append(f(t, y)) 38 | dys = torch.stack(dys, 0) 39 | weights = self.mixture_weights(t).view(-1, *([1] * (dys.ndimension() - 1))) 40 | 41 | dy = torch.sum(dys * weights, dim=0, keepdim=False) 42 | return dy 43 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/diffeq_layers/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import basic 4 | from . import container 5 | 6 | NGROUPS = 16 7 | 8 | 9 | class ResNet(container.SequentialDiffEq): 10 | def __init__(self, dim, intermediate_dim, n_resblocks, conv_block=None): 11 | super(ResNet, self).__init__() 12 | 13 | if conv_block is None: 14 | conv_block = basic.ConcatCoordConv2d 15 | 16 | self.dim = dim 17 | self.intermediate_dim = intermediate_dim 18 | self.n_resblocks = n_resblocks 19 | 20 | layers = [] 21 | layers.append(conv_block(dim, intermediate_dim, ksize=3, stride=1, padding=1, bias=False)) 22 | for _ in range(n_resblocks): 23 | layers.append(BasicBlock(intermediate_dim, conv_block)) 24 | layers.append(nn.GroupNorm(NGROUPS, intermediate_dim, eps=1e-4)) 25 | layers.append(nn.ReLU(inplace=True)) 26 | layers.append(conv_block(intermediate_dim, dim, ksize=1, bias=False)) 27 | 28 | super(ResNet, self).__init__(*layers) 29 | 30 | def __repr__(self): 31 | return ( 32 | '{name}({dim}, intermediate_dim={intermediate_dim}, n_resblocks={n_resblocks})'.format( 33 | name=self.__class__.__name__, **self.__dict__ 34 | ) 35 | ) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, dim, conv_block=None): 42 | super(BasicBlock, self).__init__() 43 | 44 | if conv_block is None: 45 | conv_block = basic.ConcatCoordConv2d 46 | 47 | self.norm1 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 48 | self.relu1 = nn.ReLU(inplace=True) 49 | self.conv1 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 50 | self.norm2 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 51 | self.relu2 = nn.ReLU(inplace=True) 52 | self.conv2 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 53 | 54 | def forward(self, t, x): 55 | residual = x 56 | 57 | out = self.norm1(x) 58 | out = self.relu1(out) 59 | out = self.conv1(t, out) 60 | 61 | out = self.norm2(out) 62 | out = self.relu2(out) 63 | out = self.conv2(t, out) 64 | 65 | out += residual 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/diffeq_layers/wrappers.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | import torch.nn as nn 3 | 4 | __all__ = ["diffeq_wrapper", "reshape_wrapper"] 5 | 6 | 7 | class DiffEqWrapper(nn.Module): 8 | def __init__(self, module): 9 | super(DiffEqWrapper, self).__init__() 10 | self.module = module 11 | if len(signature(self.module.forward).parameters) == 1: 12 | self.diffeq = lambda t, y: self.module(y) 13 | elif len(signature(self.module.forward).parameters) == 2: 14 | self.diffeq = self.module 15 | else: 16 | raise ValueError("Differential equation needs to either take (t, y) or (y,) as input.") 17 | 18 | def forward(self, t, y): 19 | return self.diffeq(t, y) 20 | 21 | def __repr__(self): 22 | return self.diffeq.__repr__() 23 | 24 | 25 | def diffeq_wrapper(layer): 26 | return DiffEqWrapper(layer) 27 | 28 | 29 | class ReshapeDiffEq(nn.Module): 30 | def __init__(self, input_shape, net): 31 | super(ReshapeDiffEq, self).__init__() 32 | assert len(signature(net.forward).parameters) == 2, "use diffeq_wrapper before reshape_wrapper." 33 | self.input_shape = input_shape 34 | self.net = net 35 | 36 | def forward(self, t, x): 37 | batchsize = x.shape[0] 38 | x = x.view(batchsize, *self.input_shape) 39 | return self.net(t, x).view(batchsize, -1) 40 | 41 | def __repr__(self): 42 | return self.diffeq.__repr__() 43 | 44 | 45 | def reshape_wrapper(input_shape, layer): 46 | return ReshapeDiffEq(input_shape, layer) 47 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _DEFAULT_ALPHA = 1e-6 6 | 7 | 8 | class ZeroMeanTransform(nn.Module): 9 | def __init__(self): 10 | nn.Module.__init__(self) 11 | 12 | def forward(self, x, logpx=None, reverse=False): 13 | if reverse: 14 | x = x + .5 15 | if logpx is None: 16 | return x 17 | return x, logpx 18 | else: 19 | x = x - .5 20 | if logpx is None: 21 | return x 22 | return x, logpx 23 | 24 | 25 | class LogitTransform(nn.Module): 26 | """ 27 | The proprocessing step used in Real NVP: 28 | y = sigmoid(x) - a / (1 - 2a) 29 | x = logit(a + (1 - 2a)*y) 30 | """ 31 | 32 | def __init__(self, alpha=_DEFAULT_ALPHA): 33 | nn.Module.__init__(self) 34 | self.alpha = alpha 35 | 36 | def forward(self, x, logpx=None, reverse=False): 37 | if reverse: 38 | return _sigmoid(x, logpx, self.alpha) 39 | else: 40 | return _logit(x, logpx, self.alpha) 41 | 42 | 43 | class SigmoidTransform(nn.Module): 44 | """Reverse of LogitTransform.""" 45 | 46 | def __init__(self, alpha=_DEFAULT_ALPHA): 47 | nn.Module.__init__(self) 48 | self.alpha = alpha 49 | 50 | def forward(self, x, logpx=None, reverse=False): 51 | if reverse: 52 | return _logit(x, logpx, self.alpha) 53 | else: 54 | return _sigmoid(x, logpx, self.alpha) 55 | 56 | 57 | def _logit(x, logpx=None, alpha=_DEFAULT_ALPHA): 58 | s = alpha + (1 - 2 * alpha) * x 59 | y = torch.log(s) - torch.log(1 - s) 60 | if logpx is None: 61 | return y 62 | return y, logpx - _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True) 63 | 64 | 65 | def _sigmoid(y, logpy=None, alpha=_DEFAULT_ALPHA): 66 | x = (torch.sigmoid(y) - alpha) / (1 - 2 * alpha) 67 | if logpy is None: 68 | return x 69 | return x, logpy + _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True) 70 | 71 | 72 | def _logdetgrad(x, alpha): 73 | s = alpha + (1 - 2 * alpha) * x 74 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * alpha) 75 | return logdetgrad 76 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BruteForceLayer(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(BruteForceLayer, self).__init__() 10 | self.weight = nn.Parameter(torch.eye(dim)) 11 | 12 | def forward(self, x, logpx=None, reverse=False): 13 | 14 | if not reverse: 15 | y = F.linear(x, self.weight) 16 | if logpx is None: 17 | return y 18 | else: 19 | return y, logpx - self._logdetgrad.expand_as(logpx) 20 | 21 | else: 22 | y = F.linear(x, self.weight.double().inverse().float()) 23 | if logpx is None: 24 | return y 25 | else: 26 | return y, logpx + self._logdetgrad.expand_as(logpx) 27 | 28 | @property 29 | def _logdetgrad(self): 30 | return torch.log(torch.abs(torch.det(self.weight.double()))).float() 31 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/norm_flows.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import grad 5 | 6 | 7 | class PlanarFlow(nn.Module): 8 | 9 | def __init__(self, nd=1): 10 | super(PlanarFlow, self).__init__() 11 | self.nd = nd 12 | self.activation = torch.tanh 13 | 14 | self.register_parameter('u', nn.Parameter(torch.randn(self.nd))) 15 | self.register_parameter('w', nn.Parameter(torch.randn(self.nd))) 16 | self.register_parameter('b', nn.Parameter(torch.randn(1))) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | stdv = 1. / math.sqrt(self.nd) 21 | self.u.data.uniform_(-stdv, stdv) 22 | self.w.data.uniform_(-stdv, stdv) 23 | self.b.data.fill_(0) 24 | self.make_invertible() 25 | 26 | def make_invertible(self): 27 | u = self.u.data 28 | w = self.w.data 29 | dot = torch.dot(u, w) 30 | m = -1 + math.log(1 + math.exp(dot)) 31 | du = (m - dot) / torch.norm(w) * w 32 | u = u + du 33 | self.u.data = u 34 | 35 | def forward(self, z, logp=None, reverse=False): 36 | """Computes f(z) and log q(f(z))""" 37 | 38 | assert not reverse, 'Planar normalizing flow cannot be reversed.' 39 | 40 | logp - torch.log(self._detgrad(z) + 1e-8) 41 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 42 | z = z + self.u.expand_as(z) * h 43 | 44 | f = self.sample(z) 45 | if logp is not None: 46 | qf = self.log_density(z, logp) 47 | return f, qf 48 | else: 49 | return f 50 | 51 | def sample(self, z): 52 | """Computes f(z)""" 53 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 54 | output = z + self.u.expand_as(z) * h 55 | return output 56 | 57 | def _detgrad(self, z): 58 | """Computes |det df/dz|""" 59 | with torch.enable_grad(): 60 | z = z.requires_grad_(True) 61 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) 62 | psi = grad(h, z, grad_outputs=torch.ones_like(h), create_graph=True, only_inputs=True)[0] 63 | u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1)) 64 | detgrad = 1 + u_dot_psi 65 | return detgrad 66 | 67 | def log_density(self, z, logqz): 68 | """Computes log density of the flow given the log density of z""" 69 | return logqz - torch.log(self._detgrad(z) + 1e-8) 70 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 6 | 7 | 8 | class MovingBatchNormNd(nn.Module): 9 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 10 | super(MovingBatchNormNd, self).__init__() 11 | self.num_features = num_features 12 | self.affine = affine 13 | self.eps = eps 14 | self.decay = decay 15 | self.bn_lag = bn_lag 16 | self.register_buffer('step', torch.zeros(1)) 17 | if self.affine: 18 | self.weight = Parameter(torch.Tensor(num_features)) 19 | self.bias = Parameter(torch.Tensor(num_features)) 20 | else: 21 | self.register_parameter('weight', None) 22 | self.register_parameter('bias', None) 23 | self.register_buffer('running_mean', torch.zeros(num_features)) 24 | self.register_buffer('running_var', torch.ones(num_features)) 25 | self.reset_parameters() 26 | 27 | @property 28 | def shape(self): 29 | raise NotImplementedError 30 | 31 | def reset_parameters(self): 32 | self.running_mean.zero_() 33 | self.running_var.fill_(1) 34 | if self.affine: 35 | self.weight.data.zero_() 36 | self.bias.data.zero_() 37 | 38 | def forward(self, x, logpx=None, reverse=False): 39 | if reverse: 40 | return self._reverse(x, logpx) 41 | else: 42 | return self._forward(x, logpx) 43 | 44 | def _forward(self, x, logpx=None): 45 | c = x.size(1) 46 | used_mean = self.running_mean.clone().detach() 47 | used_var = self.running_var.clone().detach() 48 | 49 | if self.training: 50 | # compute batch statistics 51 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 52 | batch_mean = torch.mean(x_t, dim=1) 53 | batch_var = torch.var(x_t, dim=1) 54 | 55 | # moving average 56 | if self.bn_lag > 0: 57 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 58 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 59 | used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach()) 60 | used_var /= (1. - self.bn_lag**(self.step[0] + 1)) 61 | 62 | # update running estimates 63 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 64 | self.running_var -= self.decay * (self.running_var - batch_var.data) 65 | self.step += 1 66 | 67 | # perform normalization 68 | used_mean = used_mean.view(*self.shape).expand_as(x) 69 | used_var = used_var.view(*self.shape).expand_as(x) 70 | 71 | y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps)) 72 | 73 | if self.affine: 74 | weight = self.weight.view(*self.shape).expand_as(x) 75 | bias = self.bias.view(*self.shape).expand_as(x) 76 | y = y * torch.exp(weight) + bias 77 | 78 | if logpx is None: 79 | return y 80 | else: 81 | return y, logpx - self._logdetgrad(x, used_var).view(x.size(0), -1).sum(1, keepdim=True) 82 | 83 | def _reverse(self, y, logpy=None): 84 | used_mean = self.running_mean 85 | used_var = self.running_var 86 | 87 | if self.affine: 88 | weight = self.weight.view(*self.shape).expand_as(y) 89 | bias = self.bias.view(*self.shape).expand_as(y) 90 | y = (y - bias) * torch.exp(-weight) 91 | 92 | used_mean = used_mean.view(*self.shape).expand_as(y) 93 | used_var = used_var.view(*self.shape).expand_as(y) 94 | x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean 95 | 96 | if logpy is None: 97 | return x 98 | else: 99 | return x, logpy + self._logdetgrad(x, used_var).view(x.size(0), -1).sum(1, keepdim=True) 100 | 101 | def _logdetgrad(self, x, used_var): 102 | logdetgrad = -0.5 * torch.log(used_var + self.eps) 103 | if self.affine: 104 | weight = self.weight.view(*self.shape).expand(*x.size()) 105 | logdetgrad += weight 106 | return logdetgrad 107 | 108 | def __repr__(self): 109 | return ( 110 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 111 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 112 | ) 113 | 114 | 115 | def stable_var(x, mean=None, dim=1): 116 | if mean is None: 117 | mean = x.mean(dim, keepdim=True) 118 | mean = mean.view(-1, 1) 119 | res = torch.pow(x - mean, 2) 120 | max_sqr = torch.max(res, dim, keepdim=True)[0] 121 | var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr 122 | var = var.view(-1) 123 | # change nan to zero 124 | var[var != var] = 0 125 | return var 126 | 127 | 128 | class MovingBatchNorm1d(MovingBatchNormNd): 129 | @property 130 | def shape(self): 131 | return [1, -1] 132 | 133 | 134 | class MovingBatchNorm2d(MovingBatchNormNd): 135 | @property 136 | def shape(self): 137 | return [1, -1, 1, 1] 138 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/odefunc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from . import diffeq_layers 8 | from .squeeze import squeeze, unsqueeze 9 | 10 | __all__ = ["ODEnet", "AutoencoderDiffEqNet", "ODEfunc", "AutoencoderODEfunc"] 11 | 12 | 13 | def divergence_bf(dx, y, **unused_kwargs): 14 | sum_diag = 0. 15 | # print(dx.shape, dx.requires_grad, y.shape, y.requires_grad) 16 | for i in range(y.shape[1]): 17 | sum_diag += torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0].contiguous()[:, i].contiguous() 18 | return sum_diag.contiguous() 19 | 20 | 21 | # def divergence_bf(f, y, **unused_kwargs): 22 | # jac = _get_minibatch_jacobian(f, y) 23 | # diagonal = jac.view(jac.shape[0], -1)[:, ::jac.shape[1]] 24 | # return torch.sum(diagonal, 1) 25 | 26 | 27 | def _get_minibatch_jacobian(y, x): 28 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 29 | 30 | Args: 31 | y: (N, ...) with a total of D_y elements in ... 32 | x: (N, ...) with a total of D_x elements in ... 33 | Returns: 34 | The minibatch Jacobian matrix of shape (N, D_y, D_x) 35 | """ 36 | assert y.shape[0] == x.shape[0] 37 | y = y.view(y.shape[0], -1) 38 | 39 | # Compute Jacobian row by row. 40 | jac = [] 41 | for j in range(y.shape[1]): 42 | dy_j_dx = torch.autograd.grad(y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True, 43 | create_graph=True)[0].view(x.shape[0], -1) 44 | jac.append(torch.unsqueeze(dy_j_dx, 1)) 45 | jac = torch.cat(jac, 1) 46 | return jac 47 | 48 | 49 | def divergence_approx(f, y, e=None): 50 | e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] 51 | e_dzdx_e = e_dzdx * e 52 | approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1) 53 | return approx_tr_dzdx 54 | 55 | 56 | def sample_rademacher_like(y): 57 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 58 | 59 | 60 | def sample_gaussian_like(y): 61 | return torch.randn_like(y) 62 | 63 | 64 | class Swish(nn.Module): 65 | 66 | def __init__(self): 67 | super(Swish, self).__init__() 68 | self.beta = nn.Parameter(torch.tensor(1.0)) 69 | 70 | def forward(self, x): 71 | return x * torch.sigmoid(self.beta * x) 72 | 73 | 74 | class Lambda(nn.Module): 75 | 76 | def __init__(self, f): 77 | super(Lambda, self).__init__() 78 | self.f = f 79 | 80 | def forward(self, x): 81 | return self.f(x) 82 | 83 | 84 | NONLINEARITIES = { 85 | "tanh": nn.Tanh(), 86 | "relu": nn.ReLU(), 87 | "softplus": nn.Softplus(), 88 | "elu": nn.ELU(), 89 | "swish": Swish(), 90 | "square": Lambda(lambda x: x**2), 91 | "identity": Lambda(lambda x: x), 92 | } 93 | 94 | 95 | class ODEnet(nn.Module): 96 | """ 97 | Helper class to make neural nets for use in continuous normalizing flows 98 | """ 99 | 100 | def __init__( 101 | self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0 102 | ): 103 | super(ODEnet, self).__init__() 104 | self.num_squeeze = num_squeeze 105 | if conv: 106 | assert len(strides) == len(hidden_dims) + 1 107 | base_layer = { 108 | "ignore": diffeq_layers.IgnoreConv2d, 109 | "hyper": diffeq_layers.HyperConv2d, 110 | "squash": diffeq_layers.SquashConv2d, 111 | "concat": diffeq_layers.ConcatConv2d, 112 | "concat_v2": diffeq_layers.ConcatConv2d_v2, 113 | "concatsquash": diffeq_layers.ConcatSquashConv2d, 114 | "blend": diffeq_layers.BlendConv2d, 115 | "concatcoord": diffeq_layers.ConcatCoordConv2d, 116 | }[layer_type] 117 | else: 118 | strides = [None] * (len(hidden_dims) + 1) 119 | base_layer = { 120 | "ignore": diffeq_layers.IgnoreLinear, 121 | "hyper": diffeq_layers.HyperLinear, 122 | "squash": diffeq_layers.SquashLinear, 123 | "concat": diffeq_layers.ConcatLinear, 124 | "concat_v2": diffeq_layers.ConcatLinear_v2, 125 | "concatsquash": diffeq_layers.ConcatSquashLinear, 126 | "blend": diffeq_layers.BlendLinear, 127 | "concatcoord": diffeq_layers.ConcatLinear, 128 | }[layer_type] 129 | 130 | # build layers and add them 131 | layers = [] 132 | activation_fns = [] 133 | hidden_shape = input_shape 134 | 135 | for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides): 136 | if stride is None: 137 | layer_kwargs = {} 138 | elif stride == 1: 139 | layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False} 140 | elif stride == 2: 141 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False} 142 | elif stride == -2: 143 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True} 144 | else: 145 | raise ValueError('Unsupported stride: {}'.format(stride)) 146 | 147 | layer = base_layer(hidden_shape[0], dim_out, **layer_kwargs) 148 | layers.append(layer) 149 | activation_fns.append(NONLINEARITIES[nonlinearity]) 150 | 151 | hidden_shape = list(copy.copy(hidden_shape)) 152 | hidden_shape[0] = dim_out 153 | if stride == 2: 154 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2 155 | elif stride == -2: 156 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2 157 | 158 | self.layers = nn.ModuleList(layers) 159 | self.activation_fns = nn.ModuleList(activation_fns[:-1]) 160 | 161 | def forward(self, t, y): 162 | dx = y 163 | # squeeze 164 | for _ in range(self.num_squeeze): 165 | dx = squeeze(dx, 2) 166 | for l, layer in enumerate(self.layers): 167 | dx = layer(t, dx) 168 | # if not last layer, use nonlinearity 169 | if l < len(self.layers) - 1: 170 | dx = self.activation_fns[l](dx) 171 | # unsqueeze 172 | for _ in range(self.num_squeeze): 173 | dx = unsqueeze(dx, 2) 174 | return dx 175 | 176 | 177 | class AutoencoderDiffEqNet(nn.Module): 178 | """ 179 | Helper class to make neural nets for use in continuous normalizing flows 180 | """ 181 | 182 | def __init__(self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus"): 183 | super(AutoencoderDiffEqNet, self).__init__() 184 | assert layer_type in ("ignore", "hyper", "concat", "concatcoord", "blend") 185 | assert nonlinearity in ("tanh", "relu", "softplus", "elu") 186 | 187 | self.nonlinearity = {"tanh": F.tanh, "relu": F.relu, "softplus": F.softplus, "elu": F.elu}[nonlinearity] 188 | if conv: 189 | assert len(strides) == len(hidden_dims) + 1 190 | base_layer = { 191 | "ignore": diffeq_layers.IgnoreConv2d, 192 | "hyper": diffeq_layers.HyperConv2d, 193 | "squash": diffeq_layers.SquashConv2d, 194 | "concat": diffeq_layers.ConcatConv2d, 195 | "blend": diffeq_layers.BlendConv2d, 196 | "concatcoord": diffeq_layers.ConcatCoordConv2d, 197 | }[layer_type] 198 | else: 199 | strides = [None] * (len(hidden_dims) + 1) 200 | base_layer = { 201 | "ignore": diffeq_layers.IgnoreLinear, 202 | "hyper": diffeq_layers.HyperLinear, 203 | "squash": diffeq_layers.SquashLinear, 204 | "concat": diffeq_layers.ConcatLinear, 205 | "blend": diffeq_layers.BlendLinear, 206 | "concatcoord": diffeq_layers.ConcatLinear, 207 | }[layer_type] 208 | 209 | # build layers and add them 210 | encoder_layers = [] 211 | decoder_layers = [] 212 | hidden_shape = input_shape 213 | for i, (dim_out, stride) in enumerate(zip(hidden_dims + (input_shape[0],), strides)): 214 | if i <= len(hidden_dims) // 2: 215 | layers = encoder_layers 216 | else: 217 | layers = decoder_layers 218 | 219 | if stride is None: 220 | layer_kwargs = {} 221 | elif stride == 1: 222 | layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False} 223 | elif stride == 2: 224 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False} 225 | elif stride == -2: 226 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True} 227 | else: 228 | raise ValueError('Unsupported stride: {}'.format(stride)) 229 | 230 | layers.append(base_layer(hidden_shape[0], dim_out, **layer_kwargs)) 231 | 232 | hidden_shape = list(copy.copy(hidden_shape)) 233 | hidden_shape[0] = dim_out 234 | if stride == 2: 235 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2 236 | elif stride == -2: 237 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2 238 | 239 | self.encoder_layers = nn.ModuleList(encoder_layers) 240 | self.decoder_layers = nn.ModuleList(decoder_layers) 241 | 242 | def forward(self, t, y): 243 | h = y 244 | for layer in self.encoder_layers: 245 | h = self.nonlinearity(layer(t, h)) 246 | 247 | dx = h 248 | for i, layer in enumerate(self.decoder_layers): 249 | dx = layer(t, dx) 250 | # if not last layer, use nonlinearity 251 | if i < len(self.decoder_layers) - 1: 252 | dx = self.nonlinearity(dx) 253 | return h, dx 254 | 255 | 256 | class ODEfunc(nn.Module): 257 | 258 | def __init__(self, diffeq, divergence_fn="approximate", residual=False, rademacher=False): 259 | super(ODEfunc, self).__init__() 260 | assert divergence_fn in ("brute_force", "approximate") 261 | 262 | # self.diffeq = diffeq_layers.wrappers.diffeq_wrapper(diffeq) 263 | self.diffeq = diffeq 264 | self.residual = residual 265 | self.rademacher = rademacher 266 | 267 | if divergence_fn == "brute_force": 268 | self.divergence_fn = divergence_bf 269 | elif divergence_fn == "approximate": 270 | self.divergence_fn = divergence_approx 271 | 272 | self.register_buffer("_num_evals", torch.tensor(0.)) 273 | 274 | def before_odeint(self, e=None): 275 | self._e = e 276 | self._num_evals.fill_(0) 277 | 278 | def forward(self, t, states): 279 | assert len(states) >= 2 280 | y = states[0] 281 | 282 | # increment num evals 283 | self._num_evals += 1 284 | 285 | # convert to tensor 286 | # t = torch.tensor(t).type_as(y) 287 | batchsize = y.shape[0] 288 | 289 | # Sample and fix the noise. 290 | if self._e is None: 291 | if self.rademacher: 292 | self._e = sample_rademacher_like(y) 293 | else: 294 | self._e = sample_gaussian_like(y) 295 | 296 | with torch.set_grad_enabled(True): 297 | y.requires_grad_(True) 298 | t.requires_grad_(True) 299 | for s_ in states[2:]: 300 | s_.requires_grad_(True) 301 | dy = self.diffeq(t, y, *states[2:]) 302 | # Hack for 2D data to use brute force divergence computation. 303 | if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2: 304 | divergence = divergence_bf(dy, y).view(batchsize, 1) 305 | else: 306 | divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1) 307 | if self.residual: 308 | dy = dy - y 309 | divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32 310 | ).to(divergence) 311 | return tuple([dy, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]]) 312 | 313 | 314 | class AutoencoderODEfunc(nn.Module): 315 | 316 | def __init__(self, autoencoder_diffeq, divergence_fn="approximate", residual=False, rademacher=False): 317 | assert divergence_fn in ("approximate"), "Only approximate divergence supported at the moment. (TODO)" 318 | assert isinstance(autoencoder_diffeq, AutoencoderDiffEqNet) 319 | super(AutoencoderODEfunc, self).__init__() 320 | self.residual = residual 321 | self.autoencoder_diffeq = autoencoder_diffeq 322 | self.rademacher = rademacher 323 | 324 | self.register_buffer("_num_evals", torch.tensor(0.)) 325 | 326 | def before_odeint(self, e=None): 327 | self._e = e 328 | self._num_evals.fill_(0) 329 | 330 | def forward(self, t, y_and_logpy): 331 | y, _ = y_and_logpy # remove logpy 332 | 333 | # increment num evals 334 | self._num_evals += 1 335 | 336 | # convert to tensor 337 | t = torch.tensor(t).type_as(y) 338 | batchsize = y.shape[0] 339 | 340 | with torch.set_grad_enabled(True): 341 | y.requires_grad_(True) 342 | t.requires_grad_(True) 343 | h, dy = self.autoencoder_diffeq(t, y) 344 | 345 | # Sample and fix the noise. 346 | if self._e is None: 347 | if self.rademacher: 348 | self._e = sample_rademacher_like(h) 349 | else: 350 | self._e = sample_gaussian_like(h) 351 | 352 | e_vjp_dhdy = torch.autograd.grad(h, y, self._e, create_graph=True)[0] 353 | e_vjp_dfdy = torch.autograd.grad(dy, h, e_vjp_dhdy, create_graph=True)[0] 354 | divergence = torch.sum((e_vjp_dfdy * self._e).view(batchsize, -1), 1, keepdim=True) 355 | 356 | if self.residual: 357 | dy = dy - y 358 | divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32 359 | ).to(divergence) 360 | 361 | return dy, -divergence 362 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, dim): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 11 | self.bn1 = nn.GroupNorm(2, dim, eps=1e-4) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 14 | self.bn2 = nn.GroupNorm(2, dim, eps=1e-4) 15 | 16 | def forward(self, x): 17 | residual = x 18 | 19 | out = self.conv1(x) 20 | out = self.bn1(out) 21 | out = self.relu(out) 22 | 23 | out = self.conv2(out) 24 | out = self.bn2(out) 25 | 26 | out += residual 27 | out = self.relu(out) 28 | 29 | return out 30 | 31 | 32 | class ResNeXtBottleneck(nn.Module): 33 | """ 34 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 35 | """ 36 | 37 | def __init__(self, dim, cardinality=4, base_depth=32): 38 | """ Constructor 39 | Args: 40 | in_channels: input channel dimensionality 41 | out_channels: output channel dimensionality 42 | stride: conv stride. Replaces pooling layer. 43 | cardinality: num of convolution groups. 44 | base_width: base number of channels in each group. 45 | widen_factor: factor to reduce the input dimensionality before convolution. 46 | """ 47 | super(ResNeXtBottleneck, self).__init__() 48 | D = cardinality * base_depth 49 | self.conv_reduce = nn.Conv2d(dim, D, kernel_size=1, stride=1, padding=0, bias=False) 50 | self.bn_reduce = nn.BatchNorm2d(D) 51 | self.conv_grp = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1, groups=cardinality, bias=False) 52 | self.bn = nn.BatchNorm2d(D) 53 | self.conv_expand = nn.Conv2d(D, dim, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn_expand = nn.BatchNorm2d(dim) 55 | 56 | def forward(self, x): 57 | bottleneck = self.conv_reduce.forward(x) 58 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 59 | bottleneck = self.conv_grp.forward(bottleneck) 60 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 61 | bottleneck = self.conv_expand.forward(bottleneck) 62 | bottleneck = self.bn_expand.forward(bottleneck) 63 | return F.relu(x + bottleneck, inplace=True) 64 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['SqueezeLayer'] 4 | 5 | 6 | class SqueezeLayer(nn.Module): 7 | def __init__(self, downscale_factor): 8 | super(SqueezeLayer, self).__init__() 9 | self.downscale_factor = downscale_factor 10 | 11 | def forward(self, x, logpx=None, reverse=False): 12 | if reverse: 13 | return self._upsample(x, logpx) 14 | else: 15 | return self._downsample(x, logpx) 16 | 17 | def _downsample(self, x, logpx=None): 18 | squeeze_x = squeeze(x, self.downscale_factor) 19 | if logpx is None: 20 | return squeeze_x 21 | else: 22 | return squeeze_x, logpx 23 | 24 | def _upsample(self, y, logpy=None): 25 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 26 | if logpy is None: 27 | return unsqueeze_y 28 | else: 29 | return unsqueeze_y, logpy 30 | 31 | 32 | def unsqueeze(input, upscale_factor=2): 33 | ''' 34 | [:, C*r^2, H, W] -> [:, C, H*r, W*r] 35 | ''' 36 | batch_size, in_channels, in_height, in_width = input.size() 37 | out_channels = in_channels // (upscale_factor**2) 38 | 39 | out_height = in_height * upscale_factor 40 | out_width = in_width * upscale_factor 41 | 42 | input_view = input.contiguous().view(batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width) 43 | 44 | output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 45 | return output.view(batch_size, out_channels, out_height, out_width) 46 | 47 | 48 | def squeeze(input, downscale_factor=2): 49 | ''' 50 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 51 | ''' 52 | batch_size, in_channels, in_height, in_width = input.size() 53 | out_channels = in_channels * (downscale_factor**2) 54 | 55 | out_height = in_height // downscale_factor 56 | out_width = in_width // downscale_factor 57 | 58 | input_view = input.contiguous().view( 59 | batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor 60 | ) 61 | 62 | output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 63 | return output.view(batch_size, out_channels, out_height, out_width) 64 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/layers/wrappers/cnf_regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RegularizedODEfunc(nn.Module): 6 | def __init__(self, odefunc, regularization_fns): 7 | super(RegularizedODEfunc, self).__init__() 8 | self.odefunc = odefunc 9 | self.regularization_fns = regularization_fns 10 | 11 | def before_odeint(self, *args, **kwargs): 12 | self.odefunc.before_odeint(*args, **kwargs) 13 | 14 | def forward(self, t, state): 15 | class SharedContext(object): 16 | pass 17 | 18 | with torch.enable_grad(): 19 | x, logp = state[:2] 20 | x.requires_grad_(True) 21 | logp.requires_grad_(True) 22 | t.requires_grad_(True) 23 | dstate = self.odefunc(t, (x, logp)) 24 | if len(state) > 2: 25 | dx, dlogp = dstate[:2] 26 | reg_states = tuple( 27 | reg_fn(x, logp, dx, dlogp, t, SharedContext) 28 | for reg_fn in self.regularization_fns 29 | ) 30 | return dstate + reg_states 31 | else: 32 | return dstate 33 | 34 | @property 35 | def _num_evals(self): 36 | return self.odefunc._num_evals 37 | 38 | 39 | def _batch_root_mean_squared(tensor): 40 | tensor = tensor.view(tensor.shape[0], -1) 41 | return torch.mean(torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5) 42 | 43 | 44 | def l1_regularzation_fn(x, logp, dx, dlogp, t, unused_context): 45 | del x, logp, dlogp 46 | return torch.mean(torch.abs(dx)) 47 | 48 | 49 | def l2_regularzation_fn(x, logp, dx, dlogp, t, unused_context): 50 | del x, logp, dlogp 51 | return _batch_root_mean_squared(dx) 52 | 53 | 54 | def squared_l2_regularization_fn(x, logp, dx, dlogp, t, unused_context): 55 | del x, logp, dlogp 56 | to_return = dx.view(dx.shape[0], -1) 57 | # print(t) 58 | return torch.mean(torch.pow(torch.norm(to_return, p=2, dim=1), 2)) 59 | 60 | 61 | def directional_l2_regularization_fn(x, logp, dx, dlogp, t, unused_context): 62 | del logp, dlogp 63 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 64 | # print(directional_dx.shape) 65 | # exit() 66 | return _batch_root_mean_squared(directional_dx) 67 | 68 | 69 | def directional_l2_change_penalty_fn(x, logp, dx, dlogp, t, context): 70 | del logp, dlogp 71 | # For now we ignore the directional dx penalty as this complicates things 72 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 73 | dfdt = _get_minibatch_jacobian(dx, t) 74 | dfdt_full = dfdt + torch.sum(directional_dx, axis=0) 75 | return torch.mean(torch.norm(dfdt_full, p=2) / dfdt_full.shape[0] ** 0.5) 76 | 77 | 78 | def jacobian_frobenius_regularization_fn(x, logp, dx, dlogp, t, context): 79 | del logp, dlogp, t 80 | if hasattr(context, "jac"): 81 | jac = context.jac 82 | else: 83 | jac = _get_minibatch_jacobian(dx, x) 84 | context.jac = jac 85 | return _batch_root_mean_squared(jac) 86 | 87 | 88 | def jacobian_diag_frobenius_regularization_fn(x, logp, dx, dlogp, t, context): 89 | del logp, dlogp, t 90 | if hasattr(context, "jac"): 91 | jac = context.jac 92 | else: 93 | jac = _get_minibatch_jacobian(dx, x) 94 | context.jac = jac 95 | diagonal = jac.view(jac.shape[0], -1)[ 96 | :, :: jac.shape[1] 97 | ] # assumes jac is minibatch square, ie. (N, M, M). 98 | return _batch_root_mean_squared(diagonal) 99 | 100 | 101 | def jacobian_offdiag_frobenius_regularization_fn(x, logp, dx, dlogp, t, context): 102 | del logp, dlogp, t 103 | if hasattr(context, "jac"): 104 | jac = context.jac 105 | else: 106 | jac = _get_minibatch_jacobian(dx, x) 107 | context.jac = jac 108 | diagonal = jac.view(jac.shape[0], -1)[ 109 | :, :: jac.shape[1] 110 | ] # assumes jac is minibatch square, ie. (N, M, M). 111 | ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum( 112 | diagonal ** 2, dim=1 113 | ) 114 | ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1)) 115 | return torch.mean(ms_offdiag) 116 | 117 | 118 | def _get_minibatch_jacobian(y, x, create_graph=True): 119 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 120 | 121 | Args: 122 | y: (N, ...) with a total of D_y elements in ... 123 | x: (N, ...) with a total of D_x elements in ... 124 | Returns: 125 | The minibatch Jacobian matrix of shape (N, D_y, D_x) 126 | """ 127 | # assert y.shape[0] == x.shape[0] 128 | y = y.view(y.shape[0], -1) 129 | 130 | # Compute Jacobian row by row. 131 | jac = [] 132 | for j in range(y.shape[1]): 133 | dy_j_dx = torch.autograd.grad( 134 | y[:, j], 135 | x, 136 | torch.ones_like(y[:, j]), 137 | retain_graph=True, 138 | create_graph=create_graph, 139 | )[0] 140 | jac.append(torch.unsqueeze(dy_j_dx, -1)) 141 | jac = torch.cat(jac, -1) 142 | return jac 143 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spectral Normalization from https://arxiv.org/abs/1802.05957 3 | """ 4 | import types 5 | import torch 6 | from torch.nn.functional import normalize 7 | 8 | POWER_ITERATION_FN = "spectral_norm_power_iteration" 9 | 10 | 11 | class SpectralNorm(object): 12 | def __init__(self, name='weight', dim=0, eps=1e-12): 13 | self.name = name 14 | self.dim = dim 15 | self.eps = eps 16 | 17 | def compute_weight(self, module, n_power_iterations): 18 | if n_power_iterations < 0: 19 | raise ValueError( 20 | 'Expected n_power_iterations to be non-negative, but ' 21 | 'got n_power_iterations={}'.format(n_power_iterations) 22 | ) 23 | 24 | weight = getattr(module, self.name + '_orig') 25 | u = getattr(module, self.name + '_u') 26 | v = getattr(module, self.name + '_v') 27 | weight_mat = weight 28 | if self.dim != 0: 29 | # permute dim to front 30 | weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim]) 31 | height = weight_mat.size(0) 32 | weight_mat = weight_mat.reshape(height, -1) 33 | with torch.no_grad(): 34 | for _ in range(n_power_iterations): 35 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 36 | # are the first left and right singular vectors. 37 | # This power iteration produces approximations of `u` and `v`. 38 | v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 39 | u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 40 | setattr(module, self.name + '_u', u) 41 | setattr(module, self.name + '_v', v) 42 | 43 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 44 | weight = weight / sigma 45 | setattr(module, self.name, weight) 46 | 47 | def remove(self, module): 48 | weight = getattr(module, self.name) 49 | delattr(module, self.name) 50 | delattr(module, self.name + '_u') 51 | delattr(module, self.name + '_orig') 52 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 53 | 54 | def get_update_method(self, module): 55 | def update_fn(module, n_power_iterations): 56 | self.compute_weight(module, n_power_iterations) 57 | 58 | return update_fn 59 | 60 | def __call__(self, module, unused_inputs): 61 | del unused_inputs 62 | self.compute_weight(module, n_power_iterations=0) 63 | 64 | # requires_grad might be either True or False during inference. 65 | if not module.training: 66 | r_g = getattr(module, self.name + '_orig').requires_grad 67 | setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g)) 68 | 69 | @staticmethod 70 | def apply(module, name, dim, eps): 71 | fn = SpectralNorm(name, dim, eps) 72 | weight = module._parameters[name] 73 | height = weight.size(dim) 74 | 75 | u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 76 | v = normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps) 77 | delattr(module, fn.name) 78 | module.register_parameter(fn.name + "_orig", weight) 79 | # We still need to assign weight back as fn.name because all sorts of 80 | # things may assume that it exists, e.g., when initializing weights. 81 | # However, we can't directly assign as it could be an nn.Parameter and 82 | # gets added as a parameter. Instead, we register weight.data as a 83 | # buffer, which will cause weight to be included in the state dict 84 | # and also supports nn.init due to shared storage. 85 | module.register_buffer(fn.name, weight.data) 86 | module.register_buffer(fn.name + "_u", u) 87 | module.register_buffer(fn.name + "_v", v) 88 | 89 | setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module)) 90 | 91 | module.register_forward_pre_hook(fn) 92 | return fn 93 | 94 | 95 | def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12): 96 | r"""Applies spectral normalization to a parameter in the given module. 97 | 98 | .. math:: 99 | \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ 100 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 101 | 102 | Spectral normalization stabilizes the training of discriminators (critics) 103 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor 104 | with spectral norm :math:`\sigma` of the weight matrix calculated using 105 | power iteration method. If the dimension of the weight tensor is greater 106 | than 2, it is reshaped to 2D in power iteration method to get spectral 107 | norm. This is implemented via a hook that calculates spectral norm and 108 | rescales weight before every :meth:`~Module.forward` call. 109 | 110 | See `Spectral Normalization for Generative Adversarial Networks`_ . 111 | 112 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 113 | 114 | Args: 115 | module (nn.Module): containing module 116 | name (str, optional): name of weight parameter 117 | n_power_iterations (int, optional): number of power iterations to 118 | calculate spectal norm 119 | dim (int, optional): dimension corresponding to number of outputs, 120 | the default is 0, except for modules that are instances of 121 | ConvTranspose1/2/3d, when it is 1 122 | eps (float, optional): epsilon for numerical stability in 123 | calculating norms 124 | 125 | Returns: 126 | The original module with the spectal norm hook 127 | 128 | Example:: 129 | 130 | >>> m = spectral_norm(nn.Linear(20, 40)) 131 | Linear (20 -> 40) 132 | >>> m.weight_u.size() 133 | torch.Size([20]) 134 | 135 | """ 136 | if dim is None: 137 | if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): 138 | dim = 1 139 | else: 140 | dim = 0 141 | SpectralNorm.apply(module, name, dim=dim, eps=eps) 142 | return module 143 | 144 | 145 | def remove_spectral_norm(module, name='weight'): 146 | r"""Removes the spectral normalization reparameterization from a module. 147 | 148 | Args: 149 | module (nn.Module): containing module 150 | name (str, optional): name of weight parameter 151 | 152 | Example: 153 | >>> m = spectral_norm(nn.Linear(40, 10)) 154 | >>> remove_spectral_norm(m) 155 | """ 156 | for k, hook in module._forward_pre_hooks.items(): 157 | if isinstance(hook, SpectralNorm) and hook.name == name: 158 | hook.remove(module) 159 | del module._forward_pre_hooks[k] 160 | return module 161 | 162 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 163 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from numbers import Number 4 | import logging 5 | import torch 6 | 7 | 8 | def makedirs(dirname): 9 | if not os.path.exists(dirname): 10 | os.makedirs(dirname) 11 | 12 | 13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 14 | logger = logging.getLogger() 15 | if debug: 16 | level = logging.DEBUG 17 | else: 18 | level = logging.INFO 19 | logger.setLevel(level) 20 | if saving: 21 | info_file_handler = logging.FileHandler(logpath, mode="a") 22 | info_file_handler.setLevel(level) 23 | logger.addHandler(info_file_handler) 24 | if displaying: 25 | console_handler = logging.StreamHandler() 26 | console_handler.setLevel(level) 27 | logger.addHandler(console_handler) 28 | logger.info(filepath) 29 | with open(filepath, "r") as f: 30 | logger.info(f.read()) 31 | 32 | for f in package_files: 33 | logger.info(f) 34 | with open(f, "r") as package_f: 35 | logger.info(package_f.read()) 36 | 37 | return logger 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | class RunningAverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | 62 | def __init__(self, momentum=0.99): 63 | self.momentum = momentum 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = None 68 | self.avg = 0 69 | 70 | def update(self, val): 71 | if self.val is None: 72 | self.avg = val 73 | else: 74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 75 | self.val = val 76 | 77 | 78 | def inf_generator(iterable): 79 | """Allows training with DataLoaders in a single infinite loop: 80 | for i, (x, y) in enumerate(inf_generator(train_loader)): 81 | """ 82 | iterator = iterable.__iter__() 83 | while True: 84 | try: 85 | yield iterator.__next__() 86 | except StopIteration: 87 | iterator = iterable.__iter__() 88 | 89 | 90 | def save_checkpoint(state, save, epoch): 91 | if not os.path.exists(save): 92 | os.makedirs(save) 93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 94 | torch.save(state, filename) 95 | 96 | 97 | def isnan(tensor): 98 | return (tensor != tensor) 99 | 100 | 101 | def logsumexp(value, dim=None, keepdim=False): 102 | """Numerically stable implementation of the operation 103 | value.exp().sum(dim, keepdim).log() 104 | """ 105 | if dim is not None: 106 | m, _ = torch.max(value, dim=dim, keepdim=True) 107 | value0 = value - m 108 | if keepdim is False: 109 | m = m.squeeze(dim) 110 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 111 | else: 112 | m = torch.max(value) 113 | sum_exp = torch.sum(torch.exp(value - m)) 114 | if isinstance(sum_exp, Number): 115 | return m + math.log(sum_exp) 116 | else: 117 | return m + torch.log(sum_exp) 118 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/visualize_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | LOW = -4 8 | HIGH = 4 9 | 10 | 11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 12 | """ 13 | Args: 14 | potential: computes U(z_k) given z_k 15 | """ 16 | xside = np.linspace(LOW, HIGH, npts) 17 | yside = np.linspace(LOW, HIGH, npts) 18 | xx, yy = np.meshgrid(xside, yside) 19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 20 | 21 | z = torch.Tensor(z) 22 | u = potential(z).cpu().numpy() 23 | p = np.exp(-u).reshape(npts, npts) 24 | 25 | plt.pcolormesh(xx, yy, p) 26 | ax.invert_yaxis() 27 | ax.get_xaxis().set_ticks([]) 28 | ax.get_yaxis().set_ticks([]) 29 | ax.set_title(title) 30 | 31 | 32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"): 33 | """ 34 | Args: 35 | transform: computes z_k and log(q_k) given z_0 36 | """ 37 | side = np.linspace(LOW, HIGH, npts) 38 | xx, yy = np.meshgrid(side, side) 39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 40 | 41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device) 42 | logqz = prior_logdensity(z) 43 | z, logqz = transform(z, logqz) 44 | logqz = torch.sum(logqz, dim=1)[:, None] 45 | 46 | xx = z[:, 0].cpu().numpy().reshape(npts, npts) 47 | yy = z[:, 1].cpu().numpy().reshape(npts, npts) 48 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts) 49 | 50 | plt.pcolormesh(xx, yy, qz) 51 | ax.set_xlim(LOW, HIGH) 52 | ax.set_ylim(LOW, HIGH) 53 | cmap = matplotlib.cm.get_cmap(None) 54 | ax.set_facecolor(cmap(0.)) 55 | ax.invert_yaxis() 56 | ax.get_xaxis().set_ticks([]) 57 | ax.get_yaxis().set_ticks([]) 58 | ax.set_title(title) 59 | 60 | 61 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 62 | side = np.linspace(LOW, HIGH, npts) 63 | xx, yy = np.meshgrid(side, side) 64 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 65 | 66 | x = torch.from_numpy(x).type(torch.float32).to(device) 67 | zeros = torch.zeros(x.shape[0], 1).to(x) 68 | 69 | z, delta_logp = [], [] 70 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 71 | for ii in torch.split(inds, int(memory**2)): 72 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 73 | z.append(z_) 74 | delta_logp.append(delta_logp_) 75 | z = torch.cat(z, 0) 76 | delta_logp = torch.cat(delta_logp, 0) 77 | 78 | logpz = prior_logdensity(z) 79 | logpx = logpz - delta_logp 80 | 81 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 82 | 83 | ax.imshow(px) 84 | ax.get_xaxis().set_ticks([]) 85 | ax.get_yaxis().set_ticks([]) 86 | ax.set_title(title) 87 | 88 | 89 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"): 90 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 91 | zk = [] 92 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 93 | for ii in torch.split(inds, int(memory**2)): 94 | zk.append(transform(z[ii])) 95 | zk = torch.cat(zk, 0).cpu().numpy() 96 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts) 97 | ax.invert_yaxis() 98 | ax.get_xaxis().set_ticks([]) 99 | ax.get_yaxis().set_ticks([]) 100 | ax.set_title(title) 101 | 102 | 103 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"): 104 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts) 105 | ax.invert_yaxis() 106 | ax.get_xaxis().set_ticks([]) 107 | ax.get_yaxis().set_ticks([]) 108 | ax.set_title(title) 109 | 110 | def visualize_growth(growth_model, full_data, labels, npts=200, memory=100, device='cpu'): 111 | with torch.no_grad(): 112 | fig, ax = plt.subplots(1,1) 113 | side = np.linspace(LOW, HIGH, npts) 114 | xx, yy = np.meshgrid(side, side) 115 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 116 | x = torch.from_numpy(x).type(torch.float32).to(device) 117 | output_growth = growth_model(x).cpu().numpy() 118 | output_growth = np.reshape(output_growth, (npts, npts)) 119 | im = ax.imshow(output_growth, cmap = 'bwr') 120 | ax.get_xaxis().set_ticks([]) 121 | ax.get_yaxis().set_ticks([]) 122 | fig.colorbar(im, ax=ax) 123 | ax.set_title('Growth Rate') 124 | 125 | # rescale full data to image coordinates 126 | full_data = full_data * npts / 8 + npts / 2 127 | #ax.scatter(full_data[:,0], full_data[:,1], c=labels / 5, cmap='Spectral', s=10,alpha=0.5) 128 | 129 | 130 | def visualize_transform( 131 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 132 | memory=100, device="cpu" 133 | ): 134 | """Produces visualization for the model density and samples from the model.""" 135 | plt.clf() 136 | ax = plt.subplot(1, 3, 1, aspect="equal") 137 | if samples: 138 | plt_samples(potential_or_samples, ax, npts=npts) 139 | else: 140 | plt_potential_func(potential_or_samples, ax, npts=npts) 141 | 142 | ax = plt.subplot(1, 3, 2, aspect="equal") 143 | if inverse_transform is None: 144 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 145 | else: 146 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 147 | 148 | ax = plt.subplot(1, 3, 3, aspect="equal") 149 | if transform is not None: 150 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 151 | -------------------------------------------------------------------------------- /TrajectoryNet/lib/viz_scrna.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | 10 | # def standard_normal_logprob(z): 11 | # logZ = -0.5 * math.log(2 * math.pi) 12 | # return torch.sum(logZ - z.pow(2) / 2, 1, keepdim=True) 13 | 14 | 15 | def makedirs(dirname): 16 | if not os.path.exists(dirname): 17 | os.makedirs(dirname) 18 | 19 | def save_2d_trajectory_v2(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=5, end_times=None, memory=0.01, device='cpu', limit=4): 20 | """ Save the trajectory as a series of photos such that we can easily display on paper / poster """ 21 | model.eval() 22 | 23 | # Sample from prior 24 | z_samples = prior_sampler(2000, 2).to(device) 25 | 26 | with torch.no_grad(): 27 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 28 | logp_samples = prior_logdensity(z_samples) 29 | t = 0 30 | for cnf in model.chain: 31 | 32 | # Construct integration_list 33 | if end_times is None: 34 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 35 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 36 | for i, et in enumerate(end_times[1:]): 37 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device)) 38 | full_times = torch.cat(integration_list, 0) 39 | print('integration_list', integration_list) 40 | 41 | 42 | # Integrate over evenly spaced samples 43 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True) 44 | full_traj = [(z_traj, logpz)] 45 | for i, int_times in enumerate(integration_list[1:]): 46 | prev_z, prev_logp = full_traj[-1] 47 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True) 48 | full_traj.append((z_traj[1:], logpz[1:])) 49 | full_zip = list(zip(*full_traj)) 50 | z_traj = torch.cat(full_zip[0], 0) 51 | #z_logp = torch.cat(full_zip[1], 0) 52 | z_traj = z_traj.cpu().numpy() 53 | 54 | width = z_traj.shape[0] 55 | plt.figure(figsize=(8, 8)) 56 | fig, axes = plt.subplots(1, width, figsize=(4*width, 4), sharex=True, sharey=True) 57 | axes = axes.flatten() 58 | for w in range(width): 59 | # plot the density 60 | ax = axes[w] 61 | K = 13j 62 | y, x = np.mgrid[-0.5:2.5:K, -1.5:1.5:K] 63 | #y, x = np.mgrid[-limit:limit:K, -limit:limit:K] 64 | K = int(K.imag) 65 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) 66 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) 67 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0] 68 | dydt = -dydt.cpu().detach().numpy() 69 | dydt = dydt.reshape(K, K, 2) 70 | 71 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) 72 | ax.quiver( 73 | #x, y, dydt[:, :, 0], -dydt[:, :, 1], 74 | x, y, dydt[:, :, 0], dydt[:, :, 1], 75 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" 76 | ) 77 | ax.set_xlim(-limit, limit) 78 | ax.set_ylim(limit, -limit) 79 | ax.set_xlim(-1.5, 1.5) 80 | ax.set_ylim(-0.5, 2.5) 81 | ax.axis("off") 82 | 83 | ax.scatter(z_traj[w,:,0], z_traj[w,:,1], c='k', s=0.5) 84 | #ax.set_title("Vector Field", fontsize=32) 85 | t += 1 86 | 87 | makedirs(savedir) 88 | plt.tight_layout(pad=0.0) 89 | plt.savefig(os.path.join(savedir, "vector_plot.jpg")) 90 | plt.close 91 | def save_2d_trajectory(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=5, end_times=None, memory=0.01, device='cpu'): 92 | """ Save the trajectory as a series of photos such that we can easily display on paper / poster """ 93 | model.eval() 94 | 95 | # Sample from prior 96 | z_samples = prior_sampler(2000, 2).to(device) 97 | 98 | # sample from a grid 99 | npts = 100 100 | limit = 1.5 101 | side = np.linspace(-limit, limit, npts) 102 | xx, yy = np.meshgrid(side, side) 103 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 104 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 105 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 106 | 107 | with torch.no_grad(): 108 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 109 | logp_samples = prior_logdensity(z_samples) 110 | logp_grid = prior_logdensity(z_grid) 111 | t = 0 112 | for cnf in model.chain: 113 | 114 | # Construct integration_list 115 | if end_times is None: 116 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 117 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 118 | for i, et in enumerate(end_times[1:]): 119 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device)) 120 | full_times = torch.cat(integration_list, 0) 121 | print('integration_list', integration_list) 122 | 123 | 124 | # Integrate over evenly spaced samples 125 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True) 126 | full_traj = [(z_traj, logpz)] 127 | for i, int_times in enumerate(integration_list[1:]): 128 | prev_z, prev_logp = full_traj[-1] 129 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True) 130 | full_traj.append((z_traj[1:], logpz[1:])) 131 | full_zip = list(zip(*full_traj)) 132 | z_traj = torch.cat(full_zip[0], 0) 133 | #z_logp = torch.cat(full_zip[1], 0) 134 | z_traj = z_traj.cpu().numpy() 135 | 136 | grid_z_traj, grid_logpz_traj = [], [] 137 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) 138 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)): 139 | _grid_z_traj, _grid_logpz_traj = cnf( 140 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True 141 | ) 142 | full_traj = [(_grid_z_traj, _grid_logpz_traj)] 143 | for int_times in integration_list[1:]: 144 | prev_z, prev_logp = full_traj[-1] 145 | _grid_z_traj, _grid_logpz_traj = cnf( 146 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True 147 | ) 148 | full_traj.append((_grid_z_traj, _grid_logpz_traj)) 149 | full_zip = list(zip(*full_traj)) 150 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy() 151 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy() 152 | grid_z_traj.append(_grid_z_traj) 153 | grid_logpz_traj.append(_grid_logpz_traj) 154 | 155 | grid_z_traj = np.concatenate(grid_z_traj, axis=1) 156 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) 157 | 158 | width = z_traj.shape[0] 159 | plt.figure(figsize=(8, 8)) 160 | fig, axes = plt.subplots(2, width, figsize=(4*width, 8), sharex=True, sharey=True) 161 | axes = axes.flatten() 162 | for w in range(width): 163 | # plot the density 164 | ax = axes[w] 165 | 166 | z, logqz = grid_z_traj[t], grid_logpz_traj[t] 167 | 168 | xx = z[:, 0].reshape(npts, npts) 169 | yy = z[:, 1].reshape(npts, npts) 170 | qz = np.exp(logqz).reshape(npts, npts) 171 | 172 | ax.pcolormesh(xx, yy, qz) 173 | ax.set_xlim(-limit, limit) 174 | ax.set_ylim(-limit, limit) 175 | cmap = matplotlib.cm.get_cmap(None) 176 | ax.set_facecolor(cmap(0.)) 177 | ax.invert_yaxis() 178 | ax.get_xaxis().set_ticks([]) 179 | ax.get_yaxis().set_ticks([]) 180 | #ax.set_title("Density", fontsize=32) 181 | 182 | # plot vector field 183 | ax = axes[w+width] 184 | 185 | K = 13j 186 | y, x = np.mgrid[-limit:limit:K, -limit:limit:K] 187 | K = int(K.imag) 188 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) 189 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) 190 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0] 191 | dydt = -dydt.cpu().detach().numpy() 192 | dydt = dydt.reshape(K, K, 2) 193 | 194 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) 195 | ax.quiver( 196 | x, y, dydt[:, :, 0], -dydt[:, :, 1], 197 | # x, y, dydt[:, :, 0], dydt[:, :, 1], 198 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" 199 | ) 200 | ax.set_xlim(-limit, limit) 201 | ax.set_ylim(limit, -limit) 202 | ax.axis("off") 203 | #ax.set_title("Vector Field", fontsize=32) 204 | t += 1 205 | 206 | makedirs(savedir) 207 | plt.tight_layout(pad=0.0) 208 | plt.savefig(os.path.join(savedir, "vector_plot.jpg")) 209 | plt.close 210 | 211 | 212 | def save_vectors(prior_logdensity, model, data_samples, full_data, labels, savedir, skip_first=False, ntimes=101, end_times=None, memory=0.01, device='cpu', lim=4): 213 | model.eval() 214 | 215 | # Sample from prior 216 | z_samples = data_samples.to(device) 217 | 218 | with torch.no_grad(): 219 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 220 | logp_samples = prior_logdensity(z_samples) 221 | t = 0 222 | for cnf in model.chain: 223 | # Construct integration_list 224 | if end_times is None: 225 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 226 | # integration_list = [] 227 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 228 | 229 | # Start integration at first end_time 230 | for i, et in enumerate(end_times[1:]): 231 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device)) 232 | # if len(end_times) == 1: 233 | # integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 234 | # print(integration_list) 235 | 236 | 237 | # Integrate over evenly spaced samples 238 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True) 239 | full_traj = [(z_traj, logpz)] 240 | for int_times in integration_list[1:]: 241 | prev_z, prev_logp = full_traj[-1] 242 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True) 243 | full_traj.append((z_traj, logpz)) 244 | full_zip = list(zip(*full_traj)) 245 | z_traj = torch.cat(full_zip[0], 0) 246 | z_traj = z_traj.cpu().numpy() 247 | 248 | # mask out stray negative points 249 | pos_mask = full_data[:,1] >=0 250 | full_data = full_data[pos_mask] 251 | labels = labels[pos_mask] 252 | print(np.unique(labels)) 253 | 254 | plt.figure(figsize=(8, 8)) 255 | ax = plt.subplot(aspect="equal") 256 | ax.scatter(full_data[:,0], full_data[:,1], c=labels.astype(np.int32), cmap='tab10', s=0.5, alpha=1) 257 | # If we do not have a known base density then skip vectors for the first integration. 258 | 259 | z_traj = np.swapaxes(z_traj, 0, 1) 260 | if skip_first: 261 | z_traj = z_traj[:, ntimes:, :] 262 | ax.scatter(z_traj[:,0,0], z_traj[:,0,1], s=20, c='k') 263 | for zk in z_traj: 264 | #for zk in z_traj[:,ntimes:,:]: 265 | ax.scatter(zk[:,0], zk[:,1], s=1, c = np.linspace(0,1,zk.shape[0]), cmap='Spectral') 266 | ax.set_xlim(-lim, lim) 267 | ax.set_ylim(-lim, lim) 268 | #ax.set_ylim(4, -4) 269 | makedirs(savedir) 270 | plt.xticks([]) 271 | plt.yticks([]) 272 | plt.savefig(os.path.join(savedir, f"vectors.jpg"), dpi=300) 273 | t += 1 274 | 275 | def save_trajectory_density(prior_logdensity, model, data_samples, savedir, ntimes=101, end_times=None, memory=0.01, device='cpu'): 276 | model.eval() 277 | 278 | # sample from a grid 279 | #Jnpts = 100 280 | npts = 800 281 | side = np.linspace(-4, 4, npts) 282 | xx, yy = np.meshgrid(side, side) 283 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 284 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 285 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 286 | 287 | with torch.no_grad(): 288 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 289 | logp_grid = prior_logdensity(z_grid) 290 | t = 0 291 | for cnf in model.chain: 292 | # Construct integration_list 293 | if end_times is None: 294 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 295 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 296 | for i, et in enumerate(end_times[1:]): 297 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device)) 298 | full_times = torch.cat(integration_list, 0) 299 | 300 | grid_z_traj, grid_logpz_traj = [], [] 301 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) 302 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)): 303 | _grid_z_traj, _grid_logpz_traj = cnf( 304 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True 305 | ) 306 | full_traj = [(_grid_z_traj, _grid_logpz_traj)] 307 | for int_times in integration_list[1:]: 308 | prev_z, prev_logp = full_traj[-1] 309 | _grid_z_traj, _grid_logpz_traj = cnf( 310 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True 311 | ) 312 | full_traj.append((_grid_z_traj, _grid_logpz_traj)) 313 | full_zip = list(zip(*full_traj)) 314 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy() 315 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy() 316 | print(_grid_z_traj.shape) 317 | grid_z_traj.append(_grid_z_traj) 318 | grid_logpz_traj.append(_grid_logpz_traj) 319 | 320 | grid_z_traj = np.concatenate(grid_z_traj, axis=1)[ntimes:] 321 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)[ntimes:] 322 | 323 | 324 | #plt.figure(figsize=(8, 8)) 325 | #fig, axes = plt.subplots(2,1, gridspec_kw={'height_ratios': [7, 1]}, figsize=(5,7)) 326 | for _ in range(grid_z_traj.shape[0]): 327 | fig, axes = plt.subplots(2,1, gridspec_kw={'height_ratios': [7, 1]}, figsize=(8,8)) 328 | #plt.clf() 329 | ax = axes[0] 330 | # Density 331 | z, logqz = grid_z_traj[t], grid_logpz_traj[t] 332 | 333 | xx = z[:, 0].reshape(npts, npts) 334 | yy = z[:, 1].reshape(npts, npts) 335 | qz = np.exp(logqz).reshape(npts, npts) 336 | 337 | ax.pcolormesh(xx, yy, qz) 338 | ax.set_xlim(-4, 4) 339 | ax.set_ylim(-4, 4) 340 | cmap = matplotlib.cm.get_cmap(None) 341 | ax.set_facecolor(cmap(0.)) 342 | #ax.invert_yaxis() 343 | ax.get_xaxis().set_ticks([]) 344 | ax.get_yaxis().set_ticks([]) 345 | ax.set_title("Density", fontsize=32) 346 | 347 | ax=axes[1] 348 | 349 | # Colorbar 350 | cb = matplotlib.colorbar.ColorbarBase(ax, 351 | #cmap='Spectral', 352 | cmap=plt.cm.Spectral, 353 | orientation='horizontal') 354 | #cb.set_ticks(np.linspace(0,1,4)) 355 | #cb.set_ticklabels(['48HR', 'Day 12', 'Day 18', 'Day 30']) 356 | #cb.set_ticklabels(['E12.5', 'E14.5', 'E16.0', 'E17.5']) 357 | #cb.set_ticks(np.linspace(0,1,5)) 358 | #cb.set_ticklabels(np.arange(len(end_times))) 359 | ax.axvline(t / grid_z_traj.shape[0], c='k', linewidth=15) 360 | ax.set_title('Time') 361 | 362 | print('making dir: %s' % savedir) 363 | makedirs(savedir) 364 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) 365 | plt.close() 366 | t += 1 367 | 368 | 369 | def save_trajectory(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=101, end_times=None, memory=0.01, device='cpu'): 370 | model.eval() 371 | 372 | # Sample from prior 373 | z_samples = prior_sampler(2000, 2).to(device) 374 | 375 | # sample from a grid 376 | npts = 800 377 | side = np.linspace(-4, 4, npts) 378 | xx, yy = np.meshgrid(side, side) 379 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 380 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 381 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 382 | 383 | with torch.no_grad(): 384 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. 385 | logp_samples = prior_logdensity(z_samples) 386 | logp_grid = prior_logdensity(z_grid) 387 | t = 0 388 | for cnf in model.chain: 389 | 390 | # Construct integration_list 391 | if end_times is None: 392 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] 393 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] 394 | for i, et in enumerate(end_times[1:]): 395 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device)) 396 | full_times = torch.cat(integration_list, 0) 397 | print(full_times.shape) 398 | 399 | # Integrate over evenly spaced samples 400 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True) 401 | full_traj = [(z_traj, logpz)] 402 | for int_times in integration_list[1:]: 403 | prev_z, prev_logp = full_traj[-1] 404 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True) 405 | full_traj.append((z_traj, logpz)) 406 | full_zip = list(zip(*full_traj)) 407 | z_traj = torch.cat(full_zip[0], 0) 408 | #z_logp = torch.cat(full_zip[1], 0) 409 | z_traj = z_traj.cpu().numpy() 410 | 411 | grid_z_traj, grid_logpz_traj = [], [] 412 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) 413 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)): 414 | _grid_z_traj, _grid_logpz_traj = cnf( 415 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True 416 | ) 417 | full_traj = [(_grid_z_traj, _grid_logpz_traj)] 418 | for int_times in integration_list[1:]: 419 | prev_z, prev_logp = full_traj[-1] 420 | _grid_z_traj, _grid_logpz_traj = cnf( 421 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True 422 | ) 423 | full_traj.append((_grid_z_traj, _grid_logpz_traj)) 424 | full_zip = list(zip(*full_traj)) 425 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy() 426 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy() 427 | print(_grid_z_traj.shape) 428 | grid_z_traj.append(_grid_z_traj) 429 | grid_logpz_traj.append(_grid_logpz_traj) 430 | 431 | grid_z_traj = np.concatenate(grid_z_traj, axis=1) 432 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) 433 | 434 | plt.figure(figsize=(8, 8)) 435 | for _ in range(z_traj.shape[0]): 436 | 437 | plt.clf() 438 | 439 | # plot target potential function 440 | ax = plt.subplot(2, 2, 1, aspect="equal") 441 | 442 | ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200) 443 | ax.invert_yaxis() 444 | ax.get_xaxis().set_ticks([]) 445 | ax.get_yaxis().set_ticks([]) 446 | ax.set_title("Target", fontsize=32) 447 | 448 | # plot the density 449 | ax = plt.subplot(2, 2, 2, aspect="equal") 450 | 451 | z, logqz = grid_z_traj[t], grid_logpz_traj[t] 452 | 453 | xx = z[:, 0].reshape(npts, npts) 454 | yy = z[:, 1].reshape(npts, npts) 455 | qz = np.exp(logqz).reshape(npts, npts) 456 | 457 | plt.pcolormesh(xx, yy, qz) 458 | ax.set_xlim(-4, 4) 459 | ax.set_ylim(-4, 4) 460 | cmap = matplotlib.cm.get_cmap(None) 461 | ax.set_facecolor(cmap(0.)) 462 | ax.invert_yaxis() 463 | ax.get_xaxis().set_ticks([]) 464 | ax.get_yaxis().set_ticks([]) 465 | ax.set_title("Density", fontsize=32) 466 | 467 | # plot the samples 468 | ax = plt.subplot(2, 2, 3, aspect="equal") 469 | 470 | zk = z_traj[t] 471 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200) 472 | ax.invert_yaxis() 473 | ax.get_xaxis().set_ticks([]) 474 | ax.get_yaxis().set_ticks([]) 475 | ax.set_title("Samples", fontsize=32) 476 | 477 | # plot vector field 478 | ax = plt.subplot(2, 2, 4, aspect="equal") 479 | 480 | K = 13j 481 | y, x = np.mgrid[-4:4:K, -4:4:K] 482 | K = int(K.imag) 483 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) 484 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) 485 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0] 486 | dydt = -dydt.cpu().detach().numpy() 487 | dydt = dydt.reshape(K, K, 2) 488 | 489 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) 490 | ax.quiver( 491 | x, y, dydt[:, :, 0], -dydt[:, :, 1], 492 | # x, y, dydt[:, :, 0], dydt[:, :, 1], 493 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" 494 | ) 495 | ax.set_xlim(-4, 4) 496 | ax.set_ylim(4, -4) 497 | #ax.set_ylim(-4, 4) 498 | ax.axis("off") 499 | ax.set_title("Vector Field", fontsize=32) 500 | 501 | makedirs(savedir) 502 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) 503 | t += 1 504 | 505 | 506 | def trajectory_to_video(savedir): 507 | import subprocess 508 | bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4')) 509 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 510 | output, error = process.communicate() 511 | 512 | 513 | if __name__ == '__main__': 514 | import argparse 515 | import sys 516 | 517 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))) 518 | 519 | import lib.toy_data as toy_data 520 | from train_misc import count_parameters 521 | from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns 522 | from train_misc import build_model_tabular 523 | 524 | def get_ckpt_model_and_data(args): 525 | # Load checkpoint. 526 | checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) 527 | ckpt_args = checkpt['args'] 528 | state_dict = checkpt['state_dict'] 529 | 530 | # Construct model and restore checkpoint. 531 | regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) 532 | model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) 533 | if ckpt_args.spectral_norm: add_spectral_norm(model) 534 | set_cnf_options(ckpt_args, model) 535 | 536 | model.load_state_dict(state_dict) 537 | model.to(device) 538 | 539 | print(model) 540 | print("Number of trainable parameters: {}".format(count_parameters(model))) 541 | 542 | # Load samples from dataset 543 | data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) 544 | 545 | return model, data_samples 546 | 547 | parser = argparse.ArgumentParser() 548 | parser.add_argument('--checkpt', type=str, required=True) 549 | parser.add_argument('--ntimes', type=int, default=101) 550 | parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.') 551 | parser.add_argument('--save', type=str, default='trajectory') 552 | args = parser.parse_args() 553 | 554 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 555 | model, data_samples = get_ckpt_model_and_data(args) 556 | save_trajectory(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device) 557 | trajectory_to_video(args.save) 558 | -------------------------------------------------------------------------------- /TrajectoryNet/main.py: -------------------------------------------------------------------------------- 1 | """ main.py 2 | 3 | Learns ODE from scrna data 4 | 5 | """ 6 | import os 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import time 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | 16 | from TrajectoryNet.lib.growth_net import GrowthNet 17 | from TrajectoryNet.lib import utils 18 | from TrajectoryNet.lib.visualize_flow import visualize_transform 19 | from TrajectoryNet.lib.viz_scrna import ( 20 | save_trajectory, 21 | trajectory_to_video, 22 | save_vectors, 23 | ) 24 | from TrajectoryNet.lib.viz_scrna import save_trajectory_density 25 | 26 | 27 | # from train_misc import standard_normal_logprob 28 | from TrajectoryNet.train_misc import ( 29 | set_cnf_options, 30 | count_nfe, 31 | count_parameters, 32 | count_total_time, 33 | add_spectral_norm, 34 | spectral_norm_power_iteration, 35 | create_regularization_fns, 36 | get_regularization, 37 | append_regularization_to_log, 38 | build_model_tabular, 39 | ) 40 | 41 | from TrajectoryNet import dataset 42 | from TrajectoryNet.parse import parser 43 | 44 | matplotlib.use("Agg") 45 | 46 | 47 | def get_transforms(device, args, model, integration_times): 48 | """ 49 | Given a list of integration points, 50 | returns a function giving integration times 51 | """ 52 | 53 | def sample_fn(z, logpz=None): 54 | int_list = [ 55 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device) 56 | for it in integration_times 57 | ] 58 | if logpz is not None: 59 | # TODO this works right? 60 | for it in int_list: 61 | z, logpz = model(z, logpz, integration_times=it, reverse=True) 62 | return z, logpz 63 | else: 64 | for it in int_list: 65 | z = model(z, integration_times=it, reverse=True) 66 | return z 67 | 68 | def density_fn(x, logpx=None): 69 | int_list = [ 70 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device) 71 | for it in integration_times[::-1] 72 | ] 73 | if logpx is not None: 74 | for it in int_list: 75 | x, logpx = model(x, logpx, integration_times=it, reverse=False) 76 | return x, logpx 77 | else: 78 | for it in int_list: 79 | x = model(x, integration_times=it, reverse=False) 80 | return x 81 | 82 | return sample_fn, density_fn 83 | 84 | 85 | def compute_loss(device, args, model, growth_model, logger, full_data): 86 | """ 87 | Compute loss by integrating backwards from the last time step 88 | At each time step integrate back one time step, and concatenate that 89 | to samples of the empirical distribution at that previous timestep 90 | repeating over and over to calculate the likelihood of samples in 91 | later timepoints iteratively, making sure that the ODE is evaluated 92 | at every time step to calculate those later points. 93 | 94 | The growth model is a single model of time independent cell growth / 95 | death rate defined as a variation from uniform. 96 | """ 97 | 98 | # Backward pass accumulating losses, previous state and deltas 99 | deltas = [] 100 | zs = [] 101 | z = None 102 | interp_loss = 0.0 103 | for i, (itp, tp) in enumerate(zip(args.int_tps[::-1], args.timepoints[::-1])): 104 | # tp counts down from last 105 | integration_times = torch.tensor([itp - args.time_scale, itp]) 106 | integration_times = integration_times.type(torch.float32).to(device) 107 | # integration_times.requires_grad = True 108 | 109 | # load data and add noise 110 | idx = args.data.sample_index(args.batch_size, tp) 111 | x = args.data.get_data()[idx] 112 | if args.training_noise > 0.0: 113 | x += np.random.randn(*x.shape) * args.training_noise 114 | x = torch.from_numpy(x).type(torch.float32).to(device) 115 | 116 | if i > 0: 117 | x = torch.cat((z, x)) 118 | zs.append(z) 119 | zero = torch.zeros(x.shape[0], 1).to(x) 120 | 121 | # transform to previous timepoint 122 | z, delta_logp = model(x, zero, integration_times=integration_times) 123 | deltas.append(delta_logp) 124 | 125 | # Straightline regularization 126 | # Integrate to random point at time t and assert close to (1 - t) * end + t * start 127 | if args.interp_reg: 128 | t = np.random.rand() 129 | int_t = torch.tensor([itp - t * args.time_scale, itp]) 130 | int_t = int_t.type(torch.float32).to(device) 131 | int_x = model(x, integration_times=int_t) 132 | int_x = int_x.detach() 133 | actual_int_x = x * (1 - t) + z * t 134 | interp_loss += F.mse_loss(int_x, actual_int_x) 135 | if args.interp_reg: 136 | print("interp_loss", interp_loss) 137 | 138 | logpz = args.data.base_density()(z) 139 | 140 | # build growth rates 141 | if args.use_growth: 142 | growthrates = [torch.ones_like(logpz)] 143 | for z_state, tp in zip(zs[::-1], args.timepoints[:-1]): 144 | # Full state includes time parameter to growth_model 145 | time_state = tp * torch.ones(z_state.shape[0], 1).to(z_state) 146 | full_state = torch.cat([z_state, time_state], 1) 147 | growthrates.append(growth_model(full_state)) 148 | 149 | # Accumulate losses 150 | losses = [] 151 | logps = [logpz] 152 | for i, delta_logp in enumerate(deltas[::-1]): 153 | logpx = logps[-1] - delta_logp 154 | if args.use_growth: 155 | logpx += torch.log(torch.clamp(growthrates[i], 1e-4, 1e4)) 156 | logps.append(logpx[: -args.batch_size]) 157 | losses.append(-torch.mean(logpx[-args.batch_size :])) 158 | losses = torch.stack(losses) 159 | weights = torch.ones_like(losses).to(logpx) 160 | if args.leaveout_timepoint >= 0: 161 | weights[args.leaveout_timepoint] = 0 162 | losses = torch.mean(losses * weights) 163 | 164 | # Direction regularization 165 | if args.vecint: 166 | similarity_loss = 0 167 | for i, (itp, tp) in enumerate(zip(args.int_tps, args.timepoints)): 168 | itp = torch.tensor(itp).type(torch.float32).to(device) 169 | idx = args.data.sample_index(args.batch_size, tp) 170 | x = args.data.get_data()[idx] 171 | v = args.data.get_velocity()[idx] 172 | x = torch.from_numpy(x).type(torch.float32).to(device) 173 | v = torch.from_numpy(v).type(torch.float32).to(device) 174 | x += torch.randn_like(x) * 0.1 175 | # Only penalizes at the time / place of visible samples 176 | direction = -model.chain[0].odefunc.odefunc.diffeq(itp, x) 177 | if args.use_magnitude: 178 | similarity_loss += torch.mean(F.mse_loss(direction, v)) 179 | else: 180 | similarity_loss -= torch.mean(F.cosine_similarity(direction, v)) 181 | logger.info(similarity_loss) 182 | losses += similarity_loss * args.vecint 183 | 184 | # Density regularization 185 | if args.top_k_reg > 0: 186 | density_loss = 0 187 | tp_z_map = dict(zip(args.timepoints[:-1], zs[::-1])) 188 | if args.leaveout_timepoint not in tp_z_map: 189 | idx = args.data.sample_index(args.batch_size, tp) 190 | x = args.data.get_data()[idx] 191 | if args.training_noise > 0.0: 192 | x += np.random.randn(*x.shape) * args.training_noise 193 | x = torch.from_numpy(x).type(torch.float32).to(device) 194 | t = np.random.rand() 195 | int_t = torch.tensor([itp - t * args.time_scale, itp]) 196 | int_t = int_t.type(torch.float32).to(device) 197 | int_x = model(x, integration_times=int_t) 198 | samples_05 = int_x 199 | else: 200 | # If we are leaving out a timepoint the regularize there 201 | samples_05 = tp_z_map[args.leaveout_timepoint] 202 | 203 | # Calculate distance to 5 closest neighbors 204 | # WARNING: This currently fails in the backward pass with cuda on pytorch < 1.4.0 205 | # works on CPU. Fixed in pytorch 1.5.0 206 | # RuntimeError: CUDA error: invalid configuration argument 207 | # The workaround is to run on cpu on pytorch <= 1.4.0 or upgrade 208 | cdist = torch.cdist(samples_05, full_data) 209 | values, _ = torch.topk(cdist, 5, dim=1, largest=False, sorted=False) 210 | # Hinge loss 211 | hinge_value = 0.1 212 | values -= hinge_value 213 | values[values < 0] = 0 214 | density_loss = torch.mean(values) 215 | print("Density Loss", density_loss.item()) 216 | losses += density_loss * args.top_k_reg 217 | losses += interp_loss 218 | return losses 219 | 220 | 221 | def train( 222 | device, args, model, growth_model, regularization_coeffs, regularization_fns, logger 223 | ): 224 | optimizer = optim.Adam( 225 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 226 | ) 227 | 228 | time_meter = utils.RunningAverageMeter(0.93) 229 | loss_meter = utils.RunningAverageMeter(0.93) 230 | nfef_meter = utils.RunningAverageMeter(0.93) 231 | nfeb_meter = utils.RunningAverageMeter(0.93) 232 | tt_meter = utils.RunningAverageMeter(0.93) 233 | 234 | full_data = ( 235 | torch.from_numpy( 236 | args.data.get_data()[args.data.get_times() != args.leaveout_timepoint] 237 | ) 238 | .type(torch.float32) 239 | .to(device) 240 | ) 241 | 242 | best_loss = float("inf") 243 | if args.use_growth: 244 | growth_model.eval() 245 | end = time.time() 246 | for itr in range(1, args.niters + 1): 247 | model.train() 248 | optimizer.zero_grad() 249 | 250 | # Train 251 | if args.spectral_norm: 252 | spectral_norm_power_iteration(model, 1) 253 | 254 | loss = compute_loss(device, args, model, growth_model, logger, full_data) 255 | loss_meter.update(loss.item()) 256 | 257 | if len(regularization_coeffs) > 0: 258 | # Only regularize on the last timepoint 259 | reg_states = get_regularization(model, regularization_coeffs) 260 | reg_loss = sum( 261 | reg_state * coeff 262 | for reg_state, coeff in zip(reg_states, regularization_coeffs) 263 | if coeff != 0 264 | ) 265 | loss = loss + reg_loss 266 | total_time = count_total_time(model) 267 | nfe_forward = count_nfe(model) 268 | 269 | loss.backward() 270 | optimizer.step() 271 | 272 | # Eval 273 | nfe_total = count_nfe(model) 274 | nfe_backward = nfe_total - nfe_forward 275 | nfef_meter.update(nfe_forward) 276 | nfeb_meter.update(nfe_backward) 277 | time_meter.update(time.time() - end) 278 | tt_meter.update(total_time) 279 | 280 | log_message = ( 281 | "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |" 282 | " NFE Forward {:.0f}({:.1f})" 283 | " | NFE Backward {:.0f}({:.1f})".format( 284 | itr, 285 | time_meter.val, 286 | time_meter.avg, 287 | loss_meter.val, 288 | loss_meter.avg, 289 | nfef_meter.val, 290 | nfef_meter.avg, 291 | nfeb_meter.val, 292 | nfeb_meter.avg, 293 | ) 294 | ) 295 | if len(regularization_coeffs) > 0: 296 | log_message = append_regularization_to_log( 297 | log_message, regularization_fns, reg_states 298 | ) 299 | logger.info(log_message) 300 | 301 | if itr % args.val_freq == 0 or itr == args.niters: 302 | with torch.no_grad(): 303 | train_eval( 304 | device, args, model, growth_model, itr, best_loss, logger, full_data 305 | ) 306 | 307 | if itr % args.viz_freq == 0: 308 | if args.data.get_shape()[0] > 2: 309 | logger.warning("Skipping vis as data dimension is >2") 310 | else: 311 | with torch.no_grad(): 312 | visualize(device, args, model, itr) 313 | if itr % args.save_freq == 0: 314 | chkpt = { 315 | "state_dict": model.state_dict(), 316 | } 317 | if args.use_growth: 318 | chkpt.update({"growth_state_dict": growth_model.state_dict()}) 319 | utils.save_checkpoint( 320 | chkpt, 321 | args.save, 322 | epoch=itr, 323 | ) 324 | end = time.time() 325 | logger.info("Training has finished.") 326 | 327 | 328 | def train_eval(device, args, model, growth_model, itr, best_loss, logger, full_data): 329 | model.eval() 330 | test_loss = compute_loss(device, args, model, growth_model, logger, full_data) 331 | test_nfe = count_nfe(model) 332 | log_message = "[TEST] Iter {:04d} | Test Loss {:.6f} |" " NFE {:.0f}".format( 333 | itr, test_loss, test_nfe 334 | ) 335 | logger.info(log_message) 336 | utils.makedirs(args.save) 337 | with open(os.path.join(args.save, "train_eval.csv"), "a") as f: 338 | import csv 339 | 340 | writer = csv.writer(f) 341 | writer.writerow((itr, test_loss)) 342 | 343 | if test_loss.item() < best_loss: 344 | best_loss = test_loss.item() 345 | chkpt = { 346 | "state_dict": model.state_dict(), 347 | } 348 | if args.use_growth: 349 | chkpt.update({"growth_state_dict": growth_model.state_dict()}) 350 | torch.save( 351 | chkpt, 352 | os.path.join(args.save, "checkpt.pth"), 353 | ) 354 | 355 | 356 | def visualize(device, args, model, itr): 357 | model.eval() 358 | for i, tp in enumerate(args.timepoints): 359 | idx = args.data.sample_index(args.viz_batch_size, tp) 360 | p_samples = args.data.get_data()[idx] 361 | sample_fn, density_fn = get_transforms( 362 | device, args, model, args.int_tps[: i + 1] 363 | ) 364 | plt.figure(figsize=(9, 3)) 365 | visualize_transform( 366 | p_samples, 367 | args.data.base_sample(), 368 | args.data.base_density(), 369 | transform=sample_fn, 370 | inverse_transform=density_fn, 371 | samples=True, 372 | npts=100, 373 | device=device, 374 | ) 375 | fig_filename = os.path.join( 376 | args.save, "figs", "{:04d}_{:01d}.jpg".format(itr, i) 377 | ) 378 | utils.makedirs(os.path.dirname(fig_filename)) 379 | plt.savefig(fig_filename) 380 | plt.close() 381 | 382 | 383 | def plot_output(device, args, model): 384 | save_traj_dir = os.path.join(args.save, "trajectory") 385 | # logger.info('Plotting trajectory to {}'.format(save_traj_dir)) 386 | data_samples = args.data.get_data()[args.data.sample_index(2000, 0)] 387 | np.random.seed(42) 388 | start_points = args.data.base_sample()(1000, 2) 389 | # idx = args.data.sample_index(50, 0) 390 | # start_points = args.data.get_data()[idx] 391 | # start_points = torch.from_numpy(start_points).type(torch.float32) 392 | save_vectors( 393 | args.data.base_density(), 394 | model, 395 | start_points, 396 | args.data.get_data(), 397 | args.data.get_times(), 398 | args.save, 399 | skip_first=(not args.data.known_base_density()), 400 | device=device, 401 | end_times=args.int_tps, 402 | ntimes=100, 403 | ) 404 | 405 | save_trajectory( 406 | args.data.base_density(), 407 | args.data.base_sample(), 408 | model, 409 | data_samples, 410 | save_traj_dir, 411 | device=device, 412 | end_times=args.int_tps, 413 | ntimes=25, 414 | ) 415 | 416 | density_dir = os.path.join(args.save, "density2") 417 | save_trajectory_density( 418 | args.data.base_density(), 419 | model, 420 | data_samples, 421 | density_dir, 422 | device=device, 423 | end_times=args.int_tps, 424 | ntimes=25, 425 | memory=0.1, 426 | ) 427 | 428 | if args.save_movie: 429 | trajectory_to_video(save_traj_dir) 430 | trajectory_to_video(density_dir) 431 | 432 | 433 | def main(args): 434 | # logger 435 | print(args.no_display_loss) 436 | utils.makedirs(args.save) 437 | logger = utils.get_logger( 438 | logpath=os.path.join(args.save, "logs"), 439 | filepath=os.path.abspath(__file__), 440 | displaying=~args.no_display_loss, 441 | ) 442 | 443 | if args.layer_type == "blend": 444 | logger.info("!! Setting time_scale from None to 1.0 for Blend layers.") 445 | args.time_scale = 1.0 446 | 447 | logger.info(args) 448 | 449 | device = torch.device( 450 | "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" 451 | ) 452 | if args.use_cpu: 453 | device = torch.device("cpu") 454 | 455 | args.data = dataset.SCData.factory(args.dataset, args) 456 | 457 | args.timepoints = args.data.get_unique_times() 458 | # Use maximum timepoint to establish integration_times 459 | # as some timepoints may be left out for validation etc. 460 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale 461 | 462 | regularization_fns, regularization_coeffs = create_regularization_fns(args) 463 | model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to( 464 | device 465 | ) 466 | growth_model = None 467 | if args.use_growth: 468 | if args.leaveout_timepoint == -1: 469 | growth_model_path = "../data/externel/growth_model_v2.ckpt" 470 | elif args.leaveout_timepoint in [1, 2, 3]: 471 | assert args.max_dim == 5 472 | growth_model_path = "../data/growth/model_%d" % args.leaveout_timepoint 473 | else: 474 | print("WARNING: Cannot use growth with this timepoint") 475 | 476 | growth_model = torch.load(growth_model_path, map_location=device) 477 | if args.spectral_norm: 478 | add_spectral_norm(model) 479 | set_cnf_options(args, model) 480 | 481 | if args.test: 482 | state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) 483 | model.load_state_dict(state_dict["state_dict"]) 484 | # if "growth_state_dict" not in state_dict: 485 | # print("error growth model note in save") 486 | # growth_model = None 487 | # else: 488 | # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device) 489 | # growth_model.load_state_dict(checkpt["growth_state_dict"]) 490 | # TODO can we load the arguments from the save? 491 | # eval_utils.generate_samples( 492 | # device, args, model, growth_model, timepoint=args.leaveout_timepoint 493 | # ) 494 | # with torch.no_grad(): 495 | # evaluate(device, args, model, growth_model) 496 | # exit() 497 | else: 498 | logger.info(model) 499 | n_param = count_parameters(model) 500 | logger.info("Number of trainable parameters: {}".format(n_param)) 501 | 502 | train( 503 | device, 504 | args, 505 | model, 506 | growth_model, 507 | regularization_coeffs, 508 | regularization_fns, 509 | logger, 510 | ) 511 | 512 | if args.data.data.shape[1] == 2: 513 | plot_output(device, args, model) 514 | 515 | 516 | if __name__ == "__main__": 517 | 518 | args = parser.parse_args() 519 | main(args) 520 | -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/MMD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | import torch 6 | 7 | min_var_est = 1e-8 8 | 9 | 10 | # Consider linear time MMD with a linear kernel: 11 | # K(f(x), f(y)) = f(x)^Tf(y) 12 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 13 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 14 | # 15 | # f_of_X: batch_size * k 16 | # f_of_Y: batch_size * k 17 | def linear_mmd2(f_of_X, f_of_Y): 18 | loss = 0.0 19 | delta = f_of_X - f_of_Y 20 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 21 | return loss 22 | 23 | 24 | # Consider linear time MMD with a polynomial kernel: 25 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 26 | # f_of_X: batch_size * k 27 | # f_of_Y: batch_size * k 28 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 29 | K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c 30 | K_XX_mean = torch.mean(K_XX.pow(d)) 31 | 32 | K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c 33 | K_YY_mean = torch.mean(K_YY.pow(d)) 34 | 35 | K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c 36 | K_XY_mean = torch.mean(K_XY.pow(d)) 37 | 38 | K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c 39 | K_YX_mean = torch.mean(K_YX.pow(d)) 40 | 41 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 42 | 43 | 44 | def _mix_rbf_kernel(X, Y, sigma_list): 45 | assert X.size(0) == Y.size(0) 46 | m = X.size(0) 47 | 48 | Z = torch.cat((X, Y), 0) 49 | ZZT = torch.mm(Z, Z.t()) 50 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 51 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 52 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 53 | 54 | K = 0.0 55 | for sigma in sigma_list: 56 | gamma = 1.0 / (2 * sigma**2) 57 | K += torch.exp(-gamma * exponent) 58 | 59 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 60 | 61 | 62 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 63 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 64 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 65 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 66 | 67 | 68 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 69 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 70 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 71 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 72 | 73 | 74 | ################################################################################ 75 | # Helper functions to compute variances based on kernel matrices 76 | ################################################################################ 77 | 78 | 79 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 80 | m = K_XX.size(0) # assume X, Y are same shape 81 | 82 | # Get the various sums of kernels that we'll use 83 | # Kts drop the diagonal, but we don't need to compute them explicitly 84 | if const_diagonal is not False: 85 | diag_X = diag_Y = const_diagonal 86 | sum_diag_X = sum_diag_Y = m * const_diagonal 87 | else: 88 | diag_X = torch.diag(K_XX) # (m,) 89 | diag_Y = torch.diag(K_YY) # (m,) 90 | sum_diag_X = torch.sum(diag_X) 91 | sum_diag_Y = torch.sum(diag_Y) 92 | 93 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 94 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 95 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 96 | 97 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 98 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 99 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 100 | 101 | if biased: 102 | mmd2 = ( 103 | (Kt_XX_sum + sum_diag_X) / (m * m) 104 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 105 | - 2.0 * K_XY_sum / (m * m) 106 | ) 107 | else: 108 | mmd2 = ( 109 | Kt_XX_sum / (m * (m - 1)) 110 | + Kt_YY_sum / (m * (m - 1)) 111 | - 2.0 * K_XY_sum / (m * m) 112 | ) 113 | 114 | return mmd2 115 | 116 | 117 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 118 | mmd2, var_est = _mmd2_and_variance( 119 | K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased 120 | ) 121 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 122 | return loss, mmd2, var_est 123 | 124 | 125 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 126 | m = K_XX.size(0) # assume X, Y are same shape 127 | 128 | # Get the various sums of kernels that we'll use 129 | # Kts drop the diagonal, but we don't need to compute them explicitly 130 | if const_diagonal is not False: 131 | diag_X = diag_Y = const_diagonal 132 | sum_diag_X = sum_diag_Y = m * const_diagonal 133 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 134 | else: 135 | diag_X = torch.diag(K_XX) # (m,) 136 | diag_Y = torch.diag(K_YY) # (m,) 137 | sum_diag_X = torch.sum(diag_X) 138 | sum_diag_Y = torch.sum(diag_Y) 139 | sum_diag2_X = diag_X.dot(diag_X) 140 | sum_diag2_Y = diag_Y.dot(diag_Y) 141 | 142 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 143 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 144 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 145 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 146 | 147 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 148 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 149 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 150 | 151 | Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 152 | Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 153 | K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2 154 | 155 | if biased: 156 | mmd2 = ( 157 | (Kt_XX_sum + sum_diag_X) / (m * m) 158 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 159 | - 2.0 * K_XY_sum / (m * m) 160 | ) 161 | else: 162 | mmd2 = ( 163 | Kt_XX_sum / (m * (m - 1)) 164 | + Kt_YY_sum / (m * (m - 1)) 165 | - 2.0 * K_XY_sum / (m * m) 166 | ) 167 | 168 | var_est = ( 169 | 2.0 170 | / (m**2 * (m - 1.0) ** 2) 171 | * ( 172 | 2 * Kt_XX_sums.dot(Kt_XX_sums) 173 | - Kt_XX_2_sum 174 | + 2 * Kt_YY_sums.dot(Kt_YY_sums) 175 | - Kt_YY_2_sum 176 | ) 177 | - (4.0 * m - 6.0) 178 | / (m**3 * (m - 1.0) ** 3) 179 | * (Kt_XX_sum**2 + Kt_YY_sum**2) 180 | + 4.0 181 | * (m - 2.0) 182 | / (m**3 * (m - 1.0) ** 2) 183 | * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 184 | - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum) 185 | - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 186 | + 8.0 187 | / (m**3 * (m - 1.0)) 188 | * ( 189 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 190 | - Kt_XX_sums.dot(K_XY_sums_1) 191 | - Kt_YY_sums.dot(K_XY_sums_0) 192 | ) 193 | ) 194 | return mmd2, var_est 195 | -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/__init__.py -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/emd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ot as pot # Python Optimal Transport package 3 | import scipy.sparse 4 | from sklearn.metrics.pairwise import pairwise_distances 5 | 6 | 7 | def earth_mover_distance( 8 | p, 9 | q, 10 | eigenvals=None, 11 | weights1=None, 12 | weights2=None, 13 | return_matrix=False, 14 | metric="sqeuclidean", 15 | ): 16 | """ 17 | Returns the earth mover's distance between two point clouds 18 | Parameters 19 | ---------- 20 | cloud1 : 2-D array 21 | First point cloud 22 | cloud2 : 2-D array 23 | Second point cloud 24 | Returns 25 | ------- 26 | distance : float 27 | The distance between the two point clouds 28 | """ 29 | p = p.toarray() if scipy.sparse.isspmatrix(p) else p 30 | q = q.toarray() if scipy.sparse.isspmatrix(q) else q 31 | if eigenvals is not None: 32 | p = p.dot(eigenvals) 33 | q = q.dot(eigenvals) 34 | if weights1 is None: 35 | p_weights = np.ones(len(p)) / len(p) 36 | else: 37 | weights1 = weights1.astype("float64") 38 | p_weights = weights1 / weights1.sum() 39 | 40 | if weights2 is None: 41 | q_weights = np.ones(len(q)) / len(q) 42 | else: 43 | weights2 = weights2.astype("float64") 44 | q_weights = weights2 / weights2.sum() 45 | 46 | pairwise_dist = np.ascontiguousarray( 47 | pairwise_distances(p, Y=q, metric=metric, n_jobs=-1) 48 | ) 49 | 50 | result = pot.emd2( 51 | p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=return_matrix 52 | ) 53 | if return_matrix: 54 | square_emd, log_dict = result 55 | return np.sqrt(square_emd), log_dict 56 | else: 57 | return np.sqrt(result) 58 | 59 | 60 | def interpolate_with_ot(p0, p1, tmap, interp_frac, size): 61 | """ 62 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1 63 | Parameters 64 | ---------- 65 | p0 : 2-D array 66 | The genes of each cell in the source population 67 | p1 : 2-D array 68 | The genes of each cell in the destination population 69 | tmap : 2-D array 70 | A transport map from p0 to p1 71 | t_interpolate : float 72 | The fraction at which to interpolate 73 | size : int 74 | The number of cells in the interpolated population 75 | Returns 76 | ------- 77 | p05 : 2-D array 78 | An interpolated population of 'size' cells 79 | """ 80 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 81 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 82 | p0 = np.asarray(p0, dtype=np.float64) 83 | p1 = np.asarray(p1, dtype=np.float64) 84 | tmap = np.asarray(tmap, dtype=np.float64) 85 | if p0.shape[1] != p1.shape[1]: 86 | raise ValueError("Unable to interpolate. Number of genes do not match") 87 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 88 | raise ValueError( 89 | "Unable to interpolate. Tmap size is {}, expected {}".format( 90 | tmap.shape, (len(p0), len(p1)) 91 | ) 92 | ) 93 | I = len(p0) 94 | J = len(p1) 95 | # Assume growth is exponential and retrieve growth rate at t_interpolate 96 | # If all sums are the same then this does not change anything 97 | # This only matters if sum is not the same for all rows 98 | p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 99 | p = p.flatten(order="C") 100 | p = p / p.sum() 101 | choices = np.random.choice(I * J, p=p, size=size) 102 | return np.asarray( 103 | [p0[i // J] * (1 - interp_frac) + p1[i % J] * interp_frac for i in choices], 104 | dtype=np.float64, 105 | ) 106 | 107 | 108 | def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): 109 | """ 110 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1 111 | Parameters 112 | ---------- 113 | p0 : 2-D array 114 | The genes of each cell in the source population 115 | p1 : 2-D array 116 | The genes of each cell in the destination population 117 | tmap : 2-D array 118 | A transport map from p0 to p1 119 | t_interpolate : float 120 | The fraction at which to interpolate 121 | Returns 122 | ------- 123 | p05 : 2-D array 124 | An interpolated population of 'size' cells 125 | """ 126 | assert len(p0) == len(p1) 127 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 128 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 129 | p0 = np.asarray(p0, dtype=np.float64) 130 | p1 = np.asarray(p1, dtype=np.float64) 131 | tmap = np.asarray(tmap, dtype=np.float64) 132 | if p0.shape[1] != p1.shape[1]: 133 | raise ValueError("Unable to interpolate. Number of genes do not match") 134 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 135 | raise ValueError( 136 | "Unable to interpolate. Tmap size is {}, expected {}".format( 137 | tmap.shape, (len(p0), len(p1)) 138 | ) 139 | ) 140 | 141 | I = len(p0) 142 | J = len(p1) 143 | # Assume growth is exponential and retrieve growth rate at t_interpolate 144 | # If all sums are the same then this does not change anything 145 | # This only matters if sum is not the same for all rows 146 | p = tmap / (tmap.sum(axis=0) / 1.0 - interp_frac) 147 | # p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 148 | # p = p.flatten(order="C") 149 | p = p / p.sum(axis=0) 150 | choices = np.array([np.random.choice(I, p=p[i]) for i in range(I)]) 151 | return np.asarray( 152 | [ 153 | p0[i] * (1 - interp_frac) + p1[j] * interp_frac 154 | for i, j in enumerate(choices) 155 | ], 156 | dtype=np.float64, 157 | ) 158 | -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/gcs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/gcs.npy -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/growth/traj.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/growth/traj.mp4 -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/model -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/plot_UOT_1D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | =============================== 4 | 1D Unbalanced optimal transport 5 | =============================== 6 | 7 | This example illustrates the computation of Unbalanced Optimal transport 8 | using a Kullback-Leibler relaxation. 9 | """ 10 | 11 | # Author: Hicham Janati 12 | # 13 | # License: MIT License 14 | 15 | import numpy as np 16 | import matplotlib.pylab as pl 17 | import ot 18 | import ot.plot 19 | from ot.datasets import make_1D_gauss as gauss 20 | 21 | ############################################################################## 22 | # Generate data 23 | # ------------- 24 | 25 | 26 | #%% parameters 27 | 28 | n = 2 # nb bins 29 | 30 | # bin positions 31 | x = np.arange(n, dtype=np.float64) 32 | 33 | # Gaussian distributions 34 | a = gauss(n, m=20, s=1) # m= mean, s= std 35 | b = gauss(n, m=60, s=1) 36 | 37 | a = [0.1, 0.9] 38 | b = [0.9, 0.1] 39 | 40 | # make distributions unbalanced 41 | # b *= 5. 42 | 43 | # loss matrix 44 | M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) 45 | M /= M.max() 46 | 47 | 48 | ############################################################################## 49 | # Plot distributions and loss matrix 50 | # ---------------------------------- 51 | 52 | #%% plot the distributions 53 | 54 | pl.figure(1, figsize=(6.4, 3)) 55 | pl.plot(x, a, "b", label="Source distribution") 56 | pl.plot(x, b, "r", label="Target distribution") 57 | pl.legend() 58 | 59 | # plot distributions and loss matrix 60 | 61 | pl.figure(2, figsize=(5, 5)) 62 | ot.plot.plot1D_mat(a, b, M, "Cost matrix M") 63 | 64 | 65 | ############################################################################## 66 | # Solve Unbalanced Sinkhorn 67 | # -------------- 68 | 69 | 70 | def get_transform_matrix(gamma, a, epsilon=1e-8): 71 | """Return matrix such that T @ a = b 72 | gamma : gamma @ 1 = a; gamma^T @ 1 = b 73 | """ 74 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T 75 | 76 | 77 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False): 78 | T = get_transform_matrix(gamma, a, epsilon) 79 | unnormalized_coeffs = np.sum(T, axis=0) 80 | if not normalize: 81 | return unnormalized_coeffs 82 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs) 83 | 84 | 85 | # Sinkhorn 86 | 87 | epsilon = 0.1 # entropy parameter 88 | alpha = 1 # Unbalanced KL relaxation parameter 89 | beta = 10000 90 | # Gs = ot.emd(a, b, M) 91 | Gs = sinkhorn_knopp_unbalanced(a, b, M, epsilon, alpha, beta, verbose=True) 92 | print(Gs) 93 | print(a, b) 94 | print(get_growth_coeffs(Gs, np.array(a))) 95 | print(get_transform_matrix(Gs, np.array(a)) @ a) 96 | print(get_growth_coeffs(Gs, np.array(a)) * a) 97 | exit() 98 | print(Gs) 99 | print(Gs @ np.ones_like(a)) 100 | print("bbbbbbbbbb") 101 | tt = get_transform_matrix(Gs, np.array(a)) 102 | print(tt) 103 | print("col_sum(tt)", np.sum(tt, axis=0)) 104 | print("tt @ a", tt @ a) 105 | print("bbbbbbbbbb") 106 | print(np.sum(Gs, axis=0), np.sum(Gs, axis=1)) 107 | print("aaaaaaa == 1") 108 | alpha = 1 # Unbalanced KL relaxation parameter 109 | Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) 110 | print(Gs) 111 | 112 | pl.figure(4, figsize=(5, 5)) 113 | ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn") 114 | 115 | # pl.show() 116 | -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/sinkhorn_knopp_unbalanced.py: -------------------------------------------------------------------------------- 1 | """ Implements unbalanced sinkhorn knopp optimization for unbalanced ot. 2 | 3 | This is from the package python optimal transport but modified to take 4 | three regularization parameters instead of two. This is necessary to find 5 | growth rates of the source distribution that best match the target distribution 6 | or vis versa. by setting reg_m_1 to something low and reg_m_2 to something 7 | large we can compute an unbalanced optimal transport where all the scaling is 8 | done on the source distribution and none is done on the target distribution. 9 | """ 10 | import numpy as np 11 | import warnings 12 | 13 | 14 | def sinkhorn_knopp_unbalanced( 15 | a, 16 | b, 17 | M, 18 | reg, 19 | reg_m_1, 20 | reg_m_2, 21 | numItermax=1000, 22 | stopThr=1e-6, 23 | verbose=False, 24 | log=False, 25 | **kwargs 26 | ): 27 | """ 28 | Solve the entropic regularization unbalanced optimal transport problem 29 | 30 | The function solves the following optimization problem: 31 | 32 | .. math:: 33 | W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \ 34 | \reg_m_1 KL(\gamma 1, a) + \reg_m_2 KL(\gamma^T 1, b) 35 | 36 | s.t. 37 | \gamma\geq 0 38 | where : 39 | 40 | - M is the (dim_a, dim_b) metric cost matrix 41 | - :math:`\Omega` is the entropic regularization term 42 | :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 43 | - a and b are source and target unbalanced distributions 44 | - KL is the Kullback-Leibler divergence 45 | 46 | The algorithm used for solving the problem is the generalized 47 | Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ 48 | 49 | 50 | Parameters 51 | ---------- 52 | a : np.ndarray (dim_a,) 53 | Unnormalized histogram of dimension dim_a 54 | b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) 55 | One or multiple unnormalized histograms of dimension dim_b 56 | If many, compute all the OT distances (a, b_i) 57 | M : np.ndarray (dim_a, dim_b) 58 | loss matrix 59 | reg : float 60 | Entropy regularization term > 0 61 | reg_m: float 62 | Marginal relaxation term > 0 63 | numItermax : int, optional 64 | Max number of iterations 65 | stopThr : float, optional 66 | Stop threshol on error (> 0) 67 | verbose : bool, optional 68 | Print information along iterations 69 | log : bool, optional 70 | record log if True 71 | 72 | 73 | Returns 74 | ------- 75 | if n_hists == 1: 76 | gamma : (dim_a x dim_b) ndarray 77 | Optimal transportation matrix for the given parameters 78 | log : dict 79 | log dictionary returned only if `log` is `True` 80 | else: 81 | ot_distance : (n_hists,) ndarray 82 | the OT distance between `a` and each of the histograms `b_i` 83 | log : dict 84 | log dictionary returned only if `log` is `True` 85 | Examples 86 | -------- 87 | 88 | >>> import ot 89 | >>> a=[.5, .5] 90 | >>> b=[.5, .5] 91 | >>> M=[[0., 1.],[1., 0.]] 92 | >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) 93 | array([[0.51122823, 0.18807035], 94 | [0.18807035, 0.51122823]]) 95 | 96 | References 97 | ---------- 98 | 99 | .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 100 | Scaling algorithms for unbalanced transport problems. arXiv preprint 101 | arXiv:1607.05816. 102 | 103 | .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 104 | Learning with a Wasserstein Loss, Advances in Neural Information 105 | Processing Systems (NIPS) 2015 106 | 107 | See Also 108 | -------- 109 | ot.lp.emd : Unregularized OT 110 | ot.optim.cg : General regularized OT 111 | 112 | """ 113 | 114 | a = np.asarray(a, dtype=np.float64) 115 | b = np.asarray(b, dtype=np.float64) 116 | M = np.asarray(M, dtype=np.float64) 117 | 118 | dim_a, dim_b = M.shape 119 | 120 | if len(a) == 0: 121 | a = np.ones(dim_a, dtype=np.float64) / dim_a 122 | if len(b) == 0: 123 | b = np.ones(dim_b, dtype=np.float64) / dim_b 124 | 125 | if len(b.shape) > 1: 126 | n_hists = b.shape[1] 127 | else: 128 | n_hists = 0 129 | 130 | if log: 131 | log = {"err": []} 132 | 133 | # we assume that no distances are null except those of the diagonal of 134 | # distances 135 | if n_hists: 136 | u = np.ones((dim_a, 1)) / dim_a 137 | v = np.ones((dim_b, n_hists)) / dim_b 138 | a = a.reshape(dim_a, 1) 139 | else: 140 | u = np.ones(dim_a) / dim_a 141 | v = np.ones(dim_b) / dim_b 142 | 143 | # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute 144 | K = np.empty(M.shape, dtype=M.dtype) 145 | np.divide(M, -reg, out=K) 146 | np.exp(K, out=K) 147 | 148 | cpt = 0 149 | err = 1.0 150 | 151 | while err > stopThr and cpt < numItermax: 152 | uprev = u 153 | vprev = v 154 | 155 | Kv = K.dot(v) 156 | u = (a / Kv) ** (reg_m_1 / (reg_m_1 + reg)) 157 | Ktu = K.T.dot(u) 158 | v = (b / Ktu) ** (reg_m_2 / (reg_m_2 + reg)) 159 | 160 | if ( 161 | np.any(Ktu == 0.0) 162 | or np.any(np.isnan(u)) 163 | or np.any(np.isnan(v)) 164 | or np.any(np.isinf(u)) 165 | or np.any(np.isinf(v)) 166 | ): 167 | # we have reached the machine precision 168 | # come back to previous solution and quit loop 169 | warnings.warn("Numerical errors at iteration %s" % cpt) 170 | u = uprev 171 | v = vprev 172 | break 173 | if cpt % 10 == 0: 174 | # we can speed up the process by checking for the error only all 175 | # the 10th iterations 176 | err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.0) 177 | err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.0) 178 | err = 0.5 * (err_u + err_v) 179 | if log: 180 | log["err"].append(err) 181 | if verbose: 182 | if cpt % 200 == 0: 183 | print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) 184 | print("{:5d}|{:8e}|".format(cpt, err)) 185 | cpt += 1 186 | 187 | if log: 188 | log["logu"] = np.log(u + 1e-16) 189 | log["logv"] = np.log(v + 1e-16) 190 | 191 | if n_hists: # return only loss 192 | res = np.einsum("ik,ij,jk,ij->k", u, K, v, M) 193 | if log: 194 | return res, log 195 | else: 196 | return res 197 | 198 | else: # return OT matrix 199 | 200 | if log: 201 | return u[:, None] * K * v[None, :], log 202 | else: 203 | return u[:, None] * K * v[None, :] 204 | -------------------------------------------------------------------------------- /TrajectoryNet/optimal_transport/train_growth.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.preprocessing import StandardScaler 4 | from scipy.spatial.distance import cdist 5 | import scprep 6 | import torch 7 | 8 | # import atongtf.dataset as atd 9 | 10 | from sinkhorn_knopp_unbalanced import sinkhorn_knopp_unbalanced 11 | 12 | 13 | def load_data_full(): 14 | data = atd.EB_Velocity_Dataset() 15 | labels = data.data["sample_labels"] 16 | scaler = StandardScaler() 17 | scaler.fit(data.emb) 18 | transformed = scaler.transform(data.emb) 19 | return transformed, labels, scaler 20 | 21 | 22 | def get_transform_matrix(gamma, a, epsilon=1e-8): 23 | """Return matrix such that T @ a = b 24 | gamma : gamma @ 1 = a; gamma^T @ 1 = b 25 | """ 26 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T 27 | 28 | 29 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False): 30 | T = get_transform_matrix(gamma, a, epsilon) 31 | unnormalized_coeffs = np.sum(T, axis=0) 32 | if not normalize: 33 | return unnormalized_coeffs 34 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs) 35 | 36 | 37 | data, labels, _ = load_data_full() 38 | 39 | print(data.shape, labels.shape) 40 | exit() 41 | 42 | # Compute couplings 43 | 44 | timepoints = np.unique(labels) 45 | 46 | dfs = [data[labels == tp] for tp in timepoints] 47 | 48 | 49 | def get_all_growth_coeffs(alpha): 50 | gcs = [] 51 | for i in range(len(dfs) - 1): 52 | a, b = dfs[i], dfs[i + 1] 53 | m, n = a.shape[0], b.shape[0] 54 | M = cdist(a, b) 55 | entropy_reg = 0.1 56 | reg_1, reg_2 = alpha, 10000 57 | gamma = sinkhorn_knopp_unbalanced( 58 | np.ones(m) / m, np.ones(n) / n, M, entropy_reg, reg_1, reg_2 59 | ) 60 | gc = get_growth_coeffs(gamma, np.ones(m) / m) 61 | gcs.append(gc) 62 | return gcs 63 | 64 | 65 | gcs = np.load("gcs.npy") 66 | print(gcs) 67 | 68 | 69 | class GrowthNet(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | 73 | self.fc1 = torch.nn.Linear(3, 64) 74 | self.fc2 = torch.nn.Linear(64, 64) 75 | self.fc3 = torch.nn.Linear(64, 1) 76 | 77 | def forward(self, x): 78 | x = torch.nn.functional.leaky_relu(self.fc1(x)) 79 | x = torch.nn.functional.leaky_relu(self.fc2(x)) 80 | x = self.fc3(x) 81 | return x 82 | 83 | 84 | X = np.concatenate([data, labels[:, None]], axis=1)[labels != timepoints[-1]] 85 | Y = gcs[:, None] 86 | 87 | device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu") 88 | 89 | model = GrowthNet().to(device) 90 | model.train() 91 | optimizer = torch.optim.Adam(model.parameters()) 92 | 93 | """ 94 | for it in range(100000): 95 | optimizer.zero_grad() 96 | batch_idx = np.random.randint(len(X), size=256) 97 | x = torch.from_numpy(X[batch_idx,:]).type(torch.float32).to(device) 98 | y = torch.from_numpy(Y[batch_idx,:]).type(torch.float32).to(device) 99 | negative_samples = np.concatenate([np.random.uniform(size=(256,2)) * 8 - 4, 100 | np.random.choice(timepoints, size=(256,1))], axis=1) 101 | negative_samples = torch.from_numpy(negative_samples).type(torch.float32).to(device) 102 | x = torch.cat([x, negative_samples]) 103 | y = torch.cat([y, torch.ones_like(y)]) 104 | pred = model(x) 105 | loss = torch.nn.MSELoss() 106 | output = loss(pred, y) 107 | output.backward() 108 | optimizer.step() 109 | if it % 100 == 0: 110 | print(it, output) 111 | 112 | torch.save(model, 'model') 113 | 114 | exit() 115 | """ 116 | 117 | model = torch.load("model") 118 | model.eval() 119 | import matplotlib 120 | 121 | for i, tp in enumerate(np.linspace(0, 3, 100)): 122 | fig, axes = plt.subplots( 123 | 2, 1, gridspec_kw={"height_ratios": [7, 1]}, figsize=(8, 8) 124 | ) 125 | ax = axes[0] 126 | npts = 200 127 | side = np.linspace(-4, 4, npts) 128 | xx, yy = np.meshgrid(side, side) 129 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 130 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 131 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 132 | data_in = torch.cat([z_grid, tp * torch.ones(z_grid.shape[0], 1).to(device)], 1) 133 | gr = model(data_in) 134 | gr = gr.reshape(npts, npts).to("cpu").detach().numpy() 135 | ax.pcolormesh( 136 | xx.cpu().detach(), yy.cpu().detach(), gr, cmap="RdBu_r", vmin=0, vmax=2 137 | ) 138 | scprep.plot.scatter2d(data, c="Gray", ax=ax, alpha=0.1) 139 | ax.get_xaxis().set_ticks([]) 140 | ax.get_yaxis().set_ticks([]) 141 | ax.set_title("Growth Rate") 142 | # ax.set_title("%0.2f" % tp, fontsize=32) 143 | ax = axes[1] 144 | # Colorbar 145 | cb = matplotlib.colorbar.ColorbarBase(ax, cmap="Spectral", orientation="horizontal") 146 | cb.set_ticks(np.linspace(0, 1, 4)) 147 | cb.set_ticklabels(np.arange(4)) 148 | ax.axvline(tp / 3, c="k", linewidth=15) 149 | ax.set_title("Time") 150 | plt.savefig("growth/viz-%05d.jpg" % i) 151 | plt.close() 152 | 153 | 154 | def trajectory_to_video(savedir): 155 | import subprocess 156 | import os 157 | 158 | bashCommand = "ffmpeg -y -i {} {}".format( 159 | os.path.join(savedir, "viz-%05d.jpg"), os.path.join(savedir, "traj.mp4") 160 | ) 161 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 162 | output, error = process.communicate() 163 | 164 | 165 | trajectory_to_video("growth") 166 | """ 167 | fig, axes = plt.subplots(2,2, figsize=(10,10)) 168 | axes = axes.flatten() 169 | for i, tp in enumerate(timepoints[:-1]): 170 | ax = axes[i] 171 | # Construct Grid, 172 | npts = 200 173 | side = np.linspace(-4, 4, npts) 174 | xx, yy = np.meshgrid(side, side) 175 | xx = torch.from_numpy(xx).type(torch.float32).to(device) 176 | yy = torch.from_numpy(yy).type(torch.float32).to(device) 177 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) 178 | data_in = torch.cat([z_grid, tp * torch.ones(z_grid.shape[0],1).to(device)], 1) 179 | gr = model(data_in) 180 | gr = gr.reshape(npts, npts).to('cpu').detach().numpy() 181 | ax.pcolormesh(xx.cpu().detach(), yy.cpu().detach(), gr, cmap='RdBu_r') 182 | scprep.plot.scatter2d(data, c='Gray', ax=ax, alpha=0.1) 183 | ax.get_xaxis().set_ticks([]) 184 | ax.get_yaxis().set_ticks([]) 185 | ax.set_title("%0.2f" % tp, fontsize=32) 186 | 187 | plt.savefig('growth_function_snapshots.png') 188 | plt.close() 189 | """ 190 | """ 191 | gcs = get_all_growth_coeffs(2) 192 | gcs = np.concatenate(gcs) 193 | print(gcs.shape) 194 | np.save('gcs.npy', gcs) 195 | """ 196 | 197 | """ 198 | for alpha in [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100]: 199 | fig, axes = plt.subplots(2,2, figsize=(10,10)) 200 | axes = axes.flatten() 201 | for i in range(len(dfs) - 1): 202 | a, b = dfs[i], dfs[i+1] 203 | m, n = a.shape[0], b.shape[0] 204 | M = cdist(a, b) 205 | print(a.shape, b.shape, M.shape) 206 | entropy_reg = 0.1 207 | reg_1, reg_2 = alpha, 10000 208 | gamma = sinkhorn_knopp_unbalanced(np.ones(m) / m, np.ones(n) / n, 209 | M, entropy_reg, reg_1, reg_2) 210 | gc = get_growth_coeffs(gamma, np.ones(m) / m) 211 | scprep.plot.scatter2d(data, c='Gray', ax=axes[i], alpha=0.1) 212 | scprep.plot.scatter2d(a, c=gc, cmap='RdBu_r', ax=axes[i], vmin=0, vmax=2) 213 | axes[i].set_title(i) 214 | plt.savefig('figs/alpha_%0.2f.png' % alpha) 215 | plt.close() 216 | """ 217 | -------------------------------------------------------------------------------- /TrajectoryNet/parse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .lib.layers import odefunc 3 | 4 | SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", "adams", "explicit_adams", "fixed_adams"] 5 | 6 | parser = argparse.ArgumentParser("Continuous Normalizing Flow") 7 | parser.add_argument("--test", action="store_true") 8 | parser.add_argument("--dataset", type=str, default="EB") 9 | parser.add_argument("--use_growth", action="store_true") 10 | parser.add_argument("--use_density", action="store_true") 11 | parser.add_argument("--leaveout_timepoint", type=int, default=-1) 12 | parser.add_argument( 13 | "--layer_type", 14 | type=str, 15 | default="concatsquash", 16 | choices=[ 17 | "ignore", 18 | "concat", 19 | "concat_v2", 20 | "squash", 21 | "concatsquash", 22 | "concatcoord", 23 | "hyper", 24 | "blend", 25 | ], 26 | ) 27 | parser.add_argument("--max_dim", type=int, default=10) 28 | parser.add_argument("--dims", type=str, default="64-64-64") 29 | parser.add_argument("--num_blocks", type=int, default=1, help="Number of stacked CNFs.") 30 | parser.add_argument("--time_scale", type=float, default=0.5) 31 | parser.add_argument("--train_T", type=eval, default=True) 32 | parser.add_argument( 33 | "--divergence_fn", 34 | type=str, 35 | default="brute_force", 36 | choices=["brute_force", "approximate"], 37 | ) 38 | parser.add_argument( 39 | "--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES 40 | ) 41 | parser.add_argument("--stochastic", action="store_true") 42 | 43 | parser.add_argument( 44 | "--alpha", type=float, default=0.0, help="loss weight parameter for growth model" 45 | ) 46 | parser.add_argument("--solver", type=str, default="dopri5", choices=SOLVERS) 47 | parser.add_argument("--atol", type=float, default=1e-5) 48 | parser.add_argument("--rtol", type=float, default=1e-5) 49 | parser.add_argument( 50 | "--step_size", type=float, default=None, help="Optional fixed step size." 51 | ) 52 | 53 | parser.add_argument("--test_solver", type=str, default=None, choices=SOLVERS + [None]) 54 | parser.add_argument("--test_atol", type=float, default=None) 55 | parser.add_argument("--test_rtol", type=float, default=None) 56 | 57 | parser.add_argument("--residual", action="store_true") 58 | parser.add_argument("--rademacher", action="store_true") 59 | parser.add_argument("--spectral_norm", action="store_true") 60 | parser.add_argument("--batch_norm", action="store_true") 61 | parser.add_argument("--bn_lag", type=float, default=0) 62 | 63 | parser.add_argument("--niters", type=int, default=10000) 64 | parser.add_argument("--num_workers", type=int, default=8) 65 | parser.add_argument("--batch_size", type=int, default=1000) 66 | parser.add_argument("--test_batch_size", type=int, default=1000) 67 | parser.add_argument("--viz_batch_size", type=int, default=2000) 68 | parser.add_argument("--lr", type=float, default=1e-3) 69 | parser.add_argument("--weight_decay", type=float, default=1e-5) 70 | 71 | # Track quantities 72 | parser.add_argument("--l1int", type=float, default=None, help="int_t ||f||_1") 73 | parser.add_argument("--l2int", type=float, default=None, help="int_t ||f||_2") 74 | parser.add_argument("--sl2int", type=float, default=None, help="int_t ||f||_2^2") 75 | parser.add_argument( 76 | "--dl2int", type=float, default=None, help="int_t ||f^T df/dt||_2" 77 | ) # f df/dx? 78 | parser.add_argument( 79 | "--dtl2int", type=float, default=None, help="int_t ||f^T df/dx + df/dt||_2" 80 | ) 81 | parser.add_argument("--JFrobint", type=float, default=None, help="int_t ||df/dx||_F") 82 | parser.add_argument( 83 | "--JdiagFrobint", type=float, default=None, help="int_t ||df_i/dx_i||_F" 84 | ) 85 | parser.add_argument( 86 | "--JoffdiagFrobint", type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F" 87 | ) 88 | parser.add_argument("--vecint", type=float, default=None, help="regularize direction") 89 | parser.add_argument( 90 | "--use_magnitude", 91 | action="store_true", 92 | help="regularize direction using MSE loss instead of cosine loss", 93 | ) 94 | 95 | parser.add_argument( 96 | "--interp_reg", type=float, default=None, help="regularize interpolation" 97 | ) 98 | 99 | parser.add_argument("--save", type=str, default="../results/tmp") 100 | parser.add_argument("--save_freq", type=int, default=1000) 101 | parser.add_argument("--viz_freq", type=int, default=100) 102 | parser.add_argument("--viz_freq_growth", type=int, default=100) 103 | parser.add_argument("--val_freq", type=int, default=100) 104 | parser.add_argument("--log_freq", type=int, default=10) 105 | parser.add_argument("--gpu", type=int, default=0) 106 | parser.add_argument("--use_cpu", action="store_true") 107 | parser.add_argument("--no_display_loss", action="store_false") 108 | parser.add_argument( 109 | "--top_k_reg", type=float, default=0.0, help="density following regularization" 110 | ) 111 | parser.add_argument("--training_noise", type=float, default=0.1) 112 | parser.add_argument( 113 | "--embedding_name", 114 | type=str, 115 | default="pca", 116 | help="choose embedding name to perform TrajectoryNet on", 117 | ) 118 | parser.add_argument("--whiten", action="store_true", help="Whiten data before running TrajectoryNet") 119 | parser.add_argument("--save_movie", action="store_false", help="Construct trajectory movie, requires ffmpeg to be installed") -------------------------------------------------------------------------------- /TrajectoryNet/train_growth.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.preprocessing import StandardScaler 4 | from scipy.spatial.distance import cdist 5 | import scprep 6 | import torch 7 | import time 8 | 9 | from TrajectoryNet import dataset 10 | from .optimal_transport.sinkhorn_knopp_unbalanced import sinkhorn_knopp_unbalanced 11 | 12 | 13 | eb_data = dataset.EBData("pcs", max_dim=5) 14 | 15 | 16 | def get_transform_matrix(gamma, a, epsilon=1e-8): 17 | """Return matrix such that T @ a = b 18 | gamma : gamma @ 1 = a; gamma^T @ 1 = b 19 | """ 20 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T 21 | 22 | 23 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False): 24 | T = get_transform_matrix(gamma, a, epsilon) 25 | unnormalized_coeffs = np.sum(T, axis=0) 26 | if not normalize: 27 | return unnormalized_coeffs 28 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs) 29 | 30 | 31 | data, labels = eb_data.data, eb_data.get_times() 32 | 33 | # Compute couplings 34 | 35 | timepoints = np.unique(labels) 36 | print("timepoints", timepoints) 37 | 38 | dfs = [data[labels == tp] for tp in timepoints] 39 | pairs = [(0, 1), (1, 2), (2, 3), (3, 4), (0, 2), (1, 3), (2, 4)] 40 | 41 | 42 | def get_all_growth_coeffs(alpha): 43 | gcs = [] 44 | for a_ind, b_ind in pairs: 45 | start = time.time() 46 | print(a_ind, b_ind) 47 | a, b = dfs[a_ind], dfs[b_ind] 48 | m, n = a.shape[0], b.shape[0] 49 | M = cdist(a, b) 50 | entropy_reg = 0.1 51 | reg_1, reg_2 = alpha, 10000 52 | gamma = sinkhorn_knopp_unbalanced( 53 | np.ones(m) / m, np.ones(n) / n, M, entropy_reg, reg_1, reg_2 54 | ) 55 | gc = get_growth_coeffs(gamma, np.ones(m) / m) 56 | gcs.append(gc) 57 | end = time.time() 58 | print("%s to %s took %0.2f sec" % (a_ind, b_ind, end - start)) 59 | print(gcs) 60 | return gcs 61 | 62 | 63 | gcs = np.load("../data/growth/gcs.npy", allow_pickle=True) 64 | 65 | 66 | class GrowthNet(torch.nn.Module): 67 | def __init__(self): 68 | super().__init__() 69 | 70 | self.fc1 = torch.nn.Linear(6, 64) 71 | self.fc2 = torch.nn.Linear(64, 64) 72 | self.fc3 = torch.nn.Linear(64, 1) 73 | 74 | def forward(self, x): 75 | x = torch.nn.functional.leaky_relu(self.fc1(x)) 76 | x = torch.nn.functional.leaky_relu(self.fc2(x)) 77 | x = self.fc3(x) 78 | return x 79 | 80 | 81 | device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu") 82 | device = torch.device("cpu") 83 | 84 | 85 | def train(leaveout_tp): 86 | # Data, timepoint 87 | X = np.concatenate([data, labels[:, None]], axis=1)[ 88 | (labels != timepoints[-1]) & (labels != leaveout_tp) 89 | ] 90 | if leaveout_tp == 1: 91 | Y = np.concatenate([gcs[4], gcs[2], gcs[3]]) 92 | elif leaveout_tp == 2: 93 | Y = np.concatenate([gcs[5], gcs[0], gcs[3]]) 94 | elif leaveout_tp == 3: 95 | Y = np.concatenate([gcs[6], gcs[0], gcs[1]]) 96 | elif leaveout_tp == -1: 97 | Y = np.concatenate(gcs[:4]) 98 | else: 99 | raise RuntimeError("Unknown leavout_tp %d" % leaveout_tp) 100 | print(X.shape, Y.shape) 101 | assert X.shape[0] == Y.shape[0] 102 | Y = Y[:, np.newaxis] 103 | 104 | model = GrowthNet().to(device) 105 | model.train() 106 | optimizer = torch.optim.Adam(model.parameters()) 107 | 108 | for it in range(100000): 109 | optimizer.zero_grad() 110 | batch_idx = np.random.randint(len(X), size=256) 111 | x = torch.from_numpy(X[batch_idx, :]).type(torch.float32).to(device) 112 | y = torch.from_numpy(Y[batch_idx, :]).type(torch.float32).to(device) 113 | negative_samples = np.concatenate( 114 | [ 115 | np.random.uniform(size=(256, X.shape[1] - 1)) * 8 - 4, 116 | np.random.choice(timepoints, size=(256, 1)), 117 | ], 118 | axis=1, 119 | ) 120 | negative_samples = ( 121 | torch.from_numpy(negative_samples).type(torch.float32).to(device) 122 | ) 123 | x = torch.cat([x, negative_samples]) 124 | y = torch.cat([y, torch.ones_like(y)]) 125 | pred = model(x) 126 | loss = torch.nn.MSELoss() 127 | output = loss(pred, y) 128 | output.backward() 129 | optimizer.step() 130 | if it % 100 == 0: 131 | print(it, output) 132 | 133 | torch.save(model, ("model_%d" % leaveout_tp)) 134 | 135 | 136 | # train(1) 137 | # train(2) 138 | # train(3) 139 | 140 | 141 | def trajectory_to_video(savedir): 142 | import subprocess 143 | import os 144 | 145 | bashCommand = "ffmpeg -y -i {} {}".format( 146 | os.path.join(savedir, "viz-%05d.jpg"), os.path.join(savedir, "traj.mp4") 147 | ) 148 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 149 | output, error = process.communicate() 150 | 151 | 152 | trajectory_to_video("../data/growth/viz/") 153 | -------------------------------------------------------------------------------- /TrajectoryNet/train_misc.py: -------------------------------------------------------------------------------- 1 | import six 2 | import math 3 | 4 | from .lib.layers.wrappers import cnf_regularization as reg_lib 5 | from .lib import spectral_norm, layers 6 | from .lib.layers.odefunc import divergence_bf, divergence_approx 7 | 8 | 9 | def standard_normal_logprob(z): 10 | logZ = -0.5 * math.log(2 * math.pi) 11 | return logZ - z.pow(2) / 2 12 | 13 | 14 | def set_cnf_options(args, model): 15 | def _set(module): 16 | if isinstance(module, layers.CNF): 17 | # Set training settings 18 | module.solver = args.solver 19 | module.atol = args.atol 20 | module.rtol = args.rtol 21 | if args.step_size is not None: 22 | module.solver_options["step_size"] = args.step_size 23 | 24 | # If using fixed-grid adams, restrict order to not be too high. 25 | if args.solver in ["fixed_adams", "explicit_adams"]: 26 | module.solver_options["max_order"] = 4 27 | 28 | # Set the test settings 29 | module.test_solver = args.test_solver if args.test_solver else args.solver 30 | module.test_atol = args.test_atol if args.test_atol else args.atol 31 | module.test_rtol = args.test_rtol if args.test_rtol else args.rtol 32 | 33 | if isinstance(module, layers.ODEfunc): 34 | module.rademacher = args.rademacher 35 | module.residual = args.residual 36 | 37 | model.apply(_set) 38 | 39 | 40 | def override_divergence_fn(model, divergence_fn): 41 | def _set(module): 42 | if isinstance(module, layers.ODEfunc): 43 | if divergence_fn == "brute_force": 44 | module.divergence_fn = divergence_bf 45 | elif divergence_fn == "approximate": 46 | module.divergence_fn = divergence_approx 47 | 48 | model.apply(_set) 49 | 50 | 51 | def count_nfe(model): 52 | class AccNumEvals(object): 53 | def __init__(self): 54 | self.num_evals = 0 55 | 56 | def __call__(self, module): 57 | if isinstance(module, layers.CNF): 58 | self.num_evals += module.num_evals() 59 | 60 | accumulator = AccNumEvals() 61 | model.apply(accumulator) 62 | return accumulator.num_evals 63 | 64 | 65 | def count_parameters(model): 66 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 67 | 68 | 69 | def count_total_time(model): 70 | class Accumulator(object): 71 | def __init__(self): 72 | self.total_time = 0 73 | 74 | def __call__(self, module): 75 | if isinstance(module, layers.CNF): 76 | self.total_time = ( 77 | self.total_time + module.sqrt_end_time * module.sqrt_end_time 78 | ) 79 | 80 | accumulator = Accumulator() 81 | model.apply(accumulator) 82 | return accumulator.total_time 83 | 84 | 85 | def add_spectral_norm(model, logger=None): 86 | """Applies spectral norm to all modules within the scope of a CNF.""" 87 | 88 | def apply_spectral_norm(module): 89 | if "weight" in module._parameters: 90 | if logger: 91 | logger.info("Adding spectral norm to {}".format(module)) 92 | spectral_norm.inplace_spectral_norm(module, "weight") 93 | 94 | def find_cnf(module): 95 | if isinstance(module, layers.CNF): 96 | module.apply(apply_spectral_norm) 97 | else: 98 | for child in module.children(): 99 | find_cnf(child) 100 | 101 | find_cnf(model) 102 | 103 | 104 | def spectral_norm_power_iteration(model, n_power_iterations=1): 105 | def recursive_power_iteration(module): 106 | if hasattr(module, spectral_norm.POWER_ITERATION_FN): 107 | getattr(module, spectral_norm.POWER_ITERATION_FN)(n_power_iterations) 108 | 109 | model.apply(recursive_power_iteration) 110 | 111 | 112 | REGULARIZATION_FNS = { 113 | "l1int": reg_lib.l1_regularzation_fn, 114 | "l2int": reg_lib.l2_regularzation_fn, 115 | "sl2int": reg_lib.squared_l2_regularization_fn, 116 | "dl2int": reg_lib.directional_l2_regularization_fn, 117 | "dtl2int": reg_lib.directional_l2_change_penalty_fn, 118 | "JFrobint": reg_lib.jacobian_frobenius_regularization_fn, 119 | "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn, 120 | "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn, 121 | } 122 | 123 | INV_REGULARIZATION_FNS = {v: k for k, v in six.iteritems(REGULARIZATION_FNS)} 124 | 125 | 126 | def append_regularization_to_log(log_message, regularization_fns, reg_states): 127 | for i, reg_fn in enumerate(regularization_fns): 128 | log_message = ( 129 | log_message 130 | + " | " 131 | + INV_REGULARIZATION_FNS[reg_fn] 132 | + ": {:.8f}".format(reg_states[i].item()) 133 | ) 134 | return log_message 135 | 136 | 137 | def create_regularization_fns(args): 138 | regularization_fns = [] 139 | regularization_coeffs = [] 140 | 141 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): 142 | if getattr(args, arg_key) is not None: 143 | regularization_fns.append(reg_fn) 144 | regularization_coeffs.append(eval("args." + arg_key)) 145 | 146 | regularization_fns = tuple(regularization_fns) 147 | regularization_coeffs = tuple(regularization_coeffs) 148 | return regularization_fns, regularization_coeffs 149 | 150 | 151 | def get_regularization(model, regularization_coeffs): 152 | if len(regularization_coeffs) == 0: 153 | return None 154 | 155 | acc_reg_states = tuple([0.0] * len(regularization_coeffs)) 156 | for module in model.modules(): 157 | if isinstance(module, layers.CNF): 158 | acc_reg_states = tuple( 159 | acc + reg 160 | for acc, reg in zip(acc_reg_states, module.get_regularization_states()) 161 | ) 162 | return acc_reg_states 163 | 164 | 165 | def build_model_tabular(args, dims, regularization_fns=None): 166 | 167 | hidden_dims = tuple(map(int, args.dims.split("-"))) 168 | 169 | def build_cnf(): 170 | diffeq = layers.ODEnet( 171 | hidden_dims=hidden_dims, 172 | input_shape=(dims,), 173 | strides=None, 174 | conv=False, 175 | layer_type=args.layer_type, 176 | nonlinearity=args.nonlinearity, 177 | ) 178 | odefunc = layers.ODEfunc( 179 | diffeq=diffeq, 180 | divergence_fn=args.divergence_fn, 181 | residual=args.residual, 182 | rademacher=args.rademacher, 183 | ) 184 | cnf = layers.CNF( 185 | odefunc=odefunc, 186 | T=args.time_scale, 187 | train_T=args.train_T, 188 | regularization_fns=regularization_fns, 189 | solver=args.solver, 190 | ) 191 | return cnf 192 | 193 | chain = [build_cnf() for _ in range(args.num_blocks)] 194 | if args.batch_norm: 195 | bn_layers = [ 196 | layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) 197 | for _ in range(args.num_blocks) 198 | ] 199 | bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] 200 | for a, b in zip(chain, bn_layers): 201 | bn_chain.append(a) 202 | bn_chain.append(b) 203 | chain = bn_chain 204 | model = layers.SequentialFlow(chain) 205 | 206 | set_cnf_options(args, model) 207 | 208 | return model 209 | -------------------------------------------------------------------------------- /TrajectoryNet/version.py: -------------------------------------------------------------------------------- 1 | "0.2.4" 2 | -------------------------------------------------------------------------------- /data/eb_genes.txt: -------------------------------------------------------------------------------- 1 | TAL1 2 | CD34 3 | PECAM1 4 | HOXD1 5 | HOXB4 6 | HAND1 7 | GATA6 8 | GATA5 9 | TNNT2 10 | TBX18 11 | TBX15 12 | PDGFRA 13 | SIX2 14 | TBX5 15 | WT1 16 | MYC 17 | LEF1 18 | SOX10 19 | FOXD3 20 | PAX3 21 | SOX9 22 | HOXA2 23 | OLIG3 24 | ONECUT2 25 | KLF7 26 | ONECUT1 27 | MAP2 28 | ISL1 29 | DLX1 30 | HOXB1 31 | NR2F1 32 | LMX1A 33 | DMRT3 34 | OLIG1 35 | PAX6 36 | NPAS1 37 | SOX1 38 | NKX2-8 39 | EN2 40 | ZBTB16 41 | SOX17 42 | FOXA2 43 | EOMES 44 | T 45 | GATA4 46 | ASCL2 47 | CDX2 48 | ARID3A 49 | KLF5 50 | RFX6 51 | NKX2-1 52 | SOX15 53 | TP63 54 | GATA3 55 | SATB1 56 | CER1 57 | LHX5 58 | SIX3 59 | LHX2 60 | GLI3 61 | SIX6 62 | NES 63 | GBX2 64 | ZIC2 65 | NANOG 66 | MIXL1 67 | OTX2 68 | POU5F1 -------------------------------------------------------------------------------- /data/eb_velocity_v5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/data/eb_velocity_v5.npz -------------------------------------------------------------------------------- /figures/EB-Trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/figures/EB-Trajectory.gif -------------------------------------------------------------------------------- /figures/eb_high_quality.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/figures/eb_high_quality.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | matplotlib>=3.2.1 3 | numpy>=1.18.4 4 | POT>=0.7.0 5 | scanpy 6 | scikit-learn>=0.23.1 7 | scipy>=1.4.1 8 | torch>=1.5.0 9 | torchdiffeq==0.0.1 10 | -------------------------------------------------------------------------------- /results/fig8_results/backward_trajectories.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/results/fig8_results/backward_trajectories.npy -------------------------------------------------------------------------------- /results/fig8_results/checkpt.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/results/fig8_results/checkpt.pth -------------------------------------------------------------------------------- /results/fig8_results/train_eval.csv: -------------------------------------------------------------------------------- 1 | 100,tensor(4.9269) 2 | 200,tensor(4.3838) 3 | 300,tensor(4.0277) 4 | 400,tensor(3.7389) 5 | 500,tensor(3.5852) 6 | 600,tensor(3.5531) 7 | 700,tensor(3.4902) 8 | 800,tensor(3.4699) 9 | 900,tensor(3.4623) 10 | 1000,tensor(3.3754) 11 | 1100,tensor(3.3479) 12 | 1200,tensor(3.3542) 13 | 1300,tensor(3.3215) 14 | 1400,tensor(3.3450) 15 | 1500,tensor(3.3318) 16 | 1600,tensor(3.3250) 17 | 1700,tensor(3.3133) 18 | 1800,tensor(3.2920) 19 | 1900,tensor(3.2693) 20 | 2000,tensor(3.2431) 21 | 2100,tensor(3.3150) 22 | 2200,tensor(3.2185) 23 | 2300,tensor(3.2650) 24 | 2400,tensor(3.2519) 25 | 2500,tensor(3.2392) 26 | 2600,tensor(3.2440) 27 | 2700,tensor(3.2327) 28 | 2800,tensor(3.2221) 29 | 2900,tensor(3.2445) 30 | 3000,tensor(3.1992) 31 | 3100,tensor(3.2227) 32 | 3200,tensor(3.2277) 33 | 3300,tensor(3.2125) 34 | 3400,tensor(3.2534) 35 | 3500,tensor(3.2225) 36 | 3600,tensor(3.2168) 37 | 3700,tensor(3.2291) 38 | 3800,tensor(3.2067) 39 | 3900,tensor(3.1941) 40 | 4000,tensor(3.1983) 41 | 4100,tensor(3.1990) 42 | 4200,tensor(3.1563) 43 | 4300,tensor(3.1885) 44 | 4400,tensor(3.1989) 45 | 4500,tensor(3.1565) 46 | 4600,tensor(3.2174) 47 | 4700,tensor(3.2120) 48 | 4800,tensor(3.2327) 49 | 4900,tensor(3.1519) 50 | 5000,tensor(3.1598) 51 | 5100,tensor(3.1531) 52 | 5200,tensor(3.1642) 53 | 5300,tensor(3.1758) 54 | 5400,tensor(3.1578) 55 | 5500,tensor(3.1683) 56 | 5600,tensor(3.1276) 57 | 5700,tensor(3.1566) 58 | 5800,tensor(3.1818) 59 | 5900,tensor(3.1101) 60 | 6000,tensor(3.1899) 61 | 6100,tensor(3.1512) 62 | 6200,tensor(3.1105) 63 | 6300,tensor(3.0995) 64 | 6400,tensor(3.1727) 65 | 6500,tensor(3.2001) 66 | 6600,tensor(3.1633) 67 | 6700,tensor(3.1675) 68 | 6800,tensor(3.1565) 69 | 6900,tensor(3.1832) 70 | 7000,tensor(3.1601) 71 | 7100,tensor(3.0829) 72 | 7200,tensor(3.1549) 73 | 7300,tensor(3.1271) 74 | 7400,tensor(3.1491) 75 | 7500,tensor(3.1483) 76 | 7600,tensor(3.1139) 77 | 7700,tensor(3.1430) 78 | 7800,tensor(3.1419) 79 | 7900,tensor(3.1240) 80 | 8000,tensor(3.1638) 81 | 8100,tensor(3.1157) 82 | 8200,tensor(3.1502) 83 | 8300,tensor(3.1453) 84 | 8400,tensor(3.1747) 85 | 8500,tensor(3.1588) 86 | 8600,tensor(3.0674) 87 | 8700,tensor(3.1591) 88 | 8800,tensor(3.0796) 89 | 8900,tensor(3.0984) 90 | 9000,tensor(3.0799) 91 | 9100,tensor(3.0863) 92 | 9200,tensor(3.1053) 93 | 9300,tensor(3.0981) 94 | 9400,tensor(3.1166) 95 | 9500,tensor(3.0888) 96 | 9600,tensor(3.1042) 97 | 9700,tensor(3.0866) 98 | 9800,tensor(3.1327) 99 | 9900,tensor(3.1291) 100 | 10000,tensor(3.1143) 101 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_sphinx] 2 | all-files = 1 3 | source-dir = doc/source 4 | build-dir = doc/build 5 | warning-is-error = 0 6 | 7 | [flake8] 8 | ignore = 9 | # top-level module docstring 10 | D100, D104, W503, 11 | # space before : conflicts with black 12 | E203 13 | per-file-ignores = 14 | # imported but unused 15 | __init__.py: F401 16 | max-line-length = 88 17 | exclude = 18 | .git, 19 | __pycache__, 20 | build, 21 | dist, 22 | test, 23 | doc 24 | 25 | [isort] 26 | profile = black 27 | force_single_line = true 28 | force_alphabetical_sort = true 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | 4 | 5 | install_requires = [ 6 | "argparse", 7 | "matplotlib>=3.2.1", 8 | "numpy>=1.18.4", 9 | "POT>=0.7.0", 10 | "scanpy", 11 | "scikit-learn>=0.23.1", 12 | "scipy>=1.4.1", 13 | "torch>=1.5.0", 14 | "torchdiffeq==0.0.1", 15 | ] 16 | 17 | version_py = os.path.join(os.path.dirname(__file__), "TrajectoryNet", "version.py") 18 | version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() 19 | 20 | readme = open("README.rst").read() 21 | 22 | setup( 23 | name="TrajectoryNet", 24 | packages=find_packages(), 25 | install_requires=install_requires, 26 | version=version, 27 | description="A neural ode solution for imputing trajectories between pointclouds.", 28 | author="Alexander Tong", 29 | author_email="alexandertongdev@gmail.com", 30 | license="MIT", 31 | long_description=readme, 32 | url="https://github.com/KrishnaswamyLab/TrajectoryNet", 33 | ) 34 | --------------------------------------------------------------------------------