├── .flake8
├── .github
└── workflows
│ ├── pre-commit.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE.md
├── README.rst
├── TrajectoryNet
├── __init__.py
├── dataset.py
├── eval.py
├── eval_utils.py
├── lib
│ ├── __init__.py
│ ├── growth_net.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── cnf.py
│ │ ├── container.py
│ │ ├── coupling.py
│ │ ├── diffeq_layers
│ │ │ ├── __init__.py
│ │ │ ├── basic.py
│ │ │ ├── container.py
│ │ │ ├── resnet.py
│ │ │ └── wrappers.py
│ │ ├── elemwise.py
│ │ ├── glow.py
│ │ ├── norm_flows.py
│ │ ├── normalization.py
│ │ ├── odefunc.py
│ │ ├── resnet.py
│ │ ├── squeeze.py
│ │ └── wrappers
│ │ │ ├── __init__.py
│ │ │ └── cnf_regularization.py
│ ├── spectral_norm.py
│ ├── utils.py
│ ├── visualize_flow.py
│ └── viz_scrna.py
├── main.py
├── optimal_transport
│ ├── MMD.py
│ ├── __init__.py
│ ├── emd.py
│ ├── gcs.npy
│ ├── growth
│ │ └── traj.mp4
│ ├── model
│ ├── plot_UOT_1D.py
│ ├── sinkhorn_knopp_unbalanced.py
│ └── train_growth.py
├── parse.py
├── train_growth.py
├── train_misc.py
└── version.py
├── data
├── eb_genes.txt
└── eb_velocity_v5.npz
├── figures
├── EB-Trajectory.gif
└── eb_high_quality.png
├── notebooks
├── DentateGyrus-Load.ipynb
├── EmbryoidBody_TrajectoryInference.ipynb
├── Example_Anndata_to_TrajectoryNet.ipynb
└── WOT-Schiebinger-load.ipynb
├── requirements.txt
├── results
└── fig8_results
│ ├── backward_trajectories.npy
│ ├── checkpt.pth
│ ├── logs
│ └── train_eval.csv
├── setup.cfg
└── setup.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 | on:
3 | push:
4 | branches-ignore:
5 | - 'master'
6 |
7 | jobs:
8 | pre-commit:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Cancel Previous Runs
12 | uses: styfle/cancel-workflow-action@0.6.0
13 | with:
14 | access_token: ${{ github.token }}
15 | - uses: actions/checkout@v2
16 | with:
17 | fetch-depth: 0
18 |
19 | - uses: actions/setup-python@v2
20 | with:
21 | python-version: "3.7"
22 | architecture: "x64"
23 |
24 | - uses: actions/cache@v2
25 | with:
26 | path: ~/.cache/pre-commit
27 | key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}-
28 |
29 | - uses: pre-commit/action@v2.0.0
30 | continue-on-error: true
31 |
32 | - name: Commit files
33 | run: |
34 | if [[ `git status --porcelain --untracked-files=no` ]]; then
35 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
36 | git config --local user.name "github-actions[bot]"
37 | git commit -m "pre-commit" -a
38 | fi
39 |
40 | - name: Push changes
41 | uses: ad-m/github-push-action@master
42 | with:
43 | github_token: ${{ secrets.GITHUB_TOKEN }}
44 | branch: ${{ github.ref }}
45 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Publish Python 🐍 distributions 📦 to PyPI
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: '3.x'
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install build
29 | - name: Build package
30 | run: python -m build
31 | - name: Publish distribution 📦 to Test PyPI
32 | uses: pypa/gh-action-pypi-publish@release/v1
33 | with:
34 | skip_existing: true
35 | user: __token__
36 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
37 | repository_url: https://test.pypi.org/legacy/
38 | - name: Publish distribution 📦 to PyPI
39 | uses: pypa/gh-action-pypi-publish@release/v1
40 | with:
41 | user: __token__
42 | password: ${{ secrets.PYPI_API_TOKEN }}
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/tmp/
2 | *.jpg
3 |
4 | #Vim
5 | *.sw?
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | env/
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 |
59 | # Sphinx documentation
60 | docs/_build/
61 |
62 | # PyBuilder
63 | target/
64 |
65 | # DotEnv configuration
66 | .env
67 |
68 | # Database
69 | *.db
70 | *.rdb
71 |
72 | # Pycharm
73 | .idea
74 |
75 | # VS Code
76 | .vscode/
77 |
78 | # Spyder
79 | .spyproject/
80 |
81 | # Jupyter NB Checkpoints
82 | .ipynb_checkpoints/
83 |
84 | # exclude data from source control by default
85 | /data/
86 |
87 | # exclude old folder by default
88 | /old/
89 |
90 | # Mac OS-specific storage files
91 | .DS_Store
92 |
93 | # vim
94 | *.swp
95 | *.swo
96 |
97 | # Mypy cache
98 | .mypy_cache/
99 |
100 | # Snakemake cache
101 | .snakemake/
102 | *.out
103 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ----------------------------------
2 |
3 | Non-Commercial License
4 | Yale Copyright © 2024 Yale University.
5 |
6 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu.
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9 |
10 | ----------------------------------
11 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | Pytorch Implementation of TrajectoryNet
2 | =======================================
3 |
4 | This library runs code associated with the TrajectoryNet paper [1].
5 |
6 | In brief, TrajectoryNet is a Continuous Normalizing Flow model which can
7 | perform dynamic optimal transport using energy regularization and / or a
8 | combination of velocity, density, and growth regularizations to better match
9 | cellular trajectories.
10 |
11 | Our setting is similar to that of `WaddingtonOT
12 | `_. In that we have access to a bunch of
13 | population measurements of cells over time and would like to model the dynamics
14 | of cells over that time period. TrajectoryNet is trained end-to-end and is
15 | continuous both in gene space and in time.
16 |
17 |
18 | Installation
19 | ------------
20 |
21 | TrajectoryNet is available in `pypi`. Install by running the following
22 |
23 | .. code-block:: bash
24 |
25 | pip install TrajectoryNet
26 |
27 | This code was tested with python 3.7 and 3.8.
28 |
29 | Example
30 | -------
31 |
32 | .. image:: figures/eb_high_quality.png
33 | :alt: EB PHATE Scatterplot
34 | :height: 300
35 |
36 | .. image:: figures/EB-Trajectory.gif
37 | :alt: Trajectory of density over time
38 | :height: 300
39 |
40 |
41 | Basic Usage
42 | -----------
43 |
44 | Run with
45 |
46 | .. code-block:: bash
47 |
48 | python -m TrajectoryNet.main --dataset SCURVE
49 |
50 | To run TrajectoryNet on the `S Curve` example in the paper. To use a
51 | custom dataset expose the coordinates and timepoint information according
52 | to the example jupyter notebooks in the `/notebooks/` folder.
53 |
54 | If you have an `AnnData `_ object then take a look at
55 | `notebooks/Example_Anndata_to_TrajectoryNet.ipynb
56 | `_,
57 | which shows how to load one of the example `scvelo `_ anndata objects into
58 | TrajectoryNet. Alternatively you can use the custom (compressed) format for
59 | TrajectoryNet as described below.
60 |
61 | For this format TrajectoryNet requires the following:
62 |
63 | 1. An embedding matrix titled `[embedding_name]` (Cells x Dimensions)
64 | 2. A sample labels array titled `sample_labels` (Cells)
65 |
66 | To run TrajectoryNet with a custom dataset use:
67 |
68 | .. code-block:: bash
69 |
70 | python -m TrajectoryNet.main --dataset [PATH_TO_NPZ_FILE] --embedding_name [EMBEDDING_NAME]
71 | python -m TrajectoryNet.eval --dataset [PATH_TO_NPZ_FILE] --embedding_name [EMBEDDING_NAME]
72 |
73 |
74 | See `notebooks/EB-Eval.ipynb` for an example on how to use TrajectoryNet on
75 | a PCA embedding to get trajectories in the gene space.
76 |
77 |
78 | References
79 | ----------
80 | [1] Tong, A., Huang, J., Wolf, G., van Dijk, D., and Krishnaswamy, S. TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics. In International Conference on Machine Learning, 2020. `arxiv `_ `ICML `_
81 |
82 | ---
83 |
84 | If you found this library useful, please consider citing::
85 |
86 | @inproceedings{tong2020trajectorynet,
87 | title = {TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics},
88 | shorttitle = {TrajectoryNet},
89 | booktitle = {Proceedings of the 37th International Conference on Machine Learning},
90 | author = {Tong, Alexander and Huang, Jessie and Wolf, Guy and {van Dijk}, David and Krishnaswamy, Smita},
91 | year = {2020}
92 | }
93 |
--------------------------------------------------------------------------------
/TrajectoryNet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/__init__.py
--------------------------------------------------------------------------------
/TrajectoryNet/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 | import matplotlib
6 |
7 | from TrajectoryNet import dataset, eval_utils
8 | from TrajectoryNet.parse import parser
9 | from TrajectoryNet.lib.growth_net import GrowthNet
10 | from TrajectoryNet.lib.viz_scrna import trajectory_to_video, save_vectors
11 | from TrajectoryNet.lib.viz_scrna import (
12 | save_trajectory_density,
13 | save_2d_trajectory,
14 | save_2d_trajectory_v2,
15 | )
16 |
17 | from TrajectoryNet.train_misc import (
18 | set_cnf_options,
19 | count_nfe,
20 | count_parameters,
21 | count_total_time,
22 | add_spectral_norm,
23 | spectral_norm_power_iteration,
24 | create_regularization_fns,
25 | get_regularization,
26 | append_regularization_to_log,
27 | build_model_tabular,
28 | )
29 |
30 |
31 | def makedirs(dirname):
32 | if not os.path.exists(dirname):
33 | os.makedirs(dirname)
34 |
35 |
36 | def save_trajectory(
37 | prior_logdensity,
38 | prior_sampler,
39 | model,
40 | data_samples,
41 | savedir,
42 | ntimes=101,
43 | end_times=None,
44 | memory=0.01,
45 | device="cpu",
46 | ):
47 | model.eval()
48 |
49 | # Sample from prior
50 | z_samples = prior_sampler(1000, 2).to(device)
51 |
52 | # sample from a grid
53 | npts = 100
54 | side = np.linspace(-4, 4, npts)
55 | xx, yy = np.meshgrid(side, side)
56 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
57 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
58 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
59 |
60 | with torch.no_grad():
61 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
62 | logp_samples = prior_logdensity(z_samples)
63 | logp_grid = prior_logdensity(z_grid)
64 | t = 0
65 | for cnf in model.chain:
66 |
67 | # Construct integration_list
68 | if end_times is None:
69 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
70 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
71 | for i, et in enumerate(end_times[1:]):
72 | integration_list.append(
73 | torch.linspace(end_times[i], et, ntimes).to(device)
74 | )
75 | full_times = torch.cat(integration_list, 0)
76 | print(full_times.shape)
77 |
78 | # Integrate over evenly spaced samples
79 | z_traj, logpz = cnf(
80 | z_samples,
81 | logp_samples,
82 | integration_times=integration_list[0],
83 | reverse=True,
84 | )
85 | full_traj = [(z_traj, logpz)]
86 | for int_times in integration_list[1:]:
87 | prev_z, prev_logp = full_traj[-1]
88 | z_traj, logpz = cnf(
89 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True
90 | )
91 | full_traj.append((z_traj[1:], logpz[1:]))
92 | full_zip = list(zip(*full_traj))
93 | z_traj = torch.cat(full_zip[0], 0)
94 | # z_logp = torch.cat(full_zip[1], 0)
95 | z_traj = z_traj.cpu().numpy()
96 |
97 | grid_z_traj, grid_logpz_traj = [], []
98 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64)
99 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)):
100 | _grid_z_traj, _grid_logpz_traj = cnf(
101 | z_grid[ii],
102 | logp_grid[ii],
103 | integration_times=integration_list[0],
104 | reverse=True,
105 | )
106 | full_traj = [(_grid_z_traj, _grid_logpz_traj)]
107 | for int_times in integration_list[1:]:
108 | prev_z, prev_logp = full_traj[-1]
109 | _grid_z_traj, _grid_logpz_traj = cnf(
110 | prev_z[-1],
111 | prev_logp[-1],
112 | integration_times=int_times,
113 | reverse=True,
114 | )
115 | full_traj.append((_grid_z_traj, _grid_logpz_traj))
116 | full_zip = list(zip(*full_traj))
117 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy()
118 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy()
119 | print(_grid_z_traj.shape)
120 | grid_z_traj.append(_grid_z_traj)
121 | grid_logpz_traj.append(_grid_logpz_traj)
122 |
123 | grid_z_traj = np.concatenate(grid_z_traj, axis=1)
124 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)
125 |
126 | plt.figure(figsize=(8, 8))
127 | for _ in range(z_traj.shape[0]):
128 |
129 | plt.clf()
130 |
131 | # plot target potential function
132 | ax = plt.subplot(1, 1, 1, aspect="equal")
133 |
134 | """
135 | ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
136 | ax.invert_yaxis()
137 | ax.get_xaxis().set_ticks([])
138 | ax.get_yaxis().set_ticks([])
139 | ax.set_title("Target", fontsize=32)
140 |
141 | """
142 | # plot the density
143 | # ax = plt.subplot(2, 2, 2, aspect="equal")
144 |
145 | z, logqz = grid_z_traj[t], grid_logpz_traj[t]
146 |
147 | xx = z[:, 0].reshape(npts, npts)
148 | yy = z[:, 1].reshape(npts, npts)
149 | qz = np.exp(logqz).reshape(npts, npts)
150 | rgb = plt.cm.Spectral(t / z_traj.shape[0])
151 | print(t, rgb)
152 | background_color = "white"
153 | cvals = [0, np.percentile(qz, 0.1)]
154 | colors = [
155 | background_color,
156 | rgb,
157 | ]
158 | norm = plt.Normalize(min(cvals), max(cvals))
159 | tuples = list(zip(map(norm, cvals), colors))
160 | cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
161 | from matplotlib.colors import LogNorm
162 |
163 | plt.pcolormesh(
164 | xx,
165 | yy,
166 | qz,
167 | # norm=LogNorm(vmin=qz.min(), vmax=qz.max()),
168 | cmap=cmap,
169 | )
170 | ax.set_xlim(-4, 4)
171 | ax.set_ylim(-4, 4)
172 | cmap = matplotlib.cm.get_cmap(None)
173 | ax.set_facecolor(background_color)
174 | ax.invert_yaxis()
175 | ax.get_xaxis().set_ticks([])
176 | ax.get_yaxis().set_ticks([])
177 | ax.set_title("Density", fontsize=32)
178 |
179 | """
180 | # plot the samples
181 | ax = plt.subplot(2, 2, 3, aspect="equal")
182 |
183 | zk = z_traj[t]
184 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
185 | ax.invert_yaxis()
186 | ax.get_xaxis().set_ticks([])
187 | ax.get_yaxis().set_ticks([])
188 | ax.set_title("Samples", fontsize=32)
189 |
190 | # plot vector field
191 | ax = plt.subplot(2, 2, 4, aspect="equal")
192 |
193 | K = 13j
194 | y, x = np.mgrid[-4:4:K, -4:4:K]
195 | K = int(K.imag)
196 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
197 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32)
198 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0]
199 | dydt = -dydt.cpu().detach().numpy()
200 | dydt = dydt.reshape(K, K, 2)
201 |
202 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1]))
203 | ax.quiver(
204 | x, y, dydt[:, :, 0], -dydt[:, :, 1],
205 | # x, y, dydt[:, :, 0], dydt[:, :, 1],
206 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid"
207 | )
208 | ax.set_xlim(-4, 4)
209 | ax.set_ylim(4, -4)
210 | #ax.set_ylim(-4, 4)
211 | ax.axis("off")
212 | ax.set_title("Vector Field", fontsize=32)
213 | """
214 |
215 | makedirs(savedir)
216 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg"))
217 | t += 1
218 |
219 |
220 | def get_trajectory_samples(device, model, data, n=2000):
221 | ntimes = 5
222 | model.eval()
223 | z_samples = data.base_sample()(n, 2).to(device)
224 |
225 | integration_list = [torch.linspace(0, args.int_tps[0], ntimes).to(device)]
226 | for i, et in enumerate(args.int_tps[1:]):
227 | integration_list.append(torch.linspace(args.int_tps[i], et, ntimes).to(device))
228 | print(integration_list)
229 |
230 |
231 | def plot_output(device, args, model, data):
232 | # logger.info('Plotting trajectory to {}'.format(save_traj_dir))
233 | data_samples = data.get_data()[data.sample_index(2000, 0)]
234 | start_points = data.base_sample()(1000, 2)
235 | # start_points = data.get_data()[idx]
236 | # start_points = torch.from_numpy(start_points).type(torch.float32)
237 | """
238 | save_vectors(
239 | data.base_density(),
240 | model,
241 | start_points,
242 | data.get_data()[data.get_times() == 1],
243 | data.get_times()[data.get_times() == 1],
244 | args.save,
245 | device=device,
246 | end_times=args.int_tps,
247 | ntimes=100,
248 | memory=1.0,
249 | lim=1.5,
250 | )
251 | save_traj_dir = os.path.join(args.save, "trajectory_2d")
252 | save_2d_trajectory_v2(
253 | data.base_density(),
254 | data.base_sample(),
255 | model,
256 | data_samples,
257 | save_traj_dir,
258 | device=device,
259 | end_times=args.int_tps,
260 | ntimes=3,
261 | memory=1.0,
262 | limit=2.5,
263 | )
264 | """
265 |
266 | density_dir = os.path.join(args.save, "density2")
267 | save_trajectory_density(
268 | data.base_density(),
269 | model,
270 | data_samples,
271 | density_dir,
272 | device=device,
273 | end_times=args.int_tps,
274 | ntimes=100,
275 | memory=1,
276 | )
277 | trajectory_to_video(density_dir)
278 |
279 |
280 | def integrate_backwards(
281 | end_samples, model, savedir, ntimes=100, memory=0.1, device="cpu"
282 | ):
283 | """Integrate some samples backwards and save the results."""
284 | with torch.no_grad():
285 | z = torch.from_numpy(end_samples).type(torch.float32).to(device)
286 | zero = torch.zeros(z.shape[0], 1).to(z)
287 | cnf = model.chain[0]
288 |
289 | zs = [z]
290 | deltas = []
291 | int_tps = np.linspace(args.int_tps[0], args.int_tps[-1], ntimes)
292 | for i, itp in enumerate(int_tps[::-1][:-1]):
293 | # tp counts down from last
294 | timescale = int_tps[1] - int_tps[0]
295 | integration_times = torch.tensor([itp - timescale, itp])
296 | # integration_times = torch.tensor([np.linspace(itp - args.time_scale, itp, ntimes)])
297 | integration_times = integration_times.type(torch.float32).to(device)
298 |
299 | # transform to previous timepoint
300 | z, delta_logp = cnf(zs[-1], zero, integration_times=integration_times)
301 | zs.append(z)
302 | deltas.append(delta_logp)
303 | zs = torch.stack(zs, 0)
304 | zs = zs.cpu().numpy()
305 | np.save(os.path.join(savedir, "backward_trajectories.npy"), zs)
306 |
307 |
308 | def main(args):
309 | device = torch.device(
310 | "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
311 | )
312 | if args.use_cpu:
313 | device = torch.device("cpu")
314 |
315 | data = dataset.SCData.factory(args.dataset, args)
316 |
317 | args.timepoints = data.get_unique_times()
318 |
319 | # Use maximum timepoint to establish integration_times
320 | # as some timepoints may be left out for validation etc.
321 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
322 |
323 | regularization_fns, regularization_coeffs = create_regularization_fns(args)
324 | model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(
325 | device
326 | )
327 | if args.use_growth:
328 | growth_model_path = data.get_growth_net_path()
329 | # growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
330 | growth_model = torch.load(growth_model_path, map_location=device)
331 | if args.spectral_norm:
332 | add_spectral_norm(model)
333 | set_cnf_options(args, model)
334 |
335 | state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
336 | model.load_state_dict(state_dict["state_dict"])
337 |
338 | # plot_output(device, args, model, data)
339 | # exit()
340 | # get_trajectory_samples(device, model, data)
341 |
342 | args.data = data
343 | args.timepoints = args.data.get_unique_times()
344 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
345 |
346 | print("integrating backwards")
347 | # end_time_data = data.data_dict[args.embedding_name]
348 | end_time_data = data.get_data()[
349 | args.data.get_times() == np.max(args.data.get_times())
350 | ]
351 | # np.random.permutation(end_time_data)
352 | # rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
353 | # end_time_data = end_time_data[rand_idx,:]
354 | integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device)
355 | exit()
356 | losses_list = []
357 | # for factor in np.linspace(0.05, 0.95, 19):
358 | # for factor in np.linspace(0.91, 0.99, 9):
359 | if args.dataset == "CHAFFER": # Do timepoint adjustment
360 | print("adjusting_timepoints")
361 | lt = args.leaveout_timepoint
362 | if lt == 1:
363 | factor = 0.6799872494335812
364 | factor = 0.95
365 | elif lt == 2:
366 | factor = 0.2905983814032348
367 | factor = 0.01
368 | else:
369 | raise RuntimeError("Unknown timepoint %d" % args.leaveout_timepoint)
370 | args.int_tps[lt] = (1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[
371 | lt + 1
372 | ]
373 | losses = eval_utils.evaluate_kantorovich_v2(device, args, model)
374 | losses_list.append(losses)
375 | print(np.array(losses_list))
376 | np.save(os.path.join(args.save, "emd_list"), np.array(losses_list))
377 | # zs = np.load(os.path.join(args.save, 'backward_trajectories'))
378 | # losses = eval_utils.evaluate_mse(device, args, model)
379 | # losses = eval_utils.evaluate_kantorovich(device, args, model)
380 | # print(losses)
381 | # eval_utils.generate_samples(device, args, model, growth_model, timepoint=args.timepoints[-1])
382 | # eval_utils.calculate_path_length(device, args, model, data, args.int_tps[-1])
383 |
384 |
385 | if __name__ == "__main__":
386 | args = parser.parse_args()
387 | main(args)
388 |
--------------------------------------------------------------------------------
/TrajectoryNet/eval_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import os
4 | import torch
5 |
6 | from .optimal_transport.emd import earth_mover_distance
7 |
8 |
9 | def generate_samples(device, args, model, growth_model, n=10000, timepoint=None):
10 | """generates samples using model and base density
11 |
12 | This is useful for measuring the wasserstein distance between the
13 | predicted distribution and the true distribution for evaluation
14 | purposes against other types of models. We should use
15 | negative log likelihood if possible as it is deterministic and
16 | more discriminative for this model type.
17 |
18 | TODO: Is this biased???
19 | """
20 | z_samples = args.data.base_sample()(n, *args.data.get_shape()).to(device)
21 | # Forward pass through the model / growth model
22 | with torch.no_grad():
23 | int_list = [
24 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device)
25 | for it in args.int_tps[: timepoint + 1]
26 | ]
27 |
28 | logpz = args.data.base_density()(z_samples)
29 | z = z_samples
30 | for it in int_list:
31 | z, logpz = model(z, logpz, integration_times=it, reverse=True)
32 | z = z.cpu().numpy()
33 | np.save(os.path.join(args.save, "samples_%0.2f.npy" % timepoint), z)
34 | logpz = logpz.cpu().numpy()
35 | plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5)
36 | original_data = args.data.get_data()[args.data.get_times() == timepoint]
37 | idx = np.random.randint(original_data.shape[0], size=n)
38 | samples = original_data[idx, :]
39 | plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
40 | plt.savefig(os.path.join(args.save, "samples%d.png" % timepoint))
41 | plt.close()
42 |
43 | pz = np.exp(logpz)
44 | pz = pz / np.sum(pz)
45 | print(pz)
46 |
47 | print(
48 | earth_mover_distance(
49 | original_data, samples + np.random.randn(*samples.shape) * 0.1
50 | )
51 | )
52 |
53 | print(earth_mover_distance(z, original_data))
54 | print(earth_mover_distance(z, samples))
55 | # print(earth_mover_distance(z, original_data, weights1=pz.flatten()))
56 | # print(
57 | # earth_mover_distance(
58 | # args.data.get_data()[args.data.get_times() == (timepoint - 1)],
59 | # original_data,
60 | # )
61 | # )
62 |
63 | if args.use_growth and growth_model is not None:
64 | raise NotImplementedError(
65 | "generating samples with growth model is not yet implemented"
66 | )
67 |
68 |
69 | def calculate_path_length(device, args, model, data, end_time, n_pts=10000):
70 | """Calculates the total length of the path from time 0 to timepoint"""
71 | # z_samples = torch.tensor(data.get_data()).type(torch.float32).to(device)
72 | z_samples = data.base_sample()(n_pts, *data.get_shape()).to(device)
73 | model.eval()
74 | n = 1001
75 | with torch.no_grad():
76 | integration_times = (
77 | torch.tensor(np.linspace(0, end_time, n)).type(torch.float32).to(device)
78 | )
79 | # z, _ = model(z_samples, torch.zeros_like(z_samples), integration_times=integration_times, reverse=False)
80 | z, _ = model(
81 | z_samples,
82 | torch.zeros_like(z_samples),
83 | integration_times=integration_times,
84 | reverse=True,
85 | )
86 | z = z.cpu().numpy()
87 | z_diff = np.diff(z, axis=0)
88 | z_lengths = np.sum(np.linalg.norm(z_diff, axis=-1), axis=0)
89 | total_length = np.mean(z_lengths)
90 | import ot as pot
91 | from scipy.spatial.distance import cdist
92 |
93 | emd = pot.emd2(
94 | np.ones(n_pts) / n_pts,
95 | np.ones(n_pts) / n_pts,
96 | cdist(z[-1, :, :], data.get_data()),
97 | )
98 | print(total_length, emd)
99 | plt.scatter(z[-1, :, 0], z[-1, :, 1])
100 | plt.savefig("test.png")
101 | plt.close()
102 |
103 |
104 | def evaluate_mse(device, args, model, growth_model=None):
105 | if args.use_growth or growth_model is not None:
106 | print("WARNING: Ignoring growth model and computing anyway")
107 |
108 | paths = args.data.get_paths()
109 |
110 | z_samples = torch.tensor(paths[:, 0, :]).type(torch.float32).to(device)
111 | # Forward pass through the model / growth model
112 | with torch.no_grad():
113 | int_list = [
114 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device)
115 | for it in args.int_tps
116 | ]
117 |
118 | logpz = args.data.base_density()(z_samples)
119 | z = z_samples
120 | zs = []
121 | for it in int_list:
122 | z, _ = model(z, logpz, integration_times=it, reverse=True)
123 | zs.append(z.cpu().numpy())
124 | zs = np.stack(zs)
125 | np.save(os.path.join(args.save, "path_samples.npy"), zs)
126 |
127 | # logpz = logpz.cpu().numpy()
128 | # plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5)
129 | mses = []
130 | print(zs.shape, paths[:, 1, :].shape)
131 | for tpi in range(len(args.timepoints)):
132 | mses.append(np.mean((paths[:, tpi + 1, :] - zs[tpi]) ** 2, axis=(-2, -1)))
133 | mses = np.array(mses)
134 | print(mses.shape)
135 | np.save(os.path.join(args.save, "mses.npy"), mses)
136 | return mses
137 |
138 |
139 | def evaluate_kantorovich_v2(device, args, model, growth_model=None):
140 | """Eval the model via kantorovich distance on leftout timepoint
141 |
142 | v2 computes samples from subsequent timepoint instead of base distribution.
143 | this is arguably a fairer comparison to other methods such as WOT which are
144 | not model based this should accumulate much less numerical error in the
145 | integration procedure. However fixes to the number of samples to the number in the
146 | previous timepoint.
147 |
148 | If we have a growth model we should use this to modify the weighting of the
149 | points over time.
150 | """
151 | if args.use_growth or growth_model is not None:
152 | # raise NotImplementedError(
153 | # "generating samples with growth model is not yet implemented"
154 | # )
155 | print("WARNING: Ignoring growth model and computing anyway")
156 |
157 | # Backward pass through the model / growth model
158 | with torch.no_grad():
159 | int_times = torch.tensor(
160 | [
161 | args.int_tps[args.leaveout_timepoint],
162 | args.int_tps[args.leaveout_timepoint + 1],
163 | ]
164 | )
165 | int_times = int_times.type(torch.float32).to(device)
166 | next_z = args.data.get_data()[
167 | args.data.get_times() == args.leaveout_timepoint + 1
168 | ]
169 | next_z = torch.from_numpy(next_z).type(torch.float32).to(device)
170 | prev_z = args.data.get_data()[
171 | args.data.get_times() == args.leaveout_timepoint - 1
172 | ]
173 | prev_z = torch.from_numpy(prev_z).type(torch.float32).to(device)
174 | zero = torch.zeros(next_z.shape[0], 1).to(device)
175 | z_backward, _ = model.chain[0](next_z, zero, integration_times=int_times)
176 | z_backward = z_backward.cpu().numpy()
177 | int_times = torch.tensor(
178 | [
179 | args.int_tps[args.leaveout_timepoint - 1],
180 | args.int_tps[args.leaveout_timepoint],
181 | ]
182 | )
183 | zero = torch.zeros(prev_z.shape[0], 1).to(device)
184 | z_forward, _ = model.chain[0](
185 | prev_z, zero, integration_times=int_times, reverse=True
186 | )
187 | z_forward = z_forward.cpu().numpy()
188 |
189 | emds = []
190 | for tpi in [args.leaveout_timepoint]:
191 | original_data = args.data.get_data()[
192 | args.data.get_times() == args.timepoints[tpi]
193 | ]
194 | emds.append(earth_mover_distance(z_backward, original_data))
195 | emds.append(earth_mover_distance(z_forward, original_data))
196 |
197 | emds = np.array(emds)
198 | np.save(os.path.join(args.save, "emds_v2.npy"), emds)
199 | return emds
200 |
201 |
202 | def evaluate_kantorovich(device, args, model, growth_model=None, n=10000):
203 | """Eval the model via kantorovich distance on all timepoints
204 |
205 | compute samples forward from the starting parametric distribution keeping track
206 | of growth rate to scale the final distribution.
207 |
208 | The growth model is a single model of time independent cell growth /
209 | death rate defined as a variation from uniform.
210 |
211 | If we have a growth model we should use this to modify the weighting of the
212 | points over time.
213 | """
214 | if args.use_growth or growth_model is not None:
215 | # raise NotImplementedError(
216 | # "generating samples with growth model is not yet implemented"
217 | # )
218 | print("WARNING: Ignoring growth model and computing anyway")
219 |
220 | z_samples = args.data.base_sample()(n, *args.data.get_shape()).to(device)
221 | # Forward pass through the model / growth model
222 | with torch.no_grad():
223 | int_list = []
224 | for i, it in enumerate(args.int_tps):
225 | if i == 0:
226 | prev = 0.0
227 | else:
228 | prev = args.int_tps[i - 1]
229 | int_list.append(torch.tensor([prev, it]).type(torch.float32).to(device))
230 |
231 | # int_list = [
232 | # torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device)
233 | # for it in args.int_tps
234 | # ]
235 | print(args.int_tps)
236 |
237 | logpz = args.data.base_density()(z_samples)
238 | z = z_samples
239 | zs = []
240 | growthrates = [torch.ones(z_samples.shape[0], 1).to(device)]
241 | for it, tp in zip(int_list, args.timepoints):
242 | z, _ = model(z, logpz, integration_times=it, reverse=True)
243 | zs.append(z.cpu().numpy())
244 | if args.use_growth:
245 | time_state = tp * torch.ones(z.shape[0], 1).to(device)
246 | full_state = torch.cat([z, time_state], 1)
247 | # Multiply growth rates together to get total mass along path
248 | growthrates.append(
249 | torch.clamp(growth_model(full_state), 1e-4, 1e4) * growthrates[-1]
250 | )
251 | zs = np.stack(zs)
252 | if args.use_growth:
253 | growthrates = growthrates[1:]
254 | growthrates = torch.stack(growthrates)
255 | growthrates = growthrates.cpu().numpy()
256 | np.save(os.path.join(args.save, "sample_weights.npy"), growthrates)
257 | np.save(os.path.join(args.save, "samples.npy"), zs)
258 |
259 | # logpz = logpz.cpu().numpy()
260 | # plt.scatter(z[:, 0], z[:, 1], s=0.1, alpha=0.5)
261 | emds = []
262 | for tpi in range(len(args.timepoints)):
263 | original_data = args.data.get_data()[
264 | args.data.get_times() == args.timepoints[tpi]
265 | ]
266 | if args.use_growth:
267 | emds.append(
268 | earth_mover_distance(
269 | zs[tpi], original_data, weights1=growthrates[tpi].flatten()
270 | )
271 | )
272 | else:
273 | emds.append(earth_mover_distance(zs[tpi], original_data))
274 |
275 | # Add validation point kantorovich distance evaluation
276 | if args.data.has_validation_samples():
277 | for tpi in np.unique(args.data.val_labels):
278 | original_data = args.data.val_data[
279 | args.data.val_labels == args.timepoints[tpi]
280 | ]
281 | if args.use_growth:
282 | emds.append(
283 | earth_mover_distance(
284 | zs[tpi], original_data, weights1=growthrates[tpi].flatten()
285 | )
286 | )
287 | else:
288 | emds.append(earth_mover_distance(zs[tpi], original_data))
289 |
290 | emds = np.array(emds)
291 | print(emds)
292 | np.save(os.path.join(args.save, "emds.npy"), emds)
293 | return emds
294 |
295 |
296 | def evaluate(device, args, model, growth_model=None):
297 | """Eval the model via negative log likelihood on all timepoints
298 |
299 | Compute loss by integrating backwards from the last time step
300 | At each time step integrate back one time step, and concatenate that
301 | to samples of the empirical distribution at that previous timestep
302 | repeating over and over to calculate the likelihood of samples in
303 | later timepoints iteratively, making sure that the ODE is evaluated
304 | at every time step to calculate those later points.
305 |
306 | The growth model is a single model of time independent cell growth /
307 | death rate defined as a variation from uniform.
308 | """
309 | use_growth = args.use_growth and growth_model is not None
310 |
311 | # Backward pass accumulating losses, previous state and deltas
312 | deltas = []
313 | zs = []
314 | z = None
315 | for i, (itp, tp) in enumerate(zip(args.int_tps[::-1], args.timepoints[::-1])):
316 | # tp counts down from last
317 | integration_times = torch.tensor([itp - args.time_scale, itp])
318 | integration_times = integration_times.type(torch.float32).to(device)
319 |
320 | x = args.data.get_data()[args.data.get_times() == tp]
321 | x = torch.from_numpy(x).type(torch.float32).to(device)
322 |
323 | if i > 0:
324 | x = torch.cat((z, x))
325 | zs.append(z)
326 | zero = torch.zeros(x.shape[0], 1).to(x)
327 |
328 | # transform to previous timepoint
329 | z, delta_logp = model(x, zero, integration_times=integration_times)
330 | deltas.append(delta_logp)
331 |
332 | logpz = args.data.base_density()(z)
333 |
334 | # build growth rates
335 | if use_growth:
336 | growthrates = [torch.ones_like(logpz)]
337 | for z_state, tp in zip(zs[::-1], args.timepoints[::-1][1:]):
338 | # Full state includes time parameter to growth_model
339 | time_state = tp * torch.ones(z_state.shape[0], 1).to(z_state)
340 | full_state = torch.cat([z_state, time_state], 1)
341 | growthrates.append(growth_model(full_state))
342 |
343 | # Accumulate losses
344 | losses = []
345 | logps = [logpz]
346 | for i, (delta_logp, tp) in enumerate(zip(deltas[::-1], args.timepoints)):
347 | n_cells_in_tp = np.sum(args.data.get_times() == tp)
348 | logpx = logps[-1] - delta_logp
349 | if use_growth:
350 | logpx += torch.log(growthrates[i])
351 | logps.append(logpx[:-n_cells_in_tp])
352 | losses.append(-torch.sum(logpx[-n_cells_in_tp:]))
353 | losses = torch.stack(losses).cpu().numpy()
354 | np.save(os.path.join(args.save, "nll.npy"), losses)
355 | return losses
356 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/growth_net.py:
--------------------------------------------------------------------------------
1 | """ Implements a very simple growth network
2 |
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class GrowthNet(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 | self.fc1 = nn.Linear(3, 64)
12 | self.fc2 = nn.Linear(64, 64)
13 | self.fc3 = nn.Linear(64, 1)
14 |
15 | def forward(self, x):
16 | x = F.leaky_relu(self.fc1(x))
17 | x = F.leaky_relu(self.fc2(x))
18 | x = self.fc3(x)
19 | return x
20 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .elemwise import *
2 | from .container import *
3 | from .cnf import *
4 | from .odefunc import *
5 | from .squeeze import *
6 | from .normalization import *
7 | from . import diffeq_layers
8 | from .coupling import *
9 | from .glow import *
10 | from .norm_flows import *
11 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/cnf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torchdiffeq import odeint_adjoint as odeint
5 |
6 | from .wrappers.cnf_regularization import RegularizedODEfunc
7 |
8 | __all__ = ["CNF"]
9 |
10 |
11 | class CNF(nn.Module):
12 | def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5):
13 | super(CNF, self).__init__()
14 | if train_T:
15 | self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
16 | else:
17 | self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T)))
18 |
19 | nreg = 0
20 | if regularization_fns is not None:
21 | odefunc = RegularizedODEfunc(odefunc, regularization_fns)
22 | nreg = len(regularization_fns)
23 | self.odefunc = odefunc
24 | self.nreg = nreg
25 | self.regularization_states = None
26 | self.solver = solver
27 | self.atol = atol
28 | self.rtol = rtol
29 | self.test_solver = solver
30 | self.test_atol = atol
31 | self.test_rtol = rtol
32 | self.solver_options = {}
33 |
34 | def forward(self, z, logpz=None, integration_times=None, reverse=False):
35 |
36 | if logpz is None:
37 | _logpz = torch.zeros(z.shape[0], 1).to(z)
38 | else:
39 | _logpz = logpz
40 |
41 | if integration_times is None:
42 | integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
43 | if reverse:
44 | integration_times = _flip(integration_times, 0)
45 |
46 | # Refresh the odefunc statistics.
47 | self.odefunc.before_odeint()
48 |
49 | # Add regularization states.
50 | reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))
51 |
52 | if self.training:
53 | state_t = odeint(
54 | self.odefunc,
55 | (z, _logpz) + reg_states,
56 | integration_times.to(z),
57 | atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
58 | rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
59 | method=self.solver,
60 | options=self.solver_options,
61 | )
62 | else:
63 | state_t = odeint(
64 | self.odefunc,
65 | (z, _logpz),
66 | integration_times.to(z),
67 | atol=self.test_atol,
68 | rtol=self.test_rtol,
69 | method=self.test_solver,
70 | )
71 |
72 | if len(integration_times) == 2:
73 | state_t = tuple(s[1] for s in state_t)
74 |
75 | z_t, logpz_t = state_t[:2]
76 | self.regularization_states = state_t[2:]
77 |
78 | if logpz is not None:
79 | return z_t, logpz_t
80 | else:
81 | return z_t
82 |
83 | def get_regularization_states(self):
84 | reg_states = self.regularization_states
85 | self.regularization_states = None
86 | return reg_states
87 |
88 | def num_evals(self):
89 | return self.odefunc._num_evals.item()
90 |
91 |
92 | def _flip(x, dim):
93 | indices = [slice(None)] * x.dim()
94 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
95 | return x[tuple(indices)]
96 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/container.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class SequentialFlow(nn.Module):
5 | """A generalized nn.Sequential container for normalizing flows.
6 | """
7 |
8 | def __init__(self, layersList):
9 | super(SequentialFlow, self).__init__()
10 | self.chain = nn.ModuleList(layersList)
11 |
12 | def forward(self, x, logpx=None, integration_times=None, reverse=False, inds=None):
13 | if inds is None:
14 | if reverse:
15 | inds = range(len(self.chain) - 1, -1, -1)
16 | else:
17 | inds = range(len(self.chain))
18 |
19 | if logpx is None:
20 | for i in inds:
21 | x = self.chain[i](x, reverse=reverse)
22 | return x
23 | else:
24 | for i in inds:
25 | x, logpx = self.chain[i](x, logpx, integration_times=integration_times, reverse=reverse)
26 | return x, logpx
27 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = ['CouplingLayer', 'MaskedCouplingLayer']
5 |
6 |
7 | class CouplingLayer(nn.Module):
8 | """Used in 2D experiments."""
9 |
10 | def __init__(self, d, intermediate_dim=64, swap=False):
11 | nn.Module.__init__(self)
12 | self.d = d - (d // 2)
13 | self.swap = swap
14 | self.net_s_t = nn.Sequential(
15 | nn.Linear(self.d, intermediate_dim),
16 | nn.ReLU(inplace=True),
17 | nn.Linear(intermediate_dim, intermediate_dim),
18 | nn.ReLU(inplace=True),
19 | nn.Linear(intermediate_dim, (d - self.d) * 2),
20 | )
21 |
22 | def forward(self, x, logpx=None, reverse=False):
23 |
24 | if self.swap:
25 | x = torch.cat([x[:, self.d:], x[:, :self.d]], 1)
26 |
27 | in_dim = self.d
28 | out_dim = x.shape[1] - self.d
29 |
30 | s_t = self.net_s_t(x[:, :in_dim])
31 | scale = torch.sigmoid(s_t[:, :out_dim] + 2.)
32 | shift = s_t[:, out_dim:]
33 |
34 | logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True)
35 |
36 | if not reverse:
37 | y1 = x[:, self.d:] * scale + shift
38 | delta_logp = -logdetjac
39 | else:
40 | y1 = (x[:, self.d:] - shift) / scale
41 | delta_logp = logdetjac
42 |
43 | y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1)
44 |
45 | if logpx is None:
46 | return y
47 | else:
48 | return y, logpx + delta_logp
49 |
50 |
51 | class MaskedCouplingLayer(nn.Module):
52 | """Used in the tabular experiments."""
53 |
54 | def __init__(self, d, hidden_dims, mask_type='alternate', swap=False):
55 | nn.Module.__init__(self)
56 | self.d = d
57 | self.register_buffer('mask', sample_mask(d, mask_type, swap).view(1, d))
58 | self.net_scale = build_net(d, hidden_dims, activation="tanh")
59 | self.net_shift = build_net(d, hidden_dims, activation="relu")
60 |
61 | def forward(self, x, logpx=None, reverse=False):
62 |
63 | scale = torch.exp(self.net_scale(x * self.mask))
64 | shift = self.net_shift(x * self.mask)
65 |
66 | masked_scale = scale * (1 - self.mask) + torch.ones_like(scale) * self.mask
67 | masked_shift = shift * (1 - self.mask)
68 |
69 | logdetjac = torch.sum(torch.log(masked_scale).view(scale.shape[0], -1), 1, keepdim=True)
70 |
71 | if not reverse:
72 | y = x * masked_scale + masked_shift
73 | delta_logp = -logdetjac
74 | else:
75 | y = (x - masked_shift) / masked_scale
76 | delta_logp = logdetjac
77 |
78 | if logpx is None:
79 | return y
80 | else:
81 | return y, logpx + delta_logp
82 |
83 |
84 | def sample_mask(dim, mask_type, swap):
85 | if mask_type == 'alternate':
86 | # Index-based masking in MAF paper.
87 | mask = torch.zeros(dim)
88 | mask[::2] = 1
89 | if swap:
90 | mask = 1 - mask
91 | return mask
92 | elif mask_type == 'channel':
93 | # Masking type used in Real NVP paper.
94 | mask = torch.zeros(dim)
95 | mask[:dim // 2] = 1
96 | if swap:
97 | mask = 1 - mask
98 | return mask
99 | else:
100 | raise ValueError('Unknown mask_type {}'.format(mask_type))
101 |
102 |
103 | def build_net(input_dim, hidden_dims, activation="relu"):
104 | dims = (input_dim,) + tuple(hidden_dims) + (input_dim,)
105 | activation_modules = {"relu": nn.ReLU(inplace=True), "tanh": nn.Tanh()}
106 |
107 | chain = []
108 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
109 | chain.append(nn.Linear(in_dim, out_dim))
110 | if i < len(hidden_dims):
111 | chain.append(activation_modules[activation])
112 | return nn.Sequential(*chain)
113 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/diffeq_layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .container import *
2 | from .resnet import *
3 | from .basic import *
4 | from .wrappers import *
5 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/diffeq_layers/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def weights_init(m):
7 | classname = m.__class__.__name__
8 | if classname.find('Linear') != -1 or classname.find('Conv') != -1:
9 | nn.init.constant_(m.weight, 0)
10 | nn.init.normal_(m.bias, 0, 0.01)
11 |
12 |
13 | class HyperLinear(nn.Module):
14 | def __init__(self, dim_in, dim_out, hypernet_dim=8, n_hidden=1, activation=nn.Tanh):
15 | super(HyperLinear, self).__init__()
16 | self.dim_in = dim_in
17 | self.dim_out = dim_out
18 | self.params_dim = self.dim_in * self.dim_out + self.dim_out
19 |
20 | layers = []
21 | dims = [1] + [hypernet_dim] * n_hidden + [self.params_dim]
22 | for i in range(1, len(dims)):
23 | layers.append(nn.Linear(dims[i - 1], dims[i]))
24 | if i < len(dims) - 1:
25 | layers.append(activation())
26 | self._hypernet = nn.Sequential(*layers)
27 | self._hypernet.apply(weights_init)
28 |
29 | def forward(self, t, x):
30 | params = self._hypernet(t.view(1, 1)).view(-1)
31 | b = params[:self.dim_out].view(self.dim_out)
32 | w = params[self.dim_out:].view(self.dim_out, self.dim_in)
33 | return F.linear(x, w, b)
34 |
35 |
36 | class IgnoreLinear(nn.Module):
37 | def __init__(self, dim_in, dim_out):
38 | super(IgnoreLinear, self).__init__()
39 | self._layer = nn.Linear(dim_in, dim_out)
40 |
41 | def forward(self, t, x):
42 | return self._layer(x)
43 |
44 |
45 | class ConcatLinear(nn.Module):
46 | def __init__(self, dim_in, dim_out):
47 | super(ConcatLinear, self).__init__()
48 | self._layer = nn.Linear(dim_in + 1, dim_out)
49 |
50 | def forward(self, t, x):
51 | tt = torch.ones_like(x[:, :1]) * t
52 | ttx = torch.cat([tt, x], 1)
53 | return self._layer(ttx)
54 |
55 |
56 | class ConcatLinear_v2(nn.Module):
57 | def __init__(self, dim_in, dim_out):
58 | super(ConcatLinear, self).__init__()
59 | self._layer = nn.Linear(dim_in, dim_out)
60 | self._hyper_bias = nn.Linear(1, dim_out, bias=False)
61 |
62 | def forward(self, t, x):
63 | return self._layer(x) + self._hyper_bias(t.view(1, 1))
64 |
65 |
66 | class SquashLinear(nn.Module):
67 | def __init__(self, dim_in, dim_out):
68 | super(SquashLinear, self).__init__()
69 | self._layer = nn.Linear(dim_in, dim_out)
70 | self._hyper = nn.Linear(1, dim_out)
71 |
72 | def forward(self, t, x):
73 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1)))
74 |
75 |
76 | class ConcatSquashLinear(nn.Module):
77 | def __init__(self, dim_in, dim_out):
78 | super(ConcatSquashLinear, self).__init__()
79 | self._layer = nn.Linear(dim_in, dim_out)
80 | self._hyper_bias = nn.Linear(1, dim_out, bias=False)
81 | self._hyper_gate = nn.Linear(1, dim_out)
82 |
83 | def forward(self, t, x):
84 | return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))) \
85 | + self._hyper_bias(t.view(1, 1))
86 |
87 |
88 | class HyperConv2d(nn.Module):
89 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
90 | super(HyperConv2d, self).__init__()
91 | assert dim_in % groups == 0 and dim_out % groups == 0, "dim_in and dim_out must both be divisible by groups."
92 | self.dim_in = dim_in
93 | self.dim_out = dim_out
94 | self.ksize = ksize
95 | self.stride = stride
96 | self.padding = padding
97 | self.dilation = dilation
98 | self.groups = groups
99 | self.bias = bias
100 | self.transpose = transpose
101 |
102 | self.params_dim = int(dim_in * dim_out * ksize * ksize / groups)
103 | if self.bias:
104 | self.params_dim += dim_out
105 | self._hypernet = nn.Linear(1, self.params_dim)
106 | self.conv_fn = F.conv_transpose2d if transpose else F.conv2d
107 |
108 | self._hypernet.apply(weights_init)
109 |
110 | def forward(self, t, x):
111 | params = self._hypernet(t.view(1, 1)).view(-1)
112 | weight_size = int(self.dim_in * self.dim_out * self.ksize * self.ksize / self.groups)
113 | if self.transpose:
114 | weight = params[:weight_size].view(self.dim_in, self.dim_out // self.groups, self.ksize, self.ksize)
115 | else:
116 | weight = params[:weight_size].view(self.dim_out, self.dim_in // self.groups, self.ksize, self.ksize)
117 | bias = params[:self.dim_out].view(self.dim_out) if self.bias else None
118 | return self.conv_fn(
119 | x, weight=weight, bias=bias, stride=self.stride, padding=self.padding, groups=self.groups,
120 | dilation=self.dilation
121 | )
122 |
123 |
124 | class IgnoreConv2d(nn.Module):
125 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
126 | super(IgnoreConv2d, self).__init__()
127 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
128 | self._layer = module(
129 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
130 | bias=bias
131 | )
132 |
133 | def forward(self, t, x):
134 | return self._layer(x)
135 |
136 |
137 | class SquashConv2d(nn.Module):
138 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
139 | super(SquashConv2d, self).__init__()
140 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
141 | self._layer = module(
142 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
143 | bias=bias
144 | )
145 | self._hyper = nn.Linear(1, dim_out)
146 |
147 | def forward(self, t, x):
148 | return self._layer(x) * torch.sigmoid(self._hyper(t.view(1, 1))).view(1, -1, 1, 1)
149 |
150 |
151 | class ConcatConv2d(nn.Module):
152 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
153 | super(ConcatConv2d, self).__init__()
154 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
155 | self._layer = module(
156 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
157 | bias=bias
158 | )
159 |
160 | def forward(self, t, x):
161 | tt = torch.ones_like(x[:, :1, :, :]) * t
162 | ttx = torch.cat([tt, x], 1)
163 | return self._layer(ttx)
164 |
165 |
166 | class ConcatConv2d_v2(nn.Module):
167 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
168 | super(ConcatConv2d, self).__init__()
169 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
170 | self._layer = module(
171 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
172 | bias=bias
173 | )
174 | self._hyper_bias = nn.Linear(1, dim_out, bias=False)
175 |
176 | def forward(self, t, x):
177 | return self._layer(x) + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1)
178 |
179 |
180 | class ConcatSquashConv2d(nn.Module):
181 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
182 | super(ConcatSquashConv2d, self).__init__()
183 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
184 | self._layer = module(
185 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
186 | bias=bias
187 | )
188 | self._hyper_gate = nn.Linear(1, dim_out)
189 | self._hyper_bias = nn.Linear(1, dim_out, bias=False)
190 |
191 | def forward(self, t, x):
192 | return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))).view(1, -1, 1, 1) \
193 | + self._hyper_bias(t.view(1, 1)).view(1, -1, 1, 1)
194 |
195 |
196 | class ConcatCoordConv2d(nn.Module):
197 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
198 | super(ConcatCoordConv2d, self).__init__()
199 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
200 | self._layer = module(
201 | dim_in + 3, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
202 | bias=bias
203 | )
204 |
205 | def forward(self, t, x):
206 | b, c, h, w = x.shape
207 | hh = torch.arange(h).to(x).view(1, 1, h, 1).expand(b, 1, h, w)
208 | ww = torch.arange(w).to(x).view(1, 1, 1, w).expand(b, 1, h, w)
209 | tt = t.to(x).view(1, 1, 1, 1).expand(b, 1, h, w)
210 | x_aug = torch.cat([x, tt, hh, ww], 1)
211 | return self._layer(x_aug)
212 |
213 |
214 | class GatedLinear(nn.Module):
215 | def __init__(self, in_features, out_features):
216 | super(GatedLinear, self).__init__()
217 | self.layer_f = nn.Linear(in_features, out_features)
218 | self.layer_g = nn.Linear(in_features, out_features)
219 |
220 | def forward(self, x):
221 | f = self.layer_f(x)
222 | g = torch.sigmoid(self.layer_g(x))
223 | return f * g
224 |
225 |
226 | class GatedConv(nn.Module):
227 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1):
228 | super(GatedConv, self).__init__()
229 | self.layer_f = nn.Conv2d(
230 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=1, groups=groups
231 | )
232 | self.layer_g = nn.Conv2d(
233 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=1, groups=groups
234 | )
235 |
236 | def forward(self, x):
237 | f = self.layer_f(x)
238 | g = torch.sigmoid(self.layer_g(x))
239 | return f * g
240 |
241 |
242 | class GatedConvTranspose(nn.Module):
243 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1):
244 | super(GatedConvTranspose, self).__init__()
245 | self.layer_f = nn.ConvTranspose2d(
246 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding,
247 | groups=groups
248 | )
249 | self.layer_g = nn.ConvTranspose2d(
250 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding,
251 | groups=groups
252 | )
253 |
254 | def forward(self, x):
255 | f = self.layer_f(x)
256 | g = torch.sigmoid(self.layer_g(x))
257 | return f * g
258 |
259 |
260 | class BlendLinear(nn.Module):
261 | def __init__(self, dim_in, dim_out, layer_type=nn.Linear, **unused_kwargs):
262 | super(BlendLinear, self).__init__()
263 | self._layer0 = layer_type(dim_in, dim_out)
264 | self._layer1 = layer_type(dim_in, dim_out)
265 |
266 | def forward(self, t, x):
267 | y0 = self._layer0(x)
268 | y1 = self._layer1(x)
269 | return y0 + (y1 - y0) * t
270 |
271 |
272 | class BlendConv2d(nn.Module):
273 | def __init__(
274 | self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False,
275 | **unused_kwargs
276 | ):
277 | super(BlendConv2d, self).__init__()
278 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
279 | self._layer0 = module(
280 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
281 | bias=bias
282 | )
283 | self._layer1 = module(
284 | dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
285 | bias=bias
286 | )
287 |
288 | def forward(self, t, x):
289 | y0 = self._layer0(x)
290 | y1 = self._layer1(x)
291 | return y0 + (y1 - y0) * t
292 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/diffeq_layers/container.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .wrappers import diffeq_wrapper
5 |
6 |
7 | class SequentialDiffEq(nn.Module):
8 | """A container for a sequential chain of layers. Supports both regular and diffeq layers.
9 | """
10 |
11 | def __init__(self, *layers):
12 | super(SequentialDiffEq, self).__init__()
13 | self.layers = nn.ModuleList([diffeq_wrapper(layer) for layer in layers])
14 |
15 | def forward(self, t, x):
16 | for layer in self.layers:
17 | x = layer(t, x)
18 | return x
19 |
20 |
21 | class MixtureODELayer(nn.Module):
22 | """Produces a mixture of experts where output = sigma(t) * f(t, x).
23 | Time-dependent weights sigma(t) help learn to blend the experts without resorting to a highly stiff f.
24 | Supports both regular and diffeq experts.
25 | """
26 |
27 | def __init__(self, experts):
28 | super(MixtureODELayer, self).__init__()
29 | assert len(experts) > 1
30 | wrapped_experts = [diffeq_wrapper(ex) for ex in experts]
31 | self.experts = nn.ModuleList(wrapped_experts)
32 | self.mixture_weights = nn.Linear(1, len(self.experts))
33 |
34 | def forward(self, t, y):
35 | dys = []
36 | for f in self.experts:
37 | dys.append(f(t, y))
38 | dys = torch.stack(dys, 0)
39 | weights = self.mixture_weights(t).view(-1, *([1] * (dys.ndimension() - 1)))
40 |
41 | dy = torch.sum(dys * weights, dim=0, keepdim=False)
42 | return dy
43 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/diffeq_layers/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from . import basic
4 | from . import container
5 |
6 | NGROUPS = 16
7 |
8 |
9 | class ResNet(container.SequentialDiffEq):
10 | def __init__(self, dim, intermediate_dim, n_resblocks, conv_block=None):
11 | super(ResNet, self).__init__()
12 |
13 | if conv_block is None:
14 | conv_block = basic.ConcatCoordConv2d
15 |
16 | self.dim = dim
17 | self.intermediate_dim = intermediate_dim
18 | self.n_resblocks = n_resblocks
19 |
20 | layers = []
21 | layers.append(conv_block(dim, intermediate_dim, ksize=3, stride=1, padding=1, bias=False))
22 | for _ in range(n_resblocks):
23 | layers.append(BasicBlock(intermediate_dim, conv_block))
24 | layers.append(nn.GroupNorm(NGROUPS, intermediate_dim, eps=1e-4))
25 | layers.append(nn.ReLU(inplace=True))
26 | layers.append(conv_block(intermediate_dim, dim, ksize=1, bias=False))
27 |
28 | super(ResNet, self).__init__(*layers)
29 |
30 | def __repr__(self):
31 | return (
32 | '{name}({dim}, intermediate_dim={intermediate_dim}, n_resblocks={n_resblocks})'.format(
33 | name=self.__class__.__name__, **self.__dict__
34 | )
35 | )
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, dim, conv_block=None):
42 | super(BasicBlock, self).__init__()
43 |
44 | if conv_block is None:
45 | conv_block = basic.ConcatCoordConv2d
46 |
47 | self.norm1 = nn.GroupNorm(NGROUPS, dim, eps=1e-4)
48 | self.relu1 = nn.ReLU(inplace=True)
49 | self.conv1 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False)
50 | self.norm2 = nn.GroupNorm(NGROUPS, dim, eps=1e-4)
51 | self.relu2 = nn.ReLU(inplace=True)
52 | self.conv2 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False)
53 |
54 | def forward(self, t, x):
55 | residual = x
56 |
57 | out = self.norm1(x)
58 | out = self.relu1(out)
59 | out = self.conv1(t, out)
60 |
61 | out = self.norm2(out)
62 | out = self.relu2(out)
63 | out = self.conv2(t, out)
64 |
65 | out += residual
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/diffeq_layers/wrappers.py:
--------------------------------------------------------------------------------
1 | from inspect import signature
2 | import torch.nn as nn
3 |
4 | __all__ = ["diffeq_wrapper", "reshape_wrapper"]
5 |
6 |
7 | class DiffEqWrapper(nn.Module):
8 | def __init__(self, module):
9 | super(DiffEqWrapper, self).__init__()
10 | self.module = module
11 | if len(signature(self.module.forward).parameters) == 1:
12 | self.diffeq = lambda t, y: self.module(y)
13 | elif len(signature(self.module.forward).parameters) == 2:
14 | self.diffeq = self.module
15 | else:
16 | raise ValueError("Differential equation needs to either take (t, y) or (y,) as input.")
17 |
18 | def forward(self, t, y):
19 | return self.diffeq(t, y)
20 |
21 | def __repr__(self):
22 | return self.diffeq.__repr__()
23 |
24 |
25 | def diffeq_wrapper(layer):
26 | return DiffEqWrapper(layer)
27 |
28 |
29 | class ReshapeDiffEq(nn.Module):
30 | def __init__(self, input_shape, net):
31 | super(ReshapeDiffEq, self).__init__()
32 | assert len(signature(net.forward).parameters) == 2, "use diffeq_wrapper before reshape_wrapper."
33 | self.input_shape = input_shape
34 | self.net = net
35 |
36 | def forward(self, t, x):
37 | batchsize = x.shape[0]
38 | x = x.view(batchsize, *self.input_shape)
39 | return self.net(t, x).view(batchsize, -1)
40 |
41 | def __repr__(self):
42 | return self.diffeq.__repr__()
43 |
44 |
45 | def reshape_wrapper(input_shape, layer):
46 | return ReshapeDiffEq(input_shape, layer)
47 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/elemwise.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | _DEFAULT_ALPHA = 1e-6
6 |
7 |
8 | class ZeroMeanTransform(nn.Module):
9 | def __init__(self):
10 | nn.Module.__init__(self)
11 |
12 | def forward(self, x, logpx=None, reverse=False):
13 | if reverse:
14 | x = x + .5
15 | if logpx is None:
16 | return x
17 | return x, logpx
18 | else:
19 | x = x - .5
20 | if logpx is None:
21 | return x
22 | return x, logpx
23 |
24 |
25 | class LogitTransform(nn.Module):
26 | """
27 | The proprocessing step used in Real NVP:
28 | y = sigmoid(x) - a / (1 - 2a)
29 | x = logit(a + (1 - 2a)*y)
30 | """
31 |
32 | def __init__(self, alpha=_DEFAULT_ALPHA):
33 | nn.Module.__init__(self)
34 | self.alpha = alpha
35 |
36 | def forward(self, x, logpx=None, reverse=False):
37 | if reverse:
38 | return _sigmoid(x, logpx, self.alpha)
39 | else:
40 | return _logit(x, logpx, self.alpha)
41 |
42 |
43 | class SigmoidTransform(nn.Module):
44 | """Reverse of LogitTransform."""
45 |
46 | def __init__(self, alpha=_DEFAULT_ALPHA):
47 | nn.Module.__init__(self)
48 | self.alpha = alpha
49 |
50 | def forward(self, x, logpx=None, reverse=False):
51 | if reverse:
52 | return _logit(x, logpx, self.alpha)
53 | else:
54 | return _sigmoid(x, logpx, self.alpha)
55 |
56 |
57 | def _logit(x, logpx=None, alpha=_DEFAULT_ALPHA):
58 | s = alpha + (1 - 2 * alpha) * x
59 | y = torch.log(s) - torch.log(1 - s)
60 | if logpx is None:
61 | return y
62 | return y, logpx - _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True)
63 |
64 |
65 | def _sigmoid(y, logpy=None, alpha=_DEFAULT_ALPHA):
66 | x = (torch.sigmoid(y) - alpha) / (1 - 2 * alpha)
67 | if logpy is None:
68 | return x
69 | return x, logpy + _logdetgrad(x, alpha).view(x.size(0), -1).sum(1, keepdim=True)
70 |
71 |
72 | def _logdetgrad(x, alpha):
73 | s = alpha + (1 - 2 * alpha) * x
74 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * alpha)
75 | return logdetgrad
76 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/glow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class BruteForceLayer(nn.Module):
7 |
8 | def __init__(self, dim):
9 | super(BruteForceLayer, self).__init__()
10 | self.weight = nn.Parameter(torch.eye(dim))
11 |
12 | def forward(self, x, logpx=None, reverse=False):
13 |
14 | if not reverse:
15 | y = F.linear(x, self.weight)
16 | if logpx is None:
17 | return y
18 | else:
19 | return y, logpx - self._logdetgrad.expand_as(logpx)
20 |
21 | else:
22 | y = F.linear(x, self.weight.double().inverse().float())
23 | if logpx is None:
24 | return y
25 | else:
26 | return y, logpx + self._logdetgrad.expand_as(logpx)
27 |
28 | @property
29 | def _logdetgrad(self):
30 | return torch.log(torch.abs(torch.det(self.weight.double()))).float()
31 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/norm_flows.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import grad
5 |
6 |
7 | class PlanarFlow(nn.Module):
8 |
9 | def __init__(self, nd=1):
10 | super(PlanarFlow, self).__init__()
11 | self.nd = nd
12 | self.activation = torch.tanh
13 |
14 | self.register_parameter('u', nn.Parameter(torch.randn(self.nd)))
15 | self.register_parameter('w', nn.Parameter(torch.randn(self.nd)))
16 | self.register_parameter('b', nn.Parameter(torch.randn(1)))
17 | self.reset_parameters()
18 |
19 | def reset_parameters(self):
20 | stdv = 1. / math.sqrt(self.nd)
21 | self.u.data.uniform_(-stdv, stdv)
22 | self.w.data.uniform_(-stdv, stdv)
23 | self.b.data.fill_(0)
24 | self.make_invertible()
25 |
26 | def make_invertible(self):
27 | u = self.u.data
28 | w = self.w.data
29 | dot = torch.dot(u, w)
30 | m = -1 + math.log(1 + math.exp(dot))
31 | du = (m - dot) / torch.norm(w) * w
32 | u = u + du
33 | self.u.data = u
34 |
35 | def forward(self, z, logp=None, reverse=False):
36 | """Computes f(z) and log q(f(z))"""
37 |
38 | assert not reverse, 'Planar normalizing flow cannot be reversed.'
39 |
40 | logp - torch.log(self._detgrad(z) + 1e-8)
41 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b)
42 | z = z + self.u.expand_as(z) * h
43 |
44 | f = self.sample(z)
45 | if logp is not None:
46 | qf = self.log_density(z, logp)
47 | return f, qf
48 | else:
49 | return f
50 |
51 | def sample(self, z):
52 | """Computes f(z)"""
53 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b)
54 | output = z + self.u.expand_as(z) * h
55 | return output
56 |
57 | def _detgrad(self, z):
58 | """Computes |det df/dz|"""
59 | with torch.enable_grad():
60 | z = z.requires_grad_(True)
61 | h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b)
62 | psi = grad(h, z, grad_outputs=torch.ones_like(h), create_graph=True, only_inputs=True)[0]
63 | u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1))
64 | detgrad = 1 + u_dot_psi
65 | return detgrad
66 |
67 | def log_density(self, z, logqz):
68 | """Computes log density of the flow given the log density of z"""
69 | return logqz - torch.log(self._detgrad(z) + 1e-8)
70 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d']
6 |
7 |
8 | class MovingBatchNormNd(nn.Module):
9 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
10 | super(MovingBatchNormNd, self).__init__()
11 | self.num_features = num_features
12 | self.affine = affine
13 | self.eps = eps
14 | self.decay = decay
15 | self.bn_lag = bn_lag
16 | self.register_buffer('step', torch.zeros(1))
17 | if self.affine:
18 | self.weight = Parameter(torch.Tensor(num_features))
19 | self.bias = Parameter(torch.Tensor(num_features))
20 | else:
21 | self.register_parameter('weight', None)
22 | self.register_parameter('bias', None)
23 | self.register_buffer('running_mean', torch.zeros(num_features))
24 | self.register_buffer('running_var', torch.ones(num_features))
25 | self.reset_parameters()
26 |
27 | @property
28 | def shape(self):
29 | raise NotImplementedError
30 |
31 | def reset_parameters(self):
32 | self.running_mean.zero_()
33 | self.running_var.fill_(1)
34 | if self.affine:
35 | self.weight.data.zero_()
36 | self.bias.data.zero_()
37 |
38 | def forward(self, x, logpx=None, reverse=False):
39 | if reverse:
40 | return self._reverse(x, logpx)
41 | else:
42 | return self._forward(x, logpx)
43 |
44 | def _forward(self, x, logpx=None):
45 | c = x.size(1)
46 | used_mean = self.running_mean.clone().detach()
47 | used_var = self.running_var.clone().detach()
48 |
49 | if self.training:
50 | # compute batch statistics
51 | x_t = x.transpose(0, 1).contiguous().view(c, -1)
52 | batch_mean = torch.mean(x_t, dim=1)
53 | batch_var = torch.var(x_t, dim=1)
54 |
55 | # moving average
56 | if self.bn_lag > 0:
57 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
58 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
59 | used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
60 | used_var /= (1. - self.bn_lag**(self.step[0] + 1))
61 |
62 | # update running estimates
63 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
64 | self.running_var -= self.decay * (self.running_var - batch_var.data)
65 | self.step += 1
66 |
67 | # perform normalization
68 | used_mean = used_mean.view(*self.shape).expand_as(x)
69 | used_var = used_var.view(*self.shape).expand_as(x)
70 |
71 | y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
72 |
73 | if self.affine:
74 | weight = self.weight.view(*self.shape).expand_as(x)
75 | bias = self.bias.view(*self.shape).expand_as(x)
76 | y = y * torch.exp(weight) + bias
77 |
78 | if logpx is None:
79 | return y
80 | else:
81 | return y, logpx - self._logdetgrad(x, used_var).view(x.size(0), -1).sum(1, keepdim=True)
82 |
83 | def _reverse(self, y, logpy=None):
84 | used_mean = self.running_mean
85 | used_var = self.running_var
86 |
87 | if self.affine:
88 | weight = self.weight.view(*self.shape).expand_as(y)
89 | bias = self.bias.view(*self.shape).expand_as(y)
90 | y = (y - bias) * torch.exp(-weight)
91 |
92 | used_mean = used_mean.view(*self.shape).expand_as(y)
93 | used_var = used_var.view(*self.shape).expand_as(y)
94 | x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
95 |
96 | if logpy is None:
97 | return x
98 | else:
99 | return x, logpy + self._logdetgrad(x, used_var).view(x.size(0), -1).sum(1, keepdim=True)
100 |
101 | def _logdetgrad(self, x, used_var):
102 | logdetgrad = -0.5 * torch.log(used_var + self.eps)
103 | if self.affine:
104 | weight = self.weight.view(*self.shape).expand(*x.size())
105 | logdetgrad += weight
106 | return logdetgrad
107 |
108 | def __repr__(self):
109 | return (
110 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
111 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
112 | )
113 |
114 |
115 | def stable_var(x, mean=None, dim=1):
116 | if mean is None:
117 | mean = x.mean(dim, keepdim=True)
118 | mean = mean.view(-1, 1)
119 | res = torch.pow(x - mean, 2)
120 | max_sqr = torch.max(res, dim, keepdim=True)[0]
121 | var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
122 | var = var.view(-1)
123 | # change nan to zero
124 | var[var != var] = 0
125 | return var
126 |
127 |
128 | class MovingBatchNorm1d(MovingBatchNormNd):
129 | @property
130 | def shape(self):
131 | return [1, -1]
132 |
133 |
134 | class MovingBatchNorm2d(MovingBatchNormNd):
135 | @property
136 | def shape(self):
137 | return [1, -1, 1, 1]
138 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/odefunc.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from . import diffeq_layers
8 | from .squeeze import squeeze, unsqueeze
9 |
10 | __all__ = ["ODEnet", "AutoencoderDiffEqNet", "ODEfunc", "AutoencoderODEfunc"]
11 |
12 |
13 | def divergence_bf(dx, y, **unused_kwargs):
14 | sum_diag = 0.
15 | # print(dx.shape, dx.requires_grad, y.shape, y.requires_grad)
16 | for i in range(y.shape[1]):
17 | sum_diag += torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0].contiguous()[:, i].contiguous()
18 | return sum_diag.contiguous()
19 |
20 |
21 | # def divergence_bf(f, y, **unused_kwargs):
22 | # jac = _get_minibatch_jacobian(f, y)
23 | # diagonal = jac.view(jac.shape[0], -1)[:, ::jac.shape[1]]
24 | # return torch.sum(diagonal, 1)
25 |
26 |
27 | def _get_minibatch_jacobian(y, x):
28 | """Computes the Jacobian of y wrt x assuming minibatch-mode.
29 |
30 | Args:
31 | y: (N, ...) with a total of D_y elements in ...
32 | x: (N, ...) with a total of D_x elements in ...
33 | Returns:
34 | The minibatch Jacobian matrix of shape (N, D_y, D_x)
35 | """
36 | assert y.shape[0] == x.shape[0]
37 | y = y.view(y.shape[0], -1)
38 |
39 | # Compute Jacobian row by row.
40 | jac = []
41 | for j in range(y.shape[1]):
42 | dy_j_dx = torch.autograd.grad(y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True,
43 | create_graph=True)[0].view(x.shape[0], -1)
44 | jac.append(torch.unsqueeze(dy_j_dx, 1))
45 | jac = torch.cat(jac, 1)
46 | return jac
47 |
48 |
49 | def divergence_approx(f, y, e=None):
50 | e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
51 | e_dzdx_e = e_dzdx * e
52 | approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1)
53 | return approx_tr_dzdx
54 |
55 |
56 | def sample_rademacher_like(y):
57 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1
58 |
59 |
60 | def sample_gaussian_like(y):
61 | return torch.randn_like(y)
62 |
63 |
64 | class Swish(nn.Module):
65 |
66 | def __init__(self):
67 | super(Swish, self).__init__()
68 | self.beta = nn.Parameter(torch.tensor(1.0))
69 |
70 | def forward(self, x):
71 | return x * torch.sigmoid(self.beta * x)
72 |
73 |
74 | class Lambda(nn.Module):
75 |
76 | def __init__(self, f):
77 | super(Lambda, self).__init__()
78 | self.f = f
79 |
80 | def forward(self, x):
81 | return self.f(x)
82 |
83 |
84 | NONLINEARITIES = {
85 | "tanh": nn.Tanh(),
86 | "relu": nn.ReLU(),
87 | "softplus": nn.Softplus(),
88 | "elu": nn.ELU(),
89 | "swish": Swish(),
90 | "square": Lambda(lambda x: x**2),
91 | "identity": Lambda(lambda x: x),
92 | }
93 |
94 |
95 | class ODEnet(nn.Module):
96 | """
97 | Helper class to make neural nets for use in continuous normalizing flows
98 | """
99 |
100 | def __init__(
101 | self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0
102 | ):
103 | super(ODEnet, self).__init__()
104 | self.num_squeeze = num_squeeze
105 | if conv:
106 | assert len(strides) == len(hidden_dims) + 1
107 | base_layer = {
108 | "ignore": diffeq_layers.IgnoreConv2d,
109 | "hyper": diffeq_layers.HyperConv2d,
110 | "squash": diffeq_layers.SquashConv2d,
111 | "concat": diffeq_layers.ConcatConv2d,
112 | "concat_v2": diffeq_layers.ConcatConv2d_v2,
113 | "concatsquash": diffeq_layers.ConcatSquashConv2d,
114 | "blend": diffeq_layers.BlendConv2d,
115 | "concatcoord": diffeq_layers.ConcatCoordConv2d,
116 | }[layer_type]
117 | else:
118 | strides = [None] * (len(hidden_dims) + 1)
119 | base_layer = {
120 | "ignore": diffeq_layers.IgnoreLinear,
121 | "hyper": diffeq_layers.HyperLinear,
122 | "squash": diffeq_layers.SquashLinear,
123 | "concat": diffeq_layers.ConcatLinear,
124 | "concat_v2": diffeq_layers.ConcatLinear_v2,
125 | "concatsquash": diffeq_layers.ConcatSquashLinear,
126 | "blend": diffeq_layers.BlendLinear,
127 | "concatcoord": diffeq_layers.ConcatLinear,
128 | }[layer_type]
129 |
130 | # build layers and add them
131 | layers = []
132 | activation_fns = []
133 | hidden_shape = input_shape
134 |
135 | for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
136 | if stride is None:
137 | layer_kwargs = {}
138 | elif stride == 1:
139 | layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
140 | elif stride == 2:
141 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
142 | elif stride == -2:
143 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
144 | else:
145 | raise ValueError('Unsupported stride: {}'.format(stride))
146 |
147 | layer = base_layer(hidden_shape[0], dim_out, **layer_kwargs)
148 | layers.append(layer)
149 | activation_fns.append(NONLINEARITIES[nonlinearity])
150 |
151 | hidden_shape = list(copy.copy(hidden_shape))
152 | hidden_shape[0] = dim_out
153 | if stride == 2:
154 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2
155 | elif stride == -2:
156 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2
157 |
158 | self.layers = nn.ModuleList(layers)
159 | self.activation_fns = nn.ModuleList(activation_fns[:-1])
160 |
161 | def forward(self, t, y):
162 | dx = y
163 | # squeeze
164 | for _ in range(self.num_squeeze):
165 | dx = squeeze(dx, 2)
166 | for l, layer in enumerate(self.layers):
167 | dx = layer(t, dx)
168 | # if not last layer, use nonlinearity
169 | if l < len(self.layers) - 1:
170 | dx = self.activation_fns[l](dx)
171 | # unsqueeze
172 | for _ in range(self.num_squeeze):
173 | dx = unsqueeze(dx, 2)
174 | return dx
175 |
176 |
177 | class AutoencoderDiffEqNet(nn.Module):
178 | """
179 | Helper class to make neural nets for use in continuous normalizing flows
180 | """
181 |
182 | def __init__(self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus"):
183 | super(AutoencoderDiffEqNet, self).__init__()
184 | assert layer_type in ("ignore", "hyper", "concat", "concatcoord", "blend")
185 | assert nonlinearity in ("tanh", "relu", "softplus", "elu")
186 |
187 | self.nonlinearity = {"tanh": F.tanh, "relu": F.relu, "softplus": F.softplus, "elu": F.elu}[nonlinearity]
188 | if conv:
189 | assert len(strides) == len(hidden_dims) + 1
190 | base_layer = {
191 | "ignore": diffeq_layers.IgnoreConv2d,
192 | "hyper": diffeq_layers.HyperConv2d,
193 | "squash": diffeq_layers.SquashConv2d,
194 | "concat": diffeq_layers.ConcatConv2d,
195 | "blend": diffeq_layers.BlendConv2d,
196 | "concatcoord": diffeq_layers.ConcatCoordConv2d,
197 | }[layer_type]
198 | else:
199 | strides = [None] * (len(hidden_dims) + 1)
200 | base_layer = {
201 | "ignore": diffeq_layers.IgnoreLinear,
202 | "hyper": diffeq_layers.HyperLinear,
203 | "squash": diffeq_layers.SquashLinear,
204 | "concat": diffeq_layers.ConcatLinear,
205 | "blend": diffeq_layers.BlendLinear,
206 | "concatcoord": diffeq_layers.ConcatLinear,
207 | }[layer_type]
208 |
209 | # build layers and add them
210 | encoder_layers = []
211 | decoder_layers = []
212 | hidden_shape = input_shape
213 | for i, (dim_out, stride) in enumerate(zip(hidden_dims + (input_shape[0],), strides)):
214 | if i <= len(hidden_dims) // 2:
215 | layers = encoder_layers
216 | else:
217 | layers = decoder_layers
218 |
219 | if stride is None:
220 | layer_kwargs = {}
221 | elif stride == 1:
222 | layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
223 | elif stride == 2:
224 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
225 | elif stride == -2:
226 | layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
227 | else:
228 | raise ValueError('Unsupported stride: {}'.format(stride))
229 |
230 | layers.append(base_layer(hidden_shape[0], dim_out, **layer_kwargs))
231 |
232 | hidden_shape = list(copy.copy(hidden_shape))
233 | hidden_shape[0] = dim_out
234 | if stride == 2:
235 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2
236 | elif stride == -2:
237 | hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2
238 |
239 | self.encoder_layers = nn.ModuleList(encoder_layers)
240 | self.decoder_layers = nn.ModuleList(decoder_layers)
241 |
242 | def forward(self, t, y):
243 | h = y
244 | for layer in self.encoder_layers:
245 | h = self.nonlinearity(layer(t, h))
246 |
247 | dx = h
248 | for i, layer in enumerate(self.decoder_layers):
249 | dx = layer(t, dx)
250 | # if not last layer, use nonlinearity
251 | if i < len(self.decoder_layers) - 1:
252 | dx = self.nonlinearity(dx)
253 | return h, dx
254 |
255 |
256 | class ODEfunc(nn.Module):
257 |
258 | def __init__(self, diffeq, divergence_fn="approximate", residual=False, rademacher=False):
259 | super(ODEfunc, self).__init__()
260 | assert divergence_fn in ("brute_force", "approximate")
261 |
262 | # self.diffeq = diffeq_layers.wrappers.diffeq_wrapper(diffeq)
263 | self.diffeq = diffeq
264 | self.residual = residual
265 | self.rademacher = rademacher
266 |
267 | if divergence_fn == "brute_force":
268 | self.divergence_fn = divergence_bf
269 | elif divergence_fn == "approximate":
270 | self.divergence_fn = divergence_approx
271 |
272 | self.register_buffer("_num_evals", torch.tensor(0.))
273 |
274 | def before_odeint(self, e=None):
275 | self._e = e
276 | self._num_evals.fill_(0)
277 |
278 | def forward(self, t, states):
279 | assert len(states) >= 2
280 | y = states[0]
281 |
282 | # increment num evals
283 | self._num_evals += 1
284 |
285 | # convert to tensor
286 | # t = torch.tensor(t).type_as(y)
287 | batchsize = y.shape[0]
288 |
289 | # Sample and fix the noise.
290 | if self._e is None:
291 | if self.rademacher:
292 | self._e = sample_rademacher_like(y)
293 | else:
294 | self._e = sample_gaussian_like(y)
295 |
296 | with torch.set_grad_enabled(True):
297 | y.requires_grad_(True)
298 | t.requires_grad_(True)
299 | for s_ in states[2:]:
300 | s_.requires_grad_(True)
301 | dy = self.diffeq(t, y, *states[2:])
302 | # Hack for 2D data to use brute force divergence computation.
303 | if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2:
304 | divergence = divergence_bf(dy, y).view(batchsize, 1)
305 | else:
306 | divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1)
307 | if self.residual:
308 | dy = dy - y
309 | divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32
310 | ).to(divergence)
311 | return tuple([dy, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])
312 |
313 |
314 | class AutoencoderODEfunc(nn.Module):
315 |
316 | def __init__(self, autoencoder_diffeq, divergence_fn="approximate", residual=False, rademacher=False):
317 | assert divergence_fn in ("approximate"), "Only approximate divergence supported at the moment. (TODO)"
318 | assert isinstance(autoencoder_diffeq, AutoencoderDiffEqNet)
319 | super(AutoencoderODEfunc, self).__init__()
320 | self.residual = residual
321 | self.autoencoder_diffeq = autoencoder_diffeq
322 | self.rademacher = rademacher
323 |
324 | self.register_buffer("_num_evals", torch.tensor(0.))
325 |
326 | def before_odeint(self, e=None):
327 | self._e = e
328 | self._num_evals.fill_(0)
329 |
330 | def forward(self, t, y_and_logpy):
331 | y, _ = y_and_logpy # remove logpy
332 |
333 | # increment num evals
334 | self._num_evals += 1
335 |
336 | # convert to tensor
337 | t = torch.tensor(t).type_as(y)
338 | batchsize = y.shape[0]
339 |
340 | with torch.set_grad_enabled(True):
341 | y.requires_grad_(True)
342 | t.requires_grad_(True)
343 | h, dy = self.autoencoder_diffeq(t, y)
344 |
345 | # Sample and fix the noise.
346 | if self._e is None:
347 | if self.rademacher:
348 | self._e = sample_rademacher_like(h)
349 | else:
350 | self._e = sample_gaussian_like(h)
351 |
352 | e_vjp_dhdy = torch.autograd.grad(h, y, self._e, create_graph=True)[0]
353 | e_vjp_dfdy = torch.autograd.grad(dy, h, e_vjp_dhdy, create_graph=True)[0]
354 | divergence = torch.sum((e_vjp_dfdy * self._e).view(batchsize, -1), 1, keepdim=True)
355 |
356 | if self.residual:
357 | dy = dy - y
358 | divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32
359 | ).to(divergence)
360 |
361 | return dy, -divergence
362 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BasicBlock(nn.Module):
6 | expansion = 1
7 |
8 | def __init__(self, dim):
9 | super(BasicBlock, self).__init__()
10 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False)
11 | self.bn1 = nn.GroupNorm(2, dim, eps=1e-4)
12 | self.relu = nn.ReLU(inplace=True)
13 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False)
14 | self.bn2 = nn.GroupNorm(2, dim, eps=1e-4)
15 |
16 | def forward(self, x):
17 | residual = x
18 |
19 | out = self.conv1(x)
20 | out = self.bn1(out)
21 | out = self.relu(out)
22 |
23 | out = self.conv2(out)
24 | out = self.bn2(out)
25 |
26 | out += residual
27 | out = self.relu(out)
28 |
29 | return out
30 |
31 |
32 | class ResNeXtBottleneck(nn.Module):
33 | """
34 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
35 | """
36 |
37 | def __init__(self, dim, cardinality=4, base_depth=32):
38 | """ Constructor
39 | Args:
40 | in_channels: input channel dimensionality
41 | out_channels: output channel dimensionality
42 | stride: conv stride. Replaces pooling layer.
43 | cardinality: num of convolution groups.
44 | base_width: base number of channels in each group.
45 | widen_factor: factor to reduce the input dimensionality before convolution.
46 | """
47 | super(ResNeXtBottleneck, self).__init__()
48 | D = cardinality * base_depth
49 | self.conv_reduce = nn.Conv2d(dim, D, kernel_size=1, stride=1, padding=0, bias=False)
50 | self.bn_reduce = nn.BatchNorm2d(D)
51 | self.conv_grp = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1, groups=cardinality, bias=False)
52 | self.bn = nn.BatchNorm2d(D)
53 | self.conv_expand = nn.Conv2d(D, dim, kernel_size=1, stride=1, padding=0, bias=False)
54 | self.bn_expand = nn.BatchNorm2d(dim)
55 |
56 | def forward(self, x):
57 | bottleneck = self.conv_reduce.forward(x)
58 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
59 | bottleneck = self.conv_grp.forward(bottleneck)
60 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
61 | bottleneck = self.conv_expand.forward(bottleneck)
62 | bottleneck = self.bn_expand.forward(bottleneck)
63 | return F.relu(x + bottleneck, inplace=True)
64 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/squeeze.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ['SqueezeLayer']
4 |
5 |
6 | class SqueezeLayer(nn.Module):
7 | def __init__(self, downscale_factor):
8 | super(SqueezeLayer, self).__init__()
9 | self.downscale_factor = downscale_factor
10 |
11 | def forward(self, x, logpx=None, reverse=False):
12 | if reverse:
13 | return self._upsample(x, logpx)
14 | else:
15 | return self._downsample(x, logpx)
16 |
17 | def _downsample(self, x, logpx=None):
18 | squeeze_x = squeeze(x, self.downscale_factor)
19 | if logpx is None:
20 | return squeeze_x
21 | else:
22 | return squeeze_x, logpx
23 |
24 | def _upsample(self, y, logpy=None):
25 | unsqueeze_y = unsqueeze(y, self.downscale_factor)
26 | if logpy is None:
27 | return unsqueeze_y
28 | else:
29 | return unsqueeze_y, logpy
30 |
31 |
32 | def unsqueeze(input, upscale_factor=2):
33 | '''
34 | [:, C*r^2, H, W] -> [:, C, H*r, W*r]
35 | '''
36 | batch_size, in_channels, in_height, in_width = input.size()
37 | out_channels = in_channels // (upscale_factor**2)
38 |
39 | out_height = in_height * upscale_factor
40 | out_width = in_width * upscale_factor
41 |
42 | input_view = input.contiguous().view(batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width)
43 |
44 | output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
45 | return output.view(batch_size, out_channels, out_height, out_width)
46 |
47 |
48 | def squeeze(input, downscale_factor=2):
49 | '''
50 | [:, C, H*r, W*r] -> [:, C*r^2, H, W]
51 | '''
52 | batch_size, in_channels, in_height, in_width = input.size()
53 | out_channels = in_channels * (downscale_factor**2)
54 |
55 | out_height = in_height // downscale_factor
56 | out_width = in_width // downscale_factor
57 |
58 | input_view = input.contiguous().view(
59 | batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor
60 | )
61 |
62 | output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
63 | return output.view(batch_size, out_channels, out_height, out_width)
64 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/layers/wrappers/cnf_regularization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RegularizedODEfunc(nn.Module):
6 | def __init__(self, odefunc, regularization_fns):
7 | super(RegularizedODEfunc, self).__init__()
8 | self.odefunc = odefunc
9 | self.regularization_fns = regularization_fns
10 |
11 | def before_odeint(self, *args, **kwargs):
12 | self.odefunc.before_odeint(*args, **kwargs)
13 |
14 | def forward(self, t, state):
15 | class SharedContext(object):
16 | pass
17 |
18 | with torch.enable_grad():
19 | x, logp = state[:2]
20 | x.requires_grad_(True)
21 | logp.requires_grad_(True)
22 | t.requires_grad_(True)
23 | dstate = self.odefunc(t, (x, logp))
24 | if len(state) > 2:
25 | dx, dlogp = dstate[:2]
26 | reg_states = tuple(
27 | reg_fn(x, logp, dx, dlogp, t, SharedContext)
28 | for reg_fn in self.regularization_fns
29 | )
30 | return dstate + reg_states
31 | else:
32 | return dstate
33 |
34 | @property
35 | def _num_evals(self):
36 | return self.odefunc._num_evals
37 |
38 |
39 | def _batch_root_mean_squared(tensor):
40 | tensor = tensor.view(tensor.shape[0], -1)
41 | return torch.mean(torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5)
42 |
43 |
44 | def l1_regularzation_fn(x, logp, dx, dlogp, t, unused_context):
45 | del x, logp, dlogp
46 | return torch.mean(torch.abs(dx))
47 |
48 |
49 | def l2_regularzation_fn(x, logp, dx, dlogp, t, unused_context):
50 | del x, logp, dlogp
51 | return _batch_root_mean_squared(dx)
52 |
53 |
54 | def squared_l2_regularization_fn(x, logp, dx, dlogp, t, unused_context):
55 | del x, logp, dlogp
56 | to_return = dx.view(dx.shape[0], -1)
57 | # print(t)
58 | return torch.mean(torch.pow(torch.norm(to_return, p=2, dim=1), 2))
59 |
60 |
61 | def directional_l2_regularization_fn(x, logp, dx, dlogp, t, unused_context):
62 | del logp, dlogp
63 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0]
64 | # print(directional_dx.shape)
65 | # exit()
66 | return _batch_root_mean_squared(directional_dx)
67 |
68 |
69 | def directional_l2_change_penalty_fn(x, logp, dx, dlogp, t, context):
70 | del logp, dlogp
71 | # For now we ignore the directional dx penalty as this complicates things
72 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0]
73 | dfdt = _get_minibatch_jacobian(dx, t)
74 | dfdt_full = dfdt + torch.sum(directional_dx, axis=0)
75 | return torch.mean(torch.norm(dfdt_full, p=2) / dfdt_full.shape[0] ** 0.5)
76 |
77 |
78 | def jacobian_frobenius_regularization_fn(x, logp, dx, dlogp, t, context):
79 | del logp, dlogp, t
80 | if hasattr(context, "jac"):
81 | jac = context.jac
82 | else:
83 | jac = _get_minibatch_jacobian(dx, x)
84 | context.jac = jac
85 | return _batch_root_mean_squared(jac)
86 |
87 |
88 | def jacobian_diag_frobenius_regularization_fn(x, logp, dx, dlogp, t, context):
89 | del logp, dlogp, t
90 | if hasattr(context, "jac"):
91 | jac = context.jac
92 | else:
93 | jac = _get_minibatch_jacobian(dx, x)
94 | context.jac = jac
95 | diagonal = jac.view(jac.shape[0], -1)[
96 | :, :: jac.shape[1]
97 | ] # assumes jac is minibatch square, ie. (N, M, M).
98 | return _batch_root_mean_squared(diagonal)
99 |
100 |
101 | def jacobian_offdiag_frobenius_regularization_fn(x, logp, dx, dlogp, t, context):
102 | del logp, dlogp, t
103 | if hasattr(context, "jac"):
104 | jac = context.jac
105 | else:
106 | jac = _get_minibatch_jacobian(dx, x)
107 | context.jac = jac
108 | diagonal = jac.view(jac.shape[0], -1)[
109 | :, :: jac.shape[1]
110 | ] # assumes jac is minibatch square, ie. (N, M, M).
111 | ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum(
112 | diagonal ** 2, dim=1
113 | )
114 | ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1))
115 | return torch.mean(ms_offdiag)
116 |
117 |
118 | def _get_minibatch_jacobian(y, x, create_graph=True):
119 | """Computes the Jacobian of y wrt x assuming minibatch-mode.
120 |
121 | Args:
122 | y: (N, ...) with a total of D_y elements in ...
123 | x: (N, ...) with a total of D_x elements in ...
124 | Returns:
125 | The minibatch Jacobian matrix of shape (N, D_y, D_x)
126 | """
127 | # assert y.shape[0] == x.shape[0]
128 | y = y.view(y.shape[0], -1)
129 |
130 | # Compute Jacobian row by row.
131 | jac = []
132 | for j in range(y.shape[1]):
133 | dy_j_dx = torch.autograd.grad(
134 | y[:, j],
135 | x,
136 | torch.ones_like(y[:, j]),
137 | retain_graph=True,
138 | create_graph=create_graph,
139 | )[0]
140 | jac.append(torch.unsqueeze(dy_j_dx, -1))
141 | jac = torch.cat(jac, -1)
142 | return jac
143 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/spectral_norm.py:
--------------------------------------------------------------------------------
1 | """
2 | Spectral Normalization from https://arxiv.org/abs/1802.05957
3 | """
4 | import types
5 | import torch
6 | from torch.nn.functional import normalize
7 |
8 | POWER_ITERATION_FN = "spectral_norm_power_iteration"
9 |
10 |
11 | class SpectralNorm(object):
12 | def __init__(self, name='weight', dim=0, eps=1e-12):
13 | self.name = name
14 | self.dim = dim
15 | self.eps = eps
16 |
17 | def compute_weight(self, module, n_power_iterations):
18 | if n_power_iterations < 0:
19 | raise ValueError(
20 | 'Expected n_power_iterations to be non-negative, but '
21 | 'got n_power_iterations={}'.format(n_power_iterations)
22 | )
23 |
24 | weight = getattr(module, self.name + '_orig')
25 | u = getattr(module, self.name + '_u')
26 | v = getattr(module, self.name + '_v')
27 | weight_mat = weight
28 | if self.dim != 0:
29 | # permute dim to front
30 | weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim])
31 | height = weight_mat.size(0)
32 | weight_mat = weight_mat.reshape(height, -1)
33 | with torch.no_grad():
34 | for _ in range(n_power_iterations):
35 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
36 | # are the first left and right singular vectors.
37 | # This power iteration produces approximations of `u` and `v`.
38 | v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
39 | u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
40 | setattr(module, self.name + '_u', u)
41 | setattr(module, self.name + '_v', v)
42 |
43 | sigma = torch.dot(u, torch.matmul(weight_mat, v))
44 | weight = weight / sigma
45 | setattr(module, self.name, weight)
46 |
47 | def remove(self, module):
48 | weight = getattr(module, self.name)
49 | delattr(module, self.name)
50 | delattr(module, self.name + '_u')
51 | delattr(module, self.name + '_orig')
52 | module.register_parameter(self.name, torch.nn.Parameter(weight))
53 |
54 | def get_update_method(self, module):
55 | def update_fn(module, n_power_iterations):
56 | self.compute_weight(module, n_power_iterations)
57 |
58 | return update_fn
59 |
60 | def __call__(self, module, unused_inputs):
61 | del unused_inputs
62 | self.compute_weight(module, n_power_iterations=0)
63 |
64 | # requires_grad might be either True or False during inference.
65 | if not module.training:
66 | r_g = getattr(module, self.name + '_orig').requires_grad
67 | setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g))
68 |
69 | @staticmethod
70 | def apply(module, name, dim, eps):
71 | fn = SpectralNorm(name, dim, eps)
72 | weight = module._parameters[name]
73 | height = weight.size(dim)
74 |
75 | u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
76 | v = normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps)
77 | delattr(module, fn.name)
78 | module.register_parameter(fn.name + "_orig", weight)
79 | # We still need to assign weight back as fn.name because all sorts of
80 | # things may assume that it exists, e.g., when initializing weights.
81 | # However, we can't directly assign as it could be an nn.Parameter and
82 | # gets added as a parameter. Instead, we register weight.data as a
83 | # buffer, which will cause weight to be included in the state dict
84 | # and also supports nn.init due to shared storage.
85 | module.register_buffer(fn.name, weight.data)
86 | module.register_buffer(fn.name + "_u", u)
87 | module.register_buffer(fn.name + "_v", v)
88 |
89 | setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module))
90 |
91 | module.register_forward_pre_hook(fn)
92 | return fn
93 |
94 |
95 | def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12):
96 | r"""Applies spectral normalization to a parameter in the given module.
97 |
98 | .. math::
99 | \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
100 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
101 |
102 | Spectral normalization stabilizes the training of discriminators (critics)
103 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
104 | with spectral norm :math:`\sigma` of the weight matrix calculated using
105 | power iteration method. If the dimension of the weight tensor is greater
106 | than 2, it is reshaped to 2D in power iteration method to get spectral
107 | norm. This is implemented via a hook that calculates spectral norm and
108 | rescales weight before every :meth:`~Module.forward` call.
109 |
110 | See `Spectral Normalization for Generative Adversarial Networks`_ .
111 |
112 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
113 |
114 | Args:
115 | module (nn.Module): containing module
116 | name (str, optional): name of weight parameter
117 | n_power_iterations (int, optional): number of power iterations to
118 | calculate spectal norm
119 | dim (int, optional): dimension corresponding to number of outputs,
120 | the default is 0, except for modules that are instances of
121 | ConvTranspose1/2/3d, when it is 1
122 | eps (float, optional): epsilon for numerical stability in
123 | calculating norms
124 |
125 | Returns:
126 | The original module with the spectal norm hook
127 |
128 | Example::
129 |
130 | >>> m = spectral_norm(nn.Linear(20, 40))
131 | Linear (20 -> 40)
132 | >>> m.weight_u.size()
133 | torch.Size([20])
134 |
135 | """
136 | if dim is None:
137 | if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
138 | dim = 1
139 | else:
140 | dim = 0
141 | SpectralNorm.apply(module, name, dim=dim, eps=eps)
142 | return module
143 |
144 |
145 | def remove_spectral_norm(module, name='weight'):
146 | r"""Removes the spectral normalization reparameterization from a module.
147 |
148 | Args:
149 | module (nn.Module): containing module
150 | name (str, optional): name of weight parameter
151 |
152 | Example:
153 | >>> m = spectral_norm(nn.Linear(40, 10))
154 | >>> remove_spectral_norm(m)
155 | """
156 | for k, hook in module._forward_pre_hooks.items():
157 | if isinstance(hook, SpectralNorm) and hook.name == name:
158 | hook.remove(module)
159 | del module._forward_pre_hooks[k]
160 | return module
161 |
162 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
163 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from numbers import Number
4 | import logging
5 | import torch
6 |
7 |
8 | def makedirs(dirname):
9 | if not os.path.exists(dirname):
10 | os.makedirs(dirname)
11 |
12 |
13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
14 | logger = logging.getLogger()
15 | if debug:
16 | level = logging.DEBUG
17 | else:
18 | level = logging.INFO
19 | logger.setLevel(level)
20 | if saving:
21 | info_file_handler = logging.FileHandler(logpath, mode="a")
22 | info_file_handler.setLevel(level)
23 | logger.addHandler(info_file_handler)
24 | if displaying:
25 | console_handler = logging.StreamHandler()
26 | console_handler.setLevel(level)
27 | logger.addHandler(console_handler)
28 | logger.info(filepath)
29 | with open(filepath, "r") as f:
30 | logger.info(f.read())
31 |
32 | for f in package_files:
33 | logger.info(f)
34 | with open(f, "r") as package_f:
35 | logger.info(package_f.read())
36 |
37 | return logger
38 |
39 |
40 | class AverageMeter(object):
41 | """Computes and stores the average and current value"""
42 |
43 | def __init__(self):
44 | self.reset()
45 |
46 | def reset(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 |
52 | def update(self, val, n=1):
53 | self.val = val
54 | self.sum += val * n
55 | self.count += n
56 | self.avg = self.sum / self.count
57 |
58 |
59 | class RunningAverageMeter(object):
60 | """Computes and stores the average and current value"""
61 |
62 | def __init__(self, momentum=0.99):
63 | self.momentum = momentum
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = None
68 | self.avg = 0
69 |
70 | def update(self, val):
71 | if self.val is None:
72 | self.avg = val
73 | else:
74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
75 | self.val = val
76 |
77 |
78 | def inf_generator(iterable):
79 | """Allows training with DataLoaders in a single infinite loop:
80 | for i, (x, y) in enumerate(inf_generator(train_loader)):
81 | """
82 | iterator = iterable.__iter__()
83 | while True:
84 | try:
85 | yield iterator.__next__()
86 | except StopIteration:
87 | iterator = iterable.__iter__()
88 |
89 |
90 | def save_checkpoint(state, save, epoch):
91 | if not os.path.exists(save):
92 | os.makedirs(save)
93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
94 | torch.save(state, filename)
95 |
96 |
97 | def isnan(tensor):
98 | return (tensor != tensor)
99 |
100 |
101 | def logsumexp(value, dim=None, keepdim=False):
102 | """Numerically stable implementation of the operation
103 | value.exp().sum(dim, keepdim).log()
104 | """
105 | if dim is not None:
106 | m, _ = torch.max(value, dim=dim, keepdim=True)
107 | value0 = value - m
108 | if keepdim is False:
109 | m = m.squeeze(dim)
110 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
111 | else:
112 | m = torch.max(value)
113 | sum_exp = torch.sum(torch.exp(value - m))
114 | if isinstance(sum_exp, Number):
115 | return m + math.log(sum_exp)
116 | else:
117 | return m + torch.log(sum_exp)
118 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/visualize_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use("Agg")
4 | import matplotlib.pyplot as plt
5 | import torch
6 |
7 | LOW = -4
8 | HIGH = 4
9 |
10 |
11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"):
12 | """
13 | Args:
14 | potential: computes U(z_k) given z_k
15 | """
16 | xside = np.linspace(LOW, HIGH, npts)
17 | yside = np.linspace(LOW, HIGH, npts)
18 | xx, yy = np.meshgrid(xside, yside)
19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
20 |
21 | z = torch.Tensor(z)
22 | u = potential(z).cpu().numpy()
23 | p = np.exp(-u).reshape(npts, npts)
24 |
25 | plt.pcolormesh(xx, yy, p)
26 | ax.invert_yaxis()
27 | ax.get_xaxis().set_ticks([])
28 | ax.get_yaxis().set_ticks([])
29 | ax.set_title(title)
30 |
31 |
32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"):
33 | """
34 | Args:
35 | transform: computes z_k and log(q_k) given z_0
36 | """
37 | side = np.linspace(LOW, HIGH, npts)
38 | xx, yy = np.meshgrid(side, side)
39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
40 |
41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device)
42 | logqz = prior_logdensity(z)
43 | z, logqz = transform(z, logqz)
44 | logqz = torch.sum(logqz, dim=1)[:, None]
45 |
46 | xx = z[:, 0].cpu().numpy().reshape(npts, npts)
47 | yy = z[:, 1].cpu().numpy().reshape(npts, npts)
48 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts)
49 |
50 | plt.pcolormesh(xx, yy, qz)
51 | ax.set_xlim(LOW, HIGH)
52 | ax.set_ylim(LOW, HIGH)
53 | cmap = matplotlib.cm.get_cmap(None)
54 | ax.set_facecolor(cmap(0.))
55 | ax.invert_yaxis()
56 | ax.get_xaxis().set_ticks([])
57 | ax.get_yaxis().set_ticks([])
58 | ax.set_title(title)
59 |
60 |
61 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"):
62 | side = np.linspace(LOW, HIGH, npts)
63 | xx, yy = np.meshgrid(side, side)
64 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
65 |
66 | x = torch.from_numpy(x).type(torch.float32).to(device)
67 | zeros = torch.zeros(x.shape[0], 1).to(x)
68 |
69 | z, delta_logp = [], []
70 | inds = torch.arange(0, x.shape[0]).to(torch.int64)
71 | for ii in torch.split(inds, int(memory**2)):
72 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii])
73 | z.append(z_)
74 | delta_logp.append(delta_logp_)
75 | z = torch.cat(z, 0)
76 | delta_logp = torch.cat(delta_logp, 0)
77 |
78 | logpz = prior_logdensity(z)
79 | logpx = logpz - delta_logp
80 |
81 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts)
82 |
83 | ax.imshow(px)
84 | ax.get_xaxis().set_ticks([])
85 | ax.get_yaxis().set_ticks([])
86 | ax.set_title(title)
87 |
88 |
89 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"):
90 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device)
91 | zk = []
92 | inds = torch.arange(0, z.shape[0]).to(torch.int64)
93 | for ii in torch.split(inds, int(memory**2)):
94 | zk.append(transform(z[ii]))
95 | zk = torch.cat(zk, 0).cpu().numpy()
96 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
97 | ax.invert_yaxis()
98 | ax.get_xaxis().set_ticks([])
99 | ax.get_yaxis().set_ticks([])
100 | ax.set_title(title)
101 |
102 |
103 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"):
104 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts)
105 | ax.invert_yaxis()
106 | ax.get_xaxis().set_ticks([])
107 | ax.get_yaxis().set_ticks([])
108 | ax.set_title(title)
109 |
110 | def visualize_growth(growth_model, full_data, labels, npts=200, memory=100, device='cpu'):
111 | with torch.no_grad():
112 | fig, ax = plt.subplots(1,1)
113 | side = np.linspace(LOW, HIGH, npts)
114 | xx, yy = np.meshgrid(side, side)
115 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
116 | x = torch.from_numpy(x).type(torch.float32).to(device)
117 | output_growth = growth_model(x).cpu().numpy()
118 | output_growth = np.reshape(output_growth, (npts, npts))
119 | im = ax.imshow(output_growth, cmap = 'bwr')
120 | ax.get_xaxis().set_ticks([])
121 | ax.get_yaxis().set_ticks([])
122 | fig.colorbar(im, ax=ax)
123 | ax.set_title('Growth Rate')
124 |
125 | # rescale full data to image coordinates
126 | full_data = full_data * npts / 8 + npts / 2
127 | #ax.scatter(full_data[:,0], full_data[:,1], c=labels / 5, cmap='Spectral', s=10,alpha=0.5)
128 |
129 |
130 | def visualize_transform(
131 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100,
132 | memory=100, device="cpu"
133 | ):
134 | """Produces visualization for the model density and samples from the model."""
135 | plt.clf()
136 | ax = plt.subplot(1, 3, 1, aspect="equal")
137 | if samples:
138 | plt_samples(potential_or_samples, ax, npts=npts)
139 | else:
140 | plt_potential_func(potential_or_samples, ax, npts=npts)
141 |
142 | ax = plt.subplot(1, 3, 2, aspect="equal")
143 | if inverse_transform is None:
144 | plt_flow(prior_density, transform, ax, npts=npts, device=device)
145 | else:
146 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device)
147 |
148 | ax = plt.subplot(1, 3, 3, aspect="equal")
149 | if transform is not None:
150 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device)
151 |
--------------------------------------------------------------------------------
/TrajectoryNet/lib/viz_scrna.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 | import torch
8 |
9 |
10 | # def standard_normal_logprob(z):
11 | # logZ = -0.5 * math.log(2 * math.pi)
12 | # return torch.sum(logZ - z.pow(2) / 2, 1, keepdim=True)
13 |
14 |
15 | def makedirs(dirname):
16 | if not os.path.exists(dirname):
17 | os.makedirs(dirname)
18 |
19 | def save_2d_trajectory_v2(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=5, end_times=None, memory=0.01, device='cpu', limit=4):
20 | """ Save the trajectory as a series of photos such that we can easily display on paper / poster """
21 | model.eval()
22 |
23 | # Sample from prior
24 | z_samples = prior_sampler(2000, 2).to(device)
25 |
26 | with torch.no_grad():
27 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
28 | logp_samples = prior_logdensity(z_samples)
29 | t = 0
30 | for cnf in model.chain:
31 |
32 | # Construct integration_list
33 | if end_times is None:
34 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
35 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
36 | for i, et in enumerate(end_times[1:]):
37 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device))
38 | full_times = torch.cat(integration_list, 0)
39 | print('integration_list', integration_list)
40 |
41 |
42 | # Integrate over evenly spaced samples
43 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True)
44 | full_traj = [(z_traj, logpz)]
45 | for i, int_times in enumerate(integration_list[1:]):
46 | prev_z, prev_logp = full_traj[-1]
47 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True)
48 | full_traj.append((z_traj[1:], logpz[1:]))
49 | full_zip = list(zip(*full_traj))
50 | z_traj = torch.cat(full_zip[0], 0)
51 | #z_logp = torch.cat(full_zip[1], 0)
52 | z_traj = z_traj.cpu().numpy()
53 |
54 | width = z_traj.shape[0]
55 | plt.figure(figsize=(8, 8))
56 | fig, axes = plt.subplots(1, width, figsize=(4*width, 4), sharex=True, sharey=True)
57 | axes = axes.flatten()
58 | for w in range(width):
59 | # plot the density
60 | ax = axes[w]
61 | K = 13j
62 | y, x = np.mgrid[-0.5:2.5:K, -1.5:1.5:K]
63 | #y, x = np.mgrid[-limit:limit:K, -limit:limit:K]
64 | K = int(K.imag)
65 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
66 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32)
67 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0]
68 | dydt = -dydt.cpu().detach().numpy()
69 | dydt = dydt.reshape(K, K, 2)
70 |
71 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1]))
72 | ax.quiver(
73 | #x, y, dydt[:, :, 0], -dydt[:, :, 1],
74 | x, y, dydt[:, :, 0], dydt[:, :, 1],
75 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid"
76 | )
77 | ax.set_xlim(-limit, limit)
78 | ax.set_ylim(limit, -limit)
79 | ax.set_xlim(-1.5, 1.5)
80 | ax.set_ylim(-0.5, 2.5)
81 | ax.axis("off")
82 |
83 | ax.scatter(z_traj[w,:,0], z_traj[w,:,1], c='k', s=0.5)
84 | #ax.set_title("Vector Field", fontsize=32)
85 | t += 1
86 |
87 | makedirs(savedir)
88 | plt.tight_layout(pad=0.0)
89 | plt.savefig(os.path.join(savedir, "vector_plot.jpg"))
90 | plt.close
91 | def save_2d_trajectory(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=5, end_times=None, memory=0.01, device='cpu'):
92 | """ Save the trajectory as a series of photos such that we can easily display on paper / poster """
93 | model.eval()
94 |
95 | # Sample from prior
96 | z_samples = prior_sampler(2000, 2).to(device)
97 |
98 | # sample from a grid
99 | npts = 100
100 | limit = 1.5
101 | side = np.linspace(-limit, limit, npts)
102 | xx, yy = np.meshgrid(side, side)
103 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
104 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
105 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
106 |
107 | with torch.no_grad():
108 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
109 | logp_samples = prior_logdensity(z_samples)
110 | logp_grid = prior_logdensity(z_grid)
111 | t = 0
112 | for cnf in model.chain:
113 |
114 | # Construct integration_list
115 | if end_times is None:
116 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
117 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
118 | for i, et in enumerate(end_times[1:]):
119 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device))
120 | full_times = torch.cat(integration_list, 0)
121 | print('integration_list', integration_list)
122 |
123 |
124 | # Integrate over evenly spaced samples
125 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True)
126 | full_traj = [(z_traj, logpz)]
127 | for i, int_times in enumerate(integration_list[1:]):
128 | prev_z, prev_logp = full_traj[-1]
129 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True)
130 | full_traj.append((z_traj[1:], logpz[1:]))
131 | full_zip = list(zip(*full_traj))
132 | z_traj = torch.cat(full_zip[0], 0)
133 | #z_logp = torch.cat(full_zip[1], 0)
134 | z_traj = z_traj.cpu().numpy()
135 |
136 | grid_z_traj, grid_logpz_traj = [], []
137 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64)
138 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)):
139 | _grid_z_traj, _grid_logpz_traj = cnf(
140 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True
141 | )
142 | full_traj = [(_grid_z_traj, _grid_logpz_traj)]
143 | for int_times in integration_list[1:]:
144 | prev_z, prev_logp = full_traj[-1]
145 | _grid_z_traj, _grid_logpz_traj = cnf(
146 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True
147 | )
148 | full_traj.append((_grid_z_traj, _grid_logpz_traj))
149 | full_zip = list(zip(*full_traj))
150 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy()
151 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy()
152 | grid_z_traj.append(_grid_z_traj)
153 | grid_logpz_traj.append(_grid_logpz_traj)
154 |
155 | grid_z_traj = np.concatenate(grid_z_traj, axis=1)
156 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)
157 |
158 | width = z_traj.shape[0]
159 | plt.figure(figsize=(8, 8))
160 | fig, axes = plt.subplots(2, width, figsize=(4*width, 8), sharex=True, sharey=True)
161 | axes = axes.flatten()
162 | for w in range(width):
163 | # plot the density
164 | ax = axes[w]
165 |
166 | z, logqz = grid_z_traj[t], grid_logpz_traj[t]
167 |
168 | xx = z[:, 0].reshape(npts, npts)
169 | yy = z[:, 1].reshape(npts, npts)
170 | qz = np.exp(logqz).reshape(npts, npts)
171 |
172 | ax.pcolormesh(xx, yy, qz)
173 | ax.set_xlim(-limit, limit)
174 | ax.set_ylim(-limit, limit)
175 | cmap = matplotlib.cm.get_cmap(None)
176 | ax.set_facecolor(cmap(0.))
177 | ax.invert_yaxis()
178 | ax.get_xaxis().set_ticks([])
179 | ax.get_yaxis().set_ticks([])
180 | #ax.set_title("Density", fontsize=32)
181 |
182 | # plot vector field
183 | ax = axes[w+width]
184 |
185 | K = 13j
186 | y, x = np.mgrid[-limit:limit:K, -limit:limit:K]
187 | K = int(K.imag)
188 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
189 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32)
190 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0]
191 | dydt = -dydt.cpu().detach().numpy()
192 | dydt = dydt.reshape(K, K, 2)
193 |
194 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1]))
195 | ax.quiver(
196 | x, y, dydt[:, :, 0], -dydt[:, :, 1],
197 | # x, y, dydt[:, :, 0], dydt[:, :, 1],
198 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid"
199 | )
200 | ax.set_xlim(-limit, limit)
201 | ax.set_ylim(limit, -limit)
202 | ax.axis("off")
203 | #ax.set_title("Vector Field", fontsize=32)
204 | t += 1
205 |
206 | makedirs(savedir)
207 | plt.tight_layout(pad=0.0)
208 | plt.savefig(os.path.join(savedir, "vector_plot.jpg"))
209 | plt.close
210 |
211 |
212 | def save_vectors(prior_logdensity, model, data_samples, full_data, labels, savedir, skip_first=False, ntimes=101, end_times=None, memory=0.01, device='cpu', lim=4):
213 | model.eval()
214 |
215 | # Sample from prior
216 | z_samples = data_samples.to(device)
217 |
218 | with torch.no_grad():
219 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
220 | logp_samples = prior_logdensity(z_samples)
221 | t = 0
222 | for cnf in model.chain:
223 | # Construct integration_list
224 | if end_times is None:
225 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
226 | # integration_list = []
227 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
228 |
229 | # Start integration at first end_time
230 | for i, et in enumerate(end_times[1:]):
231 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device))
232 | # if len(end_times) == 1:
233 | # integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
234 | # print(integration_list)
235 |
236 |
237 | # Integrate over evenly spaced samples
238 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True)
239 | full_traj = [(z_traj, logpz)]
240 | for int_times in integration_list[1:]:
241 | prev_z, prev_logp = full_traj[-1]
242 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True)
243 | full_traj.append((z_traj, logpz))
244 | full_zip = list(zip(*full_traj))
245 | z_traj = torch.cat(full_zip[0], 0)
246 | z_traj = z_traj.cpu().numpy()
247 |
248 | # mask out stray negative points
249 | pos_mask = full_data[:,1] >=0
250 | full_data = full_data[pos_mask]
251 | labels = labels[pos_mask]
252 | print(np.unique(labels))
253 |
254 | plt.figure(figsize=(8, 8))
255 | ax = plt.subplot(aspect="equal")
256 | ax.scatter(full_data[:,0], full_data[:,1], c=labels.astype(np.int32), cmap='tab10', s=0.5, alpha=1)
257 | # If we do not have a known base density then skip vectors for the first integration.
258 |
259 | z_traj = np.swapaxes(z_traj, 0, 1)
260 | if skip_first:
261 | z_traj = z_traj[:, ntimes:, :]
262 | ax.scatter(z_traj[:,0,0], z_traj[:,0,1], s=20, c='k')
263 | for zk in z_traj:
264 | #for zk in z_traj[:,ntimes:,:]:
265 | ax.scatter(zk[:,0], zk[:,1], s=1, c = np.linspace(0,1,zk.shape[0]), cmap='Spectral')
266 | ax.set_xlim(-lim, lim)
267 | ax.set_ylim(-lim, lim)
268 | #ax.set_ylim(4, -4)
269 | makedirs(savedir)
270 | plt.xticks([])
271 | plt.yticks([])
272 | plt.savefig(os.path.join(savedir, f"vectors.jpg"), dpi=300)
273 | t += 1
274 |
275 | def save_trajectory_density(prior_logdensity, model, data_samples, savedir, ntimes=101, end_times=None, memory=0.01, device='cpu'):
276 | model.eval()
277 |
278 | # sample from a grid
279 | #Jnpts = 100
280 | npts = 800
281 | side = np.linspace(-4, 4, npts)
282 | xx, yy = np.meshgrid(side, side)
283 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
284 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
285 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
286 |
287 | with torch.no_grad():
288 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
289 | logp_grid = prior_logdensity(z_grid)
290 | t = 0
291 | for cnf in model.chain:
292 | # Construct integration_list
293 | if end_times is None:
294 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
295 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
296 | for i, et in enumerate(end_times[1:]):
297 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device))
298 | full_times = torch.cat(integration_list, 0)
299 |
300 | grid_z_traj, grid_logpz_traj = [], []
301 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64)
302 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)):
303 | _grid_z_traj, _grid_logpz_traj = cnf(
304 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True
305 | )
306 | full_traj = [(_grid_z_traj, _grid_logpz_traj)]
307 | for int_times in integration_list[1:]:
308 | prev_z, prev_logp = full_traj[-1]
309 | _grid_z_traj, _grid_logpz_traj = cnf(
310 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True
311 | )
312 | full_traj.append((_grid_z_traj, _grid_logpz_traj))
313 | full_zip = list(zip(*full_traj))
314 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy()
315 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy()
316 | print(_grid_z_traj.shape)
317 | grid_z_traj.append(_grid_z_traj)
318 | grid_logpz_traj.append(_grid_logpz_traj)
319 |
320 | grid_z_traj = np.concatenate(grid_z_traj, axis=1)[ntimes:]
321 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)[ntimes:]
322 |
323 |
324 | #plt.figure(figsize=(8, 8))
325 | #fig, axes = plt.subplots(2,1, gridspec_kw={'height_ratios': [7, 1]}, figsize=(5,7))
326 | for _ in range(grid_z_traj.shape[0]):
327 | fig, axes = plt.subplots(2,1, gridspec_kw={'height_ratios': [7, 1]}, figsize=(8,8))
328 | #plt.clf()
329 | ax = axes[0]
330 | # Density
331 | z, logqz = grid_z_traj[t], grid_logpz_traj[t]
332 |
333 | xx = z[:, 0].reshape(npts, npts)
334 | yy = z[:, 1].reshape(npts, npts)
335 | qz = np.exp(logqz).reshape(npts, npts)
336 |
337 | ax.pcolormesh(xx, yy, qz)
338 | ax.set_xlim(-4, 4)
339 | ax.set_ylim(-4, 4)
340 | cmap = matplotlib.cm.get_cmap(None)
341 | ax.set_facecolor(cmap(0.))
342 | #ax.invert_yaxis()
343 | ax.get_xaxis().set_ticks([])
344 | ax.get_yaxis().set_ticks([])
345 | ax.set_title("Density", fontsize=32)
346 |
347 | ax=axes[1]
348 |
349 | # Colorbar
350 | cb = matplotlib.colorbar.ColorbarBase(ax,
351 | #cmap='Spectral',
352 | cmap=plt.cm.Spectral,
353 | orientation='horizontal')
354 | #cb.set_ticks(np.linspace(0,1,4))
355 | #cb.set_ticklabels(['48HR', 'Day 12', 'Day 18', 'Day 30'])
356 | #cb.set_ticklabels(['E12.5', 'E14.5', 'E16.0', 'E17.5'])
357 | #cb.set_ticks(np.linspace(0,1,5))
358 | #cb.set_ticklabels(np.arange(len(end_times)))
359 | ax.axvline(t / grid_z_traj.shape[0], c='k', linewidth=15)
360 | ax.set_title('Time')
361 |
362 | print('making dir: %s' % savedir)
363 | makedirs(savedir)
364 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg"))
365 | plt.close()
366 | t += 1
367 |
368 |
369 | def save_trajectory(prior_logdensity, prior_sampler, model, data_samples, savedir, ntimes=101, end_times=None, memory=0.01, device='cpu'):
370 | model.eval()
371 |
372 | # Sample from prior
373 | z_samples = prior_sampler(2000, 2).to(device)
374 |
375 | # sample from a grid
376 | npts = 800
377 | side = np.linspace(-4, 4, npts)
378 | xx, yy = np.meshgrid(side, side)
379 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
380 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
381 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
382 |
383 | with torch.no_grad():
384 | # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
385 | logp_samples = prior_logdensity(z_samples)
386 | logp_grid = prior_logdensity(z_grid)
387 | t = 0
388 | for cnf in model.chain:
389 |
390 | # Construct integration_list
391 | if end_times is None:
392 | end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
393 | integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
394 | for i, et in enumerate(end_times[1:]):
395 | integration_list.append(torch.linspace(end_times[i], et, ntimes).to(device))
396 | full_times = torch.cat(integration_list, 0)
397 | print(full_times.shape)
398 |
399 | # Integrate over evenly spaced samples
400 | z_traj, logpz = cnf(z_samples, logp_samples, integration_times=integration_list[0], reverse=True)
401 | full_traj = [(z_traj, logpz)]
402 | for int_times in integration_list[1:]:
403 | prev_z, prev_logp = full_traj[-1]
404 | z_traj, logpz = cnf(prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True)
405 | full_traj.append((z_traj, logpz))
406 | full_zip = list(zip(*full_traj))
407 | z_traj = torch.cat(full_zip[0], 0)
408 | #z_logp = torch.cat(full_zip[1], 0)
409 | z_traj = z_traj.cpu().numpy()
410 |
411 | grid_z_traj, grid_logpz_traj = [], []
412 | inds = torch.arange(0, z_grid.shape[0]).to(torch.int64)
413 | for ii in torch.split(inds, int(z_grid.shape[0] * memory)):
414 | _grid_z_traj, _grid_logpz_traj = cnf(
415 | z_grid[ii], logp_grid[ii], integration_times=integration_list[0], reverse=True
416 | )
417 | full_traj = [(_grid_z_traj, _grid_logpz_traj)]
418 | for int_times in integration_list[1:]:
419 | prev_z, prev_logp = full_traj[-1]
420 | _grid_z_traj, _grid_logpz_traj = cnf(
421 | prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True
422 | )
423 | full_traj.append((_grid_z_traj, _grid_logpz_traj))
424 | full_zip = list(zip(*full_traj))
425 | _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy()
426 | _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy()
427 | print(_grid_z_traj.shape)
428 | grid_z_traj.append(_grid_z_traj)
429 | grid_logpz_traj.append(_grid_logpz_traj)
430 |
431 | grid_z_traj = np.concatenate(grid_z_traj, axis=1)
432 | grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)
433 |
434 | plt.figure(figsize=(8, 8))
435 | for _ in range(z_traj.shape[0]):
436 |
437 | plt.clf()
438 |
439 | # plot target potential function
440 | ax = plt.subplot(2, 2, 1, aspect="equal")
441 |
442 | ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
443 | ax.invert_yaxis()
444 | ax.get_xaxis().set_ticks([])
445 | ax.get_yaxis().set_ticks([])
446 | ax.set_title("Target", fontsize=32)
447 |
448 | # plot the density
449 | ax = plt.subplot(2, 2, 2, aspect="equal")
450 |
451 | z, logqz = grid_z_traj[t], grid_logpz_traj[t]
452 |
453 | xx = z[:, 0].reshape(npts, npts)
454 | yy = z[:, 1].reshape(npts, npts)
455 | qz = np.exp(logqz).reshape(npts, npts)
456 |
457 | plt.pcolormesh(xx, yy, qz)
458 | ax.set_xlim(-4, 4)
459 | ax.set_ylim(-4, 4)
460 | cmap = matplotlib.cm.get_cmap(None)
461 | ax.set_facecolor(cmap(0.))
462 | ax.invert_yaxis()
463 | ax.get_xaxis().set_ticks([])
464 | ax.get_yaxis().set_ticks([])
465 | ax.set_title("Density", fontsize=32)
466 |
467 | # plot the samples
468 | ax = plt.subplot(2, 2, 3, aspect="equal")
469 |
470 | zk = z_traj[t]
471 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
472 | ax.invert_yaxis()
473 | ax.get_xaxis().set_ticks([])
474 | ax.get_yaxis().set_ticks([])
475 | ax.set_title("Samples", fontsize=32)
476 |
477 | # plot vector field
478 | ax = plt.subplot(2, 2, 4, aspect="equal")
479 |
480 | K = 13j
481 | y, x = np.mgrid[-4:4:K, -4:4:K]
482 | K = int(K.imag)
483 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
484 | logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32)
485 | dydt = cnf.odefunc(full_times[t], (zs, logps))[0]
486 | dydt = -dydt.cpu().detach().numpy()
487 | dydt = dydt.reshape(K, K, 2)
488 |
489 | logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1]))
490 | ax.quiver(
491 | x, y, dydt[:, :, 0], -dydt[:, :, 1],
492 | # x, y, dydt[:, :, 0], dydt[:, :, 1],
493 | np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid"
494 | )
495 | ax.set_xlim(-4, 4)
496 | ax.set_ylim(4, -4)
497 | #ax.set_ylim(-4, 4)
498 | ax.axis("off")
499 | ax.set_title("Vector Field", fontsize=32)
500 |
501 | makedirs(savedir)
502 | plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg"))
503 | t += 1
504 |
505 |
506 | def trajectory_to_video(savedir):
507 | import subprocess
508 | bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4'))
509 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
510 | output, error = process.communicate()
511 |
512 |
513 | if __name__ == '__main__':
514 | import argparse
515 | import sys
516 |
517 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')))
518 |
519 | import lib.toy_data as toy_data
520 | from train_misc import count_parameters
521 | from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns
522 | from train_misc import build_model_tabular
523 |
524 | def get_ckpt_model_and_data(args):
525 | # Load checkpoint.
526 | checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage)
527 | ckpt_args = checkpt['args']
528 | state_dict = checkpt['state_dict']
529 |
530 | # Construct model and restore checkpoint.
531 | regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args)
532 | model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device)
533 | if ckpt_args.spectral_norm: add_spectral_norm(model)
534 | set_cnf_options(ckpt_args, model)
535 |
536 | model.load_state_dict(state_dict)
537 | model.to(device)
538 |
539 | print(model)
540 | print("Number of trainable parameters: {}".format(count_parameters(model)))
541 |
542 | # Load samples from dataset
543 | data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000)
544 |
545 | return model, data_samples
546 |
547 | parser = argparse.ArgumentParser()
548 | parser.add_argument('--checkpt', type=str, required=True)
549 | parser.add_argument('--ntimes', type=int, default=101)
550 | parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.')
551 | parser.add_argument('--save', type=str, default='trajectory')
552 | args = parser.parse_args()
553 |
554 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
555 | model, data_samples = get_ckpt_model_and_data(args)
556 | save_trajectory(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device)
557 | trajectory_to_video(args.save)
558 |
--------------------------------------------------------------------------------
/TrajectoryNet/main.py:
--------------------------------------------------------------------------------
1 | """ main.py
2 |
3 | Learns ODE from scrna data
4 |
5 | """
6 | import os
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import time
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 |
16 | from TrajectoryNet.lib.growth_net import GrowthNet
17 | from TrajectoryNet.lib import utils
18 | from TrajectoryNet.lib.visualize_flow import visualize_transform
19 | from TrajectoryNet.lib.viz_scrna import (
20 | save_trajectory,
21 | trajectory_to_video,
22 | save_vectors,
23 | )
24 | from TrajectoryNet.lib.viz_scrna import save_trajectory_density
25 |
26 |
27 | # from train_misc import standard_normal_logprob
28 | from TrajectoryNet.train_misc import (
29 | set_cnf_options,
30 | count_nfe,
31 | count_parameters,
32 | count_total_time,
33 | add_spectral_norm,
34 | spectral_norm_power_iteration,
35 | create_regularization_fns,
36 | get_regularization,
37 | append_regularization_to_log,
38 | build_model_tabular,
39 | )
40 |
41 | from TrajectoryNet import dataset
42 | from TrajectoryNet.parse import parser
43 |
44 | matplotlib.use("Agg")
45 |
46 |
47 | def get_transforms(device, args, model, integration_times):
48 | """
49 | Given a list of integration points,
50 | returns a function giving integration times
51 | """
52 |
53 | def sample_fn(z, logpz=None):
54 | int_list = [
55 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device)
56 | for it in integration_times
57 | ]
58 | if logpz is not None:
59 | # TODO this works right?
60 | for it in int_list:
61 | z, logpz = model(z, logpz, integration_times=it, reverse=True)
62 | return z, logpz
63 | else:
64 | for it in int_list:
65 | z = model(z, integration_times=it, reverse=True)
66 | return z
67 |
68 | def density_fn(x, logpx=None):
69 | int_list = [
70 | torch.tensor([it - args.time_scale, it]).type(torch.float32).to(device)
71 | for it in integration_times[::-1]
72 | ]
73 | if logpx is not None:
74 | for it in int_list:
75 | x, logpx = model(x, logpx, integration_times=it, reverse=False)
76 | return x, logpx
77 | else:
78 | for it in int_list:
79 | x = model(x, integration_times=it, reverse=False)
80 | return x
81 |
82 | return sample_fn, density_fn
83 |
84 |
85 | def compute_loss(device, args, model, growth_model, logger, full_data):
86 | """
87 | Compute loss by integrating backwards from the last time step
88 | At each time step integrate back one time step, and concatenate that
89 | to samples of the empirical distribution at that previous timestep
90 | repeating over and over to calculate the likelihood of samples in
91 | later timepoints iteratively, making sure that the ODE is evaluated
92 | at every time step to calculate those later points.
93 |
94 | The growth model is a single model of time independent cell growth /
95 | death rate defined as a variation from uniform.
96 | """
97 |
98 | # Backward pass accumulating losses, previous state and deltas
99 | deltas = []
100 | zs = []
101 | z = None
102 | interp_loss = 0.0
103 | for i, (itp, tp) in enumerate(zip(args.int_tps[::-1], args.timepoints[::-1])):
104 | # tp counts down from last
105 | integration_times = torch.tensor([itp - args.time_scale, itp])
106 | integration_times = integration_times.type(torch.float32).to(device)
107 | # integration_times.requires_grad = True
108 |
109 | # load data and add noise
110 | idx = args.data.sample_index(args.batch_size, tp)
111 | x = args.data.get_data()[idx]
112 | if args.training_noise > 0.0:
113 | x += np.random.randn(*x.shape) * args.training_noise
114 | x = torch.from_numpy(x).type(torch.float32).to(device)
115 |
116 | if i > 0:
117 | x = torch.cat((z, x))
118 | zs.append(z)
119 | zero = torch.zeros(x.shape[0], 1).to(x)
120 |
121 | # transform to previous timepoint
122 | z, delta_logp = model(x, zero, integration_times=integration_times)
123 | deltas.append(delta_logp)
124 |
125 | # Straightline regularization
126 | # Integrate to random point at time t and assert close to (1 - t) * end + t * start
127 | if args.interp_reg:
128 | t = np.random.rand()
129 | int_t = torch.tensor([itp - t * args.time_scale, itp])
130 | int_t = int_t.type(torch.float32).to(device)
131 | int_x = model(x, integration_times=int_t)
132 | int_x = int_x.detach()
133 | actual_int_x = x * (1 - t) + z * t
134 | interp_loss += F.mse_loss(int_x, actual_int_x)
135 | if args.interp_reg:
136 | print("interp_loss", interp_loss)
137 |
138 | logpz = args.data.base_density()(z)
139 |
140 | # build growth rates
141 | if args.use_growth:
142 | growthrates = [torch.ones_like(logpz)]
143 | for z_state, tp in zip(zs[::-1], args.timepoints[:-1]):
144 | # Full state includes time parameter to growth_model
145 | time_state = tp * torch.ones(z_state.shape[0], 1).to(z_state)
146 | full_state = torch.cat([z_state, time_state], 1)
147 | growthrates.append(growth_model(full_state))
148 |
149 | # Accumulate losses
150 | losses = []
151 | logps = [logpz]
152 | for i, delta_logp in enumerate(deltas[::-1]):
153 | logpx = logps[-1] - delta_logp
154 | if args.use_growth:
155 | logpx += torch.log(torch.clamp(growthrates[i], 1e-4, 1e4))
156 | logps.append(logpx[: -args.batch_size])
157 | losses.append(-torch.mean(logpx[-args.batch_size :]))
158 | losses = torch.stack(losses)
159 | weights = torch.ones_like(losses).to(logpx)
160 | if args.leaveout_timepoint >= 0:
161 | weights[args.leaveout_timepoint] = 0
162 | losses = torch.mean(losses * weights)
163 |
164 | # Direction regularization
165 | if args.vecint:
166 | similarity_loss = 0
167 | for i, (itp, tp) in enumerate(zip(args.int_tps, args.timepoints)):
168 | itp = torch.tensor(itp).type(torch.float32).to(device)
169 | idx = args.data.sample_index(args.batch_size, tp)
170 | x = args.data.get_data()[idx]
171 | v = args.data.get_velocity()[idx]
172 | x = torch.from_numpy(x).type(torch.float32).to(device)
173 | v = torch.from_numpy(v).type(torch.float32).to(device)
174 | x += torch.randn_like(x) * 0.1
175 | # Only penalizes at the time / place of visible samples
176 | direction = -model.chain[0].odefunc.odefunc.diffeq(itp, x)
177 | if args.use_magnitude:
178 | similarity_loss += torch.mean(F.mse_loss(direction, v))
179 | else:
180 | similarity_loss -= torch.mean(F.cosine_similarity(direction, v))
181 | logger.info(similarity_loss)
182 | losses += similarity_loss * args.vecint
183 |
184 | # Density regularization
185 | if args.top_k_reg > 0:
186 | density_loss = 0
187 | tp_z_map = dict(zip(args.timepoints[:-1], zs[::-1]))
188 | if args.leaveout_timepoint not in tp_z_map:
189 | idx = args.data.sample_index(args.batch_size, tp)
190 | x = args.data.get_data()[idx]
191 | if args.training_noise > 0.0:
192 | x += np.random.randn(*x.shape) * args.training_noise
193 | x = torch.from_numpy(x).type(torch.float32).to(device)
194 | t = np.random.rand()
195 | int_t = torch.tensor([itp - t * args.time_scale, itp])
196 | int_t = int_t.type(torch.float32).to(device)
197 | int_x = model(x, integration_times=int_t)
198 | samples_05 = int_x
199 | else:
200 | # If we are leaving out a timepoint the regularize there
201 | samples_05 = tp_z_map[args.leaveout_timepoint]
202 |
203 | # Calculate distance to 5 closest neighbors
204 | # WARNING: This currently fails in the backward pass with cuda on pytorch < 1.4.0
205 | # works on CPU. Fixed in pytorch 1.5.0
206 | # RuntimeError: CUDA error: invalid configuration argument
207 | # The workaround is to run on cpu on pytorch <= 1.4.0 or upgrade
208 | cdist = torch.cdist(samples_05, full_data)
209 | values, _ = torch.topk(cdist, 5, dim=1, largest=False, sorted=False)
210 | # Hinge loss
211 | hinge_value = 0.1
212 | values -= hinge_value
213 | values[values < 0] = 0
214 | density_loss = torch.mean(values)
215 | print("Density Loss", density_loss.item())
216 | losses += density_loss * args.top_k_reg
217 | losses += interp_loss
218 | return losses
219 |
220 |
221 | def train(
222 | device, args, model, growth_model, regularization_coeffs, regularization_fns, logger
223 | ):
224 | optimizer = optim.Adam(
225 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay
226 | )
227 |
228 | time_meter = utils.RunningAverageMeter(0.93)
229 | loss_meter = utils.RunningAverageMeter(0.93)
230 | nfef_meter = utils.RunningAverageMeter(0.93)
231 | nfeb_meter = utils.RunningAverageMeter(0.93)
232 | tt_meter = utils.RunningAverageMeter(0.93)
233 |
234 | full_data = (
235 | torch.from_numpy(
236 | args.data.get_data()[args.data.get_times() != args.leaveout_timepoint]
237 | )
238 | .type(torch.float32)
239 | .to(device)
240 | )
241 |
242 | best_loss = float("inf")
243 | if args.use_growth:
244 | growth_model.eval()
245 | end = time.time()
246 | for itr in range(1, args.niters + 1):
247 | model.train()
248 | optimizer.zero_grad()
249 |
250 | # Train
251 | if args.spectral_norm:
252 | spectral_norm_power_iteration(model, 1)
253 |
254 | loss = compute_loss(device, args, model, growth_model, logger, full_data)
255 | loss_meter.update(loss.item())
256 |
257 | if len(regularization_coeffs) > 0:
258 | # Only regularize on the last timepoint
259 | reg_states = get_regularization(model, regularization_coeffs)
260 | reg_loss = sum(
261 | reg_state * coeff
262 | for reg_state, coeff in zip(reg_states, regularization_coeffs)
263 | if coeff != 0
264 | )
265 | loss = loss + reg_loss
266 | total_time = count_total_time(model)
267 | nfe_forward = count_nfe(model)
268 |
269 | loss.backward()
270 | optimizer.step()
271 |
272 | # Eval
273 | nfe_total = count_nfe(model)
274 | nfe_backward = nfe_total - nfe_forward
275 | nfef_meter.update(nfe_forward)
276 | nfeb_meter.update(nfe_backward)
277 | time_meter.update(time.time() - end)
278 | tt_meter.update(total_time)
279 |
280 | log_message = (
281 | "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |"
282 | " NFE Forward {:.0f}({:.1f})"
283 | " | NFE Backward {:.0f}({:.1f})".format(
284 | itr,
285 | time_meter.val,
286 | time_meter.avg,
287 | loss_meter.val,
288 | loss_meter.avg,
289 | nfef_meter.val,
290 | nfef_meter.avg,
291 | nfeb_meter.val,
292 | nfeb_meter.avg,
293 | )
294 | )
295 | if len(regularization_coeffs) > 0:
296 | log_message = append_regularization_to_log(
297 | log_message, regularization_fns, reg_states
298 | )
299 | logger.info(log_message)
300 |
301 | if itr % args.val_freq == 0 or itr == args.niters:
302 | with torch.no_grad():
303 | train_eval(
304 | device, args, model, growth_model, itr, best_loss, logger, full_data
305 | )
306 |
307 | if itr % args.viz_freq == 0:
308 | if args.data.get_shape()[0] > 2:
309 | logger.warning("Skipping vis as data dimension is >2")
310 | else:
311 | with torch.no_grad():
312 | visualize(device, args, model, itr)
313 | if itr % args.save_freq == 0:
314 | chkpt = {
315 | "state_dict": model.state_dict(),
316 | }
317 | if args.use_growth:
318 | chkpt.update({"growth_state_dict": growth_model.state_dict()})
319 | utils.save_checkpoint(
320 | chkpt,
321 | args.save,
322 | epoch=itr,
323 | )
324 | end = time.time()
325 | logger.info("Training has finished.")
326 |
327 |
328 | def train_eval(device, args, model, growth_model, itr, best_loss, logger, full_data):
329 | model.eval()
330 | test_loss = compute_loss(device, args, model, growth_model, logger, full_data)
331 | test_nfe = count_nfe(model)
332 | log_message = "[TEST] Iter {:04d} | Test Loss {:.6f} |" " NFE {:.0f}".format(
333 | itr, test_loss, test_nfe
334 | )
335 | logger.info(log_message)
336 | utils.makedirs(args.save)
337 | with open(os.path.join(args.save, "train_eval.csv"), "a") as f:
338 | import csv
339 |
340 | writer = csv.writer(f)
341 | writer.writerow((itr, test_loss))
342 |
343 | if test_loss.item() < best_loss:
344 | best_loss = test_loss.item()
345 | chkpt = {
346 | "state_dict": model.state_dict(),
347 | }
348 | if args.use_growth:
349 | chkpt.update({"growth_state_dict": growth_model.state_dict()})
350 | torch.save(
351 | chkpt,
352 | os.path.join(args.save, "checkpt.pth"),
353 | )
354 |
355 |
356 | def visualize(device, args, model, itr):
357 | model.eval()
358 | for i, tp in enumerate(args.timepoints):
359 | idx = args.data.sample_index(args.viz_batch_size, tp)
360 | p_samples = args.data.get_data()[idx]
361 | sample_fn, density_fn = get_transforms(
362 | device, args, model, args.int_tps[: i + 1]
363 | )
364 | plt.figure(figsize=(9, 3))
365 | visualize_transform(
366 | p_samples,
367 | args.data.base_sample(),
368 | args.data.base_density(),
369 | transform=sample_fn,
370 | inverse_transform=density_fn,
371 | samples=True,
372 | npts=100,
373 | device=device,
374 | )
375 | fig_filename = os.path.join(
376 | args.save, "figs", "{:04d}_{:01d}.jpg".format(itr, i)
377 | )
378 | utils.makedirs(os.path.dirname(fig_filename))
379 | plt.savefig(fig_filename)
380 | plt.close()
381 |
382 |
383 | def plot_output(device, args, model):
384 | save_traj_dir = os.path.join(args.save, "trajectory")
385 | # logger.info('Plotting trajectory to {}'.format(save_traj_dir))
386 | data_samples = args.data.get_data()[args.data.sample_index(2000, 0)]
387 | np.random.seed(42)
388 | start_points = args.data.base_sample()(1000, 2)
389 | # idx = args.data.sample_index(50, 0)
390 | # start_points = args.data.get_data()[idx]
391 | # start_points = torch.from_numpy(start_points).type(torch.float32)
392 | save_vectors(
393 | args.data.base_density(),
394 | model,
395 | start_points,
396 | args.data.get_data(),
397 | args.data.get_times(),
398 | args.save,
399 | skip_first=(not args.data.known_base_density()),
400 | device=device,
401 | end_times=args.int_tps,
402 | ntimes=100,
403 | )
404 |
405 | save_trajectory(
406 | args.data.base_density(),
407 | args.data.base_sample(),
408 | model,
409 | data_samples,
410 | save_traj_dir,
411 | device=device,
412 | end_times=args.int_tps,
413 | ntimes=25,
414 | )
415 |
416 | density_dir = os.path.join(args.save, "density2")
417 | save_trajectory_density(
418 | args.data.base_density(),
419 | model,
420 | data_samples,
421 | density_dir,
422 | device=device,
423 | end_times=args.int_tps,
424 | ntimes=25,
425 | memory=0.1,
426 | )
427 |
428 | if args.save_movie:
429 | trajectory_to_video(save_traj_dir)
430 | trajectory_to_video(density_dir)
431 |
432 |
433 | def main(args):
434 | # logger
435 | print(args.no_display_loss)
436 | utils.makedirs(args.save)
437 | logger = utils.get_logger(
438 | logpath=os.path.join(args.save, "logs"),
439 | filepath=os.path.abspath(__file__),
440 | displaying=~args.no_display_loss,
441 | )
442 |
443 | if args.layer_type == "blend":
444 | logger.info("!! Setting time_scale from None to 1.0 for Blend layers.")
445 | args.time_scale = 1.0
446 |
447 | logger.info(args)
448 |
449 | device = torch.device(
450 | "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
451 | )
452 | if args.use_cpu:
453 | device = torch.device("cpu")
454 |
455 | args.data = dataset.SCData.factory(args.dataset, args)
456 |
457 | args.timepoints = args.data.get_unique_times()
458 | # Use maximum timepoint to establish integration_times
459 | # as some timepoints may be left out for validation etc.
460 | args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
461 |
462 | regularization_fns, regularization_coeffs = create_regularization_fns(args)
463 | model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to(
464 | device
465 | )
466 | growth_model = None
467 | if args.use_growth:
468 | if args.leaveout_timepoint == -1:
469 | growth_model_path = "../data/externel/growth_model_v2.ckpt"
470 | elif args.leaveout_timepoint in [1, 2, 3]:
471 | assert args.max_dim == 5
472 | growth_model_path = "../data/growth/model_%d" % args.leaveout_timepoint
473 | else:
474 | print("WARNING: Cannot use growth with this timepoint")
475 |
476 | growth_model = torch.load(growth_model_path, map_location=device)
477 | if args.spectral_norm:
478 | add_spectral_norm(model)
479 | set_cnf_options(args, model)
480 |
481 | if args.test:
482 | state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
483 | model.load_state_dict(state_dict["state_dict"])
484 | # if "growth_state_dict" not in state_dict:
485 | # print("error growth model note in save")
486 | # growth_model = None
487 | # else:
488 | # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device)
489 | # growth_model.load_state_dict(checkpt["growth_state_dict"])
490 | # TODO can we load the arguments from the save?
491 | # eval_utils.generate_samples(
492 | # device, args, model, growth_model, timepoint=args.leaveout_timepoint
493 | # )
494 | # with torch.no_grad():
495 | # evaluate(device, args, model, growth_model)
496 | # exit()
497 | else:
498 | logger.info(model)
499 | n_param = count_parameters(model)
500 | logger.info("Number of trainable parameters: {}".format(n_param))
501 |
502 | train(
503 | device,
504 | args,
505 | model,
506 | growth_model,
507 | regularization_coeffs,
508 | regularization_fns,
509 | logger,
510 | )
511 |
512 | if args.data.data.shape[1] == 2:
513 | plot_output(device, args, model)
514 |
515 |
516 | if __name__ == "__main__":
517 |
518 | args = parser.parse_args()
519 | main(args)
520 |
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/MMD.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | import torch
6 |
7 | min_var_est = 1e-8
8 |
9 |
10 | # Consider linear time MMD with a linear kernel:
11 | # K(f(x), f(y)) = f(x)^Tf(y)
12 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
13 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
14 | #
15 | # f_of_X: batch_size * k
16 | # f_of_Y: batch_size * k
17 | def linear_mmd2(f_of_X, f_of_Y):
18 | loss = 0.0
19 | delta = f_of_X - f_of_Y
20 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
21 | return loss
22 |
23 |
24 | # Consider linear time MMD with a polynomial kernel:
25 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d
26 | # f_of_X: batch_size * k
27 | # f_of_Y: batch_size * k
28 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
29 | K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c
30 | K_XX_mean = torch.mean(K_XX.pow(d))
31 |
32 | K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c
33 | K_YY_mean = torch.mean(K_YY.pow(d))
34 |
35 | K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c
36 | K_XY_mean = torch.mean(K_XY.pow(d))
37 |
38 | K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c
39 | K_YX_mean = torch.mean(K_YX.pow(d))
40 |
41 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean
42 |
43 |
44 | def _mix_rbf_kernel(X, Y, sigma_list):
45 | assert X.size(0) == Y.size(0)
46 | m = X.size(0)
47 |
48 | Z = torch.cat((X, Y), 0)
49 | ZZT = torch.mm(Z, Z.t())
50 | diag_ZZT = torch.diag(ZZT).unsqueeze(1)
51 | Z_norm_sqr = diag_ZZT.expand_as(ZZT)
52 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
53 |
54 | K = 0.0
55 | for sigma in sigma_list:
56 | gamma = 1.0 / (2 * sigma**2)
57 | K += torch.exp(-gamma * exponent)
58 |
59 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
60 |
61 |
62 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
63 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
64 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
65 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
66 |
67 |
68 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
69 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
70 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
71 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
72 |
73 |
74 | ################################################################################
75 | # Helper functions to compute variances based on kernel matrices
76 | ################################################################################
77 |
78 |
79 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
80 | m = K_XX.size(0) # assume X, Y are same shape
81 |
82 | # Get the various sums of kernels that we'll use
83 | # Kts drop the diagonal, but we don't need to compute them explicitly
84 | if const_diagonal is not False:
85 | diag_X = diag_Y = const_diagonal
86 | sum_diag_X = sum_diag_Y = m * const_diagonal
87 | else:
88 | diag_X = torch.diag(K_XX) # (m,)
89 | diag_Y = torch.diag(K_YY) # (m,)
90 | sum_diag_X = torch.sum(diag_X)
91 | sum_diag_Y = torch.sum(diag_Y)
92 |
93 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
94 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
95 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
96 |
97 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
98 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
99 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
100 |
101 | if biased:
102 | mmd2 = (
103 | (Kt_XX_sum + sum_diag_X) / (m * m)
104 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
105 | - 2.0 * K_XY_sum / (m * m)
106 | )
107 | else:
108 | mmd2 = (
109 | Kt_XX_sum / (m * (m - 1))
110 | + Kt_YY_sum / (m * (m - 1))
111 | - 2.0 * K_XY_sum / (m * m)
112 | )
113 |
114 | return mmd2
115 |
116 |
117 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
118 | mmd2, var_est = _mmd2_and_variance(
119 | K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased
120 | )
121 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
122 | return loss, mmd2, var_est
123 |
124 |
125 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
126 | m = K_XX.size(0) # assume X, Y are same shape
127 |
128 | # Get the various sums of kernels that we'll use
129 | # Kts drop the diagonal, but we don't need to compute them explicitly
130 | if const_diagonal is not False:
131 | diag_X = diag_Y = const_diagonal
132 | sum_diag_X = sum_diag_Y = m * const_diagonal
133 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
134 | else:
135 | diag_X = torch.diag(K_XX) # (m,)
136 | diag_Y = torch.diag(K_YY) # (m,)
137 | sum_diag_X = torch.sum(diag_X)
138 | sum_diag_Y = torch.sum(diag_Y)
139 | sum_diag2_X = diag_X.dot(diag_X)
140 | sum_diag2_Y = diag_Y.dot(diag_Y)
141 |
142 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
143 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
144 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
145 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e
146 |
147 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
148 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
149 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
150 |
151 | Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2
152 | Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2
153 | K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2
154 |
155 | if biased:
156 | mmd2 = (
157 | (Kt_XX_sum + sum_diag_X) / (m * m)
158 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
159 | - 2.0 * K_XY_sum / (m * m)
160 | )
161 | else:
162 | mmd2 = (
163 | Kt_XX_sum / (m * (m - 1))
164 | + Kt_YY_sum / (m * (m - 1))
165 | - 2.0 * K_XY_sum / (m * m)
166 | )
167 |
168 | var_est = (
169 | 2.0
170 | / (m**2 * (m - 1.0) ** 2)
171 | * (
172 | 2 * Kt_XX_sums.dot(Kt_XX_sums)
173 | - Kt_XX_2_sum
174 | + 2 * Kt_YY_sums.dot(Kt_YY_sums)
175 | - Kt_YY_2_sum
176 | )
177 | - (4.0 * m - 6.0)
178 | / (m**3 * (m - 1.0) ** 3)
179 | * (Kt_XX_sum**2 + Kt_YY_sum**2)
180 | + 4.0
181 | * (m - 2.0)
182 | / (m**3 * (m - 1.0) ** 2)
183 | * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
184 | - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum)
185 | - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2
186 | + 8.0
187 | / (m**3 * (m - 1.0))
188 | * (
189 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
190 | - Kt_XX_sums.dot(K_XY_sums_1)
191 | - Kt_YY_sums.dot(K_XY_sums_0)
192 | )
193 | )
194 | return mmd2, var_est
195 |
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/__init__.py
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/emd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import ot as pot # Python Optimal Transport package
3 | import scipy.sparse
4 | from sklearn.metrics.pairwise import pairwise_distances
5 |
6 |
7 | def earth_mover_distance(
8 | p,
9 | q,
10 | eigenvals=None,
11 | weights1=None,
12 | weights2=None,
13 | return_matrix=False,
14 | metric="sqeuclidean",
15 | ):
16 | """
17 | Returns the earth mover's distance between two point clouds
18 | Parameters
19 | ----------
20 | cloud1 : 2-D array
21 | First point cloud
22 | cloud2 : 2-D array
23 | Second point cloud
24 | Returns
25 | -------
26 | distance : float
27 | The distance between the two point clouds
28 | """
29 | p = p.toarray() if scipy.sparse.isspmatrix(p) else p
30 | q = q.toarray() if scipy.sparse.isspmatrix(q) else q
31 | if eigenvals is not None:
32 | p = p.dot(eigenvals)
33 | q = q.dot(eigenvals)
34 | if weights1 is None:
35 | p_weights = np.ones(len(p)) / len(p)
36 | else:
37 | weights1 = weights1.astype("float64")
38 | p_weights = weights1 / weights1.sum()
39 |
40 | if weights2 is None:
41 | q_weights = np.ones(len(q)) / len(q)
42 | else:
43 | weights2 = weights2.astype("float64")
44 | q_weights = weights2 / weights2.sum()
45 |
46 | pairwise_dist = np.ascontiguousarray(
47 | pairwise_distances(p, Y=q, metric=metric, n_jobs=-1)
48 | )
49 |
50 | result = pot.emd2(
51 | p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=return_matrix
52 | )
53 | if return_matrix:
54 | square_emd, log_dict = result
55 | return np.sqrt(square_emd), log_dict
56 | else:
57 | return np.sqrt(result)
58 |
59 |
60 | def interpolate_with_ot(p0, p1, tmap, interp_frac, size):
61 | """
62 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1
63 | Parameters
64 | ----------
65 | p0 : 2-D array
66 | The genes of each cell in the source population
67 | p1 : 2-D array
68 | The genes of each cell in the destination population
69 | tmap : 2-D array
70 | A transport map from p0 to p1
71 | t_interpolate : float
72 | The fraction at which to interpolate
73 | size : int
74 | The number of cells in the interpolated population
75 | Returns
76 | -------
77 | p05 : 2-D array
78 | An interpolated population of 'size' cells
79 | """
80 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0
81 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1
82 | p0 = np.asarray(p0, dtype=np.float64)
83 | p1 = np.asarray(p1, dtype=np.float64)
84 | tmap = np.asarray(tmap, dtype=np.float64)
85 | if p0.shape[1] != p1.shape[1]:
86 | raise ValueError("Unable to interpolate. Number of genes do not match")
87 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]:
88 | raise ValueError(
89 | "Unable to interpolate. Tmap size is {}, expected {}".format(
90 | tmap.shape, (len(p0), len(p1))
91 | )
92 | )
93 | I = len(p0)
94 | J = len(p1)
95 | # Assume growth is exponential and retrieve growth rate at t_interpolate
96 | # If all sums are the same then this does not change anything
97 | # This only matters if sum is not the same for all rows
98 | p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac)
99 | p = p.flatten(order="C")
100 | p = p / p.sum()
101 | choices = np.random.choice(I * J, p=p, size=size)
102 | return np.asarray(
103 | [p0[i // J] * (1 - interp_frac) + p1[i % J] * interp_frac for i in choices],
104 | dtype=np.float64,
105 | )
106 |
107 |
108 | def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac):
109 | """
110 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1
111 | Parameters
112 | ----------
113 | p0 : 2-D array
114 | The genes of each cell in the source population
115 | p1 : 2-D array
116 | The genes of each cell in the destination population
117 | tmap : 2-D array
118 | A transport map from p0 to p1
119 | t_interpolate : float
120 | The fraction at which to interpolate
121 | Returns
122 | -------
123 | p05 : 2-D array
124 | An interpolated population of 'size' cells
125 | """
126 | assert len(p0) == len(p1)
127 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0
128 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1
129 | p0 = np.asarray(p0, dtype=np.float64)
130 | p1 = np.asarray(p1, dtype=np.float64)
131 | tmap = np.asarray(tmap, dtype=np.float64)
132 | if p0.shape[1] != p1.shape[1]:
133 | raise ValueError("Unable to interpolate. Number of genes do not match")
134 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]:
135 | raise ValueError(
136 | "Unable to interpolate. Tmap size is {}, expected {}".format(
137 | tmap.shape, (len(p0), len(p1))
138 | )
139 | )
140 |
141 | I = len(p0)
142 | J = len(p1)
143 | # Assume growth is exponential and retrieve growth rate at t_interpolate
144 | # If all sums are the same then this does not change anything
145 | # This only matters if sum is not the same for all rows
146 | p = tmap / (tmap.sum(axis=0) / 1.0 - interp_frac)
147 | # p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac)
148 | # p = p.flatten(order="C")
149 | p = p / p.sum(axis=0)
150 | choices = np.array([np.random.choice(I, p=p[i]) for i in range(I)])
151 | return np.asarray(
152 | [
153 | p0[i] * (1 - interp_frac) + p1[j] * interp_frac
154 | for i, j in enumerate(choices)
155 | ],
156 | dtype=np.float64,
157 | )
158 |
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/gcs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/gcs.npy
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/growth/traj.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/growth/traj.mp4
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/TrajectoryNet/optimal_transport/model
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/plot_UOT_1D.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | ===============================
4 | 1D Unbalanced optimal transport
5 | ===============================
6 |
7 | This example illustrates the computation of Unbalanced Optimal transport
8 | using a Kullback-Leibler relaxation.
9 | """
10 |
11 | # Author: Hicham Janati
12 | #
13 | # License: MIT License
14 |
15 | import numpy as np
16 | import matplotlib.pylab as pl
17 | import ot
18 | import ot.plot
19 | from ot.datasets import make_1D_gauss as gauss
20 |
21 | ##############################################################################
22 | # Generate data
23 | # -------------
24 |
25 |
26 | #%% parameters
27 |
28 | n = 2 # nb bins
29 |
30 | # bin positions
31 | x = np.arange(n, dtype=np.float64)
32 |
33 | # Gaussian distributions
34 | a = gauss(n, m=20, s=1) # m= mean, s= std
35 | b = gauss(n, m=60, s=1)
36 |
37 | a = [0.1, 0.9]
38 | b = [0.9, 0.1]
39 |
40 | # make distributions unbalanced
41 | # b *= 5.
42 |
43 | # loss matrix
44 | M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
45 | M /= M.max()
46 |
47 |
48 | ##############################################################################
49 | # Plot distributions and loss matrix
50 | # ----------------------------------
51 |
52 | #%% plot the distributions
53 |
54 | pl.figure(1, figsize=(6.4, 3))
55 | pl.plot(x, a, "b", label="Source distribution")
56 | pl.plot(x, b, "r", label="Target distribution")
57 | pl.legend()
58 |
59 | # plot distributions and loss matrix
60 |
61 | pl.figure(2, figsize=(5, 5))
62 | ot.plot.plot1D_mat(a, b, M, "Cost matrix M")
63 |
64 |
65 | ##############################################################################
66 | # Solve Unbalanced Sinkhorn
67 | # --------------
68 |
69 |
70 | def get_transform_matrix(gamma, a, epsilon=1e-8):
71 | """Return matrix such that T @ a = b
72 | gamma : gamma @ 1 = a; gamma^T @ 1 = b
73 | """
74 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T
75 |
76 |
77 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False):
78 | T = get_transform_matrix(gamma, a, epsilon)
79 | unnormalized_coeffs = np.sum(T, axis=0)
80 | if not normalize:
81 | return unnormalized_coeffs
82 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs)
83 |
84 |
85 | # Sinkhorn
86 |
87 | epsilon = 0.1 # entropy parameter
88 | alpha = 1 # Unbalanced KL relaxation parameter
89 | beta = 10000
90 | # Gs = ot.emd(a, b, M)
91 | Gs = sinkhorn_knopp_unbalanced(a, b, M, epsilon, alpha, beta, verbose=True)
92 | print(Gs)
93 | print(a, b)
94 | print(get_growth_coeffs(Gs, np.array(a)))
95 | print(get_transform_matrix(Gs, np.array(a)) @ a)
96 | print(get_growth_coeffs(Gs, np.array(a)) * a)
97 | exit()
98 | print(Gs)
99 | print(Gs @ np.ones_like(a))
100 | print("bbbbbbbbbb")
101 | tt = get_transform_matrix(Gs, np.array(a))
102 | print(tt)
103 | print("col_sum(tt)", np.sum(tt, axis=0))
104 | print("tt @ a", tt @ a)
105 | print("bbbbbbbbbb")
106 | print(np.sum(Gs, axis=0), np.sum(Gs, axis=1))
107 | print("aaaaaaa == 1")
108 | alpha = 1 # Unbalanced KL relaxation parameter
109 | Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
110 | print(Gs)
111 |
112 | pl.figure(4, figsize=(5, 5))
113 | ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn")
114 |
115 | # pl.show()
116 |
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/sinkhorn_knopp_unbalanced.py:
--------------------------------------------------------------------------------
1 | """ Implements unbalanced sinkhorn knopp optimization for unbalanced ot.
2 |
3 | This is from the package python optimal transport but modified to take
4 | three regularization parameters instead of two. This is necessary to find
5 | growth rates of the source distribution that best match the target distribution
6 | or vis versa. by setting reg_m_1 to something low and reg_m_2 to something
7 | large we can compute an unbalanced optimal transport where all the scaling is
8 | done on the source distribution and none is done on the target distribution.
9 | """
10 | import numpy as np
11 | import warnings
12 |
13 |
14 | def sinkhorn_knopp_unbalanced(
15 | a,
16 | b,
17 | M,
18 | reg,
19 | reg_m_1,
20 | reg_m_2,
21 | numItermax=1000,
22 | stopThr=1e-6,
23 | verbose=False,
24 | log=False,
25 | **kwargs
26 | ):
27 | """
28 | Solve the entropic regularization unbalanced optimal transport problem
29 |
30 | The function solves the following optimization problem:
31 |
32 | .. math::
33 | W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \
34 | \reg_m_1 KL(\gamma 1, a) + \reg_m_2 KL(\gamma^T 1, b)
35 |
36 | s.t.
37 | \gamma\geq 0
38 | where :
39 |
40 | - M is the (dim_a, dim_b) metric cost matrix
41 | - :math:`\Omega` is the entropic regularization term
42 | :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
43 | - a and b are source and target unbalanced distributions
44 | - KL is the Kullback-Leibler divergence
45 |
46 | The algorithm used for solving the problem is the generalized
47 | Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
48 |
49 |
50 | Parameters
51 | ----------
52 | a : np.ndarray (dim_a,)
53 | Unnormalized histogram of dimension dim_a
54 | b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
55 | One or multiple unnormalized histograms of dimension dim_b
56 | If many, compute all the OT distances (a, b_i)
57 | M : np.ndarray (dim_a, dim_b)
58 | loss matrix
59 | reg : float
60 | Entropy regularization term > 0
61 | reg_m: float
62 | Marginal relaxation term > 0
63 | numItermax : int, optional
64 | Max number of iterations
65 | stopThr : float, optional
66 | Stop threshol on error (> 0)
67 | verbose : bool, optional
68 | Print information along iterations
69 | log : bool, optional
70 | record log if True
71 |
72 |
73 | Returns
74 | -------
75 | if n_hists == 1:
76 | gamma : (dim_a x dim_b) ndarray
77 | Optimal transportation matrix for the given parameters
78 | log : dict
79 | log dictionary returned only if `log` is `True`
80 | else:
81 | ot_distance : (n_hists,) ndarray
82 | the OT distance between `a` and each of the histograms `b_i`
83 | log : dict
84 | log dictionary returned only if `log` is `True`
85 | Examples
86 | --------
87 |
88 | >>> import ot
89 | >>> a=[.5, .5]
90 | >>> b=[.5, .5]
91 | >>> M=[[0., 1.],[1., 0.]]
92 | >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
93 | array([[0.51122823, 0.18807035],
94 | [0.18807035, 0.51122823]])
95 |
96 | References
97 | ----------
98 |
99 | .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
100 | Scaling algorithms for unbalanced transport problems. arXiv preprint
101 | arXiv:1607.05816.
102 |
103 | .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
104 | Learning with a Wasserstein Loss, Advances in Neural Information
105 | Processing Systems (NIPS) 2015
106 |
107 | See Also
108 | --------
109 | ot.lp.emd : Unregularized OT
110 | ot.optim.cg : General regularized OT
111 |
112 | """
113 |
114 | a = np.asarray(a, dtype=np.float64)
115 | b = np.asarray(b, dtype=np.float64)
116 | M = np.asarray(M, dtype=np.float64)
117 |
118 | dim_a, dim_b = M.shape
119 |
120 | if len(a) == 0:
121 | a = np.ones(dim_a, dtype=np.float64) / dim_a
122 | if len(b) == 0:
123 | b = np.ones(dim_b, dtype=np.float64) / dim_b
124 |
125 | if len(b.shape) > 1:
126 | n_hists = b.shape[1]
127 | else:
128 | n_hists = 0
129 |
130 | if log:
131 | log = {"err": []}
132 |
133 | # we assume that no distances are null except those of the diagonal of
134 | # distances
135 | if n_hists:
136 | u = np.ones((dim_a, 1)) / dim_a
137 | v = np.ones((dim_b, n_hists)) / dim_b
138 | a = a.reshape(dim_a, 1)
139 | else:
140 | u = np.ones(dim_a) / dim_a
141 | v = np.ones(dim_b) / dim_b
142 |
143 | # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
144 | K = np.empty(M.shape, dtype=M.dtype)
145 | np.divide(M, -reg, out=K)
146 | np.exp(K, out=K)
147 |
148 | cpt = 0
149 | err = 1.0
150 |
151 | while err > stopThr and cpt < numItermax:
152 | uprev = u
153 | vprev = v
154 |
155 | Kv = K.dot(v)
156 | u = (a / Kv) ** (reg_m_1 / (reg_m_1 + reg))
157 | Ktu = K.T.dot(u)
158 | v = (b / Ktu) ** (reg_m_2 / (reg_m_2 + reg))
159 |
160 | if (
161 | np.any(Ktu == 0.0)
162 | or np.any(np.isnan(u))
163 | or np.any(np.isnan(v))
164 | or np.any(np.isinf(u))
165 | or np.any(np.isinf(v))
166 | ):
167 | # we have reached the machine precision
168 | # come back to previous solution and quit loop
169 | warnings.warn("Numerical errors at iteration %s" % cpt)
170 | u = uprev
171 | v = vprev
172 | break
173 | if cpt % 10 == 0:
174 | # we can speed up the process by checking for the error only all
175 | # the 10th iterations
176 | err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.0)
177 | err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.0)
178 | err = 0.5 * (err_u + err_v)
179 | if log:
180 | log["err"].append(err)
181 | if verbose:
182 | if cpt % 200 == 0:
183 | print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
184 | print("{:5d}|{:8e}|".format(cpt, err))
185 | cpt += 1
186 |
187 | if log:
188 | log["logu"] = np.log(u + 1e-16)
189 | log["logv"] = np.log(v + 1e-16)
190 |
191 | if n_hists: # return only loss
192 | res = np.einsum("ik,ij,jk,ij->k", u, K, v, M)
193 | if log:
194 | return res, log
195 | else:
196 | return res
197 |
198 | else: # return OT matrix
199 |
200 | if log:
201 | return u[:, None] * K * v[None, :], log
202 | else:
203 | return u[:, None] * K * v[None, :]
204 |
--------------------------------------------------------------------------------
/TrajectoryNet/optimal_transport/train_growth.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from sklearn.preprocessing import StandardScaler
4 | from scipy.spatial.distance import cdist
5 | import scprep
6 | import torch
7 |
8 | # import atongtf.dataset as atd
9 |
10 | from sinkhorn_knopp_unbalanced import sinkhorn_knopp_unbalanced
11 |
12 |
13 | def load_data_full():
14 | data = atd.EB_Velocity_Dataset()
15 | labels = data.data["sample_labels"]
16 | scaler = StandardScaler()
17 | scaler.fit(data.emb)
18 | transformed = scaler.transform(data.emb)
19 | return transformed, labels, scaler
20 |
21 |
22 | def get_transform_matrix(gamma, a, epsilon=1e-8):
23 | """Return matrix such that T @ a = b
24 | gamma : gamma @ 1 = a; gamma^T @ 1 = b
25 | """
26 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T
27 |
28 |
29 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False):
30 | T = get_transform_matrix(gamma, a, epsilon)
31 | unnormalized_coeffs = np.sum(T, axis=0)
32 | if not normalize:
33 | return unnormalized_coeffs
34 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs)
35 |
36 |
37 | data, labels, _ = load_data_full()
38 |
39 | print(data.shape, labels.shape)
40 | exit()
41 |
42 | # Compute couplings
43 |
44 | timepoints = np.unique(labels)
45 |
46 | dfs = [data[labels == tp] for tp in timepoints]
47 |
48 |
49 | def get_all_growth_coeffs(alpha):
50 | gcs = []
51 | for i in range(len(dfs) - 1):
52 | a, b = dfs[i], dfs[i + 1]
53 | m, n = a.shape[0], b.shape[0]
54 | M = cdist(a, b)
55 | entropy_reg = 0.1
56 | reg_1, reg_2 = alpha, 10000
57 | gamma = sinkhorn_knopp_unbalanced(
58 | np.ones(m) / m, np.ones(n) / n, M, entropy_reg, reg_1, reg_2
59 | )
60 | gc = get_growth_coeffs(gamma, np.ones(m) / m)
61 | gcs.append(gc)
62 | return gcs
63 |
64 |
65 | gcs = np.load("gcs.npy")
66 | print(gcs)
67 |
68 |
69 | class GrowthNet(torch.nn.Module):
70 | def __init__(self):
71 | super().__init__()
72 |
73 | self.fc1 = torch.nn.Linear(3, 64)
74 | self.fc2 = torch.nn.Linear(64, 64)
75 | self.fc3 = torch.nn.Linear(64, 1)
76 |
77 | def forward(self, x):
78 | x = torch.nn.functional.leaky_relu(self.fc1(x))
79 | x = torch.nn.functional.leaky_relu(self.fc2(x))
80 | x = self.fc3(x)
81 | return x
82 |
83 |
84 | X = np.concatenate([data, labels[:, None]], axis=1)[labels != timepoints[-1]]
85 | Y = gcs[:, None]
86 |
87 | device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu")
88 |
89 | model = GrowthNet().to(device)
90 | model.train()
91 | optimizer = torch.optim.Adam(model.parameters())
92 |
93 | """
94 | for it in range(100000):
95 | optimizer.zero_grad()
96 | batch_idx = np.random.randint(len(X), size=256)
97 | x = torch.from_numpy(X[batch_idx,:]).type(torch.float32).to(device)
98 | y = torch.from_numpy(Y[batch_idx,:]).type(torch.float32).to(device)
99 | negative_samples = np.concatenate([np.random.uniform(size=(256,2)) * 8 - 4,
100 | np.random.choice(timepoints, size=(256,1))], axis=1)
101 | negative_samples = torch.from_numpy(negative_samples).type(torch.float32).to(device)
102 | x = torch.cat([x, negative_samples])
103 | y = torch.cat([y, torch.ones_like(y)])
104 | pred = model(x)
105 | loss = torch.nn.MSELoss()
106 | output = loss(pred, y)
107 | output.backward()
108 | optimizer.step()
109 | if it % 100 == 0:
110 | print(it, output)
111 |
112 | torch.save(model, 'model')
113 |
114 | exit()
115 | """
116 |
117 | model = torch.load("model")
118 | model.eval()
119 | import matplotlib
120 |
121 | for i, tp in enumerate(np.linspace(0, 3, 100)):
122 | fig, axes = plt.subplots(
123 | 2, 1, gridspec_kw={"height_ratios": [7, 1]}, figsize=(8, 8)
124 | )
125 | ax = axes[0]
126 | npts = 200
127 | side = np.linspace(-4, 4, npts)
128 | xx, yy = np.meshgrid(side, side)
129 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
130 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
131 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
132 | data_in = torch.cat([z_grid, tp * torch.ones(z_grid.shape[0], 1).to(device)], 1)
133 | gr = model(data_in)
134 | gr = gr.reshape(npts, npts).to("cpu").detach().numpy()
135 | ax.pcolormesh(
136 | xx.cpu().detach(), yy.cpu().detach(), gr, cmap="RdBu_r", vmin=0, vmax=2
137 | )
138 | scprep.plot.scatter2d(data, c="Gray", ax=ax, alpha=0.1)
139 | ax.get_xaxis().set_ticks([])
140 | ax.get_yaxis().set_ticks([])
141 | ax.set_title("Growth Rate")
142 | # ax.set_title("%0.2f" % tp, fontsize=32)
143 | ax = axes[1]
144 | # Colorbar
145 | cb = matplotlib.colorbar.ColorbarBase(ax, cmap="Spectral", orientation="horizontal")
146 | cb.set_ticks(np.linspace(0, 1, 4))
147 | cb.set_ticklabels(np.arange(4))
148 | ax.axvline(tp / 3, c="k", linewidth=15)
149 | ax.set_title("Time")
150 | plt.savefig("growth/viz-%05d.jpg" % i)
151 | plt.close()
152 |
153 |
154 | def trajectory_to_video(savedir):
155 | import subprocess
156 | import os
157 |
158 | bashCommand = "ffmpeg -y -i {} {}".format(
159 | os.path.join(savedir, "viz-%05d.jpg"), os.path.join(savedir, "traj.mp4")
160 | )
161 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
162 | output, error = process.communicate()
163 |
164 |
165 | trajectory_to_video("growth")
166 | """
167 | fig, axes = plt.subplots(2,2, figsize=(10,10))
168 | axes = axes.flatten()
169 | for i, tp in enumerate(timepoints[:-1]):
170 | ax = axes[i]
171 | # Construct Grid,
172 | npts = 200
173 | side = np.linspace(-4, 4, npts)
174 | xx, yy = np.meshgrid(side, side)
175 | xx = torch.from_numpy(xx).type(torch.float32).to(device)
176 | yy = torch.from_numpy(yy).type(torch.float32).to(device)
177 | z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
178 | data_in = torch.cat([z_grid, tp * torch.ones(z_grid.shape[0],1).to(device)], 1)
179 | gr = model(data_in)
180 | gr = gr.reshape(npts, npts).to('cpu').detach().numpy()
181 | ax.pcolormesh(xx.cpu().detach(), yy.cpu().detach(), gr, cmap='RdBu_r')
182 | scprep.plot.scatter2d(data, c='Gray', ax=ax, alpha=0.1)
183 | ax.get_xaxis().set_ticks([])
184 | ax.get_yaxis().set_ticks([])
185 | ax.set_title("%0.2f" % tp, fontsize=32)
186 |
187 | plt.savefig('growth_function_snapshots.png')
188 | plt.close()
189 | """
190 | """
191 | gcs = get_all_growth_coeffs(2)
192 | gcs = np.concatenate(gcs)
193 | print(gcs.shape)
194 | np.save('gcs.npy', gcs)
195 | """
196 |
197 | """
198 | for alpha in [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100]:
199 | fig, axes = plt.subplots(2,2, figsize=(10,10))
200 | axes = axes.flatten()
201 | for i in range(len(dfs) - 1):
202 | a, b = dfs[i], dfs[i+1]
203 | m, n = a.shape[0], b.shape[0]
204 | M = cdist(a, b)
205 | print(a.shape, b.shape, M.shape)
206 | entropy_reg = 0.1
207 | reg_1, reg_2 = alpha, 10000
208 | gamma = sinkhorn_knopp_unbalanced(np.ones(m) / m, np.ones(n) / n,
209 | M, entropy_reg, reg_1, reg_2)
210 | gc = get_growth_coeffs(gamma, np.ones(m) / m)
211 | scprep.plot.scatter2d(data, c='Gray', ax=axes[i], alpha=0.1)
212 | scprep.plot.scatter2d(a, c=gc, cmap='RdBu_r', ax=axes[i], vmin=0, vmax=2)
213 | axes[i].set_title(i)
214 | plt.savefig('figs/alpha_%0.2f.png' % alpha)
215 | plt.close()
216 | """
217 |
--------------------------------------------------------------------------------
/TrajectoryNet/parse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .lib.layers import odefunc
3 |
4 | SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", "adams", "explicit_adams", "fixed_adams"]
5 |
6 | parser = argparse.ArgumentParser("Continuous Normalizing Flow")
7 | parser.add_argument("--test", action="store_true")
8 | parser.add_argument("--dataset", type=str, default="EB")
9 | parser.add_argument("--use_growth", action="store_true")
10 | parser.add_argument("--use_density", action="store_true")
11 | parser.add_argument("--leaveout_timepoint", type=int, default=-1)
12 | parser.add_argument(
13 | "--layer_type",
14 | type=str,
15 | default="concatsquash",
16 | choices=[
17 | "ignore",
18 | "concat",
19 | "concat_v2",
20 | "squash",
21 | "concatsquash",
22 | "concatcoord",
23 | "hyper",
24 | "blend",
25 | ],
26 | )
27 | parser.add_argument("--max_dim", type=int, default=10)
28 | parser.add_argument("--dims", type=str, default="64-64-64")
29 | parser.add_argument("--num_blocks", type=int, default=1, help="Number of stacked CNFs.")
30 | parser.add_argument("--time_scale", type=float, default=0.5)
31 | parser.add_argument("--train_T", type=eval, default=True)
32 | parser.add_argument(
33 | "--divergence_fn",
34 | type=str,
35 | default="brute_force",
36 | choices=["brute_force", "approximate"],
37 | )
38 | parser.add_argument(
39 | "--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES
40 | )
41 | parser.add_argument("--stochastic", action="store_true")
42 |
43 | parser.add_argument(
44 | "--alpha", type=float, default=0.0, help="loss weight parameter for growth model"
45 | )
46 | parser.add_argument("--solver", type=str, default="dopri5", choices=SOLVERS)
47 | parser.add_argument("--atol", type=float, default=1e-5)
48 | parser.add_argument("--rtol", type=float, default=1e-5)
49 | parser.add_argument(
50 | "--step_size", type=float, default=None, help="Optional fixed step size."
51 | )
52 |
53 | parser.add_argument("--test_solver", type=str, default=None, choices=SOLVERS + [None])
54 | parser.add_argument("--test_atol", type=float, default=None)
55 | parser.add_argument("--test_rtol", type=float, default=None)
56 |
57 | parser.add_argument("--residual", action="store_true")
58 | parser.add_argument("--rademacher", action="store_true")
59 | parser.add_argument("--spectral_norm", action="store_true")
60 | parser.add_argument("--batch_norm", action="store_true")
61 | parser.add_argument("--bn_lag", type=float, default=0)
62 |
63 | parser.add_argument("--niters", type=int, default=10000)
64 | parser.add_argument("--num_workers", type=int, default=8)
65 | parser.add_argument("--batch_size", type=int, default=1000)
66 | parser.add_argument("--test_batch_size", type=int, default=1000)
67 | parser.add_argument("--viz_batch_size", type=int, default=2000)
68 | parser.add_argument("--lr", type=float, default=1e-3)
69 | parser.add_argument("--weight_decay", type=float, default=1e-5)
70 |
71 | # Track quantities
72 | parser.add_argument("--l1int", type=float, default=None, help="int_t ||f||_1")
73 | parser.add_argument("--l2int", type=float, default=None, help="int_t ||f||_2")
74 | parser.add_argument("--sl2int", type=float, default=None, help="int_t ||f||_2^2")
75 | parser.add_argument(
76 | "--dl2int", type=float, default=None, help="int_t ||f^T df/dt||_2"
77 | ) # f df/dx?
78 | parser.add_argument(
79 | "--dtl2int", type=float, default=None, help="int_t ||f^T df/dx + df/dt||_2"
80 | )
81 | parser.add_argument("--JFrobint", type=float, default=None, help="int_t ||df/dx||_F")
82 | parser.add_argument(
83 | "--JdiagFrobint", type=float, default=None, help="int_t ||df_i/dx_i||_F"
84 | )
85 | parser.add_argument(
86 | "--JoffdiagFrobint", type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F"
87 | )
88 | parser.add_argument("--vecint", type=float, default=None, help="regularize direction")
89 | parser.add_argument(
90 | "--use_magnitude",
91 | action="store_true",
92 | help="regularize direction using MSE loss instead of cosine loss",
93 | )
94 |
95 | parser.add_argument(
96 | "--interp_reg", type=float, default=None, help="regularize interpolation"
97 | )
98 |
99 | parser.add_argument("--save", type=str, default="../results/tmp")
100 | parser.add_argument("--save_freq", type=int, default=1000)
101 | parser.add_argument("--viz_freq", type=int, default=100)
102 | parser.add_argument("--viz_freq_growth", type=int, default=100)
103 | parser.add_argument("--val_freq", type=int, default=100)
104 | parser.add_argument("--log_freq", type=int, default=10)
105 | parser.add_argument("--gpu", type=int, default=0)
106 | parser.add_argument("--use_cpu", action="store_true")
107 | parser.add_argument("--no_display_loss", action="store_false")
108 | parser.add_argument(
109 | "--top_k_reg", type=float, default=0.0, help="density following regularization"
110 | )
111 | parser.add_argument("--training_noise", type=float, default=0.1)
112 | parser.add_argument(
113 | "--embedding_name",
114 | type=str,
115 | default="pca",
116 | help="choose embedding name to perform TrajectoryNet on",
117 | )
118 | parser.add_argument("--whiten", action="store_true", help="Whiten data before running TrajectoryNet")
119 | parser.add_argument("--save_movie", action="store_false", help="Construct trajectory movie, requires ffmpeg to be installed")
--------------------------------------------------------------------------------
/TrajectoryNet/train_growth.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from sklearn.preprocessing import StandardScaler
4 | from scipy.spatial.distance import cdist
5 | import scprep
6 | import torch
7 | import time
8 |
9 | from TrajectoryNet import dataset
10 | from .optimal_transport.sinkhorn_knopp_unbalanced import sinkhorn_knopp_unbalanced
11 |
12 |
13 | eb_data = dataset.EBData("pcs", max_dim=5)
14 |
15 |
16 | def get_transform_matrix(gamma, a, epsilon=1e-8):
17 | """Return matrix such that T @ a = b
18 | gamma : gamma @ 1 = a; gamma^T @ 1 = b
19 | """
20 | return (np.diag(1.0 / (a + epsilon)) @ gamma).T
21 |
22 |
23 | def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False):
24 | T = get_transform_matrix(gamma, a, epsilon)
25 | unnormalized_coeffs = np.sum(T, axis=0)
26 | if not normalize:
27 | return unnormalized_coeffs
28 | return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs)
29 |
30 |
31 | data, labels = eb_data.data, eb_data.get_times()
32 |
33 | # Compute couplings
34 |
35 | timepoints = np.unique(labels)
36 | print("timepoints", timepoints)
37 |
38 | dfs = [data[labels == tp] for tp in timepoints]
39 | pairs = [(0, 1), (1, 2), (2, 3), (3, 4), (0, 2), (1, 3), (2, 4)]
40 |
41 |
42 | def get_all_growth_coeffs(alpha):
43 | gcs = []
44 | for a_ind, b_ind in pairs:
45 | start = time.time()
46 | print(a_ind, b_ind)
47 | a, b = dfs[a_ind], dfs[b_ind]
48 | m, n = a.shape[0], b.shape[0]
49 | M = cdist(a, b)
50 | entropy_reg = 0.1
51 | reg_1, reg_2 = alpha, 10000
52 | gamma = sinkhorn_knopp_unbalanced(
53 | np.ones(m) / m, np.ones(n) / n, M, entropy_reg, reg_1, reg_2
54 | )
55 | gc = get_growth_coeffs(gamma, np.ones(m) / m)
56 | gcs.append(gc)
57 | end = time.time()
58 | print("%s to %s took %0.2f sec" % (a_ind, b_ind, end - start))
59 | print(gcs)
60 | return gcs
61 |
62 |
63 | gcs = np.load("../data/growth/gcs.npy", allow_pickle=True)
64 |
65 |
66 | class GrowthNet(torch.nn.Module):
67 | def __init__(self):
68 | super().__init__()
69 |
70 | self.fc1 = torch.nn.Linear(6, 64)
71 | self.fc2 = torch.nn.Linear(64, 64)
72 | self.fc3 = torch.nn.Linear(64, 1)
73 |
74 | def forward(self, x):
75 | x = torch.nn.functional.leaky_relu(self.fc1(x))
76 | x = torch.nn.functional.leaky_relu(self.fc2(x))
77 | x = self.fc3(x)
78 | return x
79 |
80 |
81 | device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")
82 | device = torch.device("cpu")
83 |
84 |
85 | def train(leaveout_tp):
86 | # Data, timepoint
87 | X = np.concatenate([data, labels[:, None]], axis=1)[
88 | (labels != timepoints[-1]) & (labels != leaveout_tp)
89 | ]
90 | if leaveout_tp == 1:
91 | Y = np.concatenate([gcs[4], gcs[2], gcs[3]])
92 | elif leaveout_tp == 2:
93 | Y = np.concatenate([gcs[5], gcs[0], gcs[3]])
94 | elif leaveout_tp == 3:
95 | Y = np.concatenate([gcs[6], gcs[0], gcs[1]])
96 | elif leaveout_tp == -1:
97 | Y = np.concatenate(gcs[:4])
98 | else:
99 | raise RuntimeError("Unknown leavout_tp %d" % leaveout_tp)
100 | print(X.shape, Y.shape)
101 | assert X.shape[0] == Y.shape[0]
102 | Y = Y[:, np.newaxis]
103 |
104 | model = GrowthNet().to(device)
105 | model.train()
106 | optimizer = torch.optim.Adam(model.parameters())
107 |
108 | for it in range(100000):
109 | optimizer.zero_grad()
110 | batch_idx = np.random.randint(len(X), size=256)
111 | x = torch.from_numpy(X[batch_idx, :]).type(torch.float32).to(device)
112 | y = torch.from_numpy(Y[batch_idx, :]).type(torch.float32).to(device)
113 | negative_samples = np.concatenate(
114 | [
115 | np.random.uniform(size=(256, X.shape[1] - 1)) * 8 - 4,
116 | np.random.choice(timepoints, size=(256, 1)),
117 | ],
118 | axis=1,
119 | )
120 | negative_samples = (
121 | torch.from_numpy(negative_samples).type(torch.float32).to(device)
122 | )
123 | x = torch.cat([x, negative_samples])
124 | y = torch.cat([y, torch.ones_like(y)])
125 | pred = model(x)
126 | loss = torch.nn.MSELoss()
127 | output = loss(pred, y)
128 | output.backward()
129 | optimizer.step()
130 | if it % 100 == 0:
131 | print(it, output)
132 |
133 | torch.save(model, ("model_%d" % leaveout_tp))
134 |
135 |
136 | # train(1)
137 | # train(2)
138 | # train(3)
139 |
140 |
141 | def trajectory_to_video(savedir):
142 | import subprocess
143 | import os
144 |
145 | bashCommand = "ffmpeg -y -i {} {}".format(
146 | os.path.join(savedir, "viz-%05d.jpg"), os.path.join(savedir, "traj.mp4")
147 | )
148 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
149 | output, error = process.communicate()
150 |
151 |
152 | trajectory_to_video("../data/growth/viz/")
153 |
--------------------------------------------------------------------------------
/TrajectoryNet/train_misc.py:
--------------------------------------------------------------------------------
1 | import six
2 | import math
3 |
4 | from .lib.layers.wrappers import cnf_regularization as reg_lib
5 | from .lib import spectral_norm, layers
6 | from .lib.layers.odefunc import divergence_bf, divergence_approx
7 |
8 |
9 | def standard_normal_logprob(z):
10 | logZ = -0.5 * math.log(2 * math.pi)
11 | return logZ - z.pow(2) / 2
12 |
13 |
14 | def set_cnf_options(args, model):
15 | def _set(module):
16 | if isinstance(module, layers.CNF):
17 | # Set training settings
18 | module.solver = args.solver
19 | module.atol = args.atol
20 | module.rtol = args.rtol
21 | if args.step_size is not None:
22 | module.solver_options["step_size"] = args.step_size
23 |
24 | # If using fixed-grid adams, restrict order to not be too high.
25 | if args.solver in ["fixed_adams", "explicit_adams"]:
26 | module.solver_options["max_order"] = 4
27 |
28 | # Set the test settings
29 | module.test_solver = args.test_solver if args.test_solver else args.solver
30 | module.test_atol = args.test_atol if args.test_atol else args.atol
31 | module.test_rtol = args.test_rtol if args.test_rtol else args.rtol
32 |
33 | if isinstance(module, layers.ODEfunc):
34 | module.rademacher = args.rademacher
35 | module.residual = args.residual
36 |
37 | model.apply(_set)
38 |
39 |
40 | def override_divergence_fn(model, divergence_fn):
41 | def _set(module):
42 | if isinstance(module, layers.ODEfunc):
43 | if divergence_fn == "brute_force":
44 | module.divergence_fn = divergence_bf
45 | elif divergence_fn == "approximate":
46 | module.divergence_fn = divergence_approx
47 |
48 | model.apply(_set)
49 |
50 |
51 | def count_nfe(model):
52 | class AccNumEvals(object):
53 | def __init__(self):
54 | self.num_evals = 0
55 |
56 | def __call__(self, module):
57 | if isinstance(module, layers.CNF):
58 | self.num_evals += module.num_evals()
59 |
60 | accumulator = AccNumEvals()
61 | model.apply(accumulator)
62 | return accumulator.num_evals
63 |
64 |
65 | def count_parameters(model):
66 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
67 |
68 |
69 | def count_total_time(model):
70 | class Accumulator(object):
71 | def __init__(self):
72 | self.total_time = 0
73 |
74 | def __call__(self, module):
75 | if isinstance(module, layers.CNF):
76 | self.total_time = (
77 | self.total_time + module.sqrt_end_time * module.sqrt_end_time
78 | )
79 |
80 | accumulator = Accumulator()
81 | model.apply(accumulator)
82 | return accumulator.total_time
83 |
84 |
85 | def add_spectral_norm(model, logger=None):
86 | """Applies spectral norm to all modules within the scope of a CNF."""
87 |
88 | def apply_spectral_norm(module):
89 | if "weight" in module._parameters:
90 | if logger:
91 | logger.info("Adding spectral norm to {}".format(module))
92 | spectral_norm.inplace_spectral_norm(module, "weight")
93 |
94 | def find_cnf(module):
95 | if isinstance(module, layers.CNF):
96 | module.apply(apply_spectral_norm)
97 | else:
98 | for child in module.children():
99 | find_cnf(child)
100 |
101 | find_cnf(model)
102 |
103 |
104 | def spectral_norm_power_iteration(model, n_power_iterations=1):
105 | def recursive_power_iteration(module):
106 | if hasattr(module, spectral_norm.POWER_ITERATION_FN):
107 | getattr(module, spectral_norm.POWER_ITERATION_FN)(n_power_iterations)
108 |
109 | model.apply(recursive_power_iteration)
110 |
111 |
112 | REGULARIZATION_FNS = {
113 | "l1int": reg_lib.l1_regularzation_fn,
114 | "l2int": reg_lib.l2_regularzation_fn,
115 | "sl2int": reg_lib.squared_l2_regularization_fn,
116 | "dl2int": reg_lib.directional_l2_regularization_fn,
117 | "dtl2int": reg_lib.directional_l2_change_penalty_fn,
118 | "JFrobint": reg_lib.jacobian_frobenius_regularization_fn,
119 | "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn,
120 | "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn,
121 | }
122 |
123 | INV_REGULARIZATION_FNS = {v: k for k, v in six.iteritems(REGULARIZATION_FNS)}
124 |
125 |
126 | def append_regularization_to_log(log_message, regularization_fns, reg_states):
127 | for i, reg_fn in enumerate(regularization_fns):
128 | log_message = (
129 | log_message
130 | + " | "
131 | + INV_REGULARIZATION_FNS[reg_fn]
132 | + ": {:.8f}".format(reg_states[i].item())
133 | )
134 | return log_message
135 |
136 |
137 | def create_regularization_fns(args):
138 | regularization_fns = []
139 | regularization_coeffs = []
140 |
141 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS):
142 | if getattr(args, arg_key) is not None:
143 | regularization_fns.append(reg_fn)
144 | regularization_coeffs.append(eval("args." + arg_key))
145 |
146 | regularization_fns = tuple(regularization_fns)
147 | regularization_coeffs = tuple(regularization_coeffs)
148 | return regularization_fns, regularization_coeffs
149 |
150 |
151 | def get_regularization(model, regularization_coeffs):
152 | if len(regularization_coeffs) == 0:
153 | return None
154 |
155 | acc_reg_states = tuple([0.0] * len(regularization_coeffs))
156 | for module in model.modules():
157 | if isinstance(module, layers.CNF):
158 | acc_reg_states = tuple(
159 | acc + reg
160 | for acc, reg in zip(acc_reg_states, module.get_regularization_states())
161 | )
162 | return acc_reg_states
163 |
164 |
165 | def build_model_tabular(args, dims, regularization_fns=None):
166 |
167 | hidden_dims = tuple(map(int, args.dims.split("-")))
168 |
169 | def build_cnf():
170 | diffeq = layers.ODEnet(
171 | hidden_dims=hidden_dims,
172 | input_shape=(dims,),
173 | strides=None,
174 | conv=False,
175 | layer_type=args.layer_type,
176 | nonlinearity=args.nonlinearity,
177 | )
178 | odefunc = layers.ODEfunc(
179 | diffeq=diffeq,
180 | divergence_fn=args.divergence_fn,
181 | residual=args.residual,
182 | rademacher=args.rademacher,
183 | )
184 | cnf = layers.CNF(
185 | odefunc=odefunc,
186 | T=args.time_scale,
187 | train_T=args.train_T,
188 | regularization_fns=regularization_fns,
189 | solver=args.solver,
190 | )
191 | return cnf
192 |
193 | chain = [build_cnf() for _ in range(args.num_blocks)]
194 | if args.batch_norm:
195 | bn_layers = [
196 | layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)
197 | for _ in range(args.num_blocks)
198 | ]
199 | bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)]
200 | for a, b in zip(chain, bn_layers):
201 | bn_chain.append(a)
202 | bn_chain.append(b)
203 | chain = bn_chain
204 | model = layers.SequentialFlow(chain)
205 |
206 | set_cnf_options(args, model)
207 |
208 | return model
209 |
--------------------------------------------------------------------------------
/TrajectoryNet/version.py:
--------------------------------------------------------------------------------
1 | "0.2.4"
2 |
--------------------------------------------------------------------------------
/data/eb_genes.txt:
--------------------------------------------------------------------------------
1 | TAL1
2 | CD34
3 | PECAM1
4 | HOXD1
5 | HOXB4
6 | HAND1
7 | GATA6
8 | GATA5
9 | TNNT2
10 | TBX18
11 | TBX15
12 | PDGFRA
13 | SIX2
14 | TBX5
15 | WT1
16 | MYC
17 | LEF1
18 | SOX10
19 | FOXD3
20 | PAX3
21 | SOX9
22 | HOXA2
23 | OLIG3
24 | ONECUT2
25 | KLF7
26 | ONECUT1
27 | MAP2
28 | ISL1
29 | DLX1
30 | HOXB1
31 | NR2F1
32 | LMX1A
33 | DMRT3
34 | OLIG1
35 | PAX6
36 | NPAS1
37 | SOX1
38 | NKX2-8
39 | EN2
40 | ZBTB16
41 | SOX17
42 | FOXA2
43 | EOMES
44 | T
45 | GATA4
46 | ASCL2
47 | CDX2
48 | ARID3A
49 | KLF5
50 | RFX6
51 | NKX2-1
52 | SOX15
53 | TP63
54 | GATA3
55 | SATB1
56 | CER1
57 | LHX5
58 | SIX3
59 | LHX2
60 | GLI3
61 | SIX6
62 | NES
63 | GBX2
64 | ZIC2
65 | NANOG
66 | MIXL1
67 | OTX2
68 | POU5F1
--------------------------------------------------------------------------------
/data/eb_velocity_v5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/data/eb_velocity_v5.npz
--------------------------------------------------------------------------------
/figures/EB-Trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/figures/EB-Trajectory.gif
--------------------------------------------------------------------------------
/figures/eb_high_quality.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/figures/eb_high_quality.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | matplotlib>=3.2.1
3 | numpy>=1.18.4
4 | POT>=0.7.0
5 | scanpy
6 | scikit-learn>=0.23.1
7 | scipy>=1.4.1
8 | torch>=1.5.0
9 | torchdiffeq==0.0.1
10 |
--------------------------------------------------------------------------------
/results/fig8_results/backward_trajectories.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/results/fig8_results/backward_trajectories.npy
--------------------------------------------------------------------------------
/results/fig8_results/checkpt.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KrishnaswamyLab/TrajectoryNet/810c89b081f95405bc0ad42bfb3cb41038e13700/results/fig8_results/checkpt.pth
--------------------------------------------------------------------------------
/results/fig8_results/train_eval.csv:
--------------------------------------------------------------------------------
1 | 100,tensor(4.9269)
2 | 200,tensor(4.3838)
3 | 300,tensor(4.0277)
4 | 400,tensor(3.7389)
5 | 500,tensor(3.5852)
6 | 600,tensor(3.5531)
7 | 700,tensor(3.4902)
8 | 800,tensor(3.4699)
9 | 900,tensor(3.4623)
10 | 1000,tensor(3.3754)
11 | 1100,tensor(3.3479)
12 | 1200,tensor(3.3542)
13 | 1300,tensor(3.3215)
14 | 1400,tensor(3.3450)
15 | 1500,tensor(3.3318)
16 | 1600,tensor(3.3250)
17 | 1700,tensor(3.3133)
18 | 1800,tensor(3.2920)
19 | 1900,tensor(3.2693)
20 | 2000,tensor(3.2431)
21 | 2100,tensor(3.3150)
22 | 2200,tensor(3.2185)
23 | 2300,tensor(3.2650)
24 | 2400,tensor(3.2519)
25 | 2500,tensor(3.2392)
26 | 2600,tensor(3.2440)
27 | 2700,tensor(3.2327)
28 | 2800,tensor(3.2221)
29 | 2900,tensor(3.2445)
30 | 3000,tensor(3.1992)
31 | 3100,tensor(3.2227)
32 | 3200,tensor(3.2277)
33 | 3300,tensor(3.2125)
34 | 3400,tensor(3.2534)
35 | 3500,tensor(3.2225)
36 | 3600,tensor(3.2168)
37 | 3700,tensor(3.2291)
38 | 3800,tensor(3.2067)
39 | 3900,tensor(3.1941)
40 | 4000,tensor(3.1983)
41 | 4100,tensor(3.1990)
42 | 4200,tensor(3.1563)
43 | 4300,tensor(3.1885)
44 | 4400,tensor(3.1989)
45 | 4500,tensor(3.1565)
46 | 4600,tensor(3.2174)
47 | 4700,tensor(3.2120)
48 | 4800,tensor(3.2327)
49 | 4900,tensor(3.1519)
50 | 5000,tensor(3.1598)
51 | 5100,tensor(3.1531)
52 | 5200,tensor(3.1642)
53 | 5300,tensor(3.1758)
54 | 5400,tensor(3.1578)
55 | 5500,tensor(3.1683)
56 | 5600,tensor(3.1276)
57 | 5700,tensor(3.1566)
58 | 5800,tensor(3.1818)
59 | 5900,tensor(3.1101)
60 | 6000,tensor(3.1899)
61 | 6100,tensor(3.1512)
62 | 6200,tensor(3.1105)
63 | 6300,tensor(3.0995)
64 | 6400,tensor(3.1727)
65 | 6500,tensor(3.2001)
66 | 6600,tensor(3.1633)
67 | 6700,tensor(3.1675)
68 | 6800,tensor(3.1565)
69 | 6900,tensor(3.1832)
70 | 7000,tensor(3.1601)
71 | 7100,tensor(3.0829)
72 | 7200,tensor(3.1549)
73 | 7300,tensor(3.1271)
74 | 7400,tensor(3.1491)
75 | 7500,tensor(3.1483)
76 | 7600,tensor(3.1139)
77 | 7700,tensor(3.1430)
78 | 7800,tensor(3.1419)
79 | 7900,tensor(3.1240)
80 | 8000,tensor(3.1638)
81 | 8100,tensor(3.1157)
82 | 8200,tensor(3.1502)
83 | 8300,tensor(3.1453)
84 | 8400,tensor(3.1747)
85 | 8500,tensor(3.1588)
86 | 8600,tensor(3.0674)
87 | 8700,tensor(3.1591)
88 | 8800,tensor(3.0796)
89 | 8900,tensor(3.0984)
90 | 9000,tensor(3.0799)
91 | 9100,tensor(3.0863)
92 | 9200,tensor(3.1053)
93 | 9300,tensor(3.0981)
94 | 9400,tensor(3.1166)
95 | 9500,tensor(3.0888)
96 | 9600,tensor(3.1042)
97 | 9700,tensor(3.0866)
98 | 9800,tensor(3.1327)
99 | 9900,tensor(3.1291)
100 | 10000,tensor(3.1143)
101 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [build_sphinx]
2 | all-files = 1
3 | source-dir = doc/source
4 | build-dir = doc/build
5 | warning-is-error = 0
6 |
7 | [flake8]
8 | ignore =
9 | # top-level module docstring
10 | D100, D104, W503,
11 | # space before : conflicts with black
12 | E203
13 | per-file-ignores =
14 | # imported but unused
15 | __init__.py: F401
16 | max-line-length = 88
17 | exclude =
18 | .git,
19 | __pycache__,
20 | build,
21 | dist,
22 | test,
23 | doc
24 |
25 | [isort]
26 | profile = black
27 | force_single_line = true
28 | force_alphabetical_sort = true
29 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import find_packages, setup
3 |
4 |
5 | install_requires = [
6 | "argparse",
7 | "matplotlib>=3.2.1",
8 | "numpy>=1.18.4",
9 | "POT>=0.7.0",
10 | "scanpy",
11 | "scikit-learn>=0.23.1",
12 | "scipy>=1.4.1",
13 | "torch>=1.5.0",
14 | "torchdiffeq==0.0.1",
15 | ]
16 |
17 | version_py = os.path.join(os.path.dirname(__file__), "TrajectoryNet", "version.py")
18 | version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip()
19 |
20 | readme = open("README.rst").read()
21 |
22 | setup(
23 | name="TrajectoryNet",
24 | packages=find_packages(),
25 | install_requires=install_requires,
26 | version=version,
27 | description="A neural ode solution for imputing trajectories between pointclouds.",
28 | author="Alexander Tong",
29 | author_email="alexandertongdev@gmail.com",
30 | license="MIT",
31 | long_description=readme,
32 | url="https://github.com/KrishnaswamyLab/TrajectoryNet",
33 | )
34 |
--------------------------------------------------------------------------------