├── .gitignore ├── README.md ├── arguments.py ├── distributions ├── DiagonalHWN │ ├── __init__.py │ ├── distribution.py │ ├── layers.py │ └── prior.py ├── EuclideanNormal │ ├── __init__.py │ ├── distribution.py │ ├── layers.py │ └── prior.py ├── FullHWN │ ├── __init__.py │ ├── distribution.py │ ├── layers.py │ └── prior.py ├── IsotropicHWN │ ├── __init__.py │ ├── distribution.py │ ├── layers.py │ └── prior.py ├── RoWN │ ├── __init__.py │ ├── distribution.py │ ├── layers.py │ └── prior.py ├── __init__.py ├── hwn.py └── utils.py ├── requirements.txt ├── tasks ├── Breakout │ ├── __init__.py │ ├── arguments.py │ ├── dataset.py │ ├── evaluation.py │ └── model.py ├── NSBT │ ├── __init__.py │ ├── arguments.py │ ├── dataset.py │ ├── evaluation.py │ ├── model.py │ └── utils.py ├── WordNet │ ├── __init__.py │ ├── dataset.py │ ├── evaluation.py │ ├── mammals_filter.txt │ └── utils.py └── __init__.py ├── train_embedding.py ├── train_vae.py └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | wandb/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Rotated Hyperbolic Wrapped Normal Distribution for Hierarchical Representation Learning 2 | This repository is the official implementation of ["A Rotated Hyperbolic Wrapped Normal Distribution for Hierarchical Representation Learning"](https://arxiv.org/abs/2205.13371) accepted by NeurIPS 2022. 3 | 4 | ## Abstract 5 | We present a rotated hyperbolic wrapped normal distribution (RoWN), a simple yet effective alteration of a hyperbolic wrapped normal distribution (HWN). The HWN expands the domain of probabilistic modeling from Euclidean to hyperbolic space, where a tree can be embedded with arbitrary low distortion in theory. In this work, we analyze the geometric properties of the \emph{diagonal} HWN, a standard choice of distribution in probabilistic modeling. The analysis shows that the distribution is inappropriate to represent the data points at the same hierarchy level through their angular distance with the same norm in the Poincar\'e disk model. We then empirically verify the presence of limitations of HWN, and show how RoWN, the proposed distribution, can alleviate the limitations on various hierarchical datasets, including noisy synthetic binary tree, WordNet, and Atari 2600 Breakout. 6 | 7 | ## Usages 8 | You can reproduce the experiments from our paper using the following command: 9 | ``` 10 | > python train_vae.py --task NSBT --dist RoWN --depth=7 --device=cuda:0 --eval_interval=1001 --exp_name=nsbt --lr=0.0001 --n_epochs=1000 --n_layers=1 --test_samples=500 --train_batch_size=128 --seed 1 11 | > python train_vae.py --task Breakout --dist RoWN --data_dir= --device=cuda:0 --eval_interval=201 --exp_name=breakout --latent_dim=20 --lr=0.0001 --n_epochs=200 --test_batch_size=64 --test_samples=100 --train_batch_size=100 --train_samples=1 --seed 1 12 | > python train_embedding.py --dist=RoWN --data_dir data/ --device=cuda:0 --latent_dim=20 --seed=1 13 | ``` 14 | 15 | ### Dataset 16 | For noisy synthetic binary tree, we can generate the dataset using the following command: 17 | ``` 18 | > cd tasks/NSBT; python utils.py 19 | ``` 20 | 21 | For Atari 2600 Breakout, the images can be download from [here](https://www.dropbox.com/s/hyq44euztzz23o8/breakout_states_v2.h5?dl=0). 22 | 23 | For WordNet, we can download the dataset using the following command: 24 | ``` 25 | > mkdir data; cd tasks/WordNet; python utils.py 26 | ``` 27 | 28 | ### Distributions 29 | For the distributions, we implemented: 30 | - `EuclideanNormal`: Gaussian distribution defined in Euclidean space. 31 | - `IsotropicHWN`: Hyperbolic wrapped normal distribution with isotropic covariance. 32 | - `DiagonalHWN`: Hyperbolic wrapped normal distribution with diagonal covariance. 33 | - `FullHWN`: Hyperbolic wrapped normal distribution with full covariance. 34 | - `RoWN`: Hyperbolic wrapped normal distribution with rotated covariance. 35 | 36 | ## Cite 37 | Please cite our paper if you use the model or this code in your own work: 38 | ``` 39 | @article{cho2022rotated, 40 | title={A Rotated Hyperbolic Wrapped Normal Distribution for Hierarchical Representation Learning}, 41 | author={Cho, Seunghyuk and Lee, Juyong and Park, Jaesik and Kim, Dongwoo}, 42 | journal={arXiv preprint arXiv:2205.13371}, 43 | year={2022} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | tasks = ['Breakout', 'NSBT'] 5 | distributions = ['EuclideanNormal', 'IsotropicHWN', 'DiagonalHWN', 'FullHWN', 'RoWN'] 6 | 7 | 8 | def get_initial_parser(): 9 | parser = argparse.ArgumentParser(add_help=False) 10 | parser.add_argument('--task', type=str, choices=tasks) 11 | parser.add_argument('--dist', type=str, choices=distributions) 12 | return parser 13 | 14 | 15 | def add_train_args(parser): 16 | group = parser.add_argument_group('train') 17 | group.add_argument('--task', type=str, choices=tasks) 18 | group.add_argument('--dist', type=str, choices=distributions) 19 | group.add_argument('--seed', type=int, default=7777) 20 | group.add_argument('--latent_dim', type=int, default=2) 21 | group.add_argument('--beta', type=float, default=1.) 22 | group.add_argument('--n_epochs', type=int, default=10) 23 | group.add_argument('--train_batch_size', type=int, default=32) 24 | group.add_argument('--test_batch_size', type=int, default=32) 25 | group.add_argument('--lr', type=float, default=1e-5) 26 | group.add_argument('--device', type=str, default='cuda:0') 27 | group.add_argument('--eval_interval', type=int, default=10) 28 | group.add_argument('--log_interval', type=int, default=10) 29 | group.add_argument('--log_dir', type=str, default='logs/') 30 | group.add_argument('--train_samples', type=int, default=1) 31 | group.add_argument('--test_samples', type=int, default=500) 32 | group.add_argument('--exp_name', type=str, default='dummy') 33 | 34 | -------------------------------------------------------------------------------- /distributions/DiagonalHWN/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer 2 | from .distribution import Distribution 3 | from .prior import get_prior 4 | -------------------------------------------------------------------------------- /distributions/DiagonalHWN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | 4 | from ..hwn import HWN 5 | 6 | 7 | class Distribution(HWN): 8 | def __init__(self, mean, covar) -> None: 9 | base = Normal( 10 | torch.zeros( 11 | covar.size(), 12 | device=covar.device 13 | ), 14 | covar 15 | ) 16 | 17 | super().__init__(mean, base) 18 | 19 | -------------------------------------------------------------------------------- /distributions/DiagonalHWN/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from ..utils import ExpLayer, LogLayer 7 | 8 | EncoderLayer = ExpLayer 9 | DecoderLayer = LogLayer 10 | 11 | 12 | class EmbeddingLayer(nn.Module): 13 | def __init__(self, args, n_words): 14 | super().__init__() 15 | 16 | self.args = args 17 | self.latent_dim = args.latent_dim 18 | self.n_words = n_words 19 | self.initial_sigma = args.initial_sigma 20 | self.manifold = geoopt.manifolds.Lorentz() 21 | 22 | mean_initialize = torch.empty([self.n_words, self.latent_dim]) 23 | nn.init.normal_(mean_initialize, std=args.initial_sigma) 24 | self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False) 25 | 26 | covar_initialize = torch.empty([self.n_words, self.latent_dim]) 27 | nn.init.normal_(covar_initialize, std=args.initial_sigma) 28 | self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False) 29 | 30 | def forward(self, x): 31 | mean = self.mean(x) 32 | mean = F.pad(mean, (1, 0)) 33 | mean = self.manifold.expmap0(mean) 34 | 35 | covar = F.softplus(self.covar(x)) 36 | 37 | return mean, covar 38 | 39 | -------------------------------------------------------------------------------- /distributions/DiagonalHWN/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from .distribution import Distribution 4 | 5 | 6 | def get_prior(args): 7 | m = geoopt.manifolds.Lorentz() 8 | 9 | mean = m.origin([1, args.latent_dim + 1], device=args.device) 10 | covar = torch.ones( 11 | 1, 12 | args.latent_dim, 13 | device=args.device 14 | ) 15 | 16 | prior = Distribution(mean, covar) 17 | return prior 18 | 19 | -------------------------------------------------------------------------------- /distributions/EuclideanNormal/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer 2 | from .distribution import Distribution 3 | from .prior import get_prior 4 | -------------------------------------------------------------------------------- /distributions/EuclideanNormal/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | 4 | def kl_dist(mu0, std0, mu1, std1): 5 | k = mu0.size(-1) 6 | logvar0, logvar1 = 2 * std0.log(), 2 * std1.log() 7 | 8 | dist = logvar1 - logvar0 + (((mu0 - mu1).pow(2) + 1e-9).log() - logvar1).exp() + (logvar0 - logvar1).exp() 9 | dist = dist.sum(dim=-1) - k 10 | return dist * 0.5 11 | 12 | class Distribution(): 13 | def __init__(self, mean, covar) -> None: 14 | self.mean = mean 15 | self.covar = covar 16 | 17 | self.base = Normal(self.mean, self.covar) 18 | 19 | def log_prob(self, z): 20 | return self.base.log_prob(z).sum(dim=-1) 21 | 22 | def rsample(self, N): 23 | return self.base.rsample([N]) 24 | 25 | def sample(self, N): 26 | with torch.no_grad(): 27 | return self.rsample(N) 28 | 29 | -------------------------------------------------------------------------------- /distributions/EuclideanNormal/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class EncoderLayer(nn.Module): 7 | def __init__(self, args, feature_dim) -> None: 8 | super().__init__() 9 | 10 | self.latent_dim = args.latent_dim 11 | self.feature_dim = feature_dim 12 | 13 | self.variational = nn.Linear( 14 | self.feature_dim, 15 | 2 * self.latent_dim 16 | ) 17 | 18 | def forward(self, feature): 19 | feature = self.variational(feature) 20 | mean, covar = torch.split( 21 | feature, 22 | [self.latent_dim, self.latent_dim], 23 | dim=-1 24 | ) 25 | covar = F.softplus(covar) 26 | 27 | return mean, covar 28 | 29 | 30 | class DecoderLayer(nn.Module): 31 | def __init__(self) -> None: 32 | super().__init__() 33 | 34 | def forward(self, z): 35 | return z 36 | 37 | 38 | class EmbeddingLayer(nn.Module): 39 | def __init__(self, args, n_words): 40 | super().__init__() 41 | 42 | self.args = args 43 | self.latent_dim = args.latent_dim 44 | self.n_words = n_words 45 | self.initial_sigma = args.initial_sigma 46 | 47 | mean_initialize = torch.empty([self.n_words, self.latent_dim]) 48 | nn.init.normal_(mean_initialize, std=args.initial_sigma) 49 | self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False) 50 | 51 | covar_initialize = torch.empty([self.n_words, self.latent_dim]) 52 | nn.init.normal_(covar_initialize, std=args.initial_sigma) 53 | self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False) 54 | 55 | def forward(self, x): 56 | mean = self.mean(x) 57 | covar = F.softplus(self.covar(x)) 58 | 59 | return mean, covar 60 | 61 | -------------------------------------------------------------------------------- /distributions/EuclideanNormal/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .distribution import Distribution 4 | 5 | 6 | def get_prior(args): 7 | mean = torch.zeros( 8 | [1, args.latent_dim], 9 | device=args.device 10 | ) 11 | covar = torch.ones( 12 | [1, args.latent_dim], 13 | device=args.device 14 | ) 15 | 16 | prior = Distribution(mean, covar) 17 | return prior 18 | 19 | -------------------------------------------------------------------------------- /distributions/FullHWN/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer 2 | from .distribution import Distribution 3 | from .prior import get_prior 4 | -------------------------------------------------------------------------------- /distributions/FullHWN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import MultivariateNormal 3 | 4 | from ..hwn import HWN 5 | 6 | 7 | class Distribution(HWN): 8 | def __init__(self, mean, covar) -> None: 9 | base = MultivariateNormal( 10 | torch.zeros( 11 | mean.size(), 12 | device=covar.device 13 | )[..., 1:], 14 | covar 15 | ) 16 | 17 | super().__init__(mean, base) 18 | 19 | def log_prob(self, z): 20 | u = self.manifold.logmap(self.mean, z) 21 | v = self.manifold.transp(self.mean, self.origin, u) 22 | log_prob_v = self.base.log_prob(v[:, :, 1:]) 23 | 24 | r = self.manifold.norm(u) 25 | log_det = (self.latent_dim - 1) * (torch.sinh(r).log() - r.log()) 26 | 27 | log_prob_z = log_prob_v - log_det 28 | return log_prob_z 29 | 30 | -------------------------------------------------------------------------------- /distributions/FullHWN/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from ..utils import LogLayer 7 | 8 | 9 | class EncoderLayer(nn.Module): 10 | def __init__(self, args, feature_dim) -> None: 11 | super().__init__() 12 | 13 | self.latent_dim = args.latent_dim 14 | self.feature_dim = feature_dim 15 | 16 | self.manifold = geoopt.manifolds.Lorentz() 17 | self.variational = nn.Linear( 18 | self.feature_dim, 19 | self.latent_dim + self.latent_dim ** 2 20 | ) 21 | 22 | def forward(self, feature): 23 | feature = self.variational(feature) 24 | mu, covar = torch.split( 25 | feature, 26 | [self.latent_dim, self.latent_dim ** 2], 27 | dim=-1 28 | ) 29 | 30 | mu = F.pad(mu, (1, 0)) 31 | mu = self.manifold.expmap0(mu) 32 | 33 | covar_size = covar.size()[:-1] 34 | covar = covar.view( 35 | *covar_size, 36 | self.latent_dim, 37 | self.latent_dim 38 | ) 39 | covar = covar.matmul(covar.transpose(-1, -2)) 40 | covar = covar + 1e-9 * torch.eye( 41 | self.latent_dim, 42 | device=covar.device 43 | )[None, ...] 44 | 45 | return mu, covar 46 | 47 | 48 | DecoderLayer = LogLayer 49 | 50 | 51 | class EmbeddingLayer(nn.Module): 52 | def __init__(self, args, n_words): 53 | super().__init__() 54 | 55 | self.args = args 56 | self.latent_dim = args.latent_dim 57 | self.n_words = n_words 58 | self.initial_sigma = args.initial_sigma 59 | self.manifold = geoopt.manifolds.Lorentz() 60 | 61 | mean_initialize = torch.empty([self.n_words, self.latent_dim]) 62 | nn.init.normal_(mean_initialize, std=args.initial_sigma) 63 | self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False) 64 | 65 | covar_initialize = torch.stack( 66 | [torch.eye(self.latent_dim) for _ in range(self.n_words)] 67 | ).view(self.n_words, -1) 68 | covar_initialize = covar_initialize * torch.randn(covar_initialize.size()) * self.initial_sigma 69 | self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False) 70 | 71 | def forward(self, x): 72 | mean = self.mean(x) 73 | mean = F.pad(mean, (1, 0)) 74 | mean = self.manifold.expmap0(mean) 75 | 76 | covar = self.covar(x) 77 | covar_size = covar.size()[:-1] 78 | covar = covar.view( 79 | *covar_size, 80 | self.latent_dim, 81 | self.latent_dim 82 | ) 83 | covar = covar.matmul(covar.transpose(-1, -2)) 84 | covar = covar + 1e-9 * torch.eye( 85 | self.latent_dim, 86 | device=covar.device 87 | )[None, ...] 88 | 89 | return mean, covar 90 | 91 | -------------------------------------------------------------------------------- /distributions/FullHWN/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from .distribution import Distribution 4 | 5 | 6 | def get_prior(args): 7 | m = geoopt.manifolds.Lorentz() 8 | 9 | mean = m.origin([1, args.latent_dim + 1], device=args.device) 10 | covar = torch.eye( 11 | args.latent_dim, 12 | device=args.device 13 | )[None, ...] 14 | 15 | prior = Distribution(mean, covar) 16 | return prior 17 | 18 | -------------------------------------------------------------------------------- /distributions/IsotropicHWN/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer 2 | from .distribution import Distribution 3 | from .prior import get_prior 4 | -------------------------------------------------------------------------------- /distributions/IsotropicHWN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | 4 | from ..hwn import HWN 5 | 6 | 7 | class Distribution(HWN): 8 | def __init__(self, mean, covar) -> None: 9 | base = Normal( 10 | torch.zeros( 11 | mean.size(), 12 | device=covar.device 13 | )[..., 1:], 14 | covar 15 | ) 16 | 17 | super().__init__(mean, base) 18 | 19 | -------------------------------------------------------------------------------- /distributions/IsotropicHWN/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from ..utils import LogLayer 7 | 8 | 9 | class EncoderLayer(nn.Module): 10 | def __init__(self, args, feature_dim) -> None: 11 | super().__init__() 12 | 13 | self.latent_dim = args.latent_dim 14 | self.feature_dim = feature_dim 15 | 16 | self.manifold = geoopt.manifolds.Lorentz() 17 | self.variational = nn.Linear( 18 | self.feature_dim, 19 | self.latent_dim + 1 20 | ) 21 | 22 | def forward(self, feature): 23 | feature = self.variational(feature) 24 | mu, covar = torch.split( 25 | feature, 26 | [self.latent_dim, 1], 27 | dim=-1 28 | ) 29 | 30 | mu = F.pad(mu, (1, 0)) 31 | mu = self.manifold.expmap0(mu) 32 | covar = F.softplus(covar) 33 | 34 | return mu, covar 35 | 36 | 37 | DecoderLayer = LogLayer 38 | 39 | 40 | class EmbeddingLayer(nn.Module): 41 | def __init__(self, args, n_words): 42 | super().__init__() 43 | 44 | self.args = args 45 | self.latent_dim = args.latent_dim 46 | self.n_words = n_words 47 | self.initial_sigma = args.initial_sigma 48 | self.manifold = geoopt.manifolds.Lorentz() 49 | 50 | mean_initialize = torch.empty([self.n_words, self.latent_dim]) 51 | nn.init.normal_(mean_initialize, std=args.initial_sigma) 52 | self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False) 53 | 54 | covar_initialize = torch.empty([self.n_words, 1]) 55 | nn.init.normal_(covar_initialize, std=args.initial_sigma) 56 | self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False) 57 | 58 | def forward(self, x): 59 | mean = self.mean(x) 60 | mean = F.pad(mean, (1, 0)) 61 | mean = self.manifold.expmap0(mean) 62 | 63 | covar = F.softplus(self.covar(x)) 64 | 65 | return mean, covar 66 | 67 | -------------------------------------------------------------------------------- /distributions/IsotropicHWN/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from .distribution import Distribution 4 | 5 | 6 | def get_prior(args): 7 | m = geoopt.manifolds.Lorentz() 8 | 9 | mean = m.origin([1, args.latent_dim + 1], device=args.device) 10 | covar = torch.ones( 11 | 1, 12 | args.latent_dim, 13 | device=args.device 14 | ) 15 | 16 | prior = Distribution(mean, covar) 17 | return prior 18 | 19 | -------------------------------------------------------------------------------- /distributions/RoWN/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer 2 | from .distribution import Distribution 3 | from .prior import get_prior 4 | -------------------------------------------------------------------------------- /distributions/RoWN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import MultivariateNormal 3 | 4 | from ..hwn import HWN 5 | 6 | 7 | def rotation_matrix(x, y): 8 | dim = x.size(-1) 9 | x = x / (x.norm(dim=-1, keepdim=True) + 1e-9) 10 | y = y / (y.norm(dim=-1, keepdim=True) + 1e-9) 11 | 12 | x = x[..., None] 13 | y = y[..., None] 14 | I = torch.eye(dim, device=x.device) 15 | tmp = y.matmul(x.transpose(-1, -2)) - x.matmul(y.transpose(-1, -2)) 16 | R = I + tmp + 1 / (1 + (y * x).sum([-1, -2], keepdim=True)) * tmp.matmul(tmp) 17 | return R 18 | 19 | 20 | class Distribution(HWN): 21 | def __init__(self, mean, covar) -> None: 22 | target_axis = mean[..., 1:] 23 | base_axis = torch.zeros( 24 | target_axis.size(), 25 | device=mean.device 26 | ) 27 | base_axis[..., 0] = torch.where( 28 | target_axis[..., 0] >= 0, 1, -1 29 | ) 30 | R = rotation_matrix(base_axis, target_axis) 31 | 32 | covar = (R * covar[..., None, :]).matmul(R.transpose(-1, -2)) 33 | base = MultivariateNormal( 34 | torch.zeros( 35 | target_axis.size(), 36 | device=covar.device 37 | ), 38 | covar 39 | ) 40 | 41 | super().__init__(mean, base) 42 | 43 | def log_prob(self, z): 44 | u = self.manifold.logmap(self.mean, z) 45 | v = self.manifold.transp(self.mean, self.origin, u) 46 | log_prob_v = self.base.log_prob(v[:, :, 1:]) 47 | 48 | r = self.manifold.norm(u) 49 | log_det = (self.latent_dim - 1) * (torch.sinh(r).log() - r.log()) 50 | 51 | log_prob_z = log_prob_v - log_det 52 | return log_prob_z 53 | 54 | -------------------------------------------------------------------------------- /distributions/RoWN/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from ..utils import ExpLayer, LogLayer 7 | 8 | EncoderLayer = ExpLayer 9 | DecoderLayer = LogLayer 10 | 11 | 12 | class EmbeddingLayer(nn.Module): 13 | def __init__(self, args, n_words): 14 | super().__init__() 15 | 16 | self.args = args 17 | self.latent_dim = args.latent_dim 18 | self.n_words = n_words 19 | self.initial_sigma = args.initial_sigma 20 | self.manifold = geoopt.manifolds.Lorentz() 21 | 22 | mean_initialize = torch.empty([self.n_words, self.latent_dim]) 23 | nn.init.normal_(mean_initialize, std=args.initial_sigma) 24 | self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False) 25 | 26 | covar_initialize = torch.empty([self.n_words, self.latent_dim]) 27 | nn.init.normal_(covar_initialize, std=args.initial_sigma) 28 | self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False) 29 | 30 | def forward(self, x): 31 | mean = self.mean(x) 32 | mean = F.pad(mean, (1, 0)) 33 | mean = self.manifold.expmap0(mean) 34 | 35 | covar = F.softplus(self.covar(x)) 36 | 37 | return mean, covar 38 | 39 | -------------------------------------------------------------------------------- /distributions/RoWN/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from .distribution import Distribution 4 | 5 | 6 | def get_prior(args): 7 | m = geoopt.manifolds.Lorentz() 8 | 9 | mean = m.origin([1, args.latent_dim + 1], device=args.device) 10 | covar = torch.ones( 11 | 1, 12 | args.latent_dim, 13 | device=args.device 14 | ) 15 | 16 | prior = Distribution(mean, covar) 17 | return prior 18 | 19 | -------------------------------------------------------------------------------- /distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-postech/RoWN/288830d28b6012a3e6af00b3b32e982732117d4c/distributions/__init__.py -------------------------------------------------------------------------------- /distributions/hwn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch.nn import functional as F 4 | 5 | class HWN(): 6 | def __init__(self, mean, base) -> None: 7 | self.mean = mean 8 | self.base = base 9 | self.manifold = geoopt.manifolds.Lorentz() 10 | 11 | self.origin = self.manifold.origin( 12 | self.mean.size(), 13 | device=self.mean.device 14 | ) 15 | self.latent_dim = self.mean.size(-1) - 1 16 | 17 | def log_prob(self, z): 18 | u = self.manifold.logmap(self.mean, z) 19 | v = self.manifold.transp(self.mean, self.origin, u) 20 | log_prob_v = self.base.log_prob(v[:, :, 1:]).sum(-1) 21 | 22 | r = self.manifold.norm(u) 23 | log_det = (self.latent_dim - 1) * (torch.sinh(r).log() - r.log()) 24 | 25 | log_prob_z = log_prob_v - log_det 26 | return log_prob_z 27 | 28 | def rsample(self, N): 29 | v = self.base.rsample([N]) 30 | v = F.pad(v, (1, 0)) 31 | 32 | u = self.manifold.transp0(self.mean, v) 33 | z = self.manifold.expmap(self.mean, u) 34 | 35 | return z 36 | 37 | def sample(self, N): 38 | with torch.no_grad(): 39 | return self.rsample(N) 40 | 41 | -------------------------------------------------------------------------------- /distributions/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import geoopt 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class ExpLayer(nn.Module): 8 | def __init__(self, args, feature_dim) -> None: 9 | super().__init__() 10 | 11 | self.latent_dim = args.latent_dim 12 | self.feature_dim = feature_dim 13 | 14 | self.manifold = geoopt.manifolds.Lorentz() 15 | self.variational = nn.Linear( 16 | self.feature_dim, 17 | 2 * self.latent_dim 18 | ) 19 | 20 | def forward(self, feature): 21 | feature = self.variational(feature) 22 | mu, covar = torch.split( 23 | feature, 24 | [self.latent_dim, self.latent_dim], 25 | dim=-1 26 | ) 27 | 28 | mu = F.pad(mu, (1, 0)) 29 | mu = self.manifold.expmap0(mu) 30 | covar = F.softplus(covar) 31 | 32 | return mu, covar 33 | 34 | 35 | class LogLayer(nn.Module): 36 | def __init__(self) -> None: 37 | super().__init__() 38 | 39 | self.manifold = geoopt.manifolds.Lorentz() 40 | 41 | def forward(self, z): 42 | z = self.manifold.logmap0(z) 43 | return z[..., 1:] 44 | 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | opencv-python 5 | wandb 6 | pandas 7 | plotly 8 | git+https://github.com/geoopt/geoopt.git 9 | scipy 10 | nltk 11 | -------------------------------------------------------------------------------- /tasks/Breakout/__init__.py: -------------------------------------------------------------------------------- 1 | from .arguments import add_task_args 2 | from .model import Encoder, Decoder 3 | from .dataset import Dataset 4 | from .evaluation import evaluation 5 | 6 | recon_loss_type = 'BCE' 7 | -------------------------------------------------------------------------------- /tasks/Breakout/arguments.py: -------------------------------------------------------------------------------- 1 | from .dataset import add_dataset_args 2 | from .model import add_model_args 3 | 4 | def add_task_args(parser): 5 | group = parser.add_argument_group('Breakout') 6 | add_dataset_args(group) 7 | add_model_args(group) 8 | 9 | -------------------------------------------------------------------------------- /tasks/Breakout/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from torch.utils import data 4 | from torchvision.transforms import ToTensor 5 | 6 | 7 | def add_dataset_args(parser): 8 | parser.add_argument('--data_dir', type=str) 9 | 10 | 11 | class Dataset(data.Dataset): 12 | def __init__(self, args, is_train=True) -> None: 13 | super().__init__() 14 | 15 | self.args = args 16 | self.is_train = is_train 17 | self.transform = ToTensor() 18 | self.data_dir = Path(args.data_dir) 19 | 20 | prefix = 'train_' if self.is_train else 'test_' 21 | raw_data = np.load( 22 | self.data_dir / f'{prefix}data.npy' 23 | ) 24 | self.data = np.transpose( 25 | raw_data, 26 | (0, 2, 3, 1) 27 | ).astype(np.float64) 28 | 29 | self.features = np.load( 30 | self.data_dir / f'{prefix}labels.npy' 31 | ) 32 | 33 | def __len__(self): 34 | return self.data.shape[0] 35 | 36 | def __getitem__(self, idx): 37 | x = self.data[idx] 38 | x = self.transform(x) 39 | return x 40 | 41 | -------------------------------------------------------------------------------- /tasks/Breakout/evaluation.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | 4 | 5 | def evaluation( 6 | args, 7 | vae, 8 | dataset, 9 | means 10 | ): 11 | means = means.detach().cpu().numpy() 12 | if args.dist != 'EuclideanNormal': 13 | norm = means[..., 0] 14 | norm = np.sqrt((norm - 1) / (norm + 1)) 15 | else: 16 | norm = np.sqrt((means ** 2).sum(axis=-1)) 17 | 18 | features = dataset.features 19 | metric = np.corrcoef(norm, features)[0, 1] 20 | print(f'===========> Correlation with cumulative rewards: {metric}') 21 | wandb.log({ 22 | 'test_corr_reward': metric 23 | }) 24 | 25 | -------------------------------------------------------------------------------- /tasks/Breakout/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def add_model_args(parser): 5 | pass 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, args) -> None: 10 | super().__init__() 11 | 12 | self.args = args 13 | self.output_dim = 10 * 10 * 64 14 | self.encoder = nn.Sequential( 15 | nn.Conv2d(1, 16, 3, 1, 1), 16 | nn.ReLU(), # (80, 80) 17 | nn.Conv2d(16, 32, 4, 2, 1), 18 | nn.ReLU(), # (40, 40) 19 | nn.Conv2d(32, 32, 3, 1, 1), 20 | nn.ReLU(), # (40, 40) 21 | nn.Conv2d(32, 64, 4, 2, 1), 22 | nn.ReLU(), # (20, 20) 23 | nn.Conv2d(64, 64, 3, 1, 1), 24 | nn.ReLU(), # (20, 20) 25 | nn.Conv2d(64, 64, 4, 2, 1), 26 | nn.ReLU(), # (10, 10) 27 | nn.Flatten(), 28 | ) 29 | 30 | def forward(self, x): 31 | feature = self.encoder(x) 32 | return feature 33 | 34 | 35 | class Decoder(nn.Module): 36 | def __init__(self, args) -> None: 37 | super().__init__() 38 | 39 | self.args = args 40 | self.latent_dim = args.latent_dim 41 | 42 | self.decoder1 = nn.Sequential( 43 | nn.Linear(self.latent_dim, 10 * 10 * 64), 44 | nn.ReLU(), 45 | ) 46 | self.decoder2 = nn.Sequential( 47 | nn.ConvTranspose2d(64, 32, 4, 2, 1), # (32, 20, 20) 48 | nn.ReLU(), 49 | nn.Conv2d(32, 32, 3, 1, 1), 50 | nn.ReLU(), 51 | nn.ConvTranspose2d(32, 16, 4, 2, 1), # (16, 40, 40) 52 | nn.ReLU(), 53 | nn.Conv2d(16, 16, 3, 1, 1), 54 | nn.ReLU(), 55 | nn.ConvTranspose2d(16, 1, 4, 2, 1), # (1, 80, 80) 56 | nn.Sigmoid() 57 | ) 58 | 59 | def forward(self, z): 60 | fixed_shapes = z.size()[:-1] 61 | z = z.view(-1, self.latent_dim) 62 | z = self.decoder1(z) 63 | z = z.view(-1, 64, 10, 10) 64 | x = self.decoder2(z) 65 | x = x.view(*fixed_shapes, 1, 80, 80) 66 | return x 67 | 68 | -------------------------------------------------------------------------------- /tasks/NSBT/__init__.py: -------------------------------------------------------------------------------- 1 | from .arguments import add_task_args 2 | from .model import Encoder, Decoder 3 | from .dataset import Dataset 4 | from .evaluation import evaluation 5 | 6 | recon_loss_type = 'NLL' 7 | -------------------------------------------------------------------------------- /tasks/NSBT/arguments.py: -------------------------------------------------------------------------------- 1 | from .dataset import add_dataset_args 2 | from .model import add_model_args 3 | 4 | def add_task_args(parser): 5 | group = parser.add_argument_group('Noisy synthetic binary tree') 6 | add_dataset_args(group) 7 | add_model_args(group) 8 | 9 | -------------------------------------------------------------------------------- /tasks/NSBT/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from torch.utils import data 4 | 5 | from .utils import synthetic_binary_tree, noisy_sythetic_binary_tree 6 | 7 | 8 | def add_dataset_args(parser): 9 | parser.add_argument('--data_dir', type=str, default=None) 10 | parser.add_argument('--depth', type=int, default=4) 11 | 12 | 13 | class Dataset(data.Dataset): 14 | def __init__(self, args, is_train=True) -> None: 15 | super().__init__() 16 | 17 | self.args = args 18 | self.is_train = is_train 19 | self.features = None 20 | self.depth = args.depth 21 | 22 | if self.is_train: 23 | if args.data_dir is None: 24 | self.data, _ = noisy_sythetic_binary_tree(args.depth) 25 | else: 26 | data_dir = Path(args.data_dir) 27 | self.data = np.load(data_dir / f'depth_{self.depth}.npy') 28 | else: 29 | self.data = synthetic_binary_tree(self.depth) 30 | 31 | self.data = (self.data - 0.5) * 2 32 | 33 | def __len__(self): 34 | return self.data.shape[0] 35 | 36 | def __getitem__(self, idx): 37 | x = self.data[idx] 38 | return x 39 | 40 | -------------------------------------------------------------------------------- /tasks/NSBT/evaluation.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import geoopt 3 | import numpy as np 4 | 5 | 6 | def evaluation( 7 | args, 8 | vae, 9 | dataset, 10 | means 11 | ): 12 | x = dataset.data 13 | means = means.detach() 14 | x_recon = vae.generate(means).detach().cpu().numpy() 15 | 16 | N = x.shape[0] 17 | d_true = np.zeros((N, N)) 18 | d_pred = np.zeros((N, N)) 19 | m = geoopt.manifolds.Euclidean(1) if args.dist == 'EuclideanNormal' else geoopt.manifolds.Lorentz() 20 | 21 | test_error = 0 22 | for i in range(x.shape[0]): 23 | for j in range(x.shape[1]): 24 | if x[i, j] * x_recon[i, j] <= 0: 25 | test_error += 1 26 | 27 | n_data = 2 ** args.depth - 1 28 | test_error /= n_data * n_data 29 | 30 | x = x / 2 + 0.5 31 | for i in range(N): 32 | for j in range(i): 33 | d_true[i, j] = (x[i].astype(np.int32) ^ x[j].astype(np.int32)).sum() 34 | d_pred[i, j] = m.dist(means[i], means[j]) 35 | 36 | mask = np.fromfunction(lambda i, j: i > j, shape=d_true.shape) 37 | corr_distance = np.corrcoef(d_pred[mask], d_true[mask])[0, 1] 38 | 39 | depths = x.sum(axis=-1) 40 | if args.dist != 'EuclideanNormal': 41 | norm = means[..., 0].cpu().numpy() 42 | norm = np.sqrt((norm - 1) / (norm + 1)) 43 | else: 44 | norm = means.cpu().numpy() 45 | norm = np.sqrt((norm ** 2).sum(axis=-1)) 46 | 47 | corr_depth = np.corrcoef(norm, depths)[0, 1] 48 | 49 | print(f'===========> Test error: {test_error}') 50 | print(f'===========> Correlation with hamming distance: {corr_distance} | with depth: {corr_depth}') 51 | wandb.log({ 52 | 'test_error': test_error, 53 | 'test_corr_distance': corr_distance, 54 | 'test_corr_depth': corr_depth 55 | }) 56 | 57 | -------------------------------------------------------------------------------- /tasks/NSBT/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .utils import stack_linear_layers 3 | 4 | 5 | def add_model_args(parser): 6 | parser.add_argument('--n_layers', type=int, default=2) 7 | parser.add_argument('--n_hids', type=int, default=256) 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, args) -> None: 12 | super().__init__() 13 | 14 | self.args = args 15 | self.depth = args.depth 16 | self.n_layers = args.n_layers 17 | self.n_hids = args.n_hids 18 | 19 | self.encoder = nn.Sequential( 20 | nn.Linear(2 ** self.depth - 1, self.n_hids), 21 | nn.ReLU(), 22 | *stack_linear_layers(self.n_hids, self.n_layers), 23 | # nn.Linear(self.n_hids, 2 * self.latent_dim) 24 | ) 25 | 26 | self.output_dim = self.n_hids 27 | 28 | def forward(self, x): 29 | feature = self.encoder(x) 30 | return feature 31 | 32 | 33 | class Decoder(nn.Module): 34 | def __init__(self, args) -> None: 35 | super().__init__() 36 | 37 | self.args = args 38 | self.depth = args.depth 39 | self.n_layers = args.n_layers 40 | self.n_hids = args.n_hids 41 | self.latent_dim = args.latent_dim 42 | 43 | self.decoder = nn.Sequential( 44 | nn.Linear(self.latent_dim, self.n_hids), 45 | nn.ReLU(), 46 | *stack_linear_layers(self.n_hids, self.n_layers), 47 | nn.Linear(self.n_hids, 2 ** self.depth - 1), 48 | nn.Tanh() 49 | ) 50 | 51 | def forward(self, z): 52 | fixed_shapes = z.size()[:-1] 53 | z = z.view(-1, self.latent_dim) 54 | x = self.decoder(z) 55 | x = x.view(*fixed_shapes, -1) 56 | 57 | return x 58 | 59 | -------------------------------------------------------------------------------- /tasks/NSBT/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import pi 3 | from torch import nn 4 | from tqdm import tqdm 5 | 6 | 7 | def stack_linear_layers(n_hids, n_layers): 8 | return [nn.Sequential( 9 | nn.Linear(n_hids, n_hids), 10 | nn.ReLU() 11 | ) for _ in range(n_layers)] 12 | 13 | 14 | def hasone(node_index, dim_index): 15 | bin_i, bin_j = np.binary_repr(node_index), np.binary_repr(dim_index) 16 | length = len(bin_j) 17 | return (bin_i[:length] == bin_j) * 1 18 | 19 | 20 | def synthetic_binary_tree(depth): 21 | n = 2 ** depth - 1 22 | x = np.fromfunction( 23 | lambda i, j: np.vectorize(hasone)(i + 1, j + 1), 24 | (n, n), 25 | dtype=np.int64 26 | ).astype(np.float64) 27 | 28 | return x 29 | 30 | 31 | def noisy_sythetic_binary_tree(depth, n_samples=100): 32 | original_data = synthetic_binary_tree(depth) 33 | data = np.empty((0, original_data.shape[-1])) 34 | features = [] 35 | for idx in tqdm(range(original_data.shape[0])): 36 | x = original_data[idx] 37 | idxs = (x == 1).nonzero()[0][1:] 38 | for _ in range(n_samples): 39 | x_ = x.copy() 40 | if len(idxs) > 0: 41 | theta = np.random.random(len(idxs)) * 0.5 * pi / 2 42 | eps_x = np.cos(theta) 43 | eps_y = np.sin(theta) 44 | 45 | x_[idxs] = eps_x 46 | idxs_ = idxs + (idxs % 2 - 0.5) * 2 47 | idxs_ = idxs_.astype(np.int64) 48 | x_[idxs_] = eps_y 49 | features.append(theta[-1]) 50 | else: 51 | features.append(0.) 52 | data = np.concatenate( 53 | (data, x_[None, ...]), 54 | axis=0 55 | ) 56 | 57 | features = np.array(features) 58 | return data, features 59 | 60 | 61 | if __name__ == "__main__": 62 | for depth in range(4, 9): 63 | data, feature = noisy_sythetic_binary_tree(depth) 64 | np.save(f'./depth_{depth}.npy', data) 65 | np.save(f'./feature_{depth}.npy', feature) 66 | 67 | -------------------------------------------------------------------------------- /tasks/WordNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .evaluation import evaluation 3 | -------------------------------------------------------------------------------- /tasks/WordNet/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from torch.utils import data 4 | 5 | from .utils import slurp 6 | 7 | 8 | class Dataset(data.Dataset): 9 | def __init__(self, args): 10 | self.args = args 11 | file_name = Path(args.data_dir).expanduser() / f'{args.data_type}_closure.tsv' 12 | indices, objects = slurp(file_name.as_posix(), symmetrize=False) 13 | 14 | self.relations = indices[:, :2] 15 | self.words = objects 16 | self.n_negatives = args.n_negatives 17 | self.n_words = len(self.words) 18 | 19 | def __len__(self): 20 | return len(self.relations) 21 | 22 | def __getitem__(self, i): 23 | return np.r_[ 24 | self.relations[i], 25 | np.random.randint(self.n_words, size=self.n_negatives) 26 | ] 27 | 28 | -------------------------------------------------------------------------------- /tasks/WordNet/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from sklearn import metrics 4 | 5 | from .utils import create_adjacency, calculate_energy 6 | 7 | 8 | def evaluation( 9 | args, 10 | model, 11 | dataset, 12 | dist_fn 13 | ): 14 | ranks = [] 15 | ap_scores = [] 16 | 17 | adjacency = create_adjacency(dataset.relations) 18 | 19 | iterator = tqdm(adjacency.items()) 20 | batch_size = dataset.n_words // 10 21 | for i, (source, targets) in enumerate(iterator): 22 | if i % 1000 != 0: 23 | continue 24 | input_ = np.c_[ 25 | source * np.ones(dataset.n_words).astype(np.int64), 26 | np.arange(dataset.n_words) 27 | ] 28 | _energies = calculate_energy( 29 | model, 30 | input_, 31 | args.test_samples, 32 | batch_size, 33 | dist_fn 34 | ).detach().cpu().numpy() 35 | 36 | _energies[source] = 1e+12 37 | _labels = np.zeros(dataset.n_words) 38 | _energies_masked = _energies.copy() 39 | _ranks = [] 40 | for o in targets: 41 | _energies_masked[o] = np.Inf 42 | _labels[o] = 1 43 | ap_scores.append(metrics.average_precision_score(_labels, -_energies)) 44 | for o in targets: 45 | ene = _energies_masked.copy() 46 | ene[o] = _energies[o] 47 | r = np.argsort(ene) 48 | _ranks.append(np.where(r == o)[0][0] + 1) 49 | ranks += _ranks 50 | 51 | return np.mean(ranks), np.mean(ap_scores) 52 | 53 | -------------------------------------------------------------------------------- /tasks/WordNet/mammals_filter.txt: -------------------------------------------------------------------------------- 1 | \sliving_thing.n.01 2 | \sobject.n.01 3 | \sorganism.n.01 4 | \sanimal.n.01 5 | \sentity.n.01 6 | \sphysical_entity.n.01 7 | \swhole.n.02 8 | \svertebrate.n.01 9 | \schordate.n.01 10 | \sbeast_of_burden.n.01 11 | \swork_animal.n.01 12 | \sfemale.n.01 13 | \sfissipedia.n.01 14 | \spup.n.01 15 | \sabstraction.n.06 16 | \sgroup.n.01 17 | ^tusker.n.01 18 | ^female_mammal.n.01 19 | \scub.n.03 20 | \syoung.n.01 21 | \syoung_mammal.n.01 22 | \sdomestic_animal.n.01 23 | \sracer.n.03 24 | \smale.n.01 25 | -------------------------------------------------------------------------------- /tasks/WordNet/utils.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/pfnet-research/hyperbolic_wrapped_distribution/blob/master/lib/dataset/wordnet.py 2 | 3 | 4 | import torch 5 | import pathlib 6 | from itertools import count 7 | from collections import defaultdict 8 | 9 | import nltk 10 | import numpy as np 11 | from nltk.corpus import wordnet as wn 12 | 13 | 14 | def generate_dataset(output_dir, with_mammal=False): 15 | output_path = pathlib.Path(output_dir) / 'noun_closure.tsv' 16 | 17 | # make sure each edge is included only once 18 | edges = set() 19 | for synset in wn.all_synsets(pos='n'): 20 | # write the transitive closure of all hypernyms of a synset to file 21 | for hyper in synset.closure(lambda s: s.hypernyms()): 22 | edges.add((synset.name(), hyper.name())) 23 | 24 | # also write transitive closure for all instances of a synset 25 | for instance in synset.instance_hyponyms(): 26 | for hyper in instance.closure(lambda s: s.instance_hypernyms()): 27 | edges.add((instance.name(), hyper.name())) 28 | for h in hyper.closure(lambda s: s.hypernyms()): 29 | edges.add((instance.name(), h.name())) 30 | 31 | with output_path.open('w') as fout: 32 | for i, j in edges: 33 | fout.write('{}\t{}\n'.format(i, j)) 34 | 35 | if with_mammal: 36 | import subprocess 37 | mammaltxt_path = pathlib.Path(output_dir).resolve() / 'mammals.txt' 38 | mammaltxt = mammaltxt_path.open('w') 39 | mammal = (pathlib.Path(output_dir) / 'mammal_closure.tsv').open('w') 40 | commands_first = [ 41 | ['cat', '{}'.format(output_path)], 42 | ['grep', '-e', r'\smammal.n.01'], 43 | ['cut', '-f1'], 44 | ['sed', r's/\(.*\)/\^\1/g'] 45 | ] 46 | commands_second = [ 47 | ['cat', '{}'.format(output_path)], 48 | ['grep', '-f', '{}'.format(mammaltxt_path)], 49 | ['grep', '-v', '-f', '{}'.format( 50 | 'mammals_filter.txt' 51 | )] 52 | ] 53 | for writer, commands in zip([mammaltxt, mammal], [commands_first, commands_second]): 54 | for i, c in enumerate(commands): 55 | if i == 0: 56 | p = subprocess.Popen(c, stdout=subprocess.PIPE) 57 | elif i == len(commands) - 1: 58 | p = subprocess.Popen(c, stdin=p.stdout, stdout=writer) 59 | else: 60 | p = subprocess.Popen(c, stdin=p.stdout, stdout=subprocess.PIPE) 61 | # prev_p = p 62 | p.communicate() 63 | mammaltxt.close() 64 | mammal.close() 65 | 66 | 67 | def parse_seperator(line, length, sep='\t'): 68 | d = line.strip().split(sep) 69 | if len(d) == length: 70 | w = 1 71 | elif len(d) == length + 1: 72 | w = int(d[-1]) 73 | d = d[:-1] 74 | else: 75 | raise RuntimeError('Malformed input ({})'.format(line.strip())) 76 | return tuple(d) + (w,) 77 | 78 | 79 | def parse_tsv(line, length=2): 80 | return parse_seperator(line, length, '\t') 81 | 82 | 83 | def iter_line(file_name, parse_function, length=2, comment='#'): 84 | with open(file_name, 'r') as fin: 85 | for line in fin: 86 | if line[0] == comment: 87 | continue 88 | tpl = parse_function(line, length=length) 89 | if tpl is not None: 90 | yield tpl 91 | 92 | 93 | def intmap_to_list(d): 94 | arr = [None for _ in range(len(d))] 95 | for v, i in d.items(): 96 | arr[i] = v 97 | assert not any(x is None for x in arr) 98 | return arr 99 | 100 | 101 | def slurp(file_name, parse_function=parse_tsv, symmetrize=False): 102 | ecount = count() 103 | enames = defaultdict(ecount.__next__) 104 | 105 | subs = [] 106 | for i, j, w in iter_line(file_name, parse_function, length=2): 107 | if i == j: 108 | continue 109 | subs.append((enames[i], enames[j], w)) 110 | if symmetrize: 111 | subs.append((enames[j], enames[i], w)) 112 | idx = np.array(subs, dtype=np.int64) 113 | 114 | # freeze defaultdicts after training data and convert to arrays 115 | objects = intmap_to_list(dict(enames)) 116 | print('slurp: file_name={}, objects={}, edges={}'.format( 117 | file_name, len(objects), len(idx))) 118 | return idx, objects 119 | 120 | 121 | def create_adjacency(indices): 122 | adjacency = defaultdict(set) 123 | for i in range(len(indices)): 124 | s, o = indices[i] 125 | adjacency[s].add(o) 126 | return adjacency 127 | 128 | 129 | def calculate_energy(model, x, test_samples, batch_size, dist_fn): 130 | x = torch.tensor(x).cuda() 131 | kl_target = torch.zeros(x.size(0)).cuda() 132 | nb_batch = np.ceil(x.size(0) / batch_size).astype(int) 133 | 134 | for i in range(nb_batch): 135 | idx_start = i * batch_size 136 | idx_end = (i + 1) * batch_size 137 | data = x[idx_start:idx_end] 138 | 139 | mean, covar = model(data) 140 | dist_anchor = dist_fn(mean[:, 0, :], covar[:, 0, :]) 141 | dist_target = dist_fn(mean[:, 1, :], covar[:, 1, :]) 142 | 143 | z = dist_anchor.rsample(test_samples) 144 | log_prob_anchor = dist_anchor.log_prob(z) 145 | log_prob_target = dist_target.log_prob(z) 146 | kl_target[idx_start:idx_end] = (log_prob_anchor - log_prob_target).mean(dim=0) 147 | 148 | return kl_target 149 | 150 | 151 | if __name__ == "__main__": 152 | try: 153 | nltk.data.find('corpora/wordnet') 154 | except LookupError: 155 | print('wordnet dataset is not found, start download') 156 | nltk.download('wordnet') 157 | print('generate dataset') 158 | generate_dataset('../../data/', with_mammal=False) 159 | generate_dataset('../../data/', with_mammal=True) 160 | 161 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-postech/RoWN/288830d28b6012a3e6af00b3b32e982732117d4c/tasks/__init__.py -------------------------------------------------------------------------------- /train_embedding.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/pfnet-research/hyperbolic_wrapped_distribution/blob/master/lib/models/embedding.py 2 | 3 | 4 | import copy 5 | import wandb 6 | import torch 7 | import argparse 8 | import importlib 9 | import numpy as np 10 | from math import ceil 11 | from torch.optim import Adagrad 12 | from torch.nn import functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from tasks.WordNet import Dataset, evaluation 16 | 17 | 18 | class LRScheduler(): 19 | def __init__(self, optimizer, lr, c, n_burnin_steps): 20 | self.optimizer = optimizer 21 | self.lr = lr 22 | self.n_burnin_steps = n_burnin_steps 23 | self.c = c 24 | self.n_steps = 0 25 | 26 | def step_and_update_lr(self): 27 | self._update_learning_rate() 28 | self.optimizer.step() 29 | 30 | def zero_grad(self): 31 | self.optimizer.zero_grad() 32 | 33 | def _update_learning_rate(self): 34 | self.n_steps += 1 35 | if self.n_steps <= self.n_burnin_steps: 36 | lr = self.lr / self.c 37 | else: 38 | lr = self.lr 39 | 40 | for param_group in self.optimizer.param_groups: 41 | param_group['lr'] = lr 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(add_help=True) 46 | parser.add_argument('--data_dir', type=str, default='data/') 47 | parser.add_argument('--data_type', type=str, default='noun') 48 | parser.add_argument('--n_negatives', type=int, default=1) 49 | parser.add_argument('--latent_dim', type=int) 50 | parser.add_argument('--batch_size', type=int, default=50000) 51 | parser.add_argument('--lr', type=float, default=0.6) 52 | parser.add_argument('--n_epochs', type=int, default=10000) 53 | parser.add_argument('--dist', type=str, choices=['EuclideanNormal', 'IsotropicHWN', 'DiagonalHWN', 'RoWN', 'FullHWN']) 54 | parser.add_argument('--initial_sigma', type=float, default=0.01) 55 | parser.add_argument('--bound', type=float, default=37) 56 | parser.add_argument('--train_samples', type=int, default=1) 57 | parser.add_argument('--test_samples', type=int, default=100) 58 | parser.add_argument('--eval_interval', type=int, default=1000) 59 | parser.add_argument('--seed', type=int, default=1234) 60 | parser.add_argument('--c', type=float, default=40) 61 | parser.add_argument('--burnin_epochs', type=int, default=100) 62 | parser.add_argument('--device', type=str, default='cuda:0') 63 | args = parser.parse_args() 64 | 65 | np.random.seed(args.seed) 66 | torch.manual_seed(args.seed) 67 | torch.backends.cudnn.deterministic = True 68 | torch.backends.cudnn.benchmark = True 69 | torch.set_default_tensor_type(torch.DoubleTensor) 70 | 71 | dataset = Dataset(args) 72 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) 73 | 74 | dist_module = importlib.import_module(f'distributions.{args.dist}') 75 | model = getattr(dist_module, 'EmbeddingLayer')(args, dataset.n_words).to(args.device) 76 | dist_fn = getattr(dist_module, 'Distribution') 77 | 78 | optimizer = Adagrad(model.parameters(), lr=args.lr) 79 | n_batches = int(ceil(len(dataset) / args.batch_size)) 80 | n_burnin_steps = args.burnin_epochs * n_batches 81 | lr_scheduler = LRScheduler(optimizer, args.lr, args.c, n_burnin_steps) 82 | 83 | best_model = copy.deepcopy(model) 84 | best_score = None 85 | 86 | wandb.init(project='RoWN') 87 | wandb.run.name = 'wordnet' 88 | wandb.config.update(args) 89 | for epoch in range(1, args.n_epochs + 1): 90 | total_loss, total_kl_target, total_kl_negative = 0., 0., 0. 91 | total_diff = 0. 92 | n_batches = 0 93 | model.train() 94 | 95 | for x in loader: 96 | for param in model.parameters(): 97 | param.grad = None 98 | x = x.cuda() 99 | mean, covar = model(x) 100 | dist_anchor = dist_fn(mean[:, 0, :], covar[:, 0, :]) 101 | dist_target = dist_fn(mean[:, 1, :], covar[:, 1, :]) 102 | dist_negative = dist_fn(mean[:, 2, :], covar[:, 2, :]) 103 | 104 | z = dist_anchor.rsample(args.train_samples) 105 | log_prob_anchor = dist_anchor.log_prob(z) 106 | log_prob_target = dist_target.log_prob(z) 107 | log_prob_negative = dist_negative.log_prob(z) 108 | kl_target = (log_prob_anchor - log_prob_target).mean(dim=0) 109 | kl_negative = (log_prob_anchor - log_prob_negative).mean(dim=0) 110 | 111 | loss = F.relu(args.bound + kl_target - kl_negative).mean() 112 | loss.backward() 113 | lr_scheduler.step_and_update_lr() 114 | 115 | total_loss += loss.item() * kl_target.size(0) 116 | total_kl_target += kl_target.sum().item() 117 | total_kl_negative += kl_negative.sum().item() 118 | total_diff += (kl_target - kl_negative).sum().item() 119 | n_batches += kl_target.size(0) 120 | 121 | if best_score is None or best_score > total_loss: 122 | best_score = total_loss 123 | best_model = copy.deepcopy(model) 124 | 125 | print(f"Epoch {epoch:8d} | Total loss: {total_loss / n_batches:.3f} | KL Target: {total_kl_target / n_batches:.3f} | KL Negative: {total_kl_negative / n_batches:.3f}") 126 | wandb.log({ 127 | 'epoch': epoch, 128 | 'train_loss': total_loss / n_batches, 129 | 'train_kl_target': total_kl_target / n_batches, 130 | 'train_kl_negative': total_kl_negative / n_batches 131 | }) 132 | 133 | if epoch % args.eval_interval == 0 or epoch == args.n_epochs: 134 | best_model.eval() 135 | rank, ap = evaluation(args, best_model, dataset, dist_fn) 136 | print(f"===========> Mean rank: {rank} | MAP: {ap}") 137 | wandb.log({ 138 | 'epoch': epoch, 139 | 'rank': rank, 140 | 'map': ap 141 | }) 142 | 143 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import wandb 3 | import datetime 4 | import importlib 5 | from pathlib import Path 6 | 7 | import torch 8 | import argparse 9 | import numpy as np 10 | from torch.optim import Adam 11 | from torch.utils.data import DataLoader 12 | 13 | from vae import VAE 14 | from arguments import add_train_args, get_initial_parser 15 | 16 | 17 | def init_weights(m): 18 | if isinstance(m, torch.nn.Linear): 19 | torch.nn.init.xavier_normal_(m.weight) 20 | elif isinstance(m, torch.nn.Conv2d): 21 | torch.nn.init.xavier_normal_(m.weight) 22 | elif isinstance(m, torch.nn.ConvTranspose2d): 23 | torch.nn.init.xavier_normal_(m.weight) 24 | 25 | 26 | def train(epoch, args, train_loader, vae, optimizer): 27 | n_data = 0 28 | train_elbo, train_recon, train_kl = 0., 0., 0. 29 | 30 | for x in train_loader: 31 | for param in vae.parameters(): 32 | param.grad = None 33 | x = x.to(args.device) 34 | loss, elbo, _, _, recon_loss, kl_loss = vae( 35 | x, 36 | args.train_samples, 37 | args.beta 38 | ) 39 | 40 | loss.backward() 41 | optimizer.step() 42 | 43 | n_data += x.size(0) 44 | train_elbo += elbo.item() 45 | train_recon += recon_loss.item() 46 | train_kl += kl_loss.item() 47 | 48 | if epoch % args.log_interval == 0 or epoch == args.n_epochs: 49 | train_elbo /= n_data 50 | train_recon /= n_data 51 | train_kl /= n_data 52 | print(f'Epoch: {epoch:6d} | ELBO: {train_elbo:.2f} | Recon Loss: {train_recon:.2f} | KL: {train_kl:.3f}') 53 | wandb.log({ 54 | 'epoch': epoch, 55 | 'train_elbo': train_elbo, 56 | 'train_recon': train_recon, 57 | 'train_kl': train_kl 58 | }) 59 | 60 | return train_elbo 61 | 62 | 63 | def eval(epoch, args, test_loader, vae, root_dir, test_data, eval_fn): 64 | log_dir = root_dir / str(epoch) 65 | log_dir.mkdir(parents=True, exist_ok=True) 66 | 67 | with torch.no_grad(): 68 | vae.eval() 69 | 70 | n_data = 0 71 | means = torch.empty(( 72 | 0, 73 | (args.latent_dim + 1 if args.dist != 'EuclideanNormal' else args.latent_dim) 74 | ), device=args.device) 75 | total_elbo, total_recon, total_kl = 0., 0., 0. 76 | for x in test_loader: 77 | x = x.to(args.device) 78 | _, elbo, _, means_, recon_loss, kl_loss = vae(x, args.test_samples) 79 | 80 | n_data += x.size(0) 81 | total_elbo += elbo.item() 82 | total_recon += recon_loss.item() 83 | total_kl += kl_loss.item() 84 | means = torch.concat((means, means_)) 85 | 86 | total_elbo /= n_data 87 | total_recon /= n_data 88 | total_kl /= n_data 89 | print(f'===========> Test ELBO: {total_elbo:.2f} | Test Recon: {total_recon:.2f} | Test KL: {total_kl:.2f}') 90 | wandb.log({ 91 | 'test_elbo': total_elbo, 92 | 'test_recon': total_recon, 93 | 'test_kl': total_kl 94 | }) 95 | 96 | eval_fn(args, vae, test_data, means) 97 | torch.save(vae.state_dict(), root_dir / 'model.pt') 98 | 99 | 100 | if __name__ == "__main__": 101 | init_parser = get_initial_parser() 102 | task_name = init_parser.parse_known_args()[0].task 103 | task_module = importlib.import_module(f'tasks.{task_name}') 104 | dist_name = init_parser.parse_known_args()[0].dist 105 | dist_module = importlib.import_module(f'distributions.{dist_name}') 106 | 107 | parser = argparse.ArgumentParser() 108 | add_train_args(parser) 109 | getattr(task_module, 'add_task_args')(parser) 110 | args = parser.parse_args() 111 | 112 | if args.task == 'NSBT': 113 | args.latent_dim = args.depth 114 | args.n_hids = 8 * (2 ** args.depth) 115 | # args.train_batch_size = 2 ** args.depth - 1 116 | 117 | np.random.seed(args.seed) 118 | torch.manual_seed(args.seed) 119 | torch.backends.cudnn.deterministic = True 120 | torch.backends.cudnn.benchmark = True 121 | torch.set_default_tensor_type(torch.DoubleTensor) 122 | 123 | runId = datetime.datetime.now().isoformat().replace(':', '_') 124 | root_dir = Path(args.log_dir) / runId 125 | 126 | train_data = getattr(task_module, 'Dataset')(args, is_train=True) 127 | train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True) 128 | test_data = getattr(task_module, 'Dataset')(args, is_train=False) 129 | test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False) 130 | eval_fn = getattr(task_module, 'evaluation') 131 | 132 | variational_fn = getattr(dist_module, 'Distribution') 133 | prior = getattr(dist_module, 'get_prior')(args) 134 | 135 | encoder = getattr(task_module, 'Encoder')(args) 136 | encoder_layer = getattr(dist_module, 'EncoderLayer')(args, encoder.output_dim) 137 | decoder = getattr(task_module, 'Decoder')(args) 138 | decoder_layer = getattr(dist_module, 'DecoderLayer')() 139 | 140 | recon_loss_type = getattr(task_module, 'recon_loss_type') 141 | vae = VAE( 142 | prior, 143 | variational_fn, 144 | encoder, 145 | encoder_layer, 146 | decoder, 147 | decoder_layer, 148 | recon_loss_type 149 | ) 150 | # vae.apply(init_weights) 151 | vae = vae.to(args.device) 152 | 153 | optimizer = Adam( 154 | list(encoder.parameters()) + list(decoder.parameters()), 155 | lr=args.lr 156 | ) 157 | 158 | wandb.init(project='RoWN') 159 | wandb.run.name = args.exp_name 160 | wandb.config.update(args) 161 | print(root_dir) 162 | 163 | best_model = copy.deepcopy(vae) 164 | best_elbo = -1e9 165 | 166 | for epoch in range(1, args.n_epochs + 1): 167 | vae.train() 168 | train_elbo = train(epoch, args, train_loader, vae, optimizer) 169 | if best_elbo < train_elbo: 170 | best_elbo = train_elbo 171 | best_model = copy.deepcopy(vae) 172 | 173 | if epoch % args.eval_interval == 0 or epoch == args.n_epochs: 174 | best_model.eval() 175 | eval(epoch, args, test_loader, best_model, root_dir, test_data, eval_fn) 176 | 177 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, 8 | prior, 9 | dist, 10 | encoder, 11 | encoder_layer, 12 | decoder, 13 | decoder_layer, 14 | loss_type 15 | ): 16 | super().__init__() 17 | 18 | self.prior = prior 19 | self.dist = dist 20 | self.encoder = encoder 21 | self.encoder_layer = encoder_layer 22 | self.decoder = decoder 23 | self.decoder_layer = decoder_layer 24 | self.loss_type = loss_type 25 | 26 | def forward(self, x, n_samples=1, beta=1.): 27 | mean, covar = self.encoder_layer(self.encoder(x)) 28 | variational = self.dist(mean, covar) 29 | 30 | z = variational.rsample(n_samples) 31 | log_prob_base = variational.log_prob(z) 32 | log_prob_target = self.prior.log_prob(z) 33 | kl_loss = (log_prob_base - log_prob_target).mean(dim=0) 34 | 35 | x_generated = self.generate(z) 36 | if self.loss_type == 'BCE': 37 | recon_loss = F.binary_cross_entropy( 38 | x_generated, 39 | x.unsqueeze(0).expand(x_generated.size()), 40 | reduction='none' 41 | ) 42 | else: 43 | recon_loss = F.gaussian_nll_loss( 44 | x_generated, 45 | x.unsqueeze(0).expand(x_generated.size()), 46 | torch.ones(x_generated.size(), device=x.device) * 0.01, 47 | reduction='none' 48 | ) 49 | 50 | while len(recon_loss.size()) > 2: 51 | recon_loss = recon_loss.sum(-1) 52 | recon_loss = recon_loss.mean(dim=0) 53 | 54 | total_loss_sum = recon_loss + beta * kl_loss 55 | loss = total_loss_sum.mean() 56 | 57 | recon_loss = recon_loss.sum() 58 | kl_loss = kl_loss.sum() 59 | elbo = -(recon_loss + kl_loss) 60 | 61 | return loss, elbo, z, mean, recon_loss, kl_loss 62 | 63 | def generate(self, z): 64 | return self.decoder(self.decoder_layer(z)) 65 | 66 | --------------------------------------------------------------------------------