├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── advanced_example.py ├── environment.yaml ├── example.py ├── otdd ├── __init__.py ├── plotting.py ├── pytorch │ ├── __init__.py │ ├── datasets.py │ ├── distance.py │ ├── flows.py │ ├── functionals.py │ ├── moments.py │ ├── nets.py │ ├── sqrtm.py │ ├── utils.py │ └── wasserstein.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # Dont Save Data or Models 133 | data/ 134 | models/ 135 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal Transport Dataset Distance (OTDD) 2 | 3 | Codebase accompanying the papers: 4 | * [Geometric Dataset Distances via Optimal Transport](https://papers.nips.cc/paper/2020/file/f52a7b2610fb4d3f74b4106fb80b233d-Paper.pdf). 5 | * [Dataset Dynamics via Gradient Flows in Probability Space](http://proceedings.mlr.press/v139/alvarez-melis21a/alvarez-melis21a.pdf). 6 | 7 | See the papers for technical details, or the [MSR Blog Post](https://www.microsoft.com/en-us/research/blog/measuring-dataset-similarity-using-optimal-transport/) for a high-level introduction. 8 | 9 | ## Getting Started 10 | 11 | ### Installation 12 | 13 | **Note**: It is highly recommended that the following be done inside a virtual environment 14 | 15 | 16 | #### Via Conda (recommended) 17 | 18 | If you use [ana|mini]conda , you can simply do: 19 | 20 | ``` 21 | conda env create -f environment.yaml python=3.8 22 | conda activate otdd 23 | conda install . 24 | ``` 25 | 26 | (you might need to install pytorch separately if you need a custom install) 27 | 28 | #### Via pip 29 | 30 | First install dependencies. Start by install pytorch with desired configuration using the instructions provided in the [pytorch website](https://pytorch.org/get-started/locally/). Then do: 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | Finally, install this package: 35 | ``` 36 | pip install . 37 | ``` 38 | 39 | ## Usage Examples 40 | 41 | A vanilla example for OTDD: 42 | 43 | ```python 44 | from otdd.pytorch.datasets import load_torchvision_data 45 | from otdd.pytorch.distance import DatasetDistance 46 | 47 | 48 | # Load datasets 49 | loaders_src = load_torchvision_data('MNIST', valid_size=0, resize = 28, maxsize=2000)[0] 50 | loaders_tgt = load_torchvision_data('USPS', valid_size=0, resize = 28, maxsize=2000)[0] 51 | 52 | # Instantiate distance 53 | dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], 54 | inner_ot_method = 'exact', 55 | debiased_loss = True, 56 | p = 2, entreg = 1e-1, 57 | device='cpu') 58 | 59 | d = dist.distance(maxsamples = 1000) 60 | print(f'OTDD(src,tgt)={d}') 61 | 62 | ``` 63 | 64 | ## Advanced Usage 65 | 66 | ### Using a custom feature distance 67 | 68 | By default, OTDD uses the (squared) Euclidean distance between features. To use a custom distance in domains where it makes sense to use one (e.g., images), one can pass a callable to OTDD using the `feature_cost` arg. Example: 69 | 70 | ```python 71 | 72 | import torch 73 | from torchvision.models import resnet18 74 | 75 | from otdd.pytorch.datasets import load_torchvision_data 76 | from otdd.pytorch.distance import DatasetDistance, FeatureCost 77 | 78 | # Load MNIST/CIFAR in 3channels (needed by torchvision models) 79 | loaders_src = load_torchvision_data('CIFAR10', resize=28, maxsize=2000)[0] 80 | loaders_tgt = load_torchvision_data('MNIST', resize=28, to3channels=True, maxsize=2000)[0] 81 | 82 | # Embed using a pretrained (+frozen) resnet 83 | embedder = resnet18(pretrained=True).eval() 84 | embedder.fc = torch.nn.Identity() 85 | for p in embedder.parameters(): 86 | p.requires_grad = False 87 | 88 | # Here we use same embedder for both datasets 89 | feature_cost = FeatureCost(src_embedding = embedder, 90 | src_dim = (3,28,28), 91 | tgt_embedding = embedder, 92 | tgt_dim = (3,28,28), 93 | p = 2, 94 | device='cpu') 95 | 96 | dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], 97 | inner_ot_method = 'exact', 98 | debiased_loss = True, 99 | feature_cost = feature_cost, 100 | sqrt_method = 'spectral', 101 | sqrt_niters=10, 102 | precision='single', 103 | p = 2, entreg = 1e-1, 104 | device='cpu') 105 | 106 | d = dist.distance(maxsamples = 10000) 107 | 108 | ``` 109 | 110 | 111 | ### Gradient Flows 112 | 113 | ```python 114 | 115 | import os 116 | import matplotlib 117 | %matplotlib inline # Comment out if not on notebook 118 | 119 | from otdd.pytorch.flows import OTDD_Gradient_Flow 120 | from otdd.pytorch.flows import CallbackList, ImageGridCallback, TrajectoryDump 121 | 122 | # Load datasets 123 | loaders_src = load_torchvision_data('MNIST', valid_size=0, resize = 28, maxsize=2000)[0] 124 | loaders_tgt = load_torchvision_data('USPS', valid_size=0, resize = 28, maxsize=2000)[0] 125 | 126 | 127 | outdir = os.path.join('out', 'flows') 128 | callbacks = CallbackList([ 129 | ImageGridCallback(display_freq=2, animate=False, save_path = outdir + '/grid'), 130 | ]) 131 | 132 | flow = OTDD_Gradient_Flow(loaders_src['train'], loaders_tgt['train'], 133 | ### Gradient Flow Args 134 | method = 'xonly-attached', 135 | use_torchoptim=True, 136 | optim='adam', 137 | steps=10, 138 | step_size=1, 139 | callback=callbacks, 140 | clustering_method='kmeans', 141 | ### OTDD Args 142 | online_stats=True, 143 | diagonal_cov = False, 144 | device='cpu' 145 | ) 146 | d,out = flow.flow() 147 | 148 | ``` 149 | 150 | 151 | 152 | ## Acknowledgements 153 | 154 | This repo relies on the [geomloss](https://www.kernel-operations.io/geomloss/) and [POT](https://pythonot.github.io/) packages for internal EMD and Sinkhorn algorithm implementation. We are grateful to the authors and maintainers of those projects. 155 | 156 | ## Contributing 157 | 158 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 159 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 160 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 161 | 162 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 163 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 164 | provided by the bot. You will only need to do this once across all repos using our CLA. 165 | 166 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 167 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 168 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 169 | 170 | ## Trademarks 171 | 172 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 173 | trademarks or logos is subject to and must follow 174 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 175 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 176 | Any use of third-party trademarks or logos are subject to those third-party's policies. 177 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please read the README.md, the associated [paper](https://https://arxiv.org/abs/2002.02923), or [blog](https://www.microsoft.com/en-us/research/blog/measuring-dataset-similarity-using-optimal-transport/). If after reviewing those you still have questions, feel free to file an issue here. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this project is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /advanced_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | 4 | from otdd.pytorch.datasets import load_torchvision_data 5 | from otdd.pytorch.distance import DatasetDistance, FeatureCost 6 | 7 | # Load MNIST/CIFAR in 3channels (needed by torchvision models) 8 | loaders_src = load_torchvision_data('CIFAR10', resize=28, maxsize=2000)[0] 9 | loaders_tgt = load_torchvision_data('MNIST', resize=28, to3channels=True, maxsize=2000)[0] 10 | 11 | # Embed using a pretrained (+frozen) resnet 12 | embedder = resnet18(pretrained=True).eval() 13 | embedder.fc = torch.nn.Identity() 14 | for p in embedder.parameters(): 15 | p.requires_grad = False 16 | 17 | # Here we use same embedder for both datasets 18 | feature_cost = FeatureCost(src_embedding = embedder, 19 | src_dim = (3,28,28), 20 | tgt_embedding = embedder, 21 | tgt_dim = (3,28,28), 22 | p = 2, 23 | device='cpu') 24 | 25 | dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], 26 | inner_ot_method = 'exact', 27 | debiased_loss = True, 28 | feature_cost = feature_cost, 29 | sqrt_method = 'spectral', 30 | sqrt_niters=10, 31 | precision='single', 32 | p = 2, entreg = 1e-1, 33 | device='cpu') 34 | 35 | d = dist.distance(maxsamples = 10000) 36 | print(f'Embedded OTDD(MNIST,USPS)={d:8.2f}') 37 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: otdd 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.10 7 | - numpy 8 | - scipy 9 | - pandas 10 | - matplotlib 11 | - scikit-learn 12 | - pytorch 13 | - torchvision=0.9.1 14 | - torchtext 15 | - tqdm 16 | - attrdict 17 | - seaborn 18 | - adjustText 19 | - h5py 20 | - cmake # for geomloss 21 | - watermark 22 | - jupyterlab 23 | - conda-build 24 | - ipywidgets 25 | - pip: 26 | - pot 27 | - pykeops 28 | - git+https://github.com/jeanfeydy/geomloss 29 | - opentsne 30 | - munkres 31 | - celluloid 32 | - git+https://github.com/rossant/ipycache 33 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from otdd.pytorch.datasets import load_torchvision_data 2 | from otdd.pytorch.distance import DatasetDistance 3 | 4 | # Load data 5 | loaders_src = load_torchvision_data('MNIST', valid_size=0, resize = 28, maxsize=2000)[0] 6 | loaders_tgt = load_torchvision_data('USPS', valid_size=0, resize = 28, maxsize=2000)[0] 7 | 8 | # Instantiate distance 9 | dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], 10 | inner_ot_method = 'exact', 11 | debiased_loss = True, 12 | p = 2, entreg = 1e-1, 13 | device='cpu') 14 | 15 | d = dist.distance(maxsamples = 1000) 16 | print(f'OTDD(MNIST,USPS)={d:8.2f}') 17 | -------------------------------------------------------------------------------- /otdd/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath 3 | import logging 4 | # Defaults 5 | ROOT_DIR = dirname(dirname(abspath(__file__))) # Project Root 6 | HOME_DIR = os.getenv("HOME") # User home dir 7 | DATA_DIR = os.path.join(ROOT_DIR, 'data') 8 | OUTPUT_DIR = os.path.join(ROOT_DIR, 'out') 9 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 10 | from .utils import launch_logger 11 | -------------------------------------------------------------------------------- /otdd/plotting.py: -------------------------------------------------------------------------------- 1 | """Plotting tools for Optimal Transport Dataset Distance. 2 | 3 | 4 | """ 5 | 6 | import logging 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | from matplotlib import cm 10 | 11 | import numpy as np 12 | import seaborn as sns 13 | import torch 14 | 15 | import scipy.stats 16 | from scipy.stats import pearsonr, spearmanr 17 | 18 | from mpl_toolkits.axes_grid1 import make_axes_locatable 19 | 20 | from adjustText import adjust_text 21 | 22 | import pdb 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def as_si(x, ndp): 27 | """ Convert humber to latex-style x10 scientific notation string""" 28 | s = '{x:0.{ndp:d}e}'.format(x=x, ndp=ndp) 29 | m, e = s.split('e') 30 | return r'{m:s}\times 10^{{{e:d}}}'.format(m=m, e=int(e)) 31 | 32 | 33 | def get_plot_ranges(X): 34 | x, y = X[:,0], X[:,1] 35 | dx = (x.max() - x.min())/10 36 | dy = (y.max() - y.min())/10 37 | xmin = x.min() - dx 38 | xmax = x.max() + dx 39 | ymin = y.min() - dy 40 | ymax = y.max() + dy 41 | return (xmin,xmax,ymin,ymax) 42 | 43 | def gaussian_density_plot(P=None, X=None, method = 'exact', nsamples = 1000, 44 | color='blue', label_means=True, cmap='coolwarm',ax=None,eps=1e-4): 45 | if X is None and P is not None: 46 | X = P.sample(sample_shape=torch.Size([nsamples])).numpy() 47 | 48 | if ax is None: 49 | fig = plt.figure(figsize=(8,8)) 50 | ax = fig.gca() 51 | xmin, xmax, ymin, ymax = get_plot_ranges(X) 52 | logger.info(xmin, xmax, ymin, ymax) 53 | ax.set_xlim(xmin, xmax) 54 | ax.set_ylim(ymin, ymax) 55 | else: 56 | xmin,xmax = ax.get_xlim() 57 | ymin,ymax = ax.get_ylim() 58 | 59 | XY = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] 60 | xx,yy = XY[0,:,:],XY[1,:,:] 61 | 62 | 63 | if method == 'samples': 64 | positions = np.vstack([xx.ravel(), yy.ravel()]) 65 | kernel = scipy.stats.gaussian_kde(X.T) 66 | f = np.reshape(kernel(positions).T, xx.shape) 67 | elif method == 'exact': 68 | μ,Σ = P.loc.numpy(), P.covariance_matrix.numpy() 69 | f = scipy.stats.multivariate_normal.pdf(XY.transpose(1,2,0),μ,Σ) 70 | 71 | step = 0.01 72 | levels = np.arange(0, np.amax(f), step) + step 73 | 74 | if len(levels) < 2: 75 | levels = [step/2, levels[0]] 76 | 77 | cfset = ax.contourf(xx, yy, f, levels, cmap=cmap, alpha=0.5) 78 | 79 | cset = ax.contour(xx, yy, f, levels, colors='k', alpha=0.5) 80 | ax.clabel(cset, inline=1, fontsize=10) 81 | ax.set_xlabel('X') 82 | ax.set_ylabel('Y') 83 | if method == 'samples': 84 | ax.scatter(X[:,0], X[:,1], color=cmap(0.8)) 85 | ax.set_title('2D Gaussian Kernel density estimation') 86 | elif method == 'exact': 87 | ax.scatter(μ[0],μ[1], s=5, c= 'black') 88 | if label_means: 89 | ax.text(μ[0]+eps,μ[1]+eps,'μ=({:.2},{:.2})'.format(μ[0],μ[1]), fontsize=12) 90 | ax.set_title('Exact Gaussian Density') 91 | 92 | 93 | def heatmap(data, row_labels, col_labels, ax=None, cbar=True, 94 | cbar_kw={}, cbarlabel="", **kwargs): 95 | """ Create a heatmap from a numpy array and two lists of labels. 96 | 97 | Args: 98 | data: A 2D numpy array of shape (N, M). 99 | row_labels: A list or array of length N with the labels for the rows. 100 | col_labels: A list or array of length M with the labels for the columns. 101 | ax: A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 102 | not provided, use current axes or create a new one. Optional. 103 | cbar: A boolear value, whether to display colorbar or not 104 | cbar_kw: A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 105 | cbarlabel: The label for the colorbar. Optional. 106 | **kwargs: All other arguments are forwarded to `imshow`. 107 | """ 108 | 109 | if not ax: 110 | ax = plt.gca() 111 | 112 | im = ax.imshow(data, **kwargs) 113 | 114 | 115 | if cbar: 116 | if 'alpha' in kwargs: 117 | cbar_kw['alpha'] = kwargs.get('alpha') 118 | cbar = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04, **cbar_kw) 119 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 120 | 121 | ax.set_xticks(np.arange(data.shape[1])) 122 | ax.set_yticks(np.arange(data.shape[0])) 123 | ax.set_xticklabels(col_labels) 124 | ax.set_yticklabels(row_labels) 125 | 126 | ax.tick_params(top=False, bottom=True, 127 | labeltop=False, labelbottom=True) 128 | 129 | plt.setp(ax.get_xticklabels(), rotation=0, ha="right", rotation_mode="anchor") 130 | 131 | for edge, spine in ax.spines.items(): 132 | spine.set_visible(False) 133 | 134 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 135 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 136 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 137 | ax.tick_params(which="minor", bottom=False, left=False) 138 | 139 | return im, cbar 140 | 141 | 142 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 143 | textcolors=["black", "white"], 144 | threshold=None, **textkw): 145 | """ A function to annotate a heatmap. 146 | 147 | Args: 148 | im: The AxesImage to be labeled. 149 | data: Data used to annotate. If None, the image's data is used. Optional. 150 | valfmt: The format of the annotations inside the heatmap. This should either 151 | use the string format method, e.g. "$ {x:.2f}", or be a 152 | `matplotlib.ticker.Formatter`. Optional. 153 | textcolors: A list or array of two color specifications. The first is used for 154 | values below a threshold, the second for those above. Optional. 155 | threshold: Value in data units according to which the colors from textcolors are 156 | applied. If None (the default) uses the middle of the colormap as 157 | separation. Optional. 158 | **kwargs: All other arguments are forwarded to each call to `text` used to create 159 | the text labels. 160 | """ 161 | 162 | if not isinstance(data, (list, np.ndarray)): 163 | data = im.get_array() 164 | 165 | if threshold is not None: 166 | threshold = im.norm(threshold) 167 | else: 168 | threshold = im.norm(data.max())/2. 169 | 170 | kw = dict(horizontalalignment="center", 171 | verticalalignment="center") 172 | kw.update(textkw) 173 | 174 | if isinstance(valfmt, str): 175 | valfmt = mpl.ticker.StrMethodFormatter(valfmt) 176 | 177 | texts = [] 178 | for i in range(data.shape[0]): 179 | for j in range(data.shape[1]): 180 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 181 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 182 | texts.append(text) 183 | 184 | return texts 185 | 186 | 187 | def distance_scatter(d, topk=10, show=True, save_path =None): 188 | """ Distance vs adaptation scatter plots as used in the OTDD paper. 189 | Args: 190 | d (dict): dictionary of task pair (string), distance (float) 191 | topk (int): number k of top/bottom distances that will be annotated 192 | """ 193 | sorted_d = sorted(d.items(), key=lambda kv: kv[1]) 194 | keys, dists = zip(*sorted_d) 195 | if type(keys[0]) is tuple and len(keys[0]) == 2: 196 | labels = ['{}<->{}'.format(p,q) for (p,q) in keys] 197 | else: 198 | labels = ['{}'.format(p) for p in keys] 199 | x_coord = np.linspace(0,1,len(keys)) 200 | 201 | fig, ax = plt.subplots(figsize=(10,10)) 202 | ax.scatter(x_coord, dists, s = min(100/len(keys), 1)) 203 | texts=[] 204 | for i, (x, y, name) in enumerate(zip(x_coord,dists,keys)): 205 | if i < topk or i >= len(keys) - topk: 206 | label = '{}<->{}'.format(*name) if type(name) is tuple else str(name) 207 | texts.append(ax.text(x, y, label)) 208 | adjust_text(texts, force_text=0.05, arrowprops=dict(arrowstyle="-|>", 209 | color='r', alpha=0.5)) 210 | 211 | ax.set_title('Pairwise Distance Between MNIST Binary Classification Tasks') 212 | ax.set_ylabel('Dataset Distance') 213 | if save_path: 214 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 215 | if show: plt.show() 216 | 217 | def dist_adapt_joinplot(df, yvar='delta', show=True, type='joinplot', save_path = None): 218 | j = sns.jointplot(x='dist', y=yvar, data=df, kind="reg", height=7) 219 | j.annotate(scipy.stats.pearsonr) 220 | y_label = 'Acc. Improvement w/ Adapt'#.format(direction[yvar]) 221 | j.set_axis_labels('OT Task Distance', y_label) 222 | if save_path: 223 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 224 | if show: 225 | plt.show() 226 | 227 | 228 | def dist_adapt_regplot(df, yvar, xvar='dist', xerrvar=None, yerrvar=None, 229 | figsize=(6,5), title=None, 230 | show_correlation=True, corrtype='pearson', sci_pval=True, 231 | annotate=True, annotation_arrows=True, annotation_fontsize=12, 232 | force_text=0.5, 233 | legend_fontsize=12, 234 | title_fontsize=12, 235 | marker_size=10, 236 | arrowcolor='gray', 237 | barcolor='gray', 238 | xlabel = 'OT Dataset Distance', 239 | ylabel = r'Relative Drop in Test Error ($\%$)', 240 | color='#1f77b4', 241 | lw=1, 242 | ax=None, 243 | show=True, 244 | save_path=None): 245 | 246 | if ax is None: 247 | fig, ax = plt.subplots(figsize=figsize) 248 | else: 249 | show = False 250 | 251 | 252 | #### Compute Correlation 253 | if show_correlation: 254 | if corrtype == 'spearman': 255 | corr, p = spearmanr(df[xvar], df[yvar]) 256 | corrsymbol = '\\rho' 257 | elif corrtype == 'pearson': 258 | corr, p = pearsonr(df[xvar], df[yvar]) 259 | corrsymbol = 'r' 260 | else: 261 | raise ValueError('Unrecognized correlation type') 262 | if p < 0.01 and sci_pval: 263 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:s}$".format(as_si(p,1)) 264 | else: 265 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:2.2f}$".format(p) 266 | else: 267 | legend_label = None 268 | 269 | 270 | ### Actual Plots - First does scatter only, second does line 271 | sns.regplot(x=xvar, y=yvar, data=df, ax = ax, color=color, label=legend_label, 272 | scatter_kws={'s':marker_size}, 273 | line_kws={'lw': 1} 274 | ) 275 | 276 | ### Add Error Bars 277 | if xerrvar or yerrvar: 278 | xerr = df[xerrvar] if xerrvar else None 279 | yerr = df[yerrvar] if yerrvar else None 280 | ax.errorbar(df[xvar], df[yvar], xerr=xerr, yerr=yerr, fmt='none', ecolor='#d6d4d4', alpha=0.75,elinewidth=0.75) 281 | 282 | ### Annotate Points 283 | if annotate: 284 | texts = [] 285 | for i,a in df.iterrows(): 286 | lab = r'{}$\rightarrow${}'.format(a.src,a.tgt) if a.tgt is not None else r'{}'.format(a.src) 287 | texts.append(ax.text(a[xvar], a[yvar], lab,fontsize=annotation_fontsize)) 288 | if annotation_arrows: 289 | adjust_text(texts, force_text=force_text, arrowprops=dict(arrowstyle="-", color=arrowcolor, alpha=0.5, lw=0.5)) 290 | else: 291 | adjust_text(texts, force_text=force_text) 292 | 293 | ### Fix Legend for Correlation (otherwise don't show) 294 | if show_correlation: 295 | plt.rc('legend',fontsize=legend_fontsize)#,borderpad=0.2,handletextpad=0, handlelength=0) # using a size in points 296 | ax.legend([ax.get_lines()[0]], ax.get_legend_handles_labels()[-1],handlelength=1.0,loc='best')#, handletextpad=0.0) 297 | 298 | 299 | ### Add title and labels 300 | ax.set_xlabel(xlabel, fontsize=title_fontsize) 301 | ax.set_ylabel(ylabel, fontsize=title_fontsize) 302 | ax.set_title(r'Distance vs Adaptation' + (': {}'.format(title) if title else ''), fontsize=title_fontsize) 303 | 304 | if save_path: 305 | plt.savefig(save_path+'.pdf', dpi=300, bbox_inches = "tight") 306 | plt.savefig(save_path+'.png', dpi=300, bbox_inches = "tight") 307 | 308 | if show: plt.show() 309 | 310 | return ax 311 | 312 | 313 | def plot2D_samples_mat(xs, xt, G, thr=1e-8, ax=None, **kwargs): 314 | """ (ADAPTED FROM PYTHON OT LIBRARY). 315 | Plot matrix M in 2D with lines using alpha values 316 | Plot lines between source and target 2D samples with a color 317 | proportional to the value of the matrix G between samples. 318 | Parameters 319 | ---------- 320 | xs : ndarray, shape (ns,2) 321 | Source samples positions 322 | b : ndarray, shape (nt,2) 323 | Target samples positions 324 | G : ndarray, shape (na,nb) 325 | OT matrix 326 | thr : float, optional 327 | threshold above which the line is drawn 328 | **kwargs : dict 329 | paameters given to the plot functions (default color is black if 330 | nothing given) 331 | """ 332 | if ('color' not in kwargs) and ('c' not in kwargs): 333 | kwargs['color'] = 'gray' 334 | mx = G.max() 335 | if not ax: 336 | fig,ax = plt.subplots() 337 | for i in range(xs.shape[0]): 338 | for j in range(xt.shape[0]): 339 | if G[i, j] / mx > thr: 340 | ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], 341 | alpha=G[i, j] / mx, **kwargs) 342 | 343 | return ax 344 | 345 | 346 | def annotate_group(name, span, ax=None, orient='h', side=None): 347 | """Annotates a span of the x-axis (or y-axis if orient ='v')""" 348 | if not side: 349 | side = 'left' if orient == 'v' else 'bottom' 350 | def annotate(ax, name, left, right, y, pad): 351 | xy = (left, y) if orient == 'h' else (y, left) 352 | xytext=(right, y+pad) if orient =='h' else (y+pad, right) 353 | valign = 'top' if orient =='h' else 'center' 354 | halign = 'center' if orient == 'h' else 'center' 355 | rot = 0 if orient == 'h' else 0 356 | if orient == 'h': 357 | connectionstyle='angle,angleB=90,angleA=0,rad=5' 358 | else: 359 | connectionstyle='angle,angleB=0,angleA=-90,rad=5' 360 | 361 | arrow = ax.annotate(name, 362 | xy=xy, xycoords='data', 363 | xytext=xytext, textcoords='data', 364 | annotation_clip=False, verticalalignment=valign, 365 | horizontalalignment=halign, linespacing=2.0, 366 | arrowprops=dict(arrowstyle='-', shrinkA=0, shrinkB=0, 367 | connectionstyle=connectionstyle), 368 | fontsize=8, rotation=rot 369 | ) 370 | return arrow 371 | if ax is None: 372 | ax = plt.gca() 373 | lims = ax.get_ylim() if orient=='h' else ax.get_xlim() 374 | range = np.abs(lims[1] - lims[0]) 375 | lim = lims[0] if side == 'bottom' or side == 'left' else lims[1] 376 | 377 | if side == 'bottom': 378 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 379 | text_pad = 0.02*range 380 | elif side == 'right': 381 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 382 | text_pad = 0.02*range 383 | elif side == 'top': 384 | arrow_coord = lim - 0.01*range# if orient == 'h' else lim - 0.02*range 385 | text_pad = -0.05*range 386 | else: # left 387 | arrow_coord = lim - 0.01*range 388 | text_pad = -0.02*range 389 | 390 | 391 | 392 | center = np.mean(span) 393 | left_arrow = annotate(ax, name, span[0], center, arrow_coord, text_pad) 394 | right_arrow = annotate(ax, name, span[1], center, arrow_coord, text_pad) 395 | return left_arrow, right_arrow 396 | 397 | 398 | def imshow_group_boundaries(ax, gU, gV, group_names, side = 'both', alpha=0.2, lw=0.5): 399 | """Imshow must be sorted according to order in groups""" 400 | if side in ['source','both']: 401 | xmin,xmax = ax.get_xlim() 402 | ax.hlines(np.cumsum(gU[:-1]) - 0.5,xmin=xmin,xmax=xmax,lw=lw, linestyles='dashed', alpha = alpha) 403 | if side in ['target','both']: 404 | ymin,ymax = ax.get_ylim() 405 | ax.vlines(np.cumsum(gV[:-1]) - 0.5,ymin=ymin,ymax=ymax,lw=lw,linestyles='dashed', alpha=alpha) 406 | 407 | if group_names: 408 | offset = -0.5 409 | posx = np.cumsum(gU)# + offset 410 | posy = np.cumsum(gV)# + offset 411 | posx = np.insert(posx, 0, offset) 412 | posy = np.insert(posy, 0, offset) 413 | for i,y in enumerate(posy[:-1]): 414 | annotate_group(group_names[1][i], (posy[i], posy[i+1]), ax, orient='h', side = 'top') 415 | for i,x in enumerate(posx[:-1]): 416 | annotate_group(group_names[0][i], (posx[i], posx[i+1]), ax, orient='v', side = 'right') 417 | 418 | 419 | def method_comparison_plot(df, hue_var = 'method', style_var = 'method', 420 | figsize = (15,4), ax = None, save_path=None): 421 | """ Produce plots comparing OTDD variants in terms of runtime and distance """ 422 | if ax is None: 423 | fig, ax = plt.subplots(1, 2, figsize=figsize) 424 | 425 | lplot_args = { 426 | 'hue': hue_var, 427 | 'style': style_var, 428 | 'data': df, 429 | 'x': 'n', 430 | 'markers': True 431 | } 432 | 433 | sns.lineplot(y='dist', ax= ax[0], **lplot_args) 434 | ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 435 | ax[0].set_ylabel('Dataset Distance') 436 | ax[0].set_xlabel('Dataset Size') 437 | ax[0].set_xscale("log") 438 | ax[0].grid(True,which="both",ls="--",c='gray') 439 | 440 | sns.lineplot(y='time', ax= ax[1], **lplot_args) 441 | ax[1].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 442 | ax[1].set_ylabel('Runtime (s)') 443 | ax[1].set_xlabel('Dataset Size') 444 | ax[1].set_xscale("log") 445 | ax[1].set_yscale("log") 446 | ax[1].grid(True,which="both",ls="--",c='gray') 447 | 448 | handles, labels = ax[1].get_legend_handles_labels() 449 | ax[1].get_legend().remove() 450 | 451 | plt.tight_layout() 452 | if save_path: 453 | plt.savefig(save_path + '.pdf', dpi=300) 454 | plt.savefig(save_path + '.png', dpi=300) 455 | plt.show() 456 | 457 | return ax 458 | -------------------------------------------------------------------------------- /otdd/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/otdd/72f1b22d6c688e15db416194c9fe42262d927303/otdd/pytorch/__init__.py -------------------------------------------------------------------------------- /otdd/pytorch/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | from functools import partial 4 | import random 5 | import logging 6 | import string 7 | 8 | import numpy as np 9 | import torch 10 | from torch.distributions.multivariate_normal import MultivariateNormal 11 | from torch.utils.data import TensorDataset 12 | import torch.nn as nn 13 | import torch.utils.data as torchdata 14 | import torch.utils.data.dataloader as dataloader 15 | from torch.utils.data.sampler import SubsetRandomSampler 16 | 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | import torchvision.transforms.functional as TF 20 | import torchvision.datasets as dset 21 | 22 | import torchtext 23 | from torchtext.data.utils import get_tokenizer 24 | 25 | import h5py 26 | 27 | from .. import DATA_DIR 28 | 29 | from .utils import interleave, process_device_arg, random_index_split, \ 30 | spectrally_prescribed_matrix, rot, rot_evecs 31 | 32 | from .sqrtm import create_symm_matrix 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | DATASET_NCLASSES = { 38 | 'MNIST': 10, 39 | 'FashionMNIST': 10, 40 | 'EMNIST': 26, 41 | 'KMNIST': 10, 42 | 'USPS': 10, 43 | 'CIFAR10': 10, 44 | 'SVHN': 10, 45 | 'STL10': 10, 46 | 'LSUN': 10, 47 | 'tiny-ImageNet': 200 48 | } 49 | 50 | DATASET_SIZES = { 51 | 'MNIST': (28,28), 52 | 'FashionMNIST': (28,28), 53 | 'EMNIST': (28,28), 54 | 'QMNIST': (28,28), 55 | 'KMNIST': (28,28), 56 | 'USPS': (16,16), 57 | 'SVHN': (32, 32), 58 | 'CIFAR10': (32, 32), 59 | 'STL10': (96, 96), 60 | 'tiny-ImageNet': (64,64) 61 | } 62 | 63 | DATASET_NORMALIZATION = { 64 | 'MNIST': ((0.1307,), (0.3081,)), 65 | 'USPS' : ((0.1307,), (0.3081,)), 66 | 'FashionMNIST' : ((0.1307,), (0.3081,)), 67 | 'QMNIST' : ((0.1307,), (0.3081,)), 68 | 'EMNIST' : ((0.1307,), (0.3081,)), 69 | 'KMNIST' : ((0.1307,), (0.3081,)), 70 | 'ImageNet': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 71 | 'tiny-ImageNet': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 72 | 'CIFAR10': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 73 | 'CIFAR100': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 74 | 'STL10': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) 75 | } 76 | 77 | 78 | def sort_by_label(X,Y): 79 | idxs = np.argsort(Y) 80 | return X[idxs,:], Y[idxs] 81 | 82 | 83 | ### Data Transforms 84 | class DiscreteRotation: 85 | """Rotate by one of the given angles.""" 86 | 87 | def __init__(self, angles): 88 | self.angles = angles 89 | 90 | def __call__(self, x): 91 | angle = random.choice(self.angles) 92 | return TF.rotate(x, angle) 93 | 94 | class Cutout(object): 95 | def __init__(self, length): 96 | self.length = length 97 | 98 | def __call__(self, img): 99 | h, w = img.size(1), img.size(2) 100 | mask = np.ones((h, w), np.float32) 101 | y = np.random.randint(h) 102 | x = np.random.randint(w) 103 | 104 | y1 = np.clip(y - self.length // 2, 0, h) 105 | y2 = np.clip(y + self.length // 2, 0, h) 106 | x1 = np.clip(x - self.length // 2, 0, w) 107 | x2 = np.clip(x + self.length // 2, 0, w) 108 | 109 | mask[y1:y2, x1:x2] = 0.0 110 | mask = torch.from_numpy(mask) 111 | mask = mask.expand_as(img) 112 | img *= mask 113 | return img 114 | 115 | 116 | class SubsetSampler(torch.utils.data.Sampler): 117 | r"""Samples elements in order (not randomly) from a given list of indices, without replacement. 118 | 119 | Arguments: 120 | indices (sequence): a sequence of indices 121 | (this is identical to torch's SubsetRandomSampler except not random) 122 | """ 123 | 124 | def __init__(self, indices): 125 | self.indices = indices 126 | 127 | def __iter__(self): 128 | return (self.indices[i] for i in range(len(self.indices))) 129 | 130 | def __len__(self): 131 | return len(self.indices) 132 | 133 | class CustomTensorDataset(torch.utils.data.Dataset): 134 | """TensorDataset with support of transforms.""" 135 | def __init__(self, tensors, transform=None, target_transform=None): 136 | assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) 137 | self.tensors = tensors 138 | self.transform = transform 139 | self.target_transform = target_transform 140 | 141 | def __getitem__(self, index): 142 | x = self.tensors[0][index] 143 | if self.transform: 144 | x = self.transform(x) 145 | 146 | y = self.tensors[1][index] 147 | if self.target_transform: 148 | y = self.target_transform(y) 149 | 150 | return x, y 151 | 152 | def __len__(self): 153 | return self.tensors[0].size(0) 154 | 155 | class SubsetFromLabels(torch.utils.data.dataset.Dataset): 156 | """ Subset of a dataset at specified indices. 157 | 158 | Adapted from torch.utils.data.dataset.Subset to allow for label re-mapping 159 | without having to copy whole dataset. 160 | 161 | Arguments: 162 | dataset (Dataset): The whole Dataset 163 | indices (sequence): Indices in the whole set selected for subset 164 | targets_map (dict, optional): Dictionary to map targets with 165 | """ 166 | def __init__(self, dataset, labels, remap=False): 167 | self.dataset = dataset 168 | self.labels = labels 169 | self.classes = [dataset.classes[i] for i in labels] 170 | self.mask = np.isin(dataset.targets, labels).squeeze() 171 | self.indices = np.where(self.mask)[0] 172 | self.remap = remap 173 | targets = dataset.targets[self.indices] 174 | if remap: 175 | V = sorted(np.unique(targets)) 176 | assert list(V) == list(labels) 177 | targets = torch.tensor(np.digitize(targets, self.labels, right=True)) 178 | self.tmap = dict(zip(V,range(len(V)))) 179 | self.targets = targets 180 | 181 | def __getitem__(self, idx): 182 | if self.remap is False: 183 | return self.dataset[self.indices[idx]] 184 | else: 185 | item = self.dataset[self.indices[idx]] 186 | return (item[0], self.tmap[item[1]]) 187 | 188 | def __len__(self): 189 | return len(self.indices) 190 | 191 | def subdataset_from_labels(dataset, labels, remap=True): 192 | mask = np.isin(dataset.targets, labels).squeeze() 193 | idx = np.where(mask)[0] 194 | subdataset = Subset(dataset,idx, remap_targets=True) 195 | return subdataset 196 | 197 | 198 | def dataset_from_numpy(X, Y, classes = None): 199 | targets = torch.LongTensor(list(Y)) 200 | ds = TensorDataset(torch.from_numpy(X).type(torch.FloatTensor),targets) 201 | ds.targets = targets 202 | ds.classes = classes if classes is not None else [i for i in range(len(np.unique(Y)))] 203 | return ds 204 | 205 | 206 | gmm_configs = { 207 | 'star': { 208 | 'means': [torch.Tensor([0,0]), 209 | torch.Tensor([0,-2]), 210 | torch.Tensor([2,0]), 211 | torch.Tensor([0,2]), 212 | torch.Tensor([-2,0])], 213 | 'covs': [spectrally_prescribed_matrix([1,1], torch.eye(2)), 214 | spectrally_prescribed_matrix([2.5,1], torch.eye(2)), 215 | spectrally_prescribed_matrix([1,20], torch.eye(2)), 216 | spectrally_prescribed_matrix([10,1], torch.eye(2)), 217 | spectrally_prescribed_matrix([1,5], torch.eye(2)) 218 | ], 219 | 'spread': 6, 220 | } 221 | 222 | } 223 | 224 | def make_gmm_dataset(config='random', classes=10,dim=2,samples=10,spread = 1, 225 | shift=None, rotate=None, diagonal_cov=False, shuffle=True): 226 | """ Generate Gaussian Mixture Model datasets. 227 | 228 | Arguments: 229 | config (str): determines cluster locations, one of 'random' or 'star' 230 | classes (int): number of classes in dataset 231 | dim (int): feature dimension of dataset 232 | samples (int): number of samples in dataset 233 | spread (int): separation of clusters 234 | shift (bool): whether to add a shift to dataset 235 | rotate (bool): whether to rotate dataset 236 | diagonal_cov(bool): whether to use a diagonal covariance matrix 237 | shuffle (bool): whether to shuffle example indices 238 | 239 | Returns: 240 | X (tensor): tensor of size (samples, dim) with features 241 | Y (tensor): tensor of size (samples, 1) with labels 242 | distribs (torch.distributions): data-generating distributions of each class 243 | 244 | """ 245 | means, covs, distribs = [], [], [] 246 | _configd = None if config == 'random' else gmm_configs[config] 247 | spread = spread if (config == 'random' or not 'spread' in _configd) else _configd['spread'] 248 | shift = shift if (config == 'random' or not 'shift' in _configd) else _configd['shift'] 249 | 250 | for i in range(classes): 251 | if config == 'random': 252 | mean = torch.randn(dim) 253 | cov = create_symm_matrix(1, dim, verbose=False).squeeze() 254 | elif config == 'star': 255 | mean = gmm_configs['star']['means'][i] 256 | cov = gmm_configs['star']['covs'][i] 257 | if rotate: 258 | mean = rot(mean, rotate) 259 | cov = rot_evecs(cov, rotate) 260 | 261 | if diagonal_cov: 262 | cov.masked_fill_(~torch.eye(dim, dtype=bool), 0) 263 | 264 | means.append(spread*mean) 265 | covs.append(cov) 266 | distribs.append(MultivariateNormal(means[-1],covs[-1])) 267 | 268 | X = torch.cat([P.sample(sample_shape=torch.Size([samples])) for P in distribs]) 269 | Y = torch.LongTensor([samples*[i] for i in range(classes)]).flatten() 270 | 271 | if shift: 272 | print(X.shape) 273 | X += torch.tensor(shift) 274 | 275 | if shuffle: 276 | idxs = torch.randperm(Y.shape[0]) 277 | X = X[idxs, :] 278 | Y = Y[idxs] 279 | return X, Y, distribs 280 | 281 | def load_torchvision_data(dataname, valid_size=0.1, splits=None, shuffle=True, 282 | stratified=False, random_seed=None, batch_size = 64, 283 | resize=None, to3channels=False, 284 | maxsize = None, maxsize_test=None, num_workers = 0, transform=None, 285 | data=None, datadir=None, download=True, filt=False, print_stats = False): 286 | """ Load torchvision datasets. 287 | 288 | We return train and test for plots and post-training experiments 289 | """ 290 | if shuffle == True and random_seed: 291 | np.random.seed(random_seed) 292 | if transform is None: 293 | if dataname in DATASET_NORMALIZATION.keys(): 294 | transform_dataname = dataname 295 | else: 296 | transform_dataname = 'ImageNet' 297 | 298 | transform_list = [] 299 | 300 | if dataname in ['MNIST', 'USPS'] and to3channels: 301 | transform_list.append(torchvision.transforms.Grayscale(3)) 302 | 303 | transform_list.append(torchvision.transforms.ToTensor()) 304 | transform_list.append( 305 | torchvision.transforms.Normalize(*DATASET_NORMALIZATION[transform_dataname]) 306 | ) 307 | 308 | if resize: 309 | if not dataname in DATASET_SIZES or DATASET_SIZES[dataname][0] != resize: 310 | ## Avoid adding an "identity" resizing 311 | transform_list.insert(0, transforms.Resize((resize, resize))) 312 | 313 | transform = transforms.Compose(transform_list) 314 | logger.info(transform) 315 | train_transform, valid_transform = transform, transform 316 | elif data is None: 317 | if len(transform) == 1: 318 | train_transform, valid_transform = transform, transform 319 | elif len(transform) == 2: 320 | train_transform, valid_transform = transform 321 | else: 322 | raise ValueError() 323 | 324 | if data is None: 325 | DATASET = getattr(torchvision.datasets, dataname) 326 | if datadir is None: 327 | datadir = DATA_DIR 328 | if dataname == 'EMNIST': 329 | split = 'letters' 330 | train = DATASET(datadir, split=split, train=True, download=download, transform=train_transform) 331 | test = DATASET(datadir, split=split, train=False, download=download, transform=valid_transform) 332 | ## EMNIST seems to have a bug - classes are wrong 333 | _merged_classes = set(['C', 'I', 'J', 'K', 'L', 'M', 'O', 'P', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z']) 334 | _all_classes = set(list(string.digits + string.ascii_letters)) 335 | classes_split_dict = { 336 | 'byclass': list(_all_classes), 337 | 'bymerge': sorted(list(_all_classes - _merged_classes)), 338 | 'balanced': sorted(list(_all_classes - _merged_classes)), 339 | 'letters': list(string.ascii_lowercase), 340 | 'digits': list(string.digits), 341 | 'mnist': list(string.digits), 342 | } 343 | train.classes = classes_split_dict[split] 344 | if split == 'letters': 345 | ## The letters fold (and only that fold!!!) is 1-indexed 346 | train.targets -= 1 347 | test.targets -= 1 348 | elif dataname == 'STL10': 349 | train = DATASET(datadir, split='train', download=download, transform=train_transform) 350 | test = DATASET(datadir, split='test', download=download, transform=valid_transform) 351 | train.classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] 352 | test.classes = train.classes 353 | train.targets = torch.tensor(train.labels) 354 | test.targets = torch.tensor(test.labels) 355 | elif dataname == 'SVHN': 356 | train = DATASET(datadir, split='train', download=download, transform=train_transform) 357 | test = DATASET(datadir, split='test', download=download, transform=valid_transform) 358 | ## In torchvision, SVHN 0s have label 0, not 10 359 | train.classes = test.classes = [str(i) for i in range(10)] 360 | train.targets = torch.tensor(train.labels) 361 | test.targets = torch.tensor(train.labels) 362 | elif dataname == 'LSUN': 363 | pdb.set_trace() 364 | train = DATASET(datadir, classes='train', download=download, transform=train_transform) 365 | else: 366 | train = DATASET(datadir, train=True, download=download, transform=train_transform) 367 | test = DATASET(datadir, train=False, download=download, transform=valid_transform) 368 | else: 369 | train, test = data 370 | 371 | 372 | if type(train.targets) is list: 373 | train.targets = torch.LongTensor(train.targets) 374 | test.targets = torch.LongTensor(test.targets) 375 | 376 | if not hasattr(train, 'classes') or not train.classes: 377 | train.classes = sorted(torch.unique(train.targets).tolist()) 378 | test.classes = sorted(torch.unique(train.targets).tolist()) 379 | 380 | 381 | ### Data splitting 382 | fold_idxs = {} 383 | if splits is None and valid_size == 0: 384 | ## Only train 385 | fold_idxs['train'] = np.arange(len(train)) 386 | elif splits is None and valid_size > 0: 387 | ## Train/Valid 388 | train_idx, valid_idx = random_index_split(len(train), 1-valid_size, (maxsize, None)) # No maxsize for validation 389 | fold_idxs['train'] = train_idx 390 | fold_idxs['valid'] = valid_idx 391 | elif splits is not None: 392 | ## Custom splits - must be integer. 393 | if type(splits) is dict: 394 | snames, slens = zip(*splits.items()) 395 | elif type(splits) in [list, np.ndarray]: 396 | snames = ['split_{}'.format(i) for i in range(len(splits))] 397 | slens = splits 398 | slens = np.array(slens) 399 | if any(slens < 0): # Split expressed as -1, i.e., 'leftover' 400 | assert sum(slens < 0) == 1, 'Can only deal with one split being -1' 401 | idx_neg = np.where(slens == -1)[0][0] 402 | slens[idx_neg] = len(train) - np.array([x for x in slens if x > 0]).sum() 403 | elif slens.sum() > len(train): 404 | logging.warning("Not enough samples to satify splits..cropping train...") 405 | if 'train' in snames: 406 | slens[snames.index('train')] = len(train) - slens[np.array(snames) != 'train'].sum() 407 | 408 | idxs = np.arange(len(train)) 409 | if not stratified: 410 | np.random.shuffle(idxs) 411 | else: 412 | ## If stratified, we'll interleave the per-class shuffled indices 413 | idxs_class = [np.random.permutation(np.where(train.targets==c)).T for c in np.unique(train.targets)] 414 | idxs = interleave(*idxs_class).squeeze().astype(int) 415 | 416 | slens = np.array(slens).cumsum() # Need to make cumulative for np.split 417 | split_idxs = [np.sort(s) for s in np.split(idxs, slens)[:-1]] # The last one are leftovers 418 | assert len(split_idxs) == len(splits) 419 | fold_idxs = {snames[i]: v for i,v in enumerate(split_idxs)} 420 | 421 | 422 | for k, idxs in fold_idxs.items(): 423 | if maxsize and maxsize < len(idxs): 424 | fold_idxs[k] = np.sort(np.random.choice(idxs, maxsize, replace = False)) 425 | 426 | sampler_class = SubsetRandomSampler if shuffle else SubsetSampler 427 | fold_samplers = {k: sampler_class(idxs) for k,idxs in fold_idxs.items()} 428 | 429 | 430 | ### Create DataLoaders 431 | dataloader_args = dict(batch_size=batch_size,num_workers=num_workers) 432 | 433 | fold_loaders = {k:dataloader.DataLoader(train, sampler=sampler,**dataloader_args) 434 | for k,sampler in fold_samplers.items()} 435 | 436 | if maxsize_test and maxsize_test < len(test): 437 | test_idxs = np.sort(np.random.choice(len(test), maxsize_test, replace = False)) 438 | sampler_test = SubsetSampler(test_idxs) # For test don't want Random 439 | dataloader_args['sampler'] = sampler_test 440 | else: 441 | dataloader_args['shuffle'] = False 442 | test_loader = dataloader.DataLoader(test, **dataloader_args) 443 | fold_loaders['test'] = test_loader 444 | 445 | fnames, flens = zip(*[[k,len(v)] for k,v in fold_idxs.items()]) 446 | fnames = '/'.join(list(fnames) + ['test']) 447 | flens = '/'.join(map(str, list(flens) + [len(test)])) 448 | 449 | if hasattr(train, 'data'): 450 | logger.info('Input Dim: {}'.format(train.data.shape[1:])) 451 | logger.info('Classes: {} (effective: {})'.format(len(train.classes), len(torch.unique(train.targets)))) 452 | print(f'Fold Sizes: {flens} ({fnames})') 453 | 454 | return fold_loaders, {'train': train, 'test':test} 455 | 456 | 457 | def load_imagenet(datadir=None, resize=None, tiny=False, augmentations=False, **kwargs): 458 | """ Load ImageNet dataset """ 459 | if datadir is None and (not tiny): 460 | datadir = os.path.join(DATA_DIR,'imagenet') 461 | elif datadir is None and tiny: 462 | datadir = os.path.join(DATA_DIR,'tiny-imagenet-200') 463 | 464 | traindir = os.path.join(datadir, "train") 465 | validdir = os.path.join(datadir, "val") 466 | if augmentations: 467 | train_transform_list = [ 468 | transforms.RandomResizedCrop(224), 469 | transforms.RandomHorizontalFlip(), 470 | transforms.ColorJitter( 471 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 472 | ), 473 | transforms.ToTensor(), 474 | transforms.Normalize(*DATASET_NORMALIZATION['ImageNet']) 475 | ] 476 | else: 477 | train_transform_list = [ 478 | transforms.Resize(224), # revert back to 256 479 | transforms.CenterCrop(224), 480 | transforms.ToTensor(), 481 | transforms.Normalize(*DATASET_NORMALIZATION['ImageNet']) 482 | ] 483 | 484 | valid_transform_list = [ 485 | transforms.Resize(224),# revert back to 256 486 | transforms.CenterCrop(224), 487 | transforms.ToTensor(), 488 | transforms.Normalize(*DATASET_NORMALIZATION['ImageNet']) 489 | ] 490 | 491 | if resize is not None: 492 | train_transform_list.insert(3, transforms.Resize( 493 | (resize, resize))) 494 | valid_transform_list.insert(2, transforms.Resize( 495 | (resize, resize))) 496 | 497 | train_data = dset.ImageFolder( 498 | traindir, 499 | transforms.Compose( 500 | train_transform_list 501 | ), 502 | ) 503 | 504 | valid_data = dset.ImageFolder( 505 | validdir, 506 | transforms.Compose( 507 | valid_transform_list 508 | ), 509 | ) 510 | fold_loaders, dsets = load_torchvision_data('Imagenet', transform=[], 511 | data=(train_data, valid_data), 512 | **kwargs) 513 | 514 | return fold_loaders, dsets 515 | 516 | 517 | TEXTDATA_PATHS = { 518 | 'AG_NEWS': 'ag_news_csv', 519 | 'SogouNews': 'sogou_news_csv', 520 | 'DBpedia': 'dbpedia_csv', 521 | 'YelpReviewPolarity': 'yelp_review_polarity_csv', 522 | 'YelpReviewFull': 'yelp_review_full_csv', 523 | 'YahooAnswers': 'yahoo_answers_csv', 524 | 'AmazonReviewPolarity': 'amazon_review_polarity_csv', 525 | 'AmazonReviewFull': 'amazon_review_full_csv', 526 | } 527 | 528 | def load_textclassification_data(dataname, vecname='glove.42B.300d', shuffle=True, 529 | random_seed=None, num_workers = 0, preembed_sentences=True, 530 | loading_method='sentence_transformers', device='cpu', 531 | embedding_model=None, 532 | batch_size = 16, valid_size=0.1, maxsize=None, print_stats = False): 533 | """ Load torchtext datasets. 534 | 535 | Note: torchtext's TextClassification datasets are a bit different from the others: 536 | - they don't have split method. 537 | - no obvious creation of (nor access to) fields 538 | 539 | """ 540 | 541 | 542 | 543 | def batch_processor_tt(batch, TEXT=None, sentemb=None, return_lengths=True, device=None): 544 | """ For torchtext data/models """ 545 | labels, texts = zip(*batch) 546 | lens = [len(t) for t in texts] 547 | labels = torch.Tensor(labels) 548 | pad_idx = TEXT.vocab.stoi[TEXT.pad_token] 549 | texttensor = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=pad_idx) 550 | if sentemb: 551 | texttensor = sentemb(texttensor) 552 | if return_lengths: 553 | return texttensor, labels, lens 554 | else: 555 | return texttensor, labels 556 | 557 | def batch_processor_st(batch, model, device=None): 558 | """ For sentence_transformers data/models """ 559 | device = process_device_arg(device) 560 | with torch.no_grad(): 561 | batch = model.smart_batching_collate(batch) 562 | ## Always run embedding model on gpu if available 563 | features, labels = st.util.batch_to_device(batch, device) 564 | emb = model(features[0])['sentence_embedding'] 565 | return emb, labels 566 | 567 | 568 | if shuffle == True and random_seed: 569 | np.random.seed(random_seed) 570 | 571 | debug = False 572 | 573 | dataroot = '/tmp/' if debug else DATA_DIR #os.path.join(ROOT_DIR, 'data') 574 | veccache = os.path.join(dataroot,'.vector_cache') 575 | 576 | if loading_method == 'torchtext': 577 | ## TextClassification object datasets already do word to token mapping inside. 578 | DATASET = getattr(torchtext.datasets, dataname) 579 | train, test = DATASET(root=dataroot, ngrams=1) 580 | 581 | ## load_vectors reindexes embeddings so that they match the vocab's itos indices. 582 | train._vocab.load_vectors(vecname,cache=veccache,max_vectors = 50000) 583 | test._vocab.load_vectors(vecname,cache=veccache, max_vectors = 50000) 584 | 585 | ## Define Fields for Text and Labels 586 | text_field = torchtext.data.Field(sequential=True, lower=True, 587 | tokenize=get_tokenizer("basic_english"), 588 | batch_first=True, 589 | include_lengths=True, 590 | use_vocab=True) 591 | 592 | text_field.vocab = train._vocab 593 | 594 | if preembed_sentences: 595 | ## This will be used for distance computation 596 | vsize = len(text_field.vocab) 597 | edim = text_field.vocab.vectors.shape[1] 598 | pidx = text_field.vocab.stoi[text_field.pad_token] 599 | sentembedder = BoWSentenceEmbedding(vsize, edim, text_field.vocab.vectors, pidx) 600 | batch_processor = partial(batch_processor_tt,TEXT=text_field,sentemb=sentembedder,return_lengths=False) 601 | else: 602 | batch_processor = partial(batch_processor_tt,TEXT=text_field,return_lengths=True) 603 | elif loading_method == 'sentence_transformers': 604 | import sentence_transformers as st 605 | dpath = os.path.join(dataroot,TEXTDATA_PATHS[dataname]) 606 | reader = st.readers.LabelSentenceReader(dpath) 607 | if embedding_model is None: 608 | model = st.SentenceTransformer('distilbert-base-nli-stsb-mean-tokens').eval() 609 | elif type(embedding_model) is str: 610 | model = st.SentenceTransformer(embedding_model).eval() 611 | elif isinstance(embedding_model, st.SentenceTransformer): 612 | model = embedding_model.eval() 613 | else: 614 | raise ValueError('embedding model has wrong type') 615 | print('Reading and embedding {} train data...'.format(dataname)) 616 | train = st.SentencesDataset(reader.get_examples('train.tsv'), model=model) 617 | train.targets = train.labels 618 | print('Reading and embedding {} test data...'.format(dataname)) 619 | test = st.SentencesDataset(reader.get_examples('test.tsv'), model=model) 620 | test.targets = test.labels 621 | if preembed_sentences: 622 | batch_processor = partial(batch_processor_st, model=model, device=device) 623 | else: 624 | batch_processor = None 625 | 626 | ## Seems like torchtext alredy maps class ids to 0...n-1. Adapt class names to account for this. 627 | classes = torchtext.datasets.text_classification.LABELS[dataname] 628 | classes = [classes[k+1] for k in range(len(classes))] 629 | train.classes = classes 630 | test.classes = classes 631 | 632 | train_idx, valid_idx = random_index_split(len(train), 1-valid_size, (maxsize, None)) # No maxsize for validation 633 | train_sampler = SubsetRandomSampler(train_idx) 634 | valid_sampler = SubsetRandomSampler(valid_idx) 635 | 636 | dataloader_args = dict(batch_size=batch_size,num_workers=num_workers,collate_fn=batch_processor) 637 | train_loader = dataloader.DataLoader(train, sampler=train_sampler,**dataloader_args) 638 | valid_loader = dataloader.DataLoader(train, sampler=valid_sampler,**dataloader_args) 639 | dataloader_args['shuffle'] = False 640 | test_loader = dataloader.DataLoader(test, **dataloader_args) 641 | 642 | if print_stats: 643 | print('Classes: {} (effective: {})'.format(len(train.classes), len(torch.unique(train.targets)))) 644 | print('Fold Sizes: {}/{}/{} (train/valid/test)'.format(len(train_idx), len(valid_idx), len(test))) 645 | 646 | return train_loader, valid_loader, test_loader, train, test 647 | 648 | 649 | class H5Dataset(torchdata.Dataset): 650 | def __init__(self, images_path, labels_path, transform=None): 651 | super(H5Dataset, self).__init__() 652 | 653 | f = h5py.File(images_path, "r") 654 | self.data = f.get("x") 655 | 656 | g = h5py.File(labels_path, "r") 657 | self.targets = torch.from_numpy(g.get("y")[:].flatten()) 658 | 659 | self.transform = transform 660 | self.classes = [0, 1] 661 | 662 | def __getitem__(self, index): 663 | if type(index) != slice: 664 | X = ( 665 | torch.from_numpy(self.data[index, :, :, :]).permute(2, 0, 1).float() 666 | / 255 667 | ) 668 | else: 669 | X = ( 670 | torch.from_numpy(self.data[index, :, :, :]).permute(0, 3, 1, 2).float() 671 | / 255 672 | ) 673 | 674 | y = int(self.targets[index]) 675 | 676 | if self.transform: 677 | X = self.transform(torchvision.transforms.functional.to_pil_image(X)) 678 | 679 | return X, y 680 | 681 | def __len__(self): 682 | return self.data.shape[0] 683 | 684 | 685 | def combine_datasources(dset, dset_extra, valid_size=0, shuffle=True, random_seed=2019, 686 | maxsize=None, device='cpu'): 687 | """ Combine two datasets. 688 | 689 | Extends dataloader with additional data from other dataset(s). Note that we 690 | add the examples in dset only to train (no validation) 691 | 692 | Arguments: 693 | dset (DataLoader): first dataloader 694 | dset_extra (DataLoader): additional dataloader 695 | valid_size (float): fraction of data use for validation fold 696 | shiffle (bool): whether to shuffle train data 697 | random_seed (int): random seed 698 | maxsize (int): maximum number of examples in either train or validation loader 699 | device (str): device for data loading 700 | 701 | Returns: 702 | train_loader_ext (DataLoader): train dataloader for combined data sources 703 | valid_loader_ext (DataLoader): validation dataloader for combined data sources 704 | 705 | """ 706 | if shuffle == True and random_seed: 707 | np.random.seed(random_seed) 708 | 709 | ## Convert both to TensorDataset 710 | if isinstance(dset, torch.utils.data.DataLoader): 711 | dataloader_args = {k:getattr(dset, k) for k in ['batch_size', 'num_workers']} 712 | X, Y = load_full_dataset(dset, targets=True, device=device) 713 | d = int(np.sqrt(X.shape[1])) 714 | X = X.reshape(-1, 1, d, d) 715 | dset = torch.utils.data.TensorDataset(X, Y) 716 | logger.info(f'Main data size. X: {X.shape}, Y: {Y.shape}') 717 | elif isinstance(dst, torch.utils.data.Dataset): 718 | raise NotImplemented('Error: combine_datasources cant take Datasets yet.') 719 | 720 | merged_dset = torch.utils.data.ConcatDataset([dset, dset_extra]) 721 | train_idx, valid_idx = random_index_split(len(dset), 1-valid_size, (maxsize, None)) # No maxsize for validation 722 | train_idx = np.concatenate([train_idx, np.arange(len(dset_extra)) + len(dset)]) 723 | 724 | if shuffle: 725 | train_sampler = SubsetRandomSampler(train_idx) 726 | valid_sampler = SubsetRandomSampler(valid_idx) 727 | else: 728 | train_sampler = SubsetSampler(train_idx) 729 | valid_sampler = SubsetSampler(valid_idx) 730 | 731 | train_loader_ext = dataloader.DataLoader(merged_dset, sampler = train_sampler, **dataloader_args) 732 | valid_loader_ext = dataloader.DataLoader(merged_dset, sampler = valid_sampler, **dataloader_args) 733 | 734 | logger.info(f'Fold Sizes: {len(train_idx)}/{len(valid_idx)} (train/valid)') 735 | 736 | return train_loader_ext, valid_loader_ext 737 | -------------------------------------------------------------------------------- /otdd/pytorch/flows.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pdb 5 | import logging 6 | 7 | import matplotlib 8 | if os.name == 'posix' and "DISPLAY" not in os.environ: 9 | matplotlib.use('Agg') 10 | nodisplay = True 11 | else: 12 | nodisplay = False 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | import numpy as np 17 | from tqdm.autonotebook import tqdm 18 | 19 | import torch 20 | from torch.utils.data import TensorDataset 21 | from torchvision.utils import make_grid 22 | 23 | 24 | from celluloid import Camera 25 | 26 | from .distance import DatasetDistance 27 | from .moments import compute_label_stats 28 | from .utils import show_grid, register_gradient_hook, inverse_normalize 29 | from ..plotting import gaussian_density_plot, plot2D_samples_mat 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | try: 34 | from tsnecuda import TSNE 35 | tsnelib = 'tsnecuda' 36 | except: 37 | logger.warning("tsnecuda not found - will use (slower) TSNE from sklearn") 38 | from sklearn.manifold import TSNE 39 | tsnelib = 'sklearn' 40 | # 41 | from openTSNE import TSNE 42 | tsnelib = 'opentsne' 43 | 44 | 45 | ################################################################################ 46 | ##################### MAIN GRADIENT FLOW CLASSES ############################ 47 | ################################################################################ 48 | 49 | class GradientFlow(): 50 | """ Parent class for Gradient flows. Subclasses should provide at the very 51 | least init, step, end and flow methods. 52 | 53 | """ 54 | def __init__(src, tgt): 55 | self.Ds = src 56 | self.Dt = tgt 57 | self.Zt = [] # To store trajectories 58 | 59 | def _add_callback(self, cb): 60 | self.callback = cb 61 | self.compute_coupling = cb.compute_coupling if hasattr(cb,'compute_coupling') else False 62 | self.store_trajectories = cb.store_trajectories if hasattr(cb,'store_trajectories') else False 63 | 64 | def init(self): 65 | pass 66 | 67 | def end(self): 68 | pass 69 | 70 | def step(self): 71 | pass 72 | 73 | def flow(self): 74 | """ """ 75 | f_val = self.init() 76 | self.callback.on_flow_begin(self.otdd, d) 77 | 78 | pbar = tqdm(self.times, leave=False) 79 | for iter,t in enumerate(pbar): 80 | pbar.set_description(f'Flow Step {iter}/{len(self.times)}') 81 | self.callback.on_step_begin(self.otdd, iter) 82 | f_val = self.step() 83 | logger.info("t={:8.4f}, F(a_t)={:8.2f}".format(t,f_val)) 84 | self.callback.on_step_end(self.otdd, iter, t, d) 85 | print(get_gpu_memory_map()) 86 | 87 | logger.info('Done') 88 | 89 | self.end() 90 | 91 | return f_val, 92 | 93 | def animate(self, display_freq=None, save_path=None, **kwargs): 94 | """ A shortcut to creating the plotting callback with animation. 95 | 96 | - kwargs: additional arguments to be passed to Plotting2DCallback 97 | """ 98 | cb = Plotting2DCallback(animate=True, display_freq=display_freq, 99 | save_path=save_path, **kwargs) 100 | self._add_callback(cb) 101 | self.flow() 102 | return cb.animation 103 | 104 | 105 | class OTDD_Gradient_Flow(GradientFlow): 106 | """ 107 | Gradient Flows According to the OT Dataset Distance. 108 | 109 | Attributes 110 | ---------- 111 | D1 : 112 | 113 | Method (str): the flow method 114 | - xonly: gradient updates only on the features. Label distribs recomputed 115 | based on these. Graph from X to Means and Covariances is dettached 116 | - xonly-attached: gradient updates only on the features. Label distribs 117 | recomputed based on these. Graph from X to Means and Covariances is kept attached 118 | - xytied: gradient updates on both features and labels distribs. But the 119 | assignments of examples to labels are kept fixed throughout (no 120 | class creation/destruction, fixed cluster sizes) 121 | - xyaugm: gradient updates on both features and label distribs, without 122 | tying. Relies on "augmented" representation computation for OTDD 123 | and therefore on diagonal approximation of the covariance matrix. 124 | 125 | Note that these are relveant only for inner_ot_method =! exact. For the exact 126 | (nonparametric) OTDD, there's no μ, Σ, so xyaugm/xytied don't make sense, and 127 | xonly/xonly-attached are equivalent. 128 | 129 | """ 130 | def __init__(self, src, tgt=None, objective_type = 'otdd_only', functional = None, 131 | method = 'xonly', optim='sgd', step_size = 0.1, steps = 20, 132 | use_torchoptim=False, 133 | compute_coupling=False, entreg_π = 1e-3, 134 | fixed_labels=True, 135 | clustering_method='kmeans', 136 | clustering_input=None, 137 | callback=None, 138 | noisy_update = False, 139 | noise_β = 0.01, 140 | device = 'cpu', 141 | precision='single', 142 | eigen_correction=False, 143 | **kwargs): 144 | """ 145 | kwargs are for DatasetDistance. 146 | """ 147 | self.device = device 148 | self.Ds = src 149 | self.Dt = tgt 150 | 151 | assert method in ['xonly', 'xonly-attached', 'xytied', 'xyaugm'] 152 | self.method = method 153 | self.optim = optim 154 | self.use_torchoptim = use_torchoptim 155 | self.fixed_labels = fixed_labels 156 | self.clustering_method = clustering_method 157 | self.clustering_input = clustering_input 158 | 159 | 160 | self.entreg_π = entreg_π 161 | self.precision = precision 162 | self.eigen_correction = eigen_correction 163 | 164 | 165 | assert objective_type in ['otdd_only', 'ot_only', 'F_only', 'mixed'] 166 | assert (functional is None) or callable(functional) 167 | assert (functional is None) or (objective_type in ['F_only', 'mixed']) 168 | 169 | assert not (objective_type in ['otdd_only', 'ot_only', 'mixed']) or self.Dt is not None, 'If objective contains distance, must provide tgt dataset' 170 | self.objective_type = objective_type 171 | self.functional = functional 172 | 173 | ### I had this in my oher flows script. Makes sense only for fixed time horizon t in [0,1] 174 | self.times = np.arange(step_size,step_size*(steps+1), step_size) 175 | self.step_size = step_size 176 | self.steps = steps 177 | self.callback = callback if callback is not None else Callback() 178 | self.initialized = False 179 | self.X1_init = None 180 | self.compute_coupling = callback.compute_coupling if hasattr(callback,'compute_coupling') else False 181 | self.store_trajectories = callback.store_trajectories if hasattr(callback,'store_trajectories') else False 182 | self.trajectory_freq = callback.trajectory_freq if hasattr(callback,'trajectory_freq') else 1 183 | 184 | self.X1 = None # to store trajectories 185 | 186 | otdd_args = { 187 | 'inner_ot_method': 'gaussian_approx', 188 | 'nworkers_dists': 1, 189 | 'nworkers_stats': 1, 190 | 'debiased_loss': True, 191 | 'sqrt_method': 'exact', 192 | 'sqrt_niters': None, 193 | 'sqrt_pref': 1, # to save some bacward computation on sqrts of src side. 194 | 'p': 2, 195 | 'entreg': 1e-1, 196 | 'precision': precision, 197 | 'device': device, 198 | 'λ_y': None if objective_type == 'ot_only' else 1.0, 199 | 'eigen_correction': eigen_correction 200 | } 201 | 202 | otdd_args.update(kwargs) 203 | otdd_args['method'] = 'augmentation' if method == 'xyaugm' else 'precomputed_labeldist' 204 | if 'diagonal_cov' not in otdd_args: 205 | otdd_args['diagonal_cov'] = (method == 'xy_augm') 206 | 207 | assert not ((otdd_args['inner_ot_method']=='exact') and (method in ['xy_augm', 'xytied'])), \ 208 | "If inner_ot_method == 'exact', then flow method cannot be '{}'.".format(method) 209 | 210 | self.otdd = DatasetDistance(src, tgt, **otdd_args) 211 | self.history = [] 212 | 213 | def init(self): 214 | self.t = 0.0 215 | 216 | ### Also initial target Mean and Covs (not in computational graph, static) 217 | ###_, _ = self.otdd._get_label_stats(side='tgt') 218 | 219 | d = self.otdd.distance(return_coupling=False) 220 | 221 | self.class_sizes = torch.unique(self.otdd.Y1, return_counts=True)[1] #torch.tensor([(self.otdd.Y1 == c).sum().item() for c in self.otdd.V1]) 222 | 223 | logger.info('Using method: {}'.format(self.method)) 224 | if self.method == 'xonly' or self.method == 'xonly-attached': 225 | self.otdd.X1.requires_grad_(True) 226 | flow_params = [self.otdd.X1] # Updated by optim / gradient 227 | 228 | elif self.method == 'xytied': 229 | self.otdd.X1.requires_grad_(True) 230 | self.otdd.Covs[0].requires_grad_(True) 231 | self.otdd.Means[0].requires_grad_(True) 232 | stats_lr = self.step_size/(self.class_sizes.max()*0.2) 233 | flow_params = [ 234 | {'params': self.otdd.X1, 'lr': self.step_size}, 235 | {'params': self.otdd.Means[0], 'lr': stats_lr}, 236 | {'params': self.otdd.Covs[0], 'lr': stats_lr}, 237 | ] 238 | flow_params = [self.otdd.X1, self.otdd.Means[0], self.otdd.Covs[0]] 239 | 240 | elif self.method == 'xyaugm': 241 | self.otdd.XμΣ1.requires_grad_(True) 242 | flow_params = [self.otdd.XμΣ1] 243 | 244 | self._flow_params = flow_params 245 | 246 | if self.use_torchoptim: 247 | if self.optim == 'sgd': 248 | optimizer = torch.optim.SGD(flow_params, lr=self.step_size, momentum=0.5) 249 | elif self.optim == 'adam': 250 | optimizer = torch.optim.Adam(flow_params, lr=self.step_size) 251 | elif self.optim == 'adagrad': 252 | optimizer = torch.optim.Adagrad(flow_params, lr=self.step_size) 253 | self.optimizer = optimizer 254 | 255 | if self.compute_coupling is not False: self.coupling_update() 256 | 257 | ### Trigger first forward pass, now that things require grad 258 | 259 | if self.X1_init is None: 260 | self.X1_init = self.otdd.X1.detach().clone().cpu() 261 | self.Y1_init = self.otdd.Y1.detach().clone().cpu() 262 | if self.store_trajectories: 263 | self.Xt = self.X1_init.unsqueeze(-1).float() # time will be last dim 264 | self.Yt = self.Y1_init.unsqueeze(-1) 265 | if self.otdd.inner_ot_method != 'exact': 266 | self.M1_init = self.otdd.Means[0].detach().clone().cpu() 267 | self.C1_init = self.otdd.Covs[0].detach().clone().cpu() 268 | 269 | # 270 | logger.info("t={:8.2f}, F(a_t)={:8.2f}".format(0.0,d.item())) 271 | self.initialized = True 272 | return d.item() 273 | 274 | def coupling_update(self): 275 | logger.info('Coupling Update Computation') 276 | with torch.no_grad(): 277 | otmethod = 'emd' if self.entreg_π <= 1e-3 else 'sinkhorn_epsilon_scaling' 278 | self.otdd.π = self.otdd.compute_coupling(entreg=self.entreg_π, method=otmethod, 279 | verbose = True, numItermax=50) 280 | 281 | def gradient_update(self): 282 | if self.use_torchoptim: 283 | L_αβ = 0 284 | if self.objective_type in ['otdd_only', 'ot_only', 'mixed']: 285 | L_αβ += self.otdd.distance(return_coupling=False) 286 | if callable(self.functional): 287 | L_αβ += self.functional(self.otdd.X1, self.otdd.Y1) 288 | self.optimizer.zero_grad() 289 | L_αβ.backward() 290 | gnorm = torch.nn.utils.clip_grad_norm_(self._flow_params, 10) 291 | self.optimizer.step() 292 | 293 | else: # Manual Updating 294 | L_αβ = self.otdd.distance(return_coupling=False) 295 | lr = self.step_size 296 | X = self.otdd.X1 297 | if self.noisy_update: X += torch.randn(X.shape)*self.noise_β 298 | 299 | if self.method in ['xonly', 'xonly-attached']: 300 | [gX] = torch.autograd.grad(L_αβ, [X]) 301 | self.otdd.X1.data -= lr * len(X) * gX 302 | elif self.method == 'xytied': 303 | [gX,gM,gC] = torch.autograd.grad(L_αβ, [X, self.otdd.Means[0], self.otdd.Covs[0]]) 304 | for g in [gX,gM,gC]: 305 | if torch.isnan(g).any(): pdb.set_trace(header='Failed Grad Check') 306 | 307 | class_sizes = torch.tensor([(self.otdd.Y1 == c).sum().item() for c in self.otdd.V1]) * 0.1 308 | 309 | self.otdd.X1.data -= lr * len(X) * gX 310 | self.otdd.Means[0].data -= lr * len(self.otdd.Means[0]) * gM / class_sizes.view(-1,1) #* 0.2 # slower lr for Means and Covs 311 | self.otdd.Covs[0].data -= lr * len(self.otdd.Covs[0]) * gC / class_sizes.view(-1,1,1) #* 0.2 312 | elif self.method == 'xyaugm': 313 | if self.noisy_update: 314 | pdb.set_trace(header='Not sure xyaugm works with noisy update yet') 315 | [gXMC] = torch.autograd.grad(L_αβ, [self.otdd.XμΣ1])#, self.otdd.Means[0], self.otdd.Covs[0]]) 316 | self.otdd.XμΣ1.data -= lr * len(self.otdd.XμΣ1) * gXMC 317 | self.otdd.X1 = self.otdd.XμΣ1.data[:,:2] 318 | return L_αβ.item() 319 | 320 | 321 | def label_update(self): 322 | """ Triggers an update to the categorial labels associated with each 323 | particle, based on current state. 324 | 325 | Arguments: 326 | cinput (str): input to the clustering method, i.e., what parts of the 327 | current state are used to impute labels. One of 'feats', 'stats' 328 | or 'both'. 329 | 330 | """ 331 | cinput = self.clustering_input 332 | cmethod = self.clustering_method 333 | 334 | if cinput == 'feats': 335 | U = self.otdd.X1.detach().numpy() 336 | elif cinput == 'stats': 337 | if self.method == 'xyaugm': 338 | d = self.otdd.X1.shape[1] 339 | U = self.otdd.XμΣ1.detach().numpy()[:,d:] 340 | else: 341 | raise NotImplemented() 342 | elif cinput == 'both': 343 | if self.method == 'xyaugm': 344 | U = self.otdd.XμΣ1.detach().numpy() 345 | else: 346 | raise NotImplemented() 347 | 348 | if cmethod == 'kmeans': 349 | k = len(self.otdd.classes2) 350 | C,L,_ = k_means(U, k) 351 | elif cmethod == 'dbscan': 352 | L = DBSCAN(eps=5, min_samples = 4).fit(U).labels_ 353 | else: 354 | raise ValueError() 355 | 356 | self.otdd.Y1 = torch.LongTensor(L) 357 | self.otdd.targets1 = self.otdd.Y1 358 | self.otdd.classes1 = [str(i.item()) for i in torch.unique(self.otdd.targets1)] 359 | 360 | 361 | def stats_update(self): 362 | """ Triggers update on means and covariances of particles """ 363 | dtype = torch.DoubleTensor if self.precision == 'double' else torch.FloatTensor 364 | 365 | if self.method == 'xonly': 366 | Ds_t = TensorDataset(self.otdd.X1.detach().clone(), self.otdd.Y1)#self.Ds.tensors[1]) 367 | 368 | self.otdd.Means[0], self.otdd.Covs[0] = compute_label_stats(Ds_t, 369 | targets=self.otdd.targets1, 370 | indices=np.arange(len(self.otdd.Y1)), 371 | classnames=self.otdd.classes1, 372 | to_tensor=True, 373 | diagonal_cov = self.otdd.diagonal_cov, 374 | online=self.otdd.online_stats, 375 | device=self.otdd.device, 376 | dtype=dtype, 377 | eigen_correction=self.eigen_correction, 378 | ) 379 | 380 | elif self.method == 'xonly-attached': 381 | Ds_t = TensorDataset(self.otdd.X1, self.otdd.Y1) 382 | 383 | self.otdd.Means[0],self.otdd.Covs[0] = compute_label_stats(Ds_t, 384 | targets=self.otdd.targets1, 385 | indices=np.arange(len(self.otdd.Y1)), 386 | classnames=self.otdd.classes1, 387 | to_tensor=True, 388 | diagonal_cov = self.otdd.diagonal_cov, 389 | online=self.otdd.online_stats, 390 | device=self.otdd.device, 391 | dtype=dtype, 392 | eigen_correction=self.eigen_correction 393 | ) 394 | 395 | if torch.isnan(self.otdd.Covs[0]).any(): 396 | pdb.set_trace(header='Nans in Cov Matrices') 397 | 398 | 399 | elif self.method == 'xytied': 400 | pass 401 | elif self.method == 'xyaugm': 402 | d = self.otdd.X1.shape[1] 403 | 404 | self.otdd.X1 = self.otdd.XμΣ1.data[:,:d] 405 | 406 | if self.clustering_method == 'kmeans': 407 | k = len(self.otdd.classes2) 408 | C,L,_ = k_means(self.otdd.XμΣ1.detach().numpy()[:,d:], k) 409 | elif self.clustering_method == 'dbscan': 410 | L = DBSCAN(eps=5, min_samples = 4).fit(self.otdd.XμΣ1.detach().numpy()[:,d:]).labels_ 411 | 412 | self.otdd.Y1 = torch.LongTensor(L) 413 | self.otdd.targets1 = self.otdd.Y1 414 | self.otdd.classes1 = [str(i.item()) for i in torch.unique(self.otdd.targets1)] 415 | 416 | Ds_t = TensorDataset(self.otdd.X1, self.otdd.Y1) 417 | 418 | M, S = compute_label_stats(Ds_t, 419 | targets=self.otdd.targets1, 420 | indices=np.arange(len(self.otdd.Y1)), 421 | classnames=self.otdd.classes1, 422 | to_tensor=True, 423 | nworkers=0, 424 | diagonal_cov = True, 425 | online=self.otdd.online_stats, 426 | device=self.otdd.device, 427 | eigen_correction=self.eigen_correction 428 | ) 429 | self.otdd.Means[0] = M 430 | self.otdd.Covs[0] = S 431 | _, cts = torch.unique(self.otdd.Y1, return_counts=True) 432 | logger.info('Counts: {}'.format(cts)) 433 | self.class_sizes = cts 434 | 435 | else: 436 | raise ValueError('Unrecoginzed flow method') 437 | 438 | 439 | def step(self, iter): 440 | assert self.initialized 441 | 442 | ##### Update differentiable-dynamic params 443 | self.otdd.label_distances = None 444 | self.otdd._pwlabel_stats_1 = None 445 | 446 | logger.info('Performing flow gradient step...') 447 | L_αβ = self.gradient_update() 448 | 449 | ##### Update labels (if fixed_labels != False ) 450 | if not self.fixed_labels: 451 | logger.info('Performing label update...') 452 | self.label_update() 453 | 454 | ###### Trigger update of those that are not updated by gradient 455 | if self.otdd.inner_ot_method != 'exact': 456 | logger.info('Performing stats update...') 457 | self.stats_update() 458 | 459 | if self.compute_coupling == 'every_iteration': 460 | logger.info('Performing coupling update...') 461 | self.coupling_update() 462 | 463 | if self.store_trajectories and (iter % self.trajectory_freq == 0): 464 | self.Xt = torch.cat([self.Xt, self.otdd.X1.detach().clone().cpu().float().unsqueeze(-1)], dim=-1) 465 | self.Yt = torch.cat([self.Yt, self.otdd.Y1.detach().clone().cpu().unsqueeze(-1)], dim=-1) 466 | 467 | 468 | return L_αβ 469 | 470 | def flow(self, tol = 1e-3): 471 | """ """ 472 | prev = np.Inf 473 | obj = self.init() 474 | self.history.append(obj) 475 | self.callback.on_flow_begin(self.otdd, obj) 476 | pbar = tqdm(self.times, leave=False) 477 | for iter,t in enumerate(pbar, 1): 478 | pbar.set_description(f'Flow Step {iter}/{len(self.times)}, F_t={obj:8.2f}') 479 | self.callback.on_step_begin(self.otdd, iter) 480 | obj = self.step(iter) 481 | logger.info(f't={t:8.2f}, F(a_t)={obj:8.2f}') # Although things have been updated, this is obj of time t still 482 | self.history.append(obj) 483 | self.t = t 484 | self.callback.on_step_end(self, self.otdd, iter, t, obj) # Now that we pass whole flow obj, maybe no need to pass otdd 485 | Δ = np.abs(obj - prev) 486 | if tol and (Δ < tol): 487 | logger.warning(f'Stoping condition met (Δ = {Δ:2.2e} < {tol:2.2e} = tol). Terminating flow.') 488 | break 489 | else: 490 | prev = obj 491 | 492 | logger.info('Done') 493 | cbout = self.callback.on_flow_end(self, self.otdd)#, iter, t, d) 494 | self.end() 495 | return obj, cbout 496 | 497 | def end(self): 498 | self.otdd.X1 = self.X1_init 499 | self.otdd.Y1 = self.Y1_init 500 | if self.otdd.inner_ot_method != 'exact': 501 | self.otdd.Means[0] = self.M1_init 502 | self.otdd.Covs[0] = self.C1_init 503 | 504 | 505 | 506 | ################################################################################ 507 | ################### CALLBACKS FOR GRADIENT FLOW CLASS ###################### 508 | ################################################################################ 509 | 510 | class Callback(): 511 | compute_coupling = False 512 | store_trajectories = False 513 | trajectory_freq = None 514 | def __init__(self): pass 515 | def on_flow_begin(self, *args, **kwargs): pass 516 | def on_flow_end(self, *args, **kwargs): pass 517 | def on_step_begin(self, *args, **kwargs): pass 518 | def on_step_end(self, *args, **kwargs): pass 519 | 520 | 521 | class CallbackList(Callback): 522 | def __init__(self, cbs): 523 | self.cbs = cbs 524 | ### Aggregate requierements imposed by callbacks on Flow - if at least 525 | ### one of them needs it, ask for it. 526 | coupling_attrs = [cb.compute_coupling for cb in cbs] 527 | if 'every_iteration' in coupling_attrs: 528 | self.compute_coupling = 'every_iteration' 529 | elif 'initial' in coupling_attrs: 530 | self.compute_coupling = 'initial' 531 | else: 532 | self.compute_coupling = False 533 | 534 | trajectory_attrs = [cb.store_trajectories for cb in cbs] 535 | self.store_trajectories = np.array(trajectory_attrs).any() 536 | trajfreq_attrs = [cb.trajectory_freq for cb in cbs if cb.trajectory_freq is not None] 537 | self.trajectory_freq = np.array(trajfreq_attrs).min() if trajfreq_attrs else None 538 | 539 | def __getitem__(self, i): 540 | return self.cbs[i] 541 | 542 | def on_flow_begin(self, *args, **kwargs): 543 | for cb in self.cbs: cb.on_flow_begin(*args, **kwargs) 544 | def on_flow_end(self, *args, **kwargs): 545 | for cb in self.cbs: cb.on_flow_end(*args, **kwargs) 546 | def on_step_begin(self, *args, **kwargs): 547 | for cb in self.cbs: cb.on_step_begin(*args, **kwargs) 548 | def on_step_end(self, *args, **kwargs): 549 | for cb in self.cbs: cb.on_step_end(*args, **kwargs) 550 | 551 | 552 | class Plotting2DCallback(Callback): 553 | def __init__(self, display_freq=None, animate=False, entreg_π = 1e-4, 554 | show_coupling = True, show_trajectories=True, trajectory_size=5, show_target = True, 555 | plot_pad = 2, save_format='pdf', ndim=2, save_path=None): 556 | self.animate = animate 557 | self.save_path = save_path 558 | self.display_freq = display_freq 559 | self.entreg_π = entreg_π 560 | self.show_coupling = show_coupling 561 | self.compute_coupling = 'every_iteration' if show_coupling else False 562 | self.show_trajectories = show_trajectories 563 | self.store_trajectories = self.show_trajectories 564 | self.trajectory_freq = self.display_freq 565 | self.trajectory_size = trajectory_size 566 | self.show_target = show_target 567 | self.plot_pad = plot_pad 568 | self.ndim = ndim 569 | 570 | self.ax_ranges = [None]*ndim 571 | self.figsize = (6,4) if not animate else (10,7) 572 | self.save_format = save_format 573 | 574 | raise DeprecationWarning('Plotting2DCallback has been deprecated in favor of PlottingCallback') 575 | 576 | 577 | def _plot(self, otdd, X1, Y1=None, X2=None, Y2=None, title=None, trajectories=None): 578 | with torch.no_grad(): 579 | ### Now parent flow takes care of coupling computation 580 | if self.animate: 581 | ax = self.ax 582 | xrng, yrng = ax.get_xlim(), ax.get_ylim() 583 | else: 584 | fig, ax = plt.subplots(figsize=self.figsize) 585 | xrng, yrng = self._get_ax_ranges(X1, X2) 586 | 587 | otdd.plot_label_stats(ax = ax, same_plot=True, show_target=self.show_target, 588 | label_means=False, 589 | label_groups=True, show=False,shift=(2,-2)) 590 | 591 | if self.show_trajectories and trajectories is not None: 592 | pdb.set_trace() 593 | for x in trajectories: 594 | ax.plot(*x, color='k', alpha=0.2, linewidth=0.5) 595 | 596 | if self.show_target and self.show_coupling: 597 | π = otdd.π 598 | plot2D_samples_mat(X1, X2, π, ax=ax, linewidth=0.1, thr=1e-10, linestyle=':') 599 | ax.set_xlim(xrng) 600 | ax.set_ylim(yrng) 601 | ax.set_title('') 602 | ax.get_xaxis().set_ticks([]) 603 | ax.get_yaxis().set_ticks([]) 604 | if title is not None: 605 | ax.text(0.5, 1.01, title, transform=ax.transAxes, ha='center',size=18) 606 | 607 | def _get_ax_ranges(self, X1, X2): 608 | pad = self.plot_pad 609 | if all(v is not None for v in self.ax_ranges): 610 | return self.ax_ranges 611 | with torch.no_grad(): 612 | mins, maxs = [], [] 613 | for i in range(self.ndim): 614 | if self.show_target: 615 | mins.append(min(X1[:,i].min(), X2[:,i].min()) - pad) 616 | maxs.append(max(X1[:,i].max(), X2[:,i].max()) + pad) 617 | else: 618 | mins.append(X1[:,i].min() - pad) 619 | maxs.append(X1[:,i].max() + pad) 620 | 621 | self.ax_ranges = [(mins[i].item(), maxs[i].item()) for i in range(self.ndim)] 622 | return self.ax_ranges 623 | 624 | def on_flow_begin(self, otdd, d): 625 | if self.save_path: 626 | save_dir = os.path.dirname(self.save_path) 627 | if not os.path.exists(save_dir): 628 | os.makedirs(save_dir) 629 | ax_ranges = self._get_ax_ranges(otdd.X1,otdd.X2) 630 | if self.animate: 631 | self.fig, self.ax = plt.subplots(figsize=(10,7)) 632 | self.ax.set_xlim(xrng) 633 | self.ax.set_ylim(yrng) 634 | self.camera = Camera(self.fig) 635 | 636 | title = r'Time t=0, $F(\rho_t)$={:4.2f}'.format(d) 637 | _ = self._plot(otdd, otdd.X1, otdd.Y1, otdd.X2, otdd.Y2, title) 638 | if self.animate: 639 | self.camera.snap() 640 | else: 641 | if self.save_path: 642 | outpath = self.save_path + 't0.' + self.save_format 643 | plt.tight_layout() 644 | plt.savefig(outpath, dpi=300) #bbox_inches='tight', 645 | plt.show(block=False) 646 | plt.pause(1) 647 | plt.close() 648 | 649 | def on_flow_end(self, flow, otdd): 650 | if self.animate: 651 | animation = self.camera.animate() 652 | if self.save_path: 653 | animation.save(self.save_path +'.mp4') 654 | self.animation = animation 655 | plt.close(self.fig) 656 | 657 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 658 | if self.display_freq is None or (iteration % self.display_freq == 0): # display 659 | title = r'Time t={:.2f}, $F(\rho_t)$={:4.2f}'.format(t, d) 660 | if self.show_trajectories and 'trajectories' in kwargs: 661 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, trajectories=kwargs['trajectories'], title=title) 662 | elif self.show_trajectories: 663 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, trajectories=flow.Xt, title=title) 664 | else: 665 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, title=title) 666 | 667 | if self.animate: 668 | self.camera.snap() 669 | else: 670 | if self.save_path: 671 | outpath = self.save_path + 't{}.{}'.format(iteration, self.save_format) 672 | plt.tight_layout() 673 | plt.savefig(outpath, dpi=300) #bbox_inches='tight', 674 | plt.show(block=False) 675 | plt.pause(1) 676 | plt.close() 677 | 678 | 679 | class PlottingCallback(Callback): 680 | def __init__(self, display_freq=None, animate=False, entreg_π = 1e-4, 681 | show_coupling = True, show_trajectories=True, trajectory_size=5, show_target = True, 682 | plot_pad = 2, figsize=(6,4), save_format='pdf', ndim=2, azim=-80 , elev=5, save_path=None): 683 | self.animate = animate 684 | self.display_freq = display_freq 685 | self.entreg_π = entreg_π 686 | self.show_coupling = show_coupling 687 | self.compute_coupling = 'every_iteration' if show_coupling else False 688 | self.show_trajectories = show_trajectories 689 | self.store_trajectories = self.show_trajectories 690 | self.trajectory_freq = self.display_freq 691 | self.trajectory_size = trajectory_size 692 | self.show_target = show_target 693 | self.ndim = ndim 694 | 695 | ## Low-level plotting args 696 | self.figsize = figsize 697 | self.azim = azim 698 | self.elev = elev 699 | self.plot_pad = plot_pad 700 | self.ax_ranges = [None]*ndim 701 | 702 | self.save_format = save_format 703 | self.save_path = save_path 704 | 705 | 706 | def _plot(self, otdd, X1, Y1=None, X2=None, Y2=None, title=None, trajectories=None): 707 | with torch.no_grad(): 708 | if self.animate: 709 | ax = self.ax 710 | ax_ranges = [ax.get_xlim(), ax.get_ylim()] 711 | if self.ndim == 3: 712 | ax_ranges.append(ax.get_zlim()) 713 | else: 714 | fig = plt.figure(figsize=self.figsize) 715 | if self.ndim == 2: 716 | ax = fig.add_subplot(111) 717 | elif self.ndim == 3: 718 | ax = fig.add_subplot(111, azim=self.azim, elev=self.elev, projection='3d') 719 | ax_ranges = self._get_ax_ranges(X1, X2) 720 | 721 | if self.ndim == 2: 722 | otdd.plot_label_stats(ax = ax, same_plot=True, show_target=self.show_target, 723 | label_means=False, 724 | label_groups=True, show=False,shift=(2,-2)) 725 | else: 726 | ax.scatter(*X1.detach().T, c=Y1, cmap = 'tab10') 727 | 728 | if self.show_trajectories and trajectories is not None: 729 | for x in trajectories[:,:,-self.trajectory_size:]: 730 | ax.plot(*x, color='k', alpha=0.2, linewidth=0.5) 731 | 732 | 733 | ax.set_title('') 734 | ax.set_xlim(ax_ranges[0]) 735 | ax.set_ylim(ax_ranges[1]) 736 | if self.ndim == 2: 737 | ## To remove ticks & axes grids 738 | ax.get_xaxis().set_ticks([]) 739 | ax.get_yaxis().set_ticks([]) 740 | else: 741 | ## To keep grid 742 | for _ax in [ax.xaxis, ax.yaxis, ax.zaxis]: 743 | _ax.set_ticklabels([]) 744 | _ax.set_ticks_position('none') 745 | ax.set_zlim(ax_ranges[2]) 746 | 747 | if title is not None: 748 | if self.ndim == 2: 749 | ax.text(0.5, 1.1, title, transform=ax.transAxes, ha='center',size=18) 750 | else: 751 | ax.text2D(0.5, 0.85, title, transform=ax.transAxes, ha='center',size=18) 752 | 753 | def _get_ax_ranges(self, X1, X2): 754 | pad = self.plot_pad 755 | if all(v is not None for v in self.ax_ranges): 756 | return self.ax_ranges 757 | with torch.no_grad(): 758 | mins, maxs = [], [] 759 | for i in range(self.ndim): 760 | if self.show_target: 761 | mins.append(min(X1[:,i].min(), X2[:,i].min()) - pad) 762 | maxs.append(max(X1[:,i].max(), X2[:,i].max()) + pad) 763 | else: 764 | mins.append(X1[:,i].min() - pad) 765 | maxs.append(X1[:,i].max() + pad) 766 | 767 | self.ax_ranges = [(mins[i].item(), maxs[i].item()) for i in range(self.ndim)] 768 | return self.ax_ranges 769 | 770 | def on_flow_begin(self, otdd, d): 771 | if self.save_path: 772 | save_dir = os.path.dirname(self.save_path) 773 | if not os.path.exists(save_dir): 774 | os.makedirs(save_dir) 775 | ax_ranges = self._get_ax_ranges(otdd.X1,otdd.X2) 776 | 777 | if self.animate: 778 | self.fig = plt.figure(figsize=self.figsize) 779 | if self.ndim == 2: 780 | self.ax = self.fig.add_subplot(111) 781 | elif self.ndim == 3: 782 | self.ax = self.fig.add_subplot(111, azim=self.azim, elev=self.elev, projection='3d') 783 | 784 | self.ax.set_xlim(ax_ranges[0]) 785 | self.ax.set_ylim(ax_ranges[1]) 786 | if self.ndim == 3: 787 | self.ax.set_zlim(ax_ranges[2]) 788 | self.camera = Camera(self.fig) 789 | 790 | title = r'Time t=0, $F(\rho_t)$={:4.2f}'.format(d) 791 | _ = self._plot(otdd, otdd.X1, otdd.Y1, otdd.X2, otdd.Y2, title) 792 | if self.animate: 793 | self.camera.snap() 794 | else: 795 | if self.save_path: 796 | outpath = self.save_path + 't0.' + self.save_format 797 | plt.tight_layout() 798 | plt.savefig(outpath, dpi=300) #bbox_inches='tight', 799 | plt.show(block=False) 800 | plt.pause(1) 801 | plt.close() 802 | 803 | def on_flow_end(self, flow, otdd): 804 | if self.animate: 805 | animation = self.camera.animate() 806 | if self.save_path: 807 | animation.save(self.save_path +'.mp4') 808 | self.animation = animation 809 | plt.close(self.fig) 810 | 811 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 812 | if self.display_freq is None or (iteration % self.display_freq == 0): # display 813 | title = r'Time t={:.2f}, $F(\rho_t)$={:4.2f}'.format(t, d) 814 | if self.show_trajectories and 'trajectories' in kwargs: 815 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, trajectories=kwargs['trajectories'], title=title) 816 | elif self.show_trajectories: 817 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, trajectories=flow.Xt, title=title) 818 | else: 819 | self._plot(otdd, otdd.X1.detach(), otdd.Y1.detach(), otdd.X2, otdd.Y2, title=title) 820 | 821 | if self.animate: 822 | self.camera.snap() 823 | else: 824 | if self.save_path: 825 | outpath = self.save_path + 't{}.{}'.format(iteration, self.save_format) 826 | plt.tight_layout() 827 | plt.savefig(outpath, dpi=300) #bbox_inches='tight', 828 | plt.show(block=False) 829 | plt.pause(0.2) 830 | plt.close() 831 | 832 | 833 | class Embedding2DCallback(Plotting2DCallback): 834 | def __init__(self, method = 'tsne', joint = True, **kwargs): 835 | super().__init__(**kwargs) 836 | self.method = method 837 | self.joint = joint 838 | 839 | 840 | def _embed(self, otdd): 841 | with torch.no_grad(): 842 | X = torch.cat([otdd.X1.detach().clone().cpu(),otdd.X2.clone().cpu()], dim=0) 843 | if tsnelib in ['sklearn', 'tsnecuda']: 844 | X_emb = TSNE(n_components=self.ndim, verbose=0, perplexity=50).fit_transform(X) 845 | else: 846 | if not hasattr(self, 'tsne') or self.tsne is None: 847 | X_emb = TSNE(n_components=self.ndim, perplexity=50, 848 | n_jobs=8,verbose=True).fit(X) 849 | self.tsne = X_emb 850 | X_emb = self.tsne.transform(X).astype(np.float32) 851 | if isinstance(X_emb, np.ndarray): 852 | X_emb = torch.from_numpy(X_emb) 853 | X1_emb, X2_emb = X_emb[:otdd.X1.shape[0],:], X_emb[otdd.X1.shape[0]:,:] 854 | return X1_emb.to('cpu'), X2_emb.to('cpu') 855 | 856 | def on_flow_begin(self, otdd, d): 857 | 858 | X1_emb, X2_emb = self._embed(otdd) 859 | 860 | ## Compute Stats for Embedded Points 861 | Ds_emb = TensorDataset(X1_emb, otdd.Y1) 862 | Dt_emb = TensorDataset(X2_emb, otdd.Y2) 863 | 864 | Ms_emb, Cs_emb = compute_label_stats(Ds_emb, 865 | targets=otdd.Y1.cpu() - otdd.Y1.cpu().min(), 866 | indices=np.arange(len(otdd.Y1)), 867 | classnames=otdd.classes1, 868 | to_tensor=True, 869 | nworkers=0, device=otdd.device, 870 | diagonal_cov = otdd.diagonal_cov, 871 | online=otdd.online_stats, 872 | eigen_correction=otdd.eigen_correction, 873 | ) 874 | 875 | Mt_emb, Ct_emb = compute_label_stats(Dt_emb, 876 | targets=otdd.Y2.cpu() - otdd.Y2.cpu().min(), 877 | indices=np.arange(len(otdd.Y2)), 878 | classnames=otdd.classes2, 879 | to_tensor=True, 880 | nworkers=0, device=otdd.device, 881 | diagonal_cov = otdd.diagonal_cov, 882 | online=otdd.online_stats, 883 | eigen_correction=otdd.eigen_correction, 884 | ) 885 | 886 | 887 | 888 | otdd_emb = otdd.copy(keep=['classes1','classes2', 'targets1', 'targets2','Y1','Y2']) 889 | otdd_emb.Y1 = otdd_emb.Y1.cpu() 890 | otdd_emb.Y2 = otdd_emb.Y2.cpu() 891 | otdd_emb.X1 = X1_emb 892 | otdd_emb.X2 = X2_emb 893 | otdd_emb.Means = [Ms_emb.cpu(),Mt_emb.cpu()] 894 | otdd_emb.Covs = [Cs_emb.cpu(),Ct_emb.cpu()] 895 | 896 | self.otdd_emb = otdd_emb 897 | 898 | if self.store_trajectories: 899 | self.Xt = self.otdd_emb.X1.detach().clone().cpu().unsqueeze(-1).float() # time will be last dim 900 | self.Yt = self.otdd_emb.Y1.detach().clone().cpu().unsqueeze(-1) 901 | 902 | 903 | super().on_flow_begin(self.otdd_emb, d) 904 | 905 | 906 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 907 | if self.display_freq is None or (iteration % self.display_freq == 0): 908 | X1_emb, X2_emb = self._embed(otdd) 909 | Ds_emb = TensorDataset(X1_emb, otdd.Y1) 910 | 911 | Ms_emb, Cs_emb = compute_label_stats(Ds_emb, 912 | targets=otdd.Y1.cpu() - otdd.Y1.cpu().min(), 913 | indices=np.arange(len(otdd.Y1)), 914 | classnames=otdd.classes1, 915 | to_tensor=True, 916 | nworkers=0, device=otdd.device, 917 | diagonal_cov = otdd.diagonal_cov, 918 | eigen_correction=otdd.eigen_correction, 919 | online=otdd.online_stats) 920 | self.otdd_emb.X1.data = X1_emb 921 | self.otdd_emb.Means[0] = Ms_emb 922 | self.otdd_emb.Covs[0] = Cs_emb 923 | 924 | Dt_emb = TensorDataset(X2_emb, otdd.Y2) 925 | Mt_emb, Ct_emb = compute_label_stats(Dt_emb, 926 | targets=otdd.Y2.cpu() - otdd.Y2.cpu().min(), 927 | indices=np.arange(len(otdd.Y2)), 928 | classnames=otdd.classes2, 929 | to_tensor=True, 930 | nworkers=0, device=otdd.device, 931 | diagonal_cov = otdd.diagonal_cov, 932 | eigen_correction=otdd.eigen_correction, 933 | online=otdd.online_stats) 934 | self.otdd_emb.X2.data = X2_emb 935 | self.otdd_emb.Means[1] = Mt_emb 936 | self.otdd_emb.Covs[1] = Ct_emb 937 | 938 | if self.store_trajectories: 939 | ## Convert to cpu, float (in case it was double) for dumping 940 | self.Xt = torch.cat([self.Xt, self.otdd_emb.X1.detach().clone().cpu().float().unsqueeze(-1)], dim=-1) 941 | self.Yt = torch.cat([self.Yt, self.otdd_emb.Y1.detach().clone().cpu().unsqueeze(-1)], dim=-1) 942 | 943 | super().on_step_end( 944 | flow, self.otdd_emb, iteration, t, d, 945 | trajectories=self.Xt if self.show_trajectories else None, # Must override non-embedded trajectories 946 | **kwargs) 947 | 948 | 949 | 950 | 951 | class ImageGridCallback(Callback): 952 | """ 953 | 954 | by_class: Grid will sample so that each col contains a single class 955 | only_matched: Only display properly matched particles (assumes labels 956 | of src and tgt are in direct correspondence (1st <-> 1st) 957 | etc, but will compensate in case Y2's are shifted. If 958 | True, automatically does by_class too. 959 | """ 960 | def __init__(self, display_freq=None, animate=False, entreg_π = 1e-4, 961 | byclass = True, only_matched=True, nrow=10, ncol=10, 962 | channels = 1, transparent=False, denormalize=None, save_path=None): 963 | self.animate = animate 964 | self.save_path = save_path 965 | if save_path: 966 | self.outdir = os.path.dirname(save_path) 967 | self.display_freq = display_freq 968 | self.channels = channels 969 | self.entreg_π = entreg_π 970 | self.byclass = byclass 971 | self.only_matched = only_matched 972 | self.compute_coupling = 'initial' if only_matched else False 973 | self.transparent = transparent 974 | self.denormalize = denormalize 975 | 976 | if not self.byclass: 977 | self.nrow = nrow 978 | self.ncol = nrow if ncol is None else ncol 979 | else: 980 | self.ncol = ncol 981 | self.nrow = None # will be number of classes, determined later 982 | self.indices = None 983 | 984 | def _plot(self, otdd, X1, X2, title): 985 | with torch.no_grad(): 986 | batch = X1[self.indices].view(len(self.indices), self.channels, self.imdim[0],self.imdim[1]) 987 | if self.denormalize is not None: 988 | batch = inverse_normalize(batch, *self.denormalize) 989 | ## make_grid reverts nrow ncol for some strange reason. 990 | grid = make_grid(batch, nrow=self.ncol, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 991 | if self.animate: 992 | ax = self.ax 993 | else: 994 | fig, ax = plt.subplots(figsize=(self.ncol,self.nrow)) 995 | show_grid(grid, ax=ax) 996 | ax.text(0.5, 1.03, title, transform=ax.transAxes, ha='center',size=18) 997 | 998 | def _choose_examples(self, otdd): 999 | X1, Y1 = otdd.X1.detach().cpu(), otdd.Y1.cpu() 1000 | X2, Y2 = otdd.X2.detach().cpu(), otdd.Y2.cpu() 1001 | if self.only_matched: 1002 | ## we match src -> trg 1003 | tgt_idxs = np.argmax(otdd.π, axis=1) 1004 | tgt_labels = Y2[tgt_idxs] - min(Y2) 1005 | matched = (Y1 == tgt_labels) 1006 | idxs = [] 1007 | for c in otdd.V1: 1008 | idxs_class_matched = torch.where((Y1 == c) & matched)[0] 1009 | if len(idxs_class_matched) >= self.nrow: 1010 | ## Have enough matched, select from these with prob prop ot weight on correct class 1011 | p = otdd.π[idxs_class_matched,:][:,Y2-min(Y2)==c].sum(axis=1) 1012 | p /= p.sum() 1013 | idxs_class_selected = np.sort(np.random.choice(idxs_class_matched, self.nrow, replace=False, p=p)) 1014 | else: 1015 | ## Not enough matched, complete with unmatched 1016 | idxs_class = torch.where(Y1 == c)[0] 1017 | unmatched = np.random.choice(idxs_class, self.nrow - len(idxs_class_matched), replace=False) 1018 | idxs_class_selected = np.concatenate((idxs_class_matched, unmatched)) 1019 | 1020 | assert len(idxs_class_selected) == self.nrow 1021 | idxs.append(idxs_class_selected) 1022 | idxs = np.concatenate(idxs) 1023 | elif not self.byclass: 1024 | idxs = np.sort(np.random.choice(X1.shape[0], self.nrow*self.ncol, replace=False)) 1025 | else: 1026 | idxs = [] 1027 | for c in otdd.V1: # V1 is never index-shifted, so this works 1028 | idxs_class = torch.where(Y1 == c)[0] 1029 | idxs.append(np.sort(np.random.choice(idxs_class, self.ncol, replace=False))) 1030 | idxs = np.concatenate(idxs) 1031 | self.indices = idxs 1032 | 1033 | def on_flow_begin(self, otdd, d): 1034 | if not os.path.exists(self.outdir): os.makedirs(self.outdir) 1035 | ## Choose indices of examples to show in grid 1036 | n, d2 = otdd.X1.shape 1037 | self.imdim = (int(np.sqrt(d2/self.channels)),int(np.sqrt(d2/self.channels))) 1038 | if self.byclass: 1039 | self.nrow = len(otdd.V1) 1040 | self.ncol = self.nrow if self.ncol is None else self.ncol 1041 | self._choose_examples(otdd) 1042 | 1043 | if self.animate: 1044 | self.fig, self.ax = plt.subplots(figsize=(self.ncol,self.nrow)) 1045 | self.camera = Camera(self.fig) 1046 | title = 'Time t=0, OTDD(S,T)={:4.2f}'.format(d) 1047 | self._plot(otdd, otdd.X1.detach(), otdd.X2, title) 1048 | if self.animate: 1049 | self.camera.snap() 1050 | else: 1051 | if self.save_path: 1052 | outpath = self.save_path + 't0' 1053 | plt.savefig(outpath+'.pdf', dpi=300, transparent=self.transparent) #bbox_inches='tight', 1054 | plt.savefig(outpath+'.png', dpi=300, transparent=self.transparent) #bbox_inches='tight', 1055 | plt.show(block=False) 1056 | plt.pause(1) 1057 | plt.close() 1058 | 1059 | def on_flow_end(self, flow, otdd): 1060 | if self.animate: 1061 | animation = self.camera.animate() 1062 | if self.save_path: 1063 | animation.save(self.save_path +'flow.mov', codec='png', dpi=300, 1064 | savefig_kwargs={'transparent': self.transparent, 1065 | 'facecolor': 'none'}) 1066 | self.animation = animation 1067 | plt.close(self.fig) 1068 | 1069 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 1070 | if not self.display_freq or (iteration % self.display_freq == 0): 1071 | title = r'Time t={:.2f}, OTDD(S,T)={:4.2f}'.format(t, d) 1072 | self._plot(otdd, otdd.X1.detach(), otdd.X2, title) 1073 | 1074 | if self.animate: 1075 | self.camera.snap() 1076 | else: 1077 | if self.save_path: 1078 | outpath = self.save_path + 't{}'.format(iteration) 1079 | plt.savefig(outpath+'.pdf', dpi=300, transparent=self.transparent) 1080 | plt.savefig(outpath+'.png', dpi=300, transparent=self.transparent) 1081 | plt.show(block=False) 1082 | plt.pause(1) 1083 | plt.close() 1084 | 1085 | 1086 | class TrainingCallback(Callback): 1087 | def __init__(self, criterion = 'xent', lr = 0.01, momentum=0.9, iters=20): 1088 | self.criterion = criterion 1089 | self.lr = lr 1090 | self.momentum = momentum 1091 | self.iters = iters 1092 | 1093 | def init_model(self, nclasses=10): 1094 | net = torch.nn.Sequential( 1095 | torch.nn.Linear(2, 10), 1096 | torch.nn.ReLU(), 1097 | torch.nn.Linear(10, 10), 1098 | torch.nn.ReLU(), 1099 | torch.nn.Linear(10, nclasses), 1100 | ) 1101 | return net 1102 | 1103 | def accuracy(output, target, topk=(1,)): 1104 | """Computes the accuracy over the k top predictions for the specified values of k""" 1105 | with torch.no_grad(): 1106 | maxk = max(topk) 1107 | batch_size = target.size(0) 1108 | 1109 | _, pred = output.topk(maxk, 1, True, True) 1110 | pred = pred.t() 1111 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 1112 | 1113 | res = [] 1114 | for k in topk: 1115 | correct_k = correct[:k].view(k,-1).float().sum() 1116 | res.append(correct_k.mul_(100.0 / batch_size)) 1117 | return res 1118 | 1119 | def train(self, model, X,Y, **kwargs): 1120 | criterion = torch.nn.CrossEntropyLoss() 1121 | optimizer = torch.optim.SGD(model.parameters(), lr=self.lr, momentum=self.momentum) 1122 | 1123 | for it in range(self.iters): 1124 | output = model(X) 1125 | loss = criterion(output, Y) 1126 | optimizer.zero_grad() 1127 | loss.backward() 1128 | optimizer.step() 1129 | 1130 | acc1, acc3 = self.accuracy(output,Y, topk=(1,3)) 1131 | 1132 | logger.info('Finall Loss: {} ({}%)'.format(loss, acc1.item())) 1133 | 1134 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 1135 | 1136 | net = self.init_model(nclasses=len(torch.unique(otdd.Y1))) 1137 | 1138 | self.train(net, otdd.X1.detach().clone(), otdd.Y1) 1139 | 1140 | 1141 | class TrajectoryDump(Callback): 1142 | def __init__(self, save_freq=None, save_path=None): 1143 | self.save_freq = save_freq 1144 | self.save_path = save_path 1145 | self.store_trajectories = True 1146 | self.trajectory_freq = save_freq 1147 | self.outdir = save_path 1148 | 1149 | def on_flow_begin(self, otdd, d): 1150 | if not os.path.exists(self.outdir): os.makedirs(self.outdir) 1151 | 1152 | def on_flow_end(self, flow, otdd): 1153 | trajpath = os.path.join(self.outdir, 'trajectories_X.pt') 1154 | torch.save(flow.Xt, trajpath) 1155 | trajpath = os.path.join(self.outdir, 'trajectories_Y.pt') 1156 | torch.save(flow.Yt, trajpath) 1157 | if hasattr(flow.otdd, 'Y1_true') and (flow.otdd.Y1_true is not None): 1158 | trajpath = os.path.join(self.outdir, 'Y_init_true.pt') 1159 | torch.save(flow.otdd.Y1_true, trajpath) 1160 | logger.info('Saved trajectories to: {}'.format(trajpath)) 1161 | 1162 | def on_step_end(self, flow, otdd, iteration, t, d, **kwargs): 1163 | pass 1164 | -------------------------------------------------------------------------------- /otdd/pytorch/functionals.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ############### COLLECTION OF FUNCTIONALS ON DATASETS ########################## 3 | ################################################################################ 4 | import numpy as np 5 | import torch 6 | 7 | class Functional(): 8 | """ 9 | Defines a JKO functional over measures implicitly by defining it over 10 | individual particles (points). 11 | 12 | The input should be a full dataset: points X (n x d) with labels Y (n x 1). 13 | Optionally, the means/variances associated with each class can be passed. 14 | 15 | (extra space do to repeating) 16 | 17 | """ 18 | def __init__(self, V=None, W=None, f=None, weights=None): 19 | self.V = V # The functional on Z space in potential energy 𝒱() = V 20 | self.W = W # The bi-linear form on ZxZ spaces in interaction energy 𝒲 21 | self.f = f # The scalar-valued function in the niternal energy term ℱ 22 | 23 | def __call__(x, y, μ=None, Σ=None): 24 | sum = 0 25 | if self.F is not None: 26 | sum += self.F(x,y,μ,Σ) 27 | if self.V is not None: 28 | sum += self.V(x,y,μ,Σ) 29 | if self.W is not None: 30 | sum += self.W(x,y,μ,Σ) 31 | return sum 32 | 33 | ################################################################################ 34 | ####### Potential energy functionals (denoted by V in the paper) ######### 35 | ################################################################################ 36 | 37 | def affine_feature_norm(X,Y=None,A=None, b=None, threshold=None, weight=1.0): 38 | """ A simple (feature-only) potential energy based on affine transform + norm: 39 | 40 | v(x,y) = || Ax - b ||, so that V(ρ) = ∫|| Ax - b || dρ(x,y) 41 | 42 | where the integral is approximated by empirical expectation (mean). 43 | """ 44 | if A is None and b is None: 45 | norm = X.norm(dim=1) 46 | elif A is None and not b is None: 47 | norm = (X - b).norm(dim=1) 48 | elif not A is None and b is None: 49 | norm = (X - b).norm(dim=1) 50 | else: 51 | norm = (X@A - b).norm(dim=1) 52 | if threshold: 53 | norm = torch.nn.functional.threshold(norm, threshold, 0) 54 | return weight*norm.mean() 55 | 56 | def binary_hyperplane_margin(X, Y, w, b, weight=1.0): 57 | """ A potential function based on margin separation according to a (given 58 | and fixed) hyperplane: 59 | 60 | v(x,y) = max(0, 1 - y(x'w - b) ), so that V(ρ) = ∫ max(0, y(x'w - b) ) dρ(x,y) 61 | 62 | Returns 0 if all points are at least 1 away from margin. 63 | 64 | Note that y is expected to be {0,1} 65 | 66 | Needs separation hyperplane be determined by (w, b) parameters. 67 | """ 68 | Y_hat = 2*Y-1 # To map Y to {-1, 1}, required by the SVM-type margin obj we use 69 | margin = torch.relu(1-Y_hat*(torch.matmul(X, w) - b)) 70 | return weight*margin.mean() 71 | 72 | def dimension_collapse(X, Y, dim=1, v=None, weight=1.0): 73 | """ Potential function to induce a dimension collapse """ 74 | if v is None: 75 | v = 0 76 | deviation = (X[:,dim] - v)**2 77 | return weight*deviation.mean() 78 | 79 | 80 | 81 | def cluster_repulsion(X, Y): 82 | pdb.set_trace() 83 | 84 | ################################################################################ 85 | ######## Interaction energy functionals (denoted by W in the paper) ######### 86 | ################################################################################ 87 | 88 | def interaction_fun(X, Y, weight=1.0): 89 | """ 90 | 91 | """ 92 | Z = torch.cat((X, Y.float().unsqueeze(1)), -1) 93 | 94 | n,d = Z.shape 95 | Diffs = Z.repeat(n,1,1).transpose(0,1) - Z.repeat(n,1,1) 96 | 97 | def _f(δz): # Enforces cluster repulsion: 98 | δx, δy = torch.split(δz,[δz.shape[-1]-1,1], dim=-1) 99 | δy = torch.abs(δy/δy.max()).ceil() # Hacky way to get 0/1 loss for δy 100 | return -(δx*δy).norm(dim=-1).mean(dim=-1) 101 | 102 | val = _f(Diffs).mean() 103 | 104 | return val*weight 105 | 106 | 107 | def binary_cluster_margin(X, Y, μ=None, weight=1.0): 108 | """ Similar to binary_hyperplane_margin but does to require a separating 109 | hyperplane be provided in advance. Instead, computes one based on current 110 | datapoints as the hyperplane through the midpoint of their means. 111 | 112 | Also, ensures that ..., so it requires point-to-point comparison (interaction) 113 | 114 | """ 115 | 116 | μ_0 = X[Y==0].mean(0) 117 | μ_1 = X[Y==1].mean(0) 118 | 119 | n,d = X.shape 120 | diffs_x = X.repeat(n,1,1).transpose(0,1) - X.repeat(n,1,1) 121 | diffs_x = torch.nn.functional.normalize(diffs_x, dim=2, p=2) 122 | 123 | μ = torch.zeros(n,d) 124 | μ[Y==0,:] = μ_0 125 | μ[Y==1,:] = μ_1 126 | 127 | diffs_μ = μ.repeat(n,1,1).transpose(0,1) - μ.repeat(n,1,1) 128 | diffs_μ = torch.nn.functional.normalize(diffs_μ, dim=2, p=2) 129 | 130 | 131 | inner_prod = torch.einsum("ijk,ijl->ij", diffs_x, diffs_μ) 132 | 133 | print(inner_prod.min(), inner_prod.max()) 134 | 135 | out = torch.relu(-inner_prod + 1) 136 | 137 | print(out.shape) 138 | 139 | margin = torch.exp(out) 140 | return weight*margin.mean() 141 | -------------------------------------------------------------------------------- /otdd/pytorch/moments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for moment (mean/cov) computation needed by OTTD and other routines. 3 | """ 4 | 5 | import logging 6 | import pdb 7 | 8 | import torch 9 | import torch.utils.data.dataloader as dataloader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | 12 | from .utils import process_device_arg, extract_data_targets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def cov(m, mean=None, rowvar=True, inplace=False): 18 | """ Estimate a covariance matrix given data. 19 | 20 | Covariance indicates the level to which two variables vary together. 21 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 22 | then the covariance matrix element `C_{ij}` is the covariance of 23 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 24 | 25 | Arguments: 26 | m (tensor): A 1-D or 2-D array containing multiple variables and observations. 27 | Each row of `m` represents a variable, and each column a single 28 | observation of all those variables. 29 | rowvar (bool): If `rowvar` is True, then each row represents a 30 | variable, with observations in the columns. Otherwise, the 31 | relationship is transposed: each column represents a variable, 32 | while the rows contain observations. 33 | 34 | Returns: 35 | The covariance matrix of the variables. 36 | """ 37 | if m.dim() > 2: 38 | raise ValueError('m has more than 2 dimensions') 39 | if m.dim() < 2: 40 | m = m.view(1, -1) 41 | if not rowvar and m.size(0) != 1: 42 | m = m.t() 43 | fact = 1.0 / (m.size(1) - 1) 44 | if mean is None: 45 | mean = torch.mean(m, dim=1, keepdim=True) 46 | else: 47 | mean = mean.unsqueeze(1) # For broadcasting 48 | if inplace: 49 | m -= mean 50 | else: 51 | m = m - mean 52 | mt = m.t() # if complex: mt = m.t().conj() 53 | return fact * m.matmul(mt).squeeze() 54 | 55 | class OnlineStatsRecorder: 56 | """ Online batch estimation of multivariate sample mean and covariance matrix. 57 | 58 | Alleviates numerical instability due to catastrophic cancellation that 59 | the naive estimation suffers from. 60 | 61 | Two pass approach first computes population mean, and then uses stable 62 | one pass algorithm on residuals x' = (x - μ). Uses the fact that Cov is 63 | translation invariant, and less cancellation happens if E[XX'] and 64 | E[X]E[X]' are far apart, which is the case for centered data. 65 | 66 | Ideas from: 67 | - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 68 | - https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html 69 | """ 70 | def __init__(self, data=None, twopass=True, centered_cov=False, 71 | diagonal_cov=False, embedding=None, 72 | device='cpu', dtype=torch.FloatTensor): 73 | """ 74 | Arguments: 75 | data (torch tensor): batch of data of shape (nobservations, ndimensions) 76 | twopass (bool): whether two use the two-pass approach (recommended) 77 | centered_cov (bool): whether covariance matrix is centered throughout 78 | the iterations. If false, centering happens once, 79 | at the end. 80 | diagonal_cov (bool): whether covariance matrix should be diagonal 81 | (i.e. ignore cross-correlation terms). In this 82 | case only diagonal (1xdim) tensor retrieved. 83 | embedding (callable): if provided, will map features using this 84 | device (str): device for storage of computed statistics 85 | dtype (torch data type): data type for computed statistics 86 | 87 | """ 88 | self.device = device 89 | self.centered_cov = centered_cov 90 | self.diagonal_cov = diagonal_cov 91 | self.twopass = twopass 92 | self.dtype = dtype 93 | self.embedding = embedding 94 | 95 | self._init_values() 96 | 97 | def _init_values(self): 98 | self.μ = None 99 | self.Σ = None 100 | self.n = 0 101 | 102 | def compute_from_loader(self, dataloader): 103 | """ Compute statistics from dataloader """ 104 | device = process_device_arg(self.device) 105 | for x, _ in dataloader: 106 | x = x.type(self.dtype).to(device) 107 | x = self.embedding(x).detach() if self.embedding is not None else x 108 | self.update(x.view(x.shape[0], -1)) 109 | μ, Σ = self.retrieve() 110 | if self.twopass: 111 | self._init_values() 112 | self.centered_cov = False 113 | for x, _ in dataloader: 114 | x = x.type(self.dtype).to(device) 115 | x = self.embedding(x).detach() if self.embedding is not None else x 116 | self.update(x.view(x.shape[0],-1)-μ) # We compute cov on residuals 117 | _, Σ = self.retrieve() 118 | return μ, Σ 119 | 120 | def update(self, batch): 121 | """ Update statistics using batch of data. 122 | 123 | Arguments: 124 | data (tensor): tensor of shape (nobservations, ndimensions) 125 | """ 126 | if self.n == 0: 127 | self.n,self.d = batch.shape 128 | self.μ = batch.mean(axis=0) 129 | if self.diagonal_cov and self.centered_cov: 130 | self.Σ = torch.var(batch, axis=0, unbiased=True) 131 | ## unbiased is default in pytorch, shown here just to be explicit 132 | elif self.diagonal_cov and not self.centered_cov: 133 | self.Σ = batch.pow(2).sum(axis=0)/(1.0*self.n-1) 134 | elif self.centered_cov: 135 | self.Σ = ((batch-self.μ).T).matmul(batch-self.μ)/(1.0*self.n-1) 136 | else: 137 | self.Σ = (batch.T).matmul(batch)/(1.0*self.n-1) 138 | ## note that this not really covariance yet (not centered) 139 | else: 140 | if batch.shape[1] != self.d: 141 | raise ValueError("Data dims don't match prev observations.") 142 | 143 | ### Dimensions 144 | m = self.n * 1.0 145 | n = batch.shape[0] *1.0 146 | 147 | ### Mean Update 148 | self.μ = self.μ + (batch-self.μ).sum(axis=0)/(m+n) # Stable Algo 149 | 150 | ### Cov Update 151 | if self.diagonal_cov and self.centered_cov: 152 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).pow(2).sum(axis=0)))/(m+n-1) 153 | elif self.diagonal_cov and not self.centered_cov: 154 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.pow(2).sum(axis=0)) 155 | elif self.centered_cov: 156 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).T).matmul(batch-self.μ))/(m+n-1) 157 | else: 158 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.T).matmul(batch) 159 | 160 | ### Update total number of examples seen 161 | self.n += n 162 | 163 | def retrieve(self, verbose=False): 164 | """ Retrieve current statistics """ 165 | if verbose: print('Mean and Covariance computed on {} samples'.format(int(self.n))) 166 | if self.centered_cov: 167 | return self.μ, self.Σ 168 | elif self.diagonal_cov: 169 | Σ = self.Σ - self.μ.pow(2)*self.n/(self.n-1) 170 | Σ = torch.nn.functional.relu(Σ) # To avoid negative variances due to rounding 171 | return self.μ, Σ 172 | else: 173 | return self.μ, self.Σ - torch.ger(self.μ.T,self.μ)*self.n/(self.n-1) 174 | 175 | 176 | def _single_label_stats(data, i, c, label_indices, M=None, S=None, batch_size=256, 177 | embedding=None, online=True, diagonal_cov=False, 178 | dtype=None, device=None): 179 | """ Computes mean/covariance of examples that have a given label. Note that 180 | classname c is only needed for vanity printing. Device info needed here since 181 | dataloaders are used inside. 182 | 183 | Arguments: 184 | data (pytorch Dataset or Dataloader): data to compute stats on 185 | i (int): index of label (a.k.a class) to filter 186 | c (int/str): value of label (a.k.a class) to filter 187 | 188 | Returns: 189 | μ (torch tensor): empirical mean of samples with given label 190 | Σ (torch tensor): empirical covariance of samples with given label 191 | n (int): number of samples with giben label 192 | 193 | """ 194 | device = process_device_arg(device) 195 | if len(label_indices) < 2: 196 | logger.warning(" -- Class '{:10}' has too few examples ({})." \ 197 | " Ignoring it.".format(c, len(label_indices))) 198 | if M is None: 199 | return None,None,len(label_indices) 200 | else: 201 | if type(data) == dataloader.DataLoader: 202 | ## We'll reuse the provided dataloader, just setting indices. 203 | ## If loader had indices before, we restore them when we're done 204 | filtered_loader = data 205 | if hasattr(data.sampler,'indices'): 206 | _orig_indices = data.sampler.indices 207 | else: 208 | _orig_indices = None 209 | filtered_loader.sampler.indices = label_indices 210 | 211 | else: 212 | ## Create our own loader 213 | filtered_loader = dataloader.DataLoader(data, batch_size=batch_size, 214 | sampler=SubsetRandomSampler(label_indices)) 215 | _orig_indices = None 216 | 217 | if online: 218 | ## Will compute online (i.e. without loading all the data at once) 219 | stats_rec = OnlineStatsRecorder(centered_cov=True, twopass=True, 220 | diagonal_cov=diagonal_cov, device=device, 221 | embedding=embedding, 222 | dtype=dtype) 223 | μ, Σ = stats_rec.compute_from_loader(filtered_loader) 224 | 225 | n = int(stats_rec.n) 226 | else: 227 | X = torch.cat([d[0].to(device) for d in filtered_loader]).squeeze() 228 | X = embedding(X) if embedding is not None else X 229 | μ = torch.mean(X, dim = 0).flatten() 230 | if diagonal_cov: 231 | Σ = torch.var(X, dim=0).flatten() 232 | else: 233 | Σ = cov(X.view(X.shape[0], -1).t()) 234 | n = X.shape[0] 235 | logger.info(' -> class {:10} (id {:2}): {} examples'.format(c, i, n)) 236 | 237 | if diagonal_cov: 238 | try: 239 | assert Σ.min() >= 0 240 | except: 241 | pdb.set_trace() 242 | 243 | ## Reinstante original indices in sampler 244 | if _orig_indices is not None: data.sampler.indices = _orig_indices 245 | 246 | if M is not None: 247 | M[i],S[i] = μ.cpu(),Σ.cpu() # To avoid GPU parallelism problems 248 | else: 249 | return μ,Σ,n 250 | 251 | 252 | def compute_label_stats(data, targets=None,indices=None,classnames=None, 253 | online=True, batch_size=100, to_tensor=True, 254 | eigen_correction=False, 255 | eigen_correction_scale=1.0, 256 | nworkers=0, diagonal_cov = False, 257 | embedding=None, 258 | device=None, dtype = torch.FloatTensor): 259 | """ 260 | Computes mean/covariance of examples grouped by label. Data can be passed as 261 | a pytorch dataset or a dataloader. Uses dataloader to avoid loading all 262 | classes at once. 263 | 264 | Arguments: 265 | data (pytorch Dataset or Dataloader): data to compute stats on 266 | targets (Tensor, optional): If provided, will use this target array to 267 | avoid re-extracting targets. 268 | indices (array-like, optional): If provided, filtering is based on these 269 | indices (useful if e.g. dataloader has subsampler) 270 | eigen_correction (bool, optional): If ``True``, will shift the covariance 271 | matrix's diagonal by :attr:`eigen_correction_scale` to ensure PSD'ness. 272 | eigen_correction_scale (numeric, optional): Magnitude of eigenvalue 273 | correction (used only if :attr:`eigen_correction` is True) 274 | 275 | Returns: 276 | M (dict): Dictionary with sample means (Tensors) indexed by target class 277 | S (dict): Dictionary with sample covariances (Tensors) indexed by target class 278 | """ 279 | 280 | device = process_device_arg(device) 281 | M = {} # Means 282 | S = {} # Covariances 283 | 284 | ## We need to get all targets in advance, in order to filter. 285 | ## Here we assume targets is the full dataset targets (ignoring subsets, etc) 286 | ## so we need to find effective targets. 287 | if targets is None: 288 | targets, classnames, indices = extract_data_targets(data) 289 | else: 290 | assert (indices is not None), "If targets are provided, so must be indices" 291 | if classnames is None: 292 | classnames = sorted([a.item() for a in torch.unique(targets)]) 293 | 294 | effective_targets = targets[indices] 295 | 296 | if nworkers > 1: 297 | import torch.multiprocessing as mp # Ugly, sure. But useful. 298 | mp.set_start_method('spawn',force=True) 299 | M = mp.Manager().dict() # Alternatively, M = {}; M.share_memory 300 | S = mp.Manager().dict() 301 | processes = [] 302 | for i,c in enumerate(classnames): # No. of processes 303 | label_indices = indices[effective_targets == i] 304 | p = mp.Process(target=_single_label_stats, 305 | args=(data, i,c,label_indices,M,S), 306 | kwargs={'device': device, 'online':online}) 307 | p.start() 308 | processes.append(p) 309 | for p in processes: p.join() 310 | else: 311 | for i,c in enumerate(classnames): 312 | label_indices = indices[effective_targets == i] 313 | μ,Σ,n = _single_label_stats(data, i,c,label_indices, device=device, 314 | dtype=dtype, embedding=embedding, 315 | online=online, diagonal_cov=diagonal_cov) 316 | M[i],S[i] = μ, Σ 317 | 318 | if to_tensor: 319 | ## Warning: this assumes classes are *exactly* {0,...,n}, might break things 320 | ## downstream if data is missing some classes 321 | M = torch.stack([μ.to(device) for i,μ in sorted(M.items()) if μ is not None], dim=0) 322 | S = torch.stack([Σ.to(device) for i,Σ in sorted(S.items()) if Σ is not None], dim=0) 323 | 324 | ### Shift the Covariance matrix's diagonal to ensure PSD'ness 325 | if eigen_correction: 326 | logger.warning('Applying eigenvalue correction to Covariance Matrix') 327 | λ = eigen_correction_scale 328 | for i in range(S.shape[0]): 329 | if eigen_correction == 'constant': 330 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device = device)) 331 | elif eigen_correction == 'jitter': 332 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device=device).uniform_(0.99, 1.01)) 333 | elif eigen_correction == 'exact': 334 | s,v = torch.symeig(S[i]) 335 | print(s.min()) 336 | s,v = torch.lobpcg(S[i], largest=False) 337 | print(s.min()) 338 | s = torch.eig(S[i], eigenvectors=False).eigenvalues 339 | print(s.min()) 340 | pdb.set_trace() 341 | s_min = s.min() 342 | if s_min <= 1e-10: 343 | S[i] += torch.diag(λ*torch.abs(s_min)*torch.ones(S.shape[1], device=device)) 344 | raise NotImplemented() 345 | return M,S 346 | 347 | 348 | def dimreduce_means_covs(Means, Covs, redtype='diagonal'): 349 | """ Methods to reduce the dimensionality of the Feature-Mean/Covariance 350 | representation of Labels. 351 | 352 | Arguments: 353 | Means (tensor or list of tensors): original mean vectors 354 | Covs (tensor or list of tensors): original covariances matrices 355 | redtype (str): dimensionality reduction methods, one of 'diagonal', 'mds' 356 | or 'distance_embedding'. 357 | 358 | Returns: 359 | Means (tensor or list of tensors): dimensionality-reduced mean vectors 360 | Covs (tensor or list of tensors): dimensionality-reduced covariance matrices 361 | 362 | """ 363 | n1, d1 = Means[0].shape 364 | n2, d2 = Means[1].shape 365 | k = d1 366 | 367 | print(n1, d1, n2, d2) 368 | if redtype == 'diagonal': 369 | ## Leave Means As Is, Keep Only Diag of Covariance Matrices, Independent DR for Each Task 370 | Covs[0] = torch.stack([torch.diag(C) for C in Covs[0]]) 371 | Covs[1] = torch.stack([torch.diag(C) for C in Covs[1]]) 372 | elif redtype == 'mds': 373 | ## Leave Means As Is, Use MDS to DimRed Covariance Matrices, Independent DR for Each Task 374 | Covs[0] = mds(Covs[0].view(Covs[0].shape[0], -1), output_dim=k) 375 | Covs[1] = mds(Covs[1].view(Covs[1].shape[0], -1), output_dim=k) 376 | elif redtype == 'distance_embedding': 377 | ## Leaves Means As Is, Use Bipartitie MSE Embedding, Which Embeds the Pairwise Distance Matrix, Rather than the Cov Matrices Directly 378 | print('Will reduce dimension of Σs by embedding pairwise distance matrix...') 379 | D = torch.zeros(n1, n2) 380 | print('... computing pairwise bures distances ...') 381 | for (i, j) in tqdm(itertools.product(range(n1), range(n2))): 382 | D[i, j] = bures_distance(Covs[0][i], Covs[1][j]) 383 | print('... embedding distance matrix ...') 384 | U, V = bipartite_mse_embedding(D, k=k) 385 | Covs = [U, V] 386 | print("Done! Σ's Dimensions: {} (Task 1) and {} (Task 2)".format( 387 | list(U.shape), list(V.shape))) 388 | else: 389 | raise ValueError('Reduction type not recognized') 390 | return Means, Covs 391 | 392 | 393 | def pairwise_distance_mse(U, V, D, reg=1): 394 | d_uv = torch.cdist(U, V) 395 | l = torch.norm(D - d_uv)**2 / D.numel() + reg * (torch.norm(U) ** 396 | 2 / U.numel() + torch.norm(V)**2 / V.numel()) # MSE per entry 397 | return l 398 | 399 | 400 | def bipartite_mse_embedding(D, k=100, niters=10000): 401 | n, m = D.shape 402 | U = torch.randn(n, k, requires_grad=True) 403 | V = torch.randn(m, k, requires_grad=True) 404 | optim = torch.optim.SGD([U, V], lr=1e-1) 405 | for i in range(niters): 406 | optim.zero_grad() 407 | loss = pairwise_distance_mse(U, V, D) 408 | loss.backward() 409 | if i % 100 == 0: 410 | print(i, loss.item()) 411 | optim.step() 412 | loss = pairwise_distance_mse(U, V, D, reg=0) 413 | print( 414 | "Final distortion: ||D - D'||\u00b2/|D| = {:4.2f}".format(loss.item())) 415 | return U.detach(), V.detach() 416 | -------------------------------------------------------------------------------- /otdd/pytorch/nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of basic neural net models used in the OTDD experiments 3 | """ 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | from .. import ROOT_DIR, HOME_DIR 12 | 13 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 14 | 15 | MNIST_FLAT_DIM = 28 * 28 16 | 17 | def reset_parameters(m): 18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 19 | m.reset_parameters() 20 | 21 | class LeNet(nn.Module): 22 | def __init__(self, pretrained=False, num_classes = 10, input_size=28, **kwargs): 23 | super(LeNet, self).__init__() 24 | suffix = f'dim{input_size}_nc{num_classes}' 25 | self.model_path = os.path.join(MODELS_DIR, f'lenet_mnist_{suffix}.pt') 26 | assert input_size in [28,32], "Can only do LeNet on 28x28 or 32x32 for now." 27 | 28 | feat_dim = 16*5*5 if input_size == 32 else 16*4*4 29 | self.feat_dim = feat_dim 30 | self.num_classes = num_classes 31 | if input_size == 32: 32 | self.conv1 = nn.Conv2d(1, 6, 3) 33 | self.conv2 = nn.Conv2d(6, 16, 3) 34 | elif input_size == 28: 35 | self.conv1 = nn.Conv2d(1, 6, 5) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | else: 38 | raise ValueError() 39 | 40 | self._init_classifier() 41 | 42 | if pretrained: 43 | state_dict = torch.load(self.model_path) 44 | self.load_state_dict(state_dict) 45 | 46 | def _init_classifier(self, num_classes=None): 47 | """ Useful for fine-tuning """ 48 | num_classes = self.num_classes if num_classes is None else num_classes 49 | self.classifier = nn.Sequential( 50 | nn.Linear(self.feat_dim, 120), # 6*6 from image dimension 51 | nn.ReLU(), 52 | nn.Dropout(), 53 | nn.Linear(120, 84), 54 | nn.ReLU(), 55 | nn.Dropout(), 56 | nn.Linear(84, num_classes) 57 | ) 58 | 59 | def forward(self, x): 60 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 61 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 62 | x = x.view(-1, self.num_flat_features(x)) 63 | return self.classifier(x) 64 | 65 | def num_flat_features(self, x): 66 | size = x.size()[1:] # all dimensions except the batch dimension 67 | num_features = 1 68 | for s in size: 69 | num_features *= s 70 | return num_features 71 | 72 | def save(self): 73 | state_dict = self.state_dict() 74 | torch.save(state_dict, self.model_path) 75 | 76 | class MNIST_MLP(nn.Module): 77 | def __init__( 78 | self, 79 | input_dim=MNIST_FLAT_DIM, 80 | hidden_dim=98, 81 | output_dim=10, 82 | dropout=0.5, 83 | ): 84 | super(ClassifierModule, self).__init__() 85 | self.dropout = nn.Dropout(dropout) 86 | self.hidden = nn.Linear(input_dim, hidden_dim) 87 | self.output = nn.Linear(hidden_dim, output_dim) 88 | 89 | def forward(self, X, **kwargs): 90 | X = X.reshape(-1, self.hidden.in_features) 91 | X = F.relu(self.hidden(X)) 92 | X = self.dropout(X) 93 | X = F.softmax(self.output(X), dim=-1) 94 | return X 95 | 96 | class MNIST_CNN(nn.Module): 97 | def __init__(self, input_size=28, dropout=0.3, nclasses=10, pretrained=False): 98 | super(MNIST_CNN, self).__init__() 99 | self.nclasses = nclasses 100 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 102 | self.conv2_drop = nn.Dropout2d(p=dropout) 103 | self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height 104 | self.logit = nn.Linear(100, self.nclasses) 105 | self.fc1_drop = nn.Dropout(p=dropout) 106 | suffix = f'dim{input_size}_nc{nclasses}' 107 | self.model_path = os.path.join(MODELS_DIR, f'cnn_mnist_{suffix}.pt') 108 | if pretrained: 109 | state_dict = torch.load(self.model_path) 110 | self.load_state_dict(state_dict) 111 | 112 | def forward(self, x): 113 | x = torch.relu(F.max_pool2d(self.conv1(x), 2)) 114 | x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 115 | x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) 116 | x = torch.relu(self.fc1_drop(self.fc1(x))) 117 | x = self.logit(x) 118 | x = F.log_softmax(x, dim=-1) 119 | return x 120 | 121 | def save(self): 122 | state_dict = self.state_dict() 123 | torch.save(state_dict, self.model_path) 124 | 125 | 126 | class MLPClassifier(nn.Module): 127 | def __init__( 128 | self, 129 | input_size=None, 130 | hidden_size=400, 131 | num_classes=2, 132 | dropout=0.2, 133 | pretrained=False, 134 | ): 135 | super(MLPClassifier, self).__init__() 136 | self.num_classes = num_classes 137 | self.hidden_sizes = [hidden_size, int(hidden_size/2), int(hidden_size/4)] 138 | 139 | self.dropout = nn.Dropout(dropout) 140 | self.fc1 = nn.Linear(input_size, self.hidden_sizes[0]) 141 | self.fc2 = nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1]) 142 | self.fc3 = nn.Linear(self.hidden_sizes[1], self.hidden_sizes[2]) 143 | 144 | self._init_classifier() 145 | 146 | def _init_classifier(self, num_classes=None): 147 | num_classes = self.num_classes if num_classes is None else num_classes 148 | self.classifier = nn.Sequential( 149 | nn.Linear(self.hidden_sizes[-1], 20), 150 | nn.ReLU(), 151 | nn.Linear(20, num_classes) 152 | ) 153 | 154 | def forward(self, x, **kwargs): 155 | x = self.dropout(F.relu(self.fc1(x))) 156 | x = self.dropout(F.relu(self.fc2(x))) 157 | x = self.dropout(F.relu(self.fc3(x))) 158 | x = self.classifier(x) 159 | return x 160 | 161 | class BoWSentenceEmbedding(): 162 | def __init__(self, vocab_size, embedding_dim, pretrained_vec, padding_idx=None, method = 'naive'): 163 | self.method = method 164 | if method == 'bag': 165 | self.emb = nn.EmbeddingBag.from_pretrained(pretrained_vec, padding_idx=padding_idx) 166 | else: 167 | self.emb = nn.Embedding.from_pretrained(pretrained_vec) 168 | 169 | def __call__(self, x): 170 | if self.method == 'bag': 171 | return self.emb(x) 172 | else: 173 | return self.emb(x).mean(dim=1) 174 | 175 | class MLPPushforward(nn.Module): 176 | def __init__(self, input_size=2, nlayers = 3, **kwargs): 177 | super(MLPPushforward, self).__init__() 178 | d = input_size 179 | 180 | _layers = [] 181 | _d = d 182 | for i in range(nlayers): 183 | _layers.append(nn.Linear(_d, 2*_d)) 184 | _layers.append(nn.ReLU()) 185 | _layers.append(nn.Dropout(0.0)) 186 | _d = 2*_d 187 | for i in range(nlayers): 188 | _layers.append(nn.Linear(_d,int(0.5*_d))) 189 | if i < nlayers - 1: _layers.append(nn.ReLU()) 190 | _layers.append(nn.Dropout(0.0)) 191 | _d = int(0.5*_d) 192 | 193 | self.mapping = nn.Sequential(*_layers) 194 | 195 | def forward(self, x): 196 | return self.mapping(x) 197 | 198 | def reset_parameters(self): 199 | self.mapping.apply(reset_parameters) 200 | 201 | 202 | class ConvPushforward(nn.Module): 203 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 204 | super(ConvPushforward, self).__init__() 205 | self.input_size = input_size 206 | self.channels = channels 207 | if input_size == 32: 208 | self.upconv1 = nn.Conv2d(1, 6, 3) 209 | self.upconv2 = nn.Conv2d(6, 16, 3) 210 | feat_dim = 16*5*5 211 | ## decoder layers ## 212 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 213 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 214 | elif input_size == 28: 215 | self.upconv1 = nn.Conv2d(1, 6, 5) 216 | self.upconv2 = nn.Conv2d(6, 16, 5) 217 | feat_dim = 16*4*4 218 | self.dnconv1 = nn.ConvTranspose2d(16, 6, 5) 219 | self.dnconv2 = nn.ConvTranspose2d(6, 1, 5) 220 | else: 221 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 222 | self.feat_dim = feat_dim 223 | 224 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 225 | 226 | def forward(self, x): 227 | _orig_shape = x.shape 228 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 229 | x, idx1 = F.max_pool2d(F.relu(self.upconv1(x)), 2, return_indices=True) 230 | x, idx2 = F.max_pool2d(F.relu(self.upconv2(x)), 2, return_indices=True) 231 | _nonflat_shape = x.shape 232 | x = x.view(-1, self.num_flat_features(x)) 233 | x = self.mlp(x).reshape(_nonflat_shape) 234 | x = F.relu(self.dnconv1(F.max_unpool2d(x, idx2, kernel_size=2))) 235 | x = torch.tanh(self.dnconv2(F.max_unpool2d(x, idx1, kernel_size=2))) 236 | return x.reshape(_orig_shape) 237 | 238 | def num_flat_features(self, x): 239 | size = x.size()[1:] # all dimensions except the batch dimension 240 | num_features = 1 241 | for s in size: 242 | num_features *= s 243 | return num_features 244 | 245 | def reset_parameters(self): 246 | for name, module in self.named_children(): 247 | module.reset_parameters() 248 | 249 | 250 | class ConvPushforward2(nn.Module): 251 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 252 | super(ConvPushforward2, self).__init__() 253 | self.input_size = input_size 254 | self.channels = channels 255 | if input_size == 32: 256 | self.upconv1 = nn.Conv2d(1, 6, 3) 257 | self.upconv2 = nn.Conv2d(6, 16, 3) 258 | feat_dim = 16*5*5 259 | ## decoder layers ## 260 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 261 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 262 | elif input_size == 28: 263 | self.upconv1 = nn.Conv2d(1, 16, 3, stride=3, padding=1) # b, 16, 10, 10 264 | self.upconv2 = nn.Conv2d(16, 8, 3, stride=2, padding=1) # b, 8, 3, 3 265 | feat_dim = 8*2*2 266 | self.dnconv1 = nn.ConvTranspose2d(8, 16, 3, stride=2) # b, 16, 5, 5 267 | self.dnconv2 = nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1) # b, 8, 15, 15 268 | self.dnconv3 = nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1) # b, 1, 28, 28 269 | else: 270 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 271 | self.feat_dim = feat_dim 272 | 273 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 274 | 275 | def forward(self, x): 276 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 277 | x = F.max_pool2d(F.relu(self.upconv1(x)), 2, stride=2) 278 | x = F.max_pool2d(F.relu(self.upconv2(x)), 2, stride=1) 279 | _nonflat_shape = x.shape 280 | x = x.view(-1, self.num_flat_features(x)) 281 | x = self.mlp(x).reshape(_nonflat_shape) 282 | x = F.relu(self.dnconv1(x)) 283 | x = F.relu(self.dnconv2(x)) 284 | x = torch.tanh(self.dnconv3(x)) 285 | return x 286 | 287 | def num_flat_features(self, x): 288 | size = x.size()[1:] # all dimensions except the batch dimension 289 | num_features = 1 290 | for s in size: 291 | num_features *= s 292 | return num_features 293 | 294 | def reset_parameters(self): 295 | for name, module in T.named_children(): 296 | print('resetting ', name) 297 | module.reset_parameters() 298 | 299 | 300 | class ConvPushforward3(nn.Module): 301 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 302 | super(ConvPushforward3, self).__init__() 303 | self.input_size = input_size 304 | self.channels = channels 305 | 306 | self.upconv1 = nn.Conv2d(1, 128, 3, 1, 2, dilation=2) 307 | self.upconv2 = nn.Conv2d(128, 128, 3, 1, 2) 308 | self.upconv3 = nn.Conv2d(128, 256, 3, 1, 2) 309 | self.upconv4 = nn.Conv2d(256, 256, 3, 1, 2) 310 | self.upconv5 = nn.Conv2d(128, 128, 3, 1, 2) 311 | self.upconv6 = nn.Conv2d(128, 128, 3, 1, 2) 312 | self.upconv7 = nn.Conv2d(128, 128, 3, 1, 2) 313 | self.upconv8 = nn.Conv2d(128, 128, 3, 1, 2) 314 | 315 | self.dnconv4 = nn.ConvTranspose2d(256, 256, 3, 1, 2) 316 | self.dnconv3 = nn.ConvTranspose2d(256, 128, 3, 1, 2) 317 | self.dnconv2 = nn.ConvTranspose2d(128, 128, 3, 1, 2) 318 | self.dnconv1 = nn.ConvTranspose2d(128, 1, 3, 1, 2, dilation=2) 319 | 320 | self.maxpool1 = nn.MaxPool2d(2, return_indices=True) 321 | self.maxpool2 = nn.MaxPool2d(2, return_indices=True) 322 | self.maxpool3 = nn.MaxPool2d(2, return_indices=True) 323 | self.maxunpool1 = nn.MaxUnpool2d(2) 324 | self.maxunpool2 = nn.MaxUnpool2d(2) 325 | 326 | self.relu1 = nn.ReLU() 327 | self.relu2 = nn.ReLU() 328 | self.relu3 = nn.ReLU() 329 | self.relu4 = nn.ReLU() 330 | self.relu5 = nn.ReLU() 331 | self.relu6 = nn.ReLU() 332 | self.relu7 = nn.ReLU() 333 | self.relu8 = nn.ReLU() 334 | self.derelu1 = nn.ReLU() 335 | self.derelu2 = nn.ReLU() 336 | self.derelu3 = nn.ReLU() 337 | self.derelu4 = nn.ReLU() 338 | self.derelu5 = nn.ReLU() 339 | self.derelu6 = nn.ReLU() 340 | self.derelu7 = nn.ReLU() 341 | self.bn1 = nn.BatchNorm2d(16) 342 | self.bn2 = nn.BatchNorm2d(32) 343 | self.bn3 = nn.BatchNorm2d(16) 344 | self.bn4 = nn.BatchNorm2d(1) 345 | 346 | 347 | def forward(self, x): 348 | x = self.upconv1(x) 349 | x = self.relu1(x) 350 | 351 | x = self.upconv2(x) 352 | x = self.relu2(x) 353 | 354 | x = self.upconv3(x) 355 | x = self.relu3(x) 356 | 357 | x = self.upconv4(x) 358 | x = self.relu4(x) 359 | 360 | x = self.derelu4(x) 361 | x = self.dnconv4(x) 362 | 363 | x = self.derelu3(x) 364 | x = self.dnconv3(x) 365 | 366 | x = self.derelu2(x) 367 | x = self.dnconv2(x) 368 | 369 | x = self.derelu1(x) 370 | x = self.dnconv1(x) 371 | 372 | return x 373 | 374 | def num_flat_features(self, x): 375 | size = x.size()[1:] # all dimensions except the batch dimension 376 | num_features = 1 377 | for s in size: 378 | num_features *= s 379 | return num_features 380 | 381 | def reset_parameters(self): 382 | for name, module in self.named_children(): 383 | try: 384 | module.reset_parameters() 385 | except: 386 | pass 387 | -------------------------------------------------------------------------------- /otdd/pytorch/sqrtm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Routines for computing matrix square roots. 3 | 4 | With ideas from: 5 | 6 | https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py 7 | https://github.com/pytorch/pytorch/issues/25481 8 | """ 9 | 10 | import pdb 11 | import torch 12 | from torch.autograd import Function 13 | from functools import partial 14 | import numpy as np 15 | import scipy.linalg 16 | try: 17 | import cupy as cp 18 | except: 19 | import numpy as cp 20 | 21 | #### VIA SVD, version 1: from https://github.com/pytorch/pytorch/issues/25481 22 | def symsqrt_v1(A, func='symeig'): 23 | """Compute the square root of a symmetric positive definite matrix.""" 24 | ## https://github.com/pytorch/pytorch/issues/25481#issuecomment-576493693 25 | ## perform the decomposition 26 | ## Recall that for Sym Real matrices, SVD, EVD coincide, |λ_i| = σ_i, so 27 | ## for PSD matrices, these are equal and coincide, so we can use either. 28 | if func == 'symeig': 29 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 30 | elif func == 'svd': 31 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 32 | else: 33 | raise ValueError() 34 | 35 | ## truncate small components 36 | good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps 37 | components = good.sum(-1) 38 | common = components.max() 39 | unbalanced = common != components.min() 40 | if common < s.size(-1): 41 | s = s[..., :common] 42 | v = v[..., :common] 43 | if unbalanced: 44 | good = good[..., :common] 45 | if unbalanced: 46 | s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype)) 47 | return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) 48 | 49 | 50 | #### VIA SVD, version 2: from https://github.com/pytorch/pytorch/issues/25481 51 | def symsqrt_v2(A, func='symeig'): 52 | """Compute the square root of a symmetric positive definite matrix.""" 53 | if func == 'symeig': 54 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 55 | elif func == 'svd': 56 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 57 | else: 58 | raise ValueError() 59 | 60 | above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps 61 | 62 | ### This doesn't work for batched version 63 | 64 | ### This does but fails gradcheck because of inpalce 65 | 66 | ### This seems to be equivalent to above, work for batch, and pass inplace. CHECK!!!! 67 | s = torch.where(above_cutoff, s, torch.zeros_like(s)) 68 | 69 | sol =torch.matmul(torch.matmul(v,torch.diag_embed(s.sqrt(),dim1=-2,dim2=-1)),v.transpose(-2,-1)) 70 | 71 | return sol 72 | 73 | # 74 | # 75 | 76 | def special_sylvester(a, b): 77 | """Solves the eqation `A @ X + X @ A = B` for a positive definite `A`.""" 78 | s, v = a.symeig(eigenvectors=True) 79 | d = s.unsqueeze(-1) 80 | d = d + d.transpose(-2, -1) 81 | vt = v.transpose(-2, -1) 82 | c = vt @ b @ v 83 | return v @ (c / d) @ vt 84 | 85 | 86 | ##### Via Newton-Schulz: based on 87 | ## https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py, and 88 | ## https://github.com/BorisMuzellec/EllipticalEmbeddings/blob/master/utils.py 89 | def sqrtm_newton_schulz(A, numIters, reg=None, return_error=False, return_inverse=False): 90 | """ Matrix squareroot based on Newton-Schulz method """ 91 | if A.ndim <= 2: # Non-batched mode 92 | A = A.unsqueeze(0) 93 | batched = False 94 | else: 95 | batched = True 96 | batchSize = A.shape[0] 97 | dim = A.shape[1] 98 | normA = (A**2).sum((-2,-1)).sqrt() # Slightly faster than : A.mul(A).sum((-2,-1)).sqrt() 99 | 100 | if reg: 101 | ## Renormalize so that the each matrix has a norm lesser than 1/reg, 102 | ## but only normalize when necessary 103 | normA *= reg 104 | renorm = torch.ones_like(normA) 105 | renorm[torch.where(normA > 1.0)] = normA[cp.where(normA > 1.0)] 106 | else: 107 | renorm = normA 108 | 109 | Y = A.div(renorm.view(batchSize, 1, 1).expand_as(A)) 110 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 111 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 112 | for i in range(numIters): 113 | T = 0.5*(3.0*I - Z.bmm(Y)) 114 | Y = Y.bmm(T) 115 | Z = T.bmm(Z) 116 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 117 | sAinv = Z/torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 118 | if not batched: 119 | sA = sA[0,:,:] 120 | sAinv = sAinv[0,:,:] 121 | 122 | if not return_inverse and not return_error: 123 | return sA 124 | elif not return_inverse and return_error: 125 | return sA, compute_error(A, sA) 126 | elif return_inverse and not return_error: 127 | return sA,sAinv 128 | else: 129 | return sA, sAinv, compute_error(A, sA) 130 | 131 | def create_symm_matrix(batchSize, dim, numPts=20, tau=1.0, dtype=torch.float32, 132 | verbose=False): 133 | """ Creates a random PSD matrix """ 134 | A = torch.zeros(batchSize, dim, dim).type(dtype) 135 | for i in range(batchSize): 136 | pts = np.random.randn(numPts, dim).astype(np.float32) 137 | sA = np.dot(pts.T, pts)/numPts + tau*np.eye(dim).astype(np.float32); 138 | A[i,:,:] = torch.from_numpy(sA); 139 | if verbose: print('Creating batch %d, dim %d, pts %d, tau %f, dtype %s' % (batchSize, dim, numPts, tau, dtype)) 140 | return A 141 | 142 | def compute_error(A, sA): 143 | """ Computes error in approximation """ 144 | normA = torch.sqrt(torch.sum(torch.sum(A * A, dim=1),dim=1)) 145 | error = A - torch.bmm(sA, sA) 146 | error = torch.sqrt((error * error).sum(dim=1).sum(dim=1)) / normA 147 | return torch.mean(error) 148 | 149 | ###========================== 150 | 151 | class MatrixSquareRoot(Function): 152 | """Square root of a positive definite matrix. 153 | 154 | NOTE: square root is not differentiable for matrices with zero eigenvalues. 155 | 156 | """ 157 | @staticmethod 158 | def forward(ctx, input, method = 'numpy'): 159 | _dev = input.device 160 | if method == 'numpy': 161 | m = input.cpu().detach().numpy().astype(np.float_) 162 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input) 163 | elif method == 'pytorch': 164 | sqrtm = symsqrt(input) 165 | ctx.save_for_backward(sqrtm) 166 | return sqrtm 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output, method = 'numpy'): 170 | grad_input = None 171 | if ctx.needs_input_grad[0]: 172 | sqrtm, = ctx.saved_tensors 173 | if method == 'numpy': 174 | sqrtm = sqrtm.data.numpy().astype(np.float_) 175 | gm = grad_output.data.numpy().astype(np.float_) 176 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 177 | grad_input = torch.from_numpy(grad_sqrtm).type_as(grad_output.data) 178 | elif method == 'pytorch': 179 | grad_input = special_sylvester(sqrtm, grad_output) 180 | return grad_input 181 | 182 | 183 | ## ========================================================================== ## 184 | ## NOTE: Must pick which version of matrix square root to use!!!! 185 | 186 | ## sqrtm = MatrixSquareRoot.apply 187 | sqrtm = symsqrt_v2 188 | ## sqrtm = symsqrt_v1 189 | ## sqrtm = symsqrt_diff 190 | ## ========================================================================== ## 191 | 192 | def main(): 193 | from torch.autograd import gradcheck 194 | 195 | k = torch.randn(5, 20, 20).double() 196 | M = k @ k.transpose(-1,-2) 197 | 198 | s1 = symsqrt_v1(M, func='symeig') 199 | test = torch.allclose(M, s1 @ s1.transpose(-1,-2)) 200 | print('Via symeig:', test) 201 | 202 | s2 = symsqrt_v1(M, func='svd') 203 | test = torch.allclose(M, s2 @ s2.transpose(-1,-2)) 204 | print('Via svd: ', test) 205 | 206 | print('Sqrtm with symeig and svd match:', torch.allclose(s1,s2)) 207 | 208 | M.requires_grad = True 209 | 210 | ## Check gradients for symsqrt 211 | _sqrt = partial(symsqrt, func='svd') 212 | test = gradcheck(_sqrt, (M,)) 213 | print('Grach Check for sqrtm/svd:', test) 214 | 215 | ## Check symeig itself 216 | S = torch.rand(5,20,20, requires_grad=True).double() 217 | def func(S): 218 | x = 0.5 * (S + S.transpose(-2, -1)) 219 | return torch.symeig(x, eigenvectors=True) 220 | print('Grad check for symeig', gradcheck(func, [S])) 221 | 222 | ## Check gradients for symsqrt with symeig 223 | _sqrt = partial(symsqrt, func='symeig') 224 | test = gradcheck(_sqrt, (M,)) 225 | print('Grach Check for sqrtm/symeig:', test) 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /otdd/pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import zip_longest, product 3 | from functools import partial 4 | from os.path import dirname 5 | import numpy as np 6 | import scipy.sparse 7 | from tqdm.autonotebook import tqdm 8 | import torch 9 | import random 10 | import pdb 11 | import string 12 | import logging 13 | from sklearn.cluster import k_means, DBSCAN 14 | 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | from PIL import Image 19 | import PIL.ImageOps 20 | 21 | import torch.nn as nn 22 | import torch.utils.data as torchdata 23 | import torch.utils.data.dataloader as dataloader 24 | from torch.utils.data.sampler import SubsetRandomSampler 25 | from munkres import Munkres 26 | 27 | from .nets import BoWSentenceEmbedding 28 | from .sqrtm import sqrtm, sqrtm_newton_schulz 29 | 30 | DATASET_NORMALIZATION = { 31 | 'MNIST': ((0.1307,), (0.3081,)), 32 | 'USPS' : ((0.1307,), (0.3081,)), 33 | 'FashionMNIST' : ((0.1307,), (0.3081,)), 34 | 'QMNIST' : ((0.1307,), (0.3081,)), 35 | 'EMNIST' : ((0.1307,), (0.3081,)), 36 | 'KMNIST' : ((0.1307,), (0.3081,)), 37 | 'ImageNet': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 38 | 'CIFAR10': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 39 | 'CIFAR100': ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 40 | 'camelyonpatch': ((0.70038027, 0.53827554, 0.69125885), (0.23614734, 0.27760974, 0.21410067)) 41 | } 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | def inverse_normalize(tensor, mean, std): 46 | _tensor = tensor.clone() 47 | for ch in range(len(mean)): 48 | _tensor[:,ch,:,:].mul_(std[ch]).add_(mean[ch]) 49 | return _tensor 50 | 51 | def process_device_arg(device): 52 | " Convient function to abstract away processing of torch.device argument" 53 | if device is None: # Default to cuda:0 if possible, otherwise cpu 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | elif type(device) is str: 56 | device = torch.device(device) 57 | else: 58 | pass 59 | return device 60 | 61 | 62 | def interleave(*a): 63 | ## zip_longest filling values with as many NaNs as values in second axis 64 | l = *zip_longest(*a, fillvalue=[np.nan]*a[0].shape[1]), 65 | ## build a 2d array from the list 66 | out = np.concatenate(l) 67 | ## return non-NaN values 68 | return out[~np.isnan(out[:,0])] 69 | 70 | 71 | def random_index_split(input, alpha=0.9, max_split_sizes=(None,None)): 72 | " Returns two np arrays of indices, such that the first one has size alpha*n" 73 | if type(input) is int: 74 | indices, n = np.arange(input), input 75 | elif type(input) is list: 76 | indices, n = np.array(input), len(input) 77 | elif type(input) is np.ndarray: 78 | indices, n = input, len(input) 79 | np.random.shuffle(indices) # inplace 80 | split = int(np.floor(alpha * n)) 81 | idxs1, idxs2 = np.array(indices[:split]), np.array(indices[split:]) 82 | if max_split_sizes[0] is not None and (max_split_sizes[0] < len(idxs1)): 83 | idxs1 = np.sort(np.random.choice(idxs1, max_split_sizes[0], replace = False)) 84 | if max_split_sizes[1] is not None and (max_split_sizes[1] < len(idxs2)): 85 | idxs2 = np.sort(np.random.choice(idxs2, max_split_sizes[1], replace = False)) 86 | return idxs1, idxs2 87 | 88 | 89 | def extract_dataset_targets(d): 90 | """ Extracts targets from dataset. 91 | 92 | Extracts labels, classes and effective indices from a object of type 93 | torch.util.data.dataset.**. 94 | 95 | Arguments: 96 | d (torch Dataset): dataset to extract targets from 97 | 98 | Returns: 99 | targets (tensor): tensor with integer targets 100 | classes (tensor): tensor with class labels (might or might not be integer) 101 | indices (tensor): indices of examples 102 | 103 | Note: 104 | Indices can differ from range(len(d)) if, for example, this is a Subset dataset. 105 | 106 | """ 107 | assert isinstance(d, torch.utils.data.dataset.Dataset) 108 | if isinstance(d, torch.utils.data.dataset.Subset): 109 | dataset = d.dataset 110 | indices = d.indices 111 | elif isinstance(d, torch.utils.data.dataset.Dataset): # should be last option, since all above satisfy it 112 | dataset = d 113 | indices = d.indices if hasattr(d, 'indices') else None # this should always return None. Check. 114 | 115 | if hasattr(dataset, 'targets'): # most torchivision datasets 116 | targets = dataset.targets 117 | elif hasattr(dataset, '_data'): # some torchtext datasets 118 | targets = torch.LongTensor([e[0] for e in dataset._data]) 119 | elif hasattr(dataset, 'tensors') and len(dataset.tensors) == 2: # TensorDatasets 120 | targets = dataset.tensors[1] 121 | elif hasattr(dataset, 'tensors') and len(dataset.tensors) == 1: 122 | logger.warning('Dataset seems to be unlabeled - this modality is in beta mode!') 123 | targets = None 124 | else: 125 | raise ValueError("Could not find targets in dataset.") 126 | 127 | classes = dataset.classes if hasattr(dataset, 'classes') else torch.sort(torch.unique(targets)).values 128 | 129 | if (indices is None) and (targets is not None): 130 | indices = np.arange(len(targets)) 131 | elif indices is None: 132 | indices = np.arange(len(dataset)) 133 | else: 134 | indices = np.sort(indices) 135 | 136 | return targets, classes, indices 137 | 138 | 139 | def extract_dataloader_targets(dl): 140 | """ Extracts targets from dataloader. 141 | 142 | Extracts labels, classes and effective indices from a object of type 143 | torch.util.data.dataset.**. 144 | 145 | Arguments: 146 | d (torch DataLoader): dataloader to extract targets from 147 | 148 | Returns: 149 | targets (tensor): tensor with integer targets 150 | classes (tensor): tensor with class labels (might or might not be integer) 151 | indices (tensor): indices of examples 152 | 153 | Note: 154 | Indices can differ from range(len(d)) if, for example, this is a Subset dataset. 155 | 156 | """ 157 | assert isinstance(dl, torch.utils.data.dataloader.DataLoader) 158 | assert hasattr(dl, 'dataset'), "Dataloader does not have dataset attribute." 159 | 160 | ## Extract targets from underlying dataset 161 | targets, classes, indices = extract_dataset_targets(dl.dataset) 162 | 163 | ## Now need to check if loader does some subsampling 164 | if hasattr(dl, 'sampler') and hasattr(dl.sampler, 'indices'): 165 | idxs_sampler = dl.sampler.indices 166 | if indices is not None and len(indices)!=len(targets) and idxs_sampler is not None: 167 | ## Sampler indices should be subset of datasetd indices 168 | if set(idxs_sampler).issubset(set(indices)): 169 | indices = idxs_sampler 170 | else: 171 | print("STOPPING. Incosistent dataset and sampler indices.") 172 | pdb.set_trace() 173 | else: 174 | indices = idxs_sampler 175 | 176 | if indices is None: 177 | indices = np.arange(len(targets)) 178 | else: 179 | indices = np.sort(indices) 180 | 181 | return targets, classes, indices 182 | 183 | 184 | def extract_data_targets(d): 185 | """ Wrapper around extract_dataloader_targets and extract_dataset_targets, 186 | for convenience """ 187 | if isinstance(d, torch.utils.data.dataloader.DataLoader): 188 | return extract_dataloader_targets(d) 189 | elif isinstance(d, torch.utils.data.dataset.Dataset): 190 | return extract_dataset_targets(d) 191 | else: 192 | raise ValueError("Incompatible data object") 193 | 194 | 195 | def load_full_dataset(data, targets=False, return_both_targets=False, 196 | labels_keep=None, min_labelcount=None, 197 | batch_size = 256, 198 | maxsamples = None, device='cpu', dtype=torch.FloatTensor, 199 | feature_embedding=None, labeling_function=None, 200 | force_label_alignment = False, 201 | reindex=False, reindex_start=0): 202 | """ Loads full dataset into memory. 203 | 204 | Arguments: 205 | targets (bool, or 'infer'): Whether to colleect and return targets (labels) too 206 | return_both_targets (bool): Only used when targets='infer'. Indicates whether 207 | the true targets should also be returned. 208 | labels_keep (list): If provided, will only keep examples with these labels 209 | reindex (bool): Whether/how to reindex labels. If True, will 210 | reindex to {reindex_start,...,reindex_start+num_unique_labels}. 211 | 212 | maxsamples (int): Maximum number of examples to load. (this might not equal 213 | actual size of return tensors, if label_keep also provided) 214 | 215 | Returns: 216 | X (tensor): tensor of dataset features, stacked along first dimension 217 | Y (tensor): tensor of dataset targets 218 | 219 | """ 220 | device = process_device_arg(device) 221 | orig_idxs = None 222 | if type(data) == dataloader.DataLoader: 223 | loader = data 224 | if maxsamples: 225 | if hasattr(loader, 'sampler') and hasattr(loader.sampler, 'indices'): 226 | if len(loader.sampler.indices) <= maxsamples: 227 | logger.warning('Maxsamples is greater than number of effective examples in loader. Will not subsample.') 228 | else: 229 | ## Resample from sampler indices. 230 | orig_idxs = loader.sampler.indices 231 | idxs = np.sort(np.random.choice(orig_idxs, maxsamples, replace=False)) 232 | loader.sampler.indices = idxs 233 | elif hasattr(loader, 'dataset'): # This probably means the sampler is not a subsampler. So len(dataset) is indeed true size. 234 | if len(loader.dataset) <= maxsamples: 235 | logger.warning('Maxsamples is greater than number of examples in loader. Will not subsample.') 236 | else: 237 | ## Create new sampler 238 | idxs = np.sort(np.random.choice(len(loader.dataset), maxsamples, replace=False)) 239 | sampler = SubsetRandomSampler(idxs) 240 | loader = dataloader.DataLoader(data, sampler=sampler, batch_size=batch_size) 241 | else: 242 | ## I don't think we'll ever be in this case. 243 | print('Warning: maxsamplers provided but loader doesnt have subsampler or dataset. Cannot subsample.') 244 | else: 245 | ## data is a dataset 246 | if maxsamples and len(data) > maxsamples: 247 | idxs = np.sort(np.random.choice(len(data), maxsamples, replace=False)) 248 | sampler = SubsetRandomSampler(idxs) 249 | loader = dataloader.DataLoader(data, sampler=sampler, batch_size=batch_size) 250 | else: 251 | ## No subsampling 252 | loader = dataloader.DataLoader(data, batch_size=batch_size) 253 | 254 | X = [] 255 | Y = [] 256 | seen_targets = {} 257 | keeps = None 258 | collect_targets = targets and ((targets != 'infer') or return_both_targets) 259 | 260 | for batch in tqdm(loader, leave=False): 261 | x = batch[0] 262 | if (len(batch) == 2) and targets: 263 | y = batch[1] 264 | 265 | if feature_embedding is not None: 266 | ## if embedding is cuda, and device='cpu', want to map to device *after* 267 | ## embedding, to take advantage of CUDA forward pass. 268 | try: 269 | x = feature_embedding(x.type(dtype).cuda()).detach().to(device) 270 | except: 271 | x = feature_embedding(x.type(dtype).to(device)).detach() 272 | else: 273 | x = x.type(dtype).to(device) 274 | 275 | X.append(x.squeeze().view(x.shape[0],-1)) 276 | if collect_targets: # = True or infer 277 | Y.append(y.to(device).squeeze()) 278 | X = torch.cat(X) 279 | 280 | if collect_targets: Y = torch.cat(Y) 281 | 282 | if targets == 'infer': 283 | logger.warning('Performing clustering') 284 | if Y is not None: # Save true targets before overwriting them with inferred 285 | Y_true = Y 286 | Y = labeling_function(X) 287 | 288 | if force_label_alignment: 289 | K = torch.unique(Y_true).shape[0] 290 | M = [((Y == k) & (Y_true == l)).sum().item() for k,l in product(range(K),range(K))] 291 | M = np.array(M).reshape(K,K) 292 | idx_map = dict(Munkres().compute(1 - M/len(Y))) 293 | Y = torch.tensor([idx_map[int(y.item())] for y in Y]) 294 | 295 | if min_labelcount is not None: 296 | assert not labels_keep, "Cannot specify both min_labelcount and labels_keep" 297 | vals, cts = torch.unique(Y, return_counts=True) 298 | labels_keep = torch.sort(vals[cts >= min_labelcount])[0] 299 | 300 | 301 | if labels_keep is not None: # Filter out examples with unwanted label 302 | keeps = np.isin(Y.cpu(), labels_keep) 303 | X = X[keeps,:] 304 | Y = Y[keeps] 305 | 306 | if orig_idxs is not None: 307 | loader.sampler.indices = orig_idxs 308 | if targets is False: 309 | return X 310 | else: 311 | if reindex: 312 | labels = sorted(torch.unique(Y).tolist()) 313 | reindex_vals = range(reindex_start, reindex_start + len(labels)) 314 | lmap = dict(zip(labels, reindex_vals)) 315 | Y = torch.LongTensor([lmap[y.item()] for y in Y]).to(device) 316 | if not return_both_targets: 317 | return X, Y 318 | else: 319 | return X, Y, Y_true 320 | 321 | 322 | def sample_kshot_task(dataset,k=10,valid=None): 323 | """ This is agnostic to the labels used, it will inferr them from dataset 324 | so it works equally well with remaped or non remap subsets. 325 | """ 326 | inds_train = [] 327 | Y = dataset.targets 328 | V = sorted(list(torch.unique(Y))) 329 | inds_valid = [] 330 | for c in V: 331 | m = torch.where(Y == c)[0].squeeze() 332 | srt_ind = m[torch.randperm(len(m))] 333 | inds_train.append(srt_ind[:k]) 334 | if valid: 335 | inds_valid.append(srt_ind[k:k+valid]) 336 | inds_train = torch.sort(torch.cat(inds_train))[0] 337 | assert len(inds_train) == k*len(V) 338 | train = torch.utils.data.Subset(dataset,inds_train) 339 | tr_lbls = [train[i][1] for i in range(len(train))] 340 | tr_cnts = np.bincount(tr_lbls) 341 | assert np.all(tr_cnts == [k]*len(V)) 342 | 343 | if valid: 344 | inds_valid = torch.sort(torch.cat(inds_valid))[0] 345 | valid = torch.utils.data.Subset(dataset,inds_valid) 346 | return train, valid 347 | else: 348 | return train 349 | 350 | 351 | def load_trajectories(path, device='cpu'): 352 | Xt = torch.load(path + '/trajectories_X.pt') 353 | Yt = torch.load(path + '/trajectories_Y.pt') 354 | assert Xt.ndim == 3 355 | assert Yt.ndim == 2 356 | assert Xt.shape[0] == Yt.shape[0] 357 | assert Xt.shape[-1] == Yt.shape[-1] 358 | n,d,t = Xt.shape 359 | logger.info(f'Trajectories: {n} points, {d} dim, {t} steps.') 360 | if device is not None: 361 | Xt = Xt.to(torch.device(device)) 362 | Yt = Yt.to(torch.device(device)) 363 | return Xt, Yt 364 | 365 | 366 | def augmented_dataset(dataset, means, covs, maxn=1000):#, diagonal_cov=False): 367 | """ Generate moment-augmented dataset by concatenating features, means and 368 | covariances. This will only make sense when using Gaussians for target 369 | representation. Every instance in the augmented dataset will have form: 370 | 371 | x̂_i = [x_i,mean(y_i),vec(cov(y_i))] 372 | 373 | Therefore: 374 | ||x̂_i - x̂_j||_p^p = ||x_i - x_j||_p^p + 375 | ||mean(y_i)-mean(y_j)||_p^p + 376 | ||sqrt(cov(y_i))-sqrt(cov(y_j))||_p^p 377 | 378 | """ 379 | if type(dataset) is tuple and type(dataset[0]) is torch.Tensor: 380 | X, Y = dataset 381 | elif type(dataset) is torch.utils.data.dataset.Dataset: 382 | X, Y = load_full_dataset(dataset, targets=True) 383 | else: 384 | raise ValueError('Wrong Format') 385 | 386 | if maxn and maxn < X.shape[0]: 387 | idxs = sorted(np.random.choice(range(X.shape[0]),maxn, replace=False)) 388 | else: 389 | idxs = range(X.shape[0]) 390 | 391 | X = X[idxs,:] 392 | Y = Y[idxs] 393 | if Y.min() > 0: # We reindxed the labels, need to revert 394 | Y -= Y.min() 395 | M = means[Y[idxs],:] 396 | if covs[0].ndim == 1: 397 | ## Implies Covariance is diagonal 398 | sqrt_covs = torch.sqrt(covs) 399 | else: 400 | sqrt_covs = torch.stack([sqrtm(c) for c in torch.unbind(covs, 0)]) 401 | 402 | C = sqrt_covs[Y[idxs],:] 403 | 404 | C = C.view(C.shape[0], -1) 405 | 406 | dim_before = X.shape[1] 407 | X_aug = torch.cat([X,M,C],1) 408 | logger.info('Augmented from dim {} to {}'.format(dim_before, X_aug.shape[1])) 409 | return X_aug 410 | 411 | 412 | def extract_torchmeta_task(cs, class_ids): 413 | """ Extracts a single "episode" (ie, task) from a ClassSplitter object, in the 414 | form of a dataset, and appends variables needed by DatasetDistance computation. 415 | 416 | Arguments: 417 | cs (torchmeta.transforms.ClassSplitter): the ClassSplitter where to extract data from 418 | class_ids (tuple): indices of classes to be selected by Splitter 419 | 420 | Returns: 421 | ds_train (Dataset): train dataset 422 | ds_test (Dataset): test dataset 423 | 424 | """ 425 | ds = cs[class_ids] 426 | ds_train, ds_test = ds['train'], ds['test'] 427 | 428 | for ds in [ds_train, ds_test]: 429 | ds.targets = torch.tensor([ds[i][1] for i in range(len(ds))]) 430 | ds.classes = [p[-1] for i,p in enumerate(cs.dataset._labels) if i in class_ids] 431 | return ds_train, ds_test 432 | 433 | 434 | def save_image(tensor, fp, dataname, format='png', invert=True): 435 | """ Similar to torchvision's save_image, but corrects normalization """ 436 | if dataname and dataname in DATASET_NORMALIZATION: 437 | ## Brings back to [0,1] range 438 | mean, std = (d[0] for d in DATASET_NORMALIZATION[dataname]) 439 | tensor = tensor.mul(std).add_(mean) 440 | ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).to('cpu', torch.uint8).numpy() 441 | im = Image.fromarray(ndarr) 442 | if invert: 443 | im = PIL.ImageOps.invert(im) 444 | im.save(fp, format=format) 445 | 446 | def show_grid(tensor, dataname=None, invert=True, title=None, 447 | save_path=None, to_pil=False, ax = None,format='png'): 448 | """ Displays image grid. To be used after torchvision's make_grid """ 449 | if dataname and dataname in DATASET_NORMALIZATION: 450 | ## Brings back to [0,1] range 451 | mean, std = (d[0] for d in DATASET_NORMALIZATION[dataname]) 452 | tensor = tensor.mul(std).add_(mean) 453 | ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).to('cpu', torch.uint8).numpy() 454 | ndarr = np.transpose(ndarr, (1,2,0)) 455 | if to_pil: 456 | im = Image.fromarray(ndarr) 457 | if invert: 458 | im = PIL.ImageOps.invert(im) 459 | im.show(title=title) 460 | if save_path: 461 | im.save(save_path, format=format) 462 | else: 463 | if not ax: fig, ax = plt.subplots() 464 | ax.imshow(ndarr, interpolation='nearest') 465 | ax.set_xticks([]) 466 | ax.set_yticks([]) 467 | if title: ax.set_title(title) 468 | 469 | def coupling_to_csv(G, fp, thresh = 1e-14, sep=',', labels1=None,labels2=None): 470 | """ Dumps an OT coupling matrix to a csv file """ 471 | sG = G.copy() 472 | if thresh is not None: 473 | sG[G 1: 109 | results = Parallel(n_jobs=nworkers, verbose=1, backend="threading")( 110 | delayed(wasserstein_gauss_distance)(M1[i], M2[j], S1[i], S2[j], squared=True) for i, j in pairs) 111 | for (i, j), d in zip(pairs, results): 112 | D[i, j] = d 113 | if symmetric: 114 | D[j, i] = D[i, j] 115 | else: 116 | for i, j in tqdm(pairs, leave=False): 117 | D[i, j] = wasserstein_gauss_distance( 118 | M1[i], M2[j], S1[i], S2[j], squared=True, commute=commute) 119 | if symmetric: 120 | D[j, i] = D[i, j] 121 | 122 | if return_dmeans: 123 | D_means = torch.cdist(M1, M2) # For viz purposes only 124 | return D, D_means 125 | else: 126 | return D 127 | 128 | 129 | def efficient_pwdist_gauss(M1, S1, M2=None, S2=None, sqrtS1=None, sqrtS2=None, 130 | symmetric=False, diagonal_cov=False, commute=False, 131 | sqrt_method='spectral',sqrt_niters=20,sqrt_pref=0, 132 | device='cpu',nworkers=1, 133 | cost_function='euclidean', 134 | return_dmeans=False, return_sqrts=False): 135 | """ [Formerly known as efficient_pwassdist] Efficient computation of pairwise 136 | label-to-label Wasserstein distances between various distributions. Saves 137 | computation by precomputing and storing covariance square roots.""" 138 | if M2 is None: 139 | symmetric = True 140 | M2, S2 = M1, S1 141 | 142 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 143 | if symmetric: 144 | ## If tasks are symmetric (same data on both sides) only need combinations 145 | pairs = list(itertools.combinations(range(n1), 2)) 146 | else: 147 | ## If tasks are assymetric, need n1 x n2 comparisons 148 | pairs = list(itertools.product(range(n1), range(n2))) 149 | 150 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 151 | 152 | sqrtS = [] 153 | ## Note that we need inverses of only one of two datasets. 154 | ## If sqrtS of S1 provided, use those. If S2 provided, flip roles of covs in Bures 155 | both_sqrt = (sqrtS1 is not None) and (sqrtS2 is not None) 156 | if (both_sqrt and sqrt_pref==0) or (sqrtS1 is not None): 157 | ## Either both were provided and S1 (idx=0) is prefered, or only S1 provided 158 | flip = False 159 | sqrtS = sqrtS1 160 | elif sqrtS2 is not None: 161 | ## S1 wasn't provided 162 | if sqrt_pref == 0: logger.warning('sqrt_pref=0 but S1 not provided!') 163 | flip = True 164 | sqrtS = sqrtS2 # S2 playes role of S1 165 | elif len(S1) <= len(S2): # No precomputed squareroots provided. Compute, but choose smaller of the two! 166 | flip = False 167 | S = S1 168 | else: 169 | flip = True 170 | S = S2 # S2 playes role of S1 171 | 172 | if not sqrtS: 173 | logger.info('Precomputing covariance matrix square roots...') 174 | for i, Σ in tqdm(enumerate(S), leave=False): 175 | if diagonal_cov: 176 | assert Σ.ndim == 1 177 | sqrtS.append(torch.sqrt(Σ)) # This is actually not needed. 178 | else: 179 | sqrtS.append(sqrtm(Σ) if sqrt_method == 180 | 'spectral' else sqrtm_newton_schulz(Σ, sqrt_niters)) 181 | 182 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 183 | pbar = tqdm(pairs, leave=False) 184 | pbar.set_description('Computing label-to-label distances') 185 | for i, j in pbar: 186 | if not flip: 187 | D[i, j] = wasserstein_gauss_distance(M1[i], M2[j], S1[i], S2[j], sqrtS[i], 188 | diagonal_cov=diagonal_cov, 189 | commute=commute, squared=True, 190 | cost_function=cost_function, 191 | sqrt_method=sqrt_method, 192 | sqrt_niters=sqrt_niters) 193 | else: 194 | D[i, j] = wasserstein_gauss_distance(M2[j], M1[i], S2[j], S1[i], sqrtS[j], 195 | diagonal_cov=diagonal_cov, 196 | commute=commute, squared=True, 197 | cost_function=cost_function, 198 | sqrt_method=sqrt_method, 199 | sqrt_niters=sqrt_niters) 200 | if symmetric: 201 | D[j, i] = D[i, j] 202 | 203 | if return_dmeans: 204 | D_means = torch.cdist(M1, M2) # For viz purposes only 205 | if return_sqrts: 206 | return D, D_means, sqrtS 207 | else: 208 | return D, D_means 209 | elif return_sqrts: 210 | return D, sqrtS 211 | else: 212 | return D 213 | 214 | def pwdist_means_only(M1, M2=None, symmetric=False, device=None): 215 | if M2 is None or symmetric: 216 | symmetric = True 217 | M2 = M1 218 | D = torch.cdist(M1, M2) 219 | if device: 220 | D = D.to(device) 221 | return D 222 | 223 | def pwdist_upperbound(M1, S1, M2=None, S2=None,symmetric=False, means_only=False, 224 | diagonal_cov=False, commute=False, device=None, 225 | return_dmeans=False): 226 | """ Computes upper bound of the Wasserstein distance between distributions 227 | with given mean and covariance. 228 | """ 229 | 230 | if M2 is None: 231 | symmetric = True 232 | M2, S2 = M1, S1 233 | 234 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 235 | if symmetric: 236 | ## If tasks are symmetric (same data on both sides) only need combinations 237 | pairs = list(itertools.combinations(range(n1), 2)) 238 | else: 239 | ## If tasks are assymetric, need n1 x n2 comparisons 240 | pairs = list(itertools.product(range(n1), range(n2))) 241 | 242 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 243 | 244 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 245 | pbar = tqdm(pairs, leave=False) 246 | pbar.set_description('Computing label-to-label distances') 247 | 248 | if means_only or return_dmeans: 249 | D_means = torch.cdist(M1, M2) 250 | 251 | if not means_only: 252 | for i, j in pbar: 253 | if means_only: 254 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) 255 | else: 256 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) + (S1[i] + S2[j]).diagonal(dim1=-2, dim2=-1).sum(-1) 257 | if symmetric: 258 | D[j, i] = D[i, j] 259 | else: 260 | D = D_means 261 | 262 | if return_dmeans: 263 | D_means = torch.cdist(M1, M2) # For viz purposes only 264 | return D, D_means 265 | else: 266 | return D 267 | 268 | def pwdist_exact(X1, Y1, X2=None, Y2=None, symmetric=False, loss='sinkhorn', 269 | cost_function='euclidean', p=2, debias=True, entreg=1e-1, device='cpu'): 270 | 271 | """ Efficient computation of pairwise label-to-label Wasserstein distances 272 | between multiple distributions, without using Gaussian assumption. 273 | 274 | Args: 275 | X1,X2 (tensor): n x d matrix with features 276 | Y1,Y2 (tensor): labels corresponding to samples 277 | symmetric (bool): whether X1/Y1 and X2/Y2 are to be treated as the same dataset 278 | cost_function (callable/string): the 'ground metric' between features to 279 | be used in optimal transport problem. If callable, should take follow 280 | the convection of the cost argument in geomloss.SamplesLoss 281 | p (int): power of the cost (i.e. order of p-Wasserstein distance). Ignored 282 | if cost_function is a callable. 283 | debias (bool): Only relevant for Sinkhorn. If true, uses debiased sinkhorn 284 | divergence. 285 | 286 | 287 | """ 288 | device = process_device_arg(device) 289 | if X2 is None: 290 | symmetric = True 291 | X2, Y2 = X1, Y1 292 | 293 | c1 = torch.unique(Y1) 294 | c2 = torch.unique(Y2) 295 | n1, n2 = len(c1), len(c2) 296 | 297 | ## We account for the possibility that labels are shifted (c1[0]!=0), see below 298 | 299 | if symmetric: 300 | ## If tasks are symmetric (same data on both sides) only need combinations 301 | pairs = list(itertools.combinations(range(n1), 2)) 302 | else: 303 | ## If tasks are assymetric, need n1 x n2 comparisons 304 | pairs = list(itertools.product(range(n1), range(n2))) 305 | 306 | 307 | if cost_function == 'euclidean': 308 | if p == 1: 309 | cost_function = lambda x, y: geomloss.utils.distances(x, y) 310 | elif p == 2: 311 | cost_function = lambda x, y: geomloss.utils.squared_distances(x, y) 312 | else: 313 | raise ValueError() 314 | 315 | if loss == 'sinkhorn': 316 | distance = geomloss.SamplesLoss( 317 | loss=loss, p=p, 318 | cost=cost_function, 319 | debias=debias, 320 | blur=entreg**(1 / p), 321 | ) 322 | elif loss == 'wasserstein': 323 | def distance(Xa, Xb): 324 | C = cost_function(Xa, Xb).cpu() 325 | return torch.tensor(ot.emd2(ot.unif(Xa.shape[0]), ot.unif(Xb.shape[0]), C))#, verbose=True) 326 | else: 327 | raise ValueError('Wrong loss') 328 | 329 | 330 | logger.info('Computing label-to-label (exact) wasserstein distances...') 331 | pbar = tqdm(pairs, leave=False) 332 | pbar.set_description('Computing label-to-label distances') 333 | D = torch.zeros((n1, n2), device = device, dtype=X1.dtype) 334 | for i, j in pbar: 335 | try: 336 | D[i, j] = distance(X1[Y1==c1[i]].to(device), X2[Y2==c2[j]].to(device)).item() 337 | except: 338 | print("This is awkward. Distance computation failed. Geomloss is hard to debug" \ 339 | "But here's a few things that might be happening: "\ 340 | " 1. Too many samples with this label, causing memory issues" \ 341 | " 2. Datatype errors, e.g., if the two datasets have different type") 342 | sys.exit('Distance computation failed. Aborting.') 343 | if symmetric: 344 | D[j, i] = D[i, j] 345 | return D 346 | -------------------------------------------------------------------------------- /otdd/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle as pkl 4 | import pdb 5 | import shutil 6 | import logging 7 | import tempfile 8 | 9 | def launch_logger(console_level='warning'): 10 | ############################### Logging Config ################################# 11 | ## Remove all handlers of root logger object -> needed to override basicConfig above 12 | for handler in logging.root.handlers[:]: 13 | logging.root.removeHandler(handler) 14 | 15 | _logger = logging.getLogger() 16 | _logger.setLevel(logging.INFO) # Has to be min of all the others 17 | 18 | ## create file handler which logs even debug messages, use random logfile name 19 | logfile = tempfile.NamedTemporaryFile(prefix="otddlog_", dir='/tmp').name 20 | fh = logging.FileHandler(logfile) 21 | fh.setLevel(logging.INFO) 22 | 23 | ## create console handler with a higher log level 24 | ch = logging.StreamHandler(stream=sys.stdout) 25 | if console_level == 'warning': 26 | ch.setLevel(logging.WARNING) 27 | elif console_level == 'info': 28 | ch.setLevel(logging.INFO) 29 | else: 30 | raise ValueError() 31 | ## create formatter and add it to the handlers 32 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(levelname)s: %(message)s', 33 | datefmt='%Y-%m-%d %H:%M:%S') 34 | fh.setFormatter(formatter) 35 | ch.setFormatter(formatter) 36 | _logger.addHandler(fh) 37 | _logger.addHandler(ch) 38 | ################################################################################ 39 | return _logger 40 | 41 | def safedump(d,f): 42 | try: 43 | pkl.dump(d, open(f, 'wb')) 44 | except: 45 | pdb.set_trace() 46 | 47 | def append_to_file(fname, l): 48 | with open(fname, "a") as f: 49 | f.write('\t'.join(l) + '\n') 50 | 51 | def delete_if_exists(path, typ='f'): 52 | if typ == 'f' and os.path.exists(path): 53 | os.remove(path) 54 | elif typ == 'd' and os.path.isdir(path): 55 | shutil.rmtree(path) 56 | else: 57 | raise ValueError("Unrecognized path type") 58 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | pandas 3 | scikit-learn 4 | 5 | numpy 6 | scipy 7 | matplotlib 8 | torchvision 9 | torchtext 10 | sentence-transformers 11 | skorch 12 | 13 | pot 14 | 15 | cmake 16 | pykeops 17 | git+https://github.com/jeanfeydy/geomloss 18 | munkres 19 | 20 | tqdm 21 | attrdict 22 | seaborn 23 | adjustText 24 | h5py 25 | opentsne 26 | 27 | watermark 28 | seaborn 29 | celluloid 30 | git+https://github.com/rossant/ipycache 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='otdd', 5 | version='0.1.0', 6 | description='Optimal Transport Dataset Distance', 7 | author='David Alvarez-Melis & Nicolo Fusi', 8 | license='MIT', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'numpy', 12 | 'scipy', 13 | 'matplotlib', 14 | 'tqdm', 15 | 'pot', 16 | 'torch', 17 | 'torchvision', 18 | 'torchtext', 19 | 'attrdict', 20 | 'opentsne', 21 | 'seaborn', 22 | 'scikit-learn', 23 | 'pandas', 24 | 'geomloss', 25 | 'munkres', 26 | 'adjustText' 27 | ], 28 | include_package_data=True, 29 | zip_safe=False 30 | ) 31 | --------------------------------------------------------------------------------