├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── argmax_utils.py ├── assets ├── fig1.png └── voronoi.png ├── charlm ├── LICENSE ├── README.md ├── data │ ├── data.py │ ├── dataset_enwik8.py │ ├── dataset_text8.py │ └── vocab.py ├── experiment │ ├── base.py │ ├── flow.py │ └── utils.py ├── model │ ├── CategoricalNF │ │ ├── __init__.py │ │ ├── categorical_encoding │ │ │ ├── __init__.py │ │ │ ├── decoder.py │ │ │ ├── linear_encoding.py │ │ │ ├── mutils.py │ │ │ ├── variational_auto_encoding.py │ │ │ ├── variational_dequantization.py │ │ │ └── variational_encoding.py │ │ ├── flows │ │ │ ├── activation_normalization.py │ │ │ ├── autoregressive_coupling.py │ │ │ ├── autoregressive_coupling2.py │ │ │ ├── coupling_layer.py │ │ │ ├── discrete_coupling.py │ │ │ ├── distributions.py │ │ │ ├── flow_layer.py │ │ │ ├── flow_model.py │ │ │ ├── mixture_cdf_layer.py │ │ │ ├── permutation_layers.py │ │ │ └── sigmoid_layer.py │ │ ├── general │ │ │ ├── README.md │ │ │ ├── mutils.py │ │ │ ├── parameter_scheduler.py │ │ │ ├── radam.py │ │ │ ├── task.py │ │ │ └── train.py │ │ └── networks │ │ │ ├── autoregressive_layers.py │ │ │ ├── autoregressive_layers2.py │ │ │ ├── graph_layers.py │ │ │ └── help_layers.py │ ├── ar │ │ ├── encoder_context.py │ │ ├── encoder_transforms.py │ │ ├── flow.py │ │ └── vorflow.py │ ├── coupling │ │ ├── cond_ar_affine.py │ │ ├── cond_ar_spline.py │ │ ├── flow.py │ │ ├── masked_linear.py │ │ └── vorflow.py │ ├── distributions │ │ ├── __init__.py │ │ ├── binary_encoder.py │ │ ├── conv_normal1d.py │ │ └── gumbel.py │ ├── model.py │ └── transforms │ │ ├── __init__.py │ │ ├── argmax_product.py │ │ ├── autoregressive │ │ ├── __init__.py │ │ ├── ar.py │ │ ├── ar_linear.py │ │ ├── ar_mixtures.py │ │ ├── ar_splines.py │ │ ├── conditional │ │ │ ├── __init__.py │ │ │ ├── ar.py │ │ │ ├── ar_linear.py │ │ │ ├── ar_mixtures.py │ │ │ └── ar_splines.py │ │ └── utils.py │ │ ├── squeeze1d.py │ │ ├── utils.py │ │ └── voronoi_surjection.py ├── optim │ ├── base.py │ └── expdecay.py └── utils.py ├── configs ├── charlm.yaml ├── default.yaml ├── discrete2d.yaml ├── disjoint2d.yaml ├── disjoint_uci.yaml ├── itemsets.yaml └── uci_categorical.yaml ├── data ├── discrete_8gaussians_pmf.csv ├── discrete_pinwheel_pmf.csv ├── download_itemsets.py └── download_uci.py ├── datasets ├── LICENSE.txt ├── __init__.py ├── bsds300.py ├── gas.py ├── hepmass.py ├── miniboone.py └── power.py ├── diagnostics └── voronoi_plot_2d.py ├── discrete_datasets.py ├── layers ├── __init__.py ├── act_norm.py ├── autoreg.py ├── base │ ├── __init__.py │ ├── activations.py │ ├── lipschitz.py │ ├── mixed_lipschitz.py │ └── utils.py ├── cnf.py ├── container.py ├── coupling.py ├── diffeq_layers │ ├── __init__.py │ ├── basic.py │ ├── container.py │ ├── normalization.py │ ├── resnet.py │ └── wrappers.py ├── elemwise.py ├── glow.py ├── iresblock.py ├── made.py ├── mixlogcdf.py ├── normalization.py ├── softmax.py └── squeeze.py ├── model_itemsets.py ├── model_utils.py ├── multihead_attention.py ├── multiscale_flow.py ├── test_voronoi.py ├── toy_data.py ├── train_charlm.py ├── train_discrete2d.py ├── train_disjoint2d.py ├── train_disjoint_uci.py ├── train_itemsets.py ├── train_uci.py ├── utils.py └── voronoi.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .DS_Store 4 | .vscode 5 | *.png 6 | *.pdf 7 | data/* 8 | *.pt -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to semi-discrete-flow 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to semi-discrete-flow, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Discrete Normalizing Flows through Differentiable Tessellation 2 | 3 | This is code for the NeurIPS 2022 paper https://arxiv.org/abs/2203.06832. 4 | 5 |

6 | Semi-Discrete Flow 7 |

8 | 9 | ## Bijective transformation onto the Voronoi cell 10 | 11 | This scripts provides some examples of using the Voronoi tessellation and bijective transformation used in the paper. 12 | ``` 13 | python test_voronoi.py 14 | ``` 15 | This creates the following figure, among other visualizations. 16 | 17 |

18 | Semi-Discrete Flow 19 |

20 | 21 | ## Data 22 | 23 | Scripts to download the discrete UCI and itemset datasets are provided. 24 | 25 | ``` 26 | cd data 27 | python download_itemsets.py 28 | python download_uci.py 29 | ``` 30 | 31 | Download the continuous UCI data from https://github.com/gpapamak/maf#how-to-get-the-datasets and place them in the `data/` directory. 32 | 33 | ## Experiments 34 | 35 | Discrete 2D: 36 | ``` 37 | python train_discrete2d.py -m dataset=cluster3,cluster10,discrete_8gaussians embedding_dim=2,4,8 seed=0,1,2 38 | ``` 39 | 40 | UCI Categorical: 41 | ``` 42 | python train_uci.py -m dequantization=voronoi dataset=mushroom,nursery,connect4,uscensus90,pokerhand,forests num_transformer_layers=2,3 num_flows=4,8 embedding_dim=2,4,6 share_embeddings=True,False base=gaussian,resampled 43 | ``` 44 | 45 | Character-level Language Modeling: 46 | 47 | Install the `survae` dependency 48 | ``` 49 | pip install git+https://github.com/didriknielsen/survae_flows 50 | ``` 51 | 52 | ``` 53 | python train_charlm.py -m dequantization=voronoi dataset=text8,enwik8 model=ar embedding_dim=5,8,12 54 | ``` 55 | 56 | ## Citations 57 | If you find this repository helpful for your publications, 58 | please consider citing our paper: 59 | 60 | ``` 61 | @inproceedings{ 62 | chen2022semidiscrete, 63 | title={Semi-Discrete Normalizing Flows through Differentiable Tessellation}, 64 | author={Ricky T. Q. Chen and Brandon Amos and Maximilian Nickel}, 65 | booktitle={Advances in Neural Information Processing Systems}, 66 | year={2022}, 67 | } 68 | ``` 69 | 70 | ## License 71 | This repository is licensed under the 72 | [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). -------------------------------------------------------------------------------- /argmax_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Didrik Nielsen 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import torch 26 | 27 | 28 | def integer_to_base(idx_tensor, base, dims): 29 | """ 30 | Encodes index tensor to a Cartesian product representation. 31 | Args: 32 | idx_tensor (LongTensor): An index tensor, shape (...), to be encoded. 33 | base (int): The base to use for encoding. 34 | dims (int): The number of dimensions to use for encoding. 35 | Returns: 36 | LongTensor: The encoded tensor, shape (..., dims). 37 | """ 38 | powers = base ** torch.arange(dims - 1, -1, -1, device=idx_tensor.device) 39 | floored = torch.div(idx_tensor[..., None], powers, rounding_mode="floor") 40 | remainder = floored % base 41 | 42 | base_tensor = remainder 43 | return base_tensor 44 | 45 | 46 | def base_to_integer(base_tensor, base): 47 | """ 48 | Decodes Cartesian product representation to an index tensor. 49 | Args: 50 | base_tensor (LongTensor): The encoded tensor, shape (..., dims). 51 | base (int): The base used in the encoding. 52 | Returns: 53 | LongTensor: The index tensor, shape (...). 54 | """ 55 | dims = base_tensor.shape[-1] 56 | powers = base ** torch.arange(dims - 1, -1, -1, device=base_tensor.device) 57 | powers = powers[(None,) * (base_tensor.dim() - 1)] 58 | 59 | idx_tensor = (base_tensor * powers).sum(-1) 60 | return idx_tensor 61 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/semi-discrete-flow/005b4e87ba44d86cc5ebe6cd843b6e168ac51320/assets/fig1.png -------------------------------------------------------------------------------- /assets/voronoi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/semi-discrete-flow/005b4e87ba44d86cc5ebe6cd843b6e168ac51320/assets/voronoi.png -------------------------------------------------------------------------------- /charlm/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Didrik Nielsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /charlm/README.md: -------------------------------------------------------------------------------- 1 | Code is modified based on original version from https://github.com/didriknielsen/argmax_flows. -------------------------------------------------------------------------------- /charlm/data/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, ConcatDataset 3 | from .dataset_text8 import Text8 4 | from .dataset_enwik8 import EnWik8 5 | 6 | dataset_choices = {"text8", "enwik8"} 7 | 8 | 9 | def get_data(args): 10 | assert args.dataset in dataset_choices 11 | 12 | # Dataset 13 | if args.dataset == "text8": 14 | data = Text8(seq_len=256) 15 | data_shape = (1, 256) 16 | num_classes = 27 17 | elif args.dataset == "enwik8": 18 | data = EnWik8(seq_len=320) 19 | data_shape = (1, 320) 20 | num_classes = 256 21 | 22 | # Data Loader 23 | dataset_train = ConcatDataset([data.train, data.valid]) 24 | train_loader = DataLoader( 25 | dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=4 26 | ) 27 | valid_loader = DataLoader( 28 | data.valid, batch_size=args.test_batch_size, shuffle=False, num_workers=4 29 | ) 30 | test_loader = DataLoader( 31 | data.test, batch_size=args.test_batch_size, shuffle=False, num_workers=4 32 | ) 33 | return train_loader, valid_loader, test_loader, data_shape, num_classes 34 | -------------------------------------------------------------------------------- /charlm/data/dataset_enwik8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import zipfile 5 | import urllib.request 6 | from torch.utils.data import Dataset 7 | from survae.data import TrainValidTestLoader, DATA_PATH 8 | from .vocab import Vocab 9 | 10 | 11 | class EnWik8(TrainValidTestLoader): 12 | def __init__(self, root=DATA_PATH, seq_len=256, download=True): 13 | self.train = EnWik8Dataset(root, seq_len=seq_len, split='train', download=download) 14 | self.valid = EnWik8Dataset(root, seq_len=seq_len, split='valid') 15 | self.test = EnWik8Dataset(root, seq_len=seq_len, split='test') 16 | 17 | 18 | class EnWik8Dataset(Dataset): 19 | 20 | def __init__(self, root=DATA_PATH, seq_len=256, split='train', download=False): 21 | assert split in {'train', 'valid', 'test'} 22 | self.root = os.path.join(root, 'enwik8') 23 | self.seq_len = seq_len 24 | self.split = split 25 | 26 | if not os.path.exists(self.raw_file): 27 | if download: 28 | self.download() 29 | else: 30 | raise RuntimeError('Dataset not found. You can use download=True to download it.') 31 | 32 | # Get vocabulary 33 | self.vocab = Vocab() 34 | vocab_file = os.path.join(self.root, 'vocab.json') 35 | stoi = self._create_stoi() 36 | self.vocab.fill(stoi) 37 | 38 | # Load data 39 | self.data = self._preprocess_data(split).unsqueeze(1) 40 | 41 | def __getitem__(self, index): 42 | return self.data[index], self.seq_len 43 | 44 | def __len__(self): 45 | return len(self.data) 46 | 47 | def _create_stoi(self): 48 | # Just a simple identity conversion for 8bit (byte)-valued chunks. 49 | stoi = {i: i for i in range(256)} 50 | return stoi 51 | 52 | def _preprocess_data(self, split): 53 | # Read raw data 54 | rawdata = zipfile.ZipFile(self.raw_file).read('enwik8') 55 | 56 | n_train = int(90e6) 57 | n_valid = int(5e6) 58 | n_test = int(5e6) 59 | 60 | # Extract subset 61 | if split == 'train': 62 | rawdata = rawdata[:n_train] 63 | elif split == 'valid': 64 | rawdata = rawdata[n_train:n_train+n_valid] 65 | elif split == 'test': 66 | rawdata = rawdata[n_train+n_valid:n_train+n_valid+n_test] 67 | 68 | # Encode characters 69 | data = torch.tensor([self.vocab.stoi[s] for s in rawdata]) 70 | 71 | # Split into chunks 72 | data = data.reshape(-1, self.seq_len) 73 | 74 | return data 75 | 76 | @property 77 | def raw_file(self): 78 | return os.path.join(self.root, 'enwik8.zip') 79 | 80 | def download(self): 81 | if not os.path.exists(self.root): 82 | os.makedirs(self.root) 83 | 84 | print('Downloading enwik8...') 85 | url = 'http://mattmahoney.net/dc/enwik8.zip' 86 | print('Downloading from {}...'.format(url)) 87 | urllib.request.urlretrieve(url, self.raw_file) 88 | print('Saved to {}'.format(self.raw_file)) 89 | -------------------------------------------------------------------------------- /charlm/data/dataset_text8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import urllib.request 5 | import zipfile 6 | import json 7 | from survae.data import TrainValidTestLoader, DATA_PATH 8 | 9 | 10 | class Text8(TrainValidTestLoader): 11 | def __init__(self, root=DATA_PATH, seq_len=256, download=True): 12 | self.train = Text8Dataset(root, seq_len=seq_len, split='train', download=download) 13 | self.valid = Text8Dataset(root, seq_len=seq_len, split='valid') 14 | self.test = Text8Dataset(root, seq_len=seq_len, split='test') 15 | 16 | 17 | class Text8Dataset(data.Dataset): 18 | """ 19 | The text8 dataset consisting of 100M characters (with vocab size 27). 20 | We here split the dataset into (90M, 5M, 5M) characters for 21 | (train, val, test) as in [1,2,3]. 22 | The sets are then split into chunks of equal length as specified by `seq_len`. 23 | The default is 256, corresponding to what was used in [1]. Other choices 24 | include 180, as [2] reports using. 25 | [1] Discrete Flows: Invertible Generative Models of Discrete Data 26 | Tran et al., 2019, https://arxiv.org/abs/1905.10347 27 | [2] Architectural Complexity Measures of Recurrent Neural Networks 28 | Zhang et al., 2016, https://arxiv.org/abs/1602.08210 29 | [3] Subword Language Modeling with Neural Networks 30 | Mikolov et al., 2013, http://www.fit.vutbr.cz/~imikolov/rnnlm/char.pdf 31 | """ 32 | 33 | def __init__(self, root=DATA_PATH, seq_len=256, split='train', download=False): 34 | assert split in {'train', 'valid', 'test'} 35 | self.root = os.path.join(os.path.expanduser(root), 'text8') 36 | self.seq_len = seq_len 37 | self.split = split 38 | 39 | if not self._check_raw_exists(): 40 | if download: 41 | self.download() 42 | else: 43 | raise RuntimeError('Dataset not found. You can use download=True to download it.') 44 | 45 | if not self._check_processed_exists(split): 46 | self._preprocess_data(split) 47 | 48 | # Load data 49 | self.data = torch.load(self.processed_filename(split)) 50 | 51 | # Load lookup tables 52 | char2idx_file = os.path.join(self.root, 'char2idx.json') 53 | idx2char_file = os.path.join(self.root, 'idx2char.json') 54 | with open(char2idx_file) as f: 55 | self.char2idx = json.load(f) 56 | with open(idx2char_file) as f: 57 | self.idx2char = json.load(f) 58 | 59 | def __getitem__(self, index): 60 | return self.data[index], self.seq_len 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | def s2t(self, s): 66 | assert len(s) == self.seq_len, 'String not of length {}'.format(self.seq_len) 67 | return torch.tensor([self.char2idx[char] for char in s]) 68 | 69 | def t2s(self, t): 70 | return ''.join([self.idx2char[t[i]] if t[i] < len(self.idx2char) else ' ' for i in range(self.seq_len)]) 71 | 72 | def text2tensor(self, text): 73 | if isinstance(text, str): 74 | tensor = self.s2t(text).unsqueeze(0) 75 | else: 76 | tensor = torch.stack([self.s2t(s) for s in text], dim=0) 77 | return tensor.unsqueeze(1) # (B, 1, L) 78 | 79 | def tensor2text(self, tensor): 80 | assert tensor.dim() == 3, 'Tensor should have shape (batch_size, 1, {})'.format(self.seq_len) 81 | assert tensor.shape[1] == 1, 'Tensor should have shape (batch_size, 1, {})'.format(self.seq_len) 82 | assert tensor.shape[2] == self.seq_len, 'Tensor should have shape (batch_size, 1, {})'.format(self.seq_len) 83 | bsize = tensor.shape[0] 84 | text = [self.t2s(tensor[b].squeeze(0)) for b in range(bsize)] 85 | return text 86 | 87 | def _preprocess_data(self, split): 88 | # Read raw data 89 | rawdata = zipfile.ZipFile(self.local_filename).read('text8').decode('utf-8') 90 | 91 | # Extract vocab 92 | vocab = sorted(list(set(rawdata))) 93 | char2idx, idx2char = {}, [] 94 | for i, char in enumerate(vocab): 95 | char2idx[char] = i 96 | idx2char.append(char) 97 | 98 | # Extract subset 99 | if split == 'train': 100 | rawdata = rawdata[:90000000] 101 | elif split == 'valid': 102 | rawdata = rawdata[90000000:95000000] 103 | elif split == 'test': 104 | rawdata = rawdata[95000000:] 105 | 106 | # Encode characters 107 | data = torch.tensor([char2idx[char] for char in rawdata]) 108 | 109 | # Split into chunks 110 | data = data[:self.seq_len*(len(data)//self.seq_len)] 111 | data = data.reshape(-1, 1, self.seq_len) 112 | 113 | # Save processed data 114 | torch.save(data, self.processed_filename(split)) 115 | 116 | # Save lookup tables 117 | char2idx_file = os.path.join(self.root, 'char2idx.json') 118 | idx2char_file = os.path.join(self.root, 'idx2char.json') 119 | with open(char2idx_file, 'w') as f: 120 | json.dump(char2idx, f) 121 | with open(idx2char_file, 'w') as f: 122 | json.dump(idx2char, f) 123 | 124 | @property 125 | def local_filename(self): 126 | return os.path.join(self.root, 'text8.zip') 127 | 128 | def processed_filename(self, split): 129 | return os.path.join(self.root, '{}.pt'.format(split)) 130 | 131 | def download(self): 132 | if not os.path.exists(self.root): 133 | os.makedirs(self.root) 134 | 135 | print('Downloading text8...') 136 | 137 | url = 'http://mattmahoney.net/dc/text8.zip' 138 | print('Downloading from {}...'.format(url)) 139 | urllib.request.urlretrieve(url, self.local_filename) 140 | print('Saved to {}'.format(self.local_filename)) 141 | 142 | def _check_raw_exists(self): 143 | return os.path.exists(self.local_filename) 144 | 145 | def _check_processed_exists(self, split): 146 | return os.path.exists(self.processed_filename(split)) 147 | -------------------------------------------------------------------------------- /charlm/data/vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import warnings 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class Vocab(): 10 | 11 | def __init__(self, stoi={}): 12 | self.fill(stoi) 13 | 14 | def fill(self, stoi): 15 | self.stoi = stoi 16 | self.itos = {i:s for s,i in stoi.items()} 17 | 18 | def save_json(self, path): 19 | if not os.path.exists(path): os.makedirs(path) 20 | vocab_file = os.path.join(path, 'vocab.json') 21 | with open(vocab_file, 'w') as f: 22 | json.dump(self.stoi, f, indent=4) 23 | 24 | def load_json(self, path): 25 | vocab_file = os.path.join(path, 'vocab.json') 26 | with open(vocab_file, 'r') as f: 27 | stoi = json.load(f) 28 | self.fill(stoi) 29 | 30 | def string_to_idx(self, string): 31 | assert isinstance(string, str) 32 | return [self.stoi[s] for s in string] 33 | 34 | def idx_to_string(self, idx): 35 | assert isinstance(idx, list) 36 | count_err = np.sum([1 for i in idx if i not in self.itos]) 37 | if count_err > 0: 38 | print(f'Warning, {count_err} decodings were not in vocab.') 39 | print(set([i for i in idx if i not in self.itos])) 40 | return ''.join([self.itos[i] for i in idx if i in self.itos]) 41 | 42 | def encode(self, text, padding_value=0): 43 | assert isinstance(text, list) 44 | length = torch.tensor([len(string) for string in text]) 45 | tensor_list = [torch.tensor(self.string_to_idx(string)) for string in text] 46 | tensor = nn.utils.rnn.pad_sequence(tensor_list, batch_first=True, padding_value=padding_value) 47 | return tensor, length 48 | 49 | def decode(self, tensor, length): 50 | assert torch.is_tensor(tensor) 51 | assert tensor.dim() == 2, 'Tensor should have shape (batch_size, seq_len)' 52 | text = [self.idx_to_string(tensor[b][:length[b]].tolist()) for b in range(tensor.shape[0])] 53 | return text 54 | -------------------------------------------------------------------------------- /charlm/experiment/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from .utils import get_args_table, get_metric_table 5 | 6 | 7 | class BaseExperiment(object): 8 | 9 | def __init__(self, model, optimizer, scheduler_iter, scheduler_epoch, 10 | log_path, eval_every, check_every): 11 | 12 | # Objects 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.scheduler_iter = scheduler_iter 16 | self.scheduler_epoch = scheduler_epoch 17 | 18 | # Paths 19 | self.log_path = log_path 20 | self.check_path = os.path.join(log_path, 'check') 21 | 22 | # Intervals 23 | self.eval_every = eval_every 24 | self.check_every = check_every 25 | 26 | # Initialize 27 | self.current_epoch = 0 28 | self.train_metrics = {} 29 | self.eval_metrics = {} 30 | self.eval_epochs = [] 31 | 32 | def train_fn(self, epoch): 33 | raise NotImplementedError() 34 | 35 | def eval_fn(self, epoch): 36 | raise NotImplementedError() 37 | 38 | def log_fn(self, epoch, train_dict, eval_dict): 39 | raise NotImplementedError() 40 | 41 | def log_train_metrics(self, train_dict): 42 | if len(self.train_metrics)==0: 43 | for metric_name, metric_value in train_dict.items(): 44 | self.train_metrics[metric_name] = [metric_value] 45 | else: 46 | for metric_name, metric_value in train_dict.items(): 47 | self.train_metrics[metric_name].append(metric_value) 48 | 49 | def log_eval_metrics(self, eval_dict): 50 | if len(self.eval_metrics)==0: 51 | for metric_name, metric_value in eval_dict.items(): 52 | self.eval_metrics[metric_name] = [metric_value] 53 | else: 54 | for metric_name, metric_value in eval_dict.items(): 55 | self.eval_metrics[metric_name].append(metric_value) 56 | 57 | def create_folders(self): 58 | 59 | # Create log folder 60 | os.makedirs(self.log_path) 61 | print("Storing logs in:", self.log_path) 62 | 63 | # Create check folder 64 | if self.check_every is not None: 65 | os.makedirs(self.check_path) 66 | print("Storing checkpoints in:", self.check_path) 67 | 68 | def save_args(self, args): 69 | 70 | # Save args 71 | with open(os.path.join(self.log_path, 'args.pickle'), "wb") as f: 72 | pickle.dump(args, f) 73 | 74 | # Save args table 75 | args_table = get_args_table(vars(args)) 76 | with open(os.path.join(self.log_path,'args_table.txt'), "w") as f: 77 | f.write(str(args_table)) 78 | 79 | def save_metrics(self): 80 | 81 | # Save metrics 82 | with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f: 83 | pickle.dump(self.train_metrics, f) 84 | with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f: 85 | pickle.dump(self.eval_metrics, f) 86 | 87 | # Save metrics table 88 | metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2))) 89 | with open(os.path.join(self.log_path,'metrics_train.txt'), "w") as f: 90 | f.write(str(metric_table)) 91 | metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs]) 92 | with open(os.path.join(self.log_path,'metrics_eval.txt'), "w") as f: 93 | f.write(str(metric_table)) 94 | 95 | def checkpoint_save(self): 96 | checkpoint = {'current_epoch': self.current_epoch, 97 | 'train_metrics': self.train_metrics, 98 | 'eval_metrics': self.eval_metrics, 99 | 'eval_epochs': self.eval_epochs, 100 | 'model': self.model.state_dict(), 101 | 'optimizer': self.optimizer.state_dict(), 102 | 'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None, 103 | 'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None} 104 | torch.save(checkpoint, os.path.join(self.check_path, 'checkpoint.pt')) 105 | 106 | def checkpoint_load(self, check_path): 107 | checkpoint = torch.load(os.path.join(check_path, 'checkpoint.pt')) 108 | self.current_epoch = checkpoint['current_epoch'] 109 | self.train_metrics = checkpoint['train_metrics'] 110 | self.eval_metrics = checkpoint['eval_metrics'] 111 | self.eval_epochs = checkpoint['eval_epochs'] 112 | self.model.load_state_dict(checkpoint['model']) 113 | self.optimizer.load_state_dict(checkpoint['optimizer']) 114 | if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter']) 115 | if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch']) 116 | 117 | def run(self, epochs): 118 | 119 | for epoch in range(self.current_epoch, epochs): 120 | 121 | # Train 122 | train_dict = self.train_fn(epoch) 123 | self.log_train_metrics(train_dict) 124 | 125 | # Eval 126 | if (epoch+1) % self.eval_every == 0: 127 | eval_dict = self.eval_fn(epoch) 128 | self.log_eval_metrics(eval_dict) 129 | self.eval_epochs.append(epoch) 130 | else: 131 | eval_dict = None 132 | 133 | # Log 134 | self.save_metrics() 135 | self.log_fn(epoch, train_dict, eval_dict) 136 | 137 | # Checkpoint 138 | self.current_epoch += 1 139 | if (epoch+1) % self.check_every == 0: 140 | self.checkpoint_save() 141 | -------------------------------------------------------------------------------- /charlm/experiment/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from prettytable import PrettyTable 3 | 4 | 5 | def get_args_table(args_dict): 6 | table = PrettyTable(['Arg', 'Value']) 7 | for arg, val in args_dict.items(): 8 | table.add_row([arg, val]) 9 | return table 10 | 11 | 12 | def get_metric_table(metric_dict, epochs): 13 | max_len = len(epochs) 14 | for _, vs in metric_dict.items(): 15 | max_len = min(max_len, len(vs)) 16 | table = PrettyTable() 17 | table.add_column('Epoch', epochs[:max_len]) 18 | if len(metric_dict)>0: 19 | for metric_name, metric_values in metric_dict.items(): 20 | table.add_column(metric_name, metric_values[:max_len]) 21 | return table 22 | 23 | 24 | def clean_dict(d, keys): 25 | d2 = copy.deepcopy(d) 26 | for key in keys: 27 | if key in d2: 28 | del d2[key] 29 | return d2 30 | -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/semi-discrete-flow/005b4e87ba44d86cc5ebe6cd843b6e168ac51320/charlm/model/CategoricalNF/__init__.py -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/categorical_encoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/semi-discrete-flow/005b4e87ba44d86cc5ebe6cd843b6e168ac51320/charlm/model/CategoricalNF/categorical_encoding/__init__.py -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/categorical_encoding/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | sys.path.append("../") 7 | 8 | from ..general.mutils import get_param_val 9 | from ..networks.help_layers import LinearNet 10 | 11 | 12 | def create_embed_layer(vocab, vocab_size, default_embed_layer_dims): 13 | ## Creating an embedding layer either from a torchtext vocabulary or from scratch 14 | use_vocab_vectors = (vocab is not None and vocab.vectors is not None) 15 | embed_layer_dims = vocab.vectors.shape[1] if use_vocab_vectors else default_embed_layer_dims 16 | vocab_size = len(vocab) if use_vocab_vectors else vocab_size 17 | embed_layer = nn.Embedding(vocab_size, embed_layer_dims) 18 | if use_vocab_vectors: 19 | embed_layer.weight.data.copy_(vocab.vectors) 20 | embed_layer.weight.requires_grad = True 21 | return embed_layer, vocab_size 22 | 23 | 24 | def create_decoder(num_categories, num_dims, config, **kwargs): 25 | num_layers = get_param_val(config, "num_layers", 1) 26 | hidden_size = get_param_val(config, "hidden_size", 64) 27 | 28 | return DecoderLinear(num_categories, 29 | embed_dim=num_dims, 30 | hidden_size=hidden_size, 31 | num_layers=num_layers, 32 | **kwargs) 33 | 34 | 35 | class DecoderLinear(nn.Module): 36 | """ 37 | A simple linear decoder with flexible number of layers. 38 | """ 39 | 40 | def __init__(self, num_categories, embed_dim, hidden_size, num_layers, class_prior_log=None): 41 | super().__init__() 42 | self.hidden_size = hidden_size 43 | self.num_layers = num_layers 44 | 45 | self.layers = LinearNet(c_in=3*embed_dim, 46 | c_out=num_categories, 47 | hidden_size=hidden_size, 48 | num_layers=num_layers) 49 | self.log_softmax = nn.LogSoftmax(dim=-1) 50 | 51 | if class_prior_log is not None: 52 | if isinstance(class_prior_log, np.ndarray): 53 | class_prior_log = torch.from_numpy(class_prior_log) 54 | self.layers.set_bias(class_prior_log) 55 | 56 | def forward(self, z_cont): 57 | z_cont = torch.cat([z_cont, F.elu(z_cont), F.elu(-z_cont)], dim=-1) 58 | out = self.layers(z_cont) 59 | logits = self.log_softmax(out) 60 | return logits 61 | 62 | def info(self): 63 | return "Linear model with hidden size %i and %i layers" % (self.hidden_size, self.num_layers) -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/categorical_encoding/mutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append("../") 6 | 7 | from ..general.mutils import get_param_val 8 | from .variational_dequantization import VariationalDequantization 9 | from .linear_encoding import LinearCategoricalEncoding 10 | from .variational_encoding import VariationalCategoricalEncoding 11 | 12 | 13 | 14 | def add_encoding_parameters(parser, postfix=""): 15 | # General parameters 16 | parser.add_argument("--encoding_dim" + postfix, help="Dimensionality of the embeddings.", type=int, default=4) 17 | parser.add_argument("--encoding_dequantization" + postfix, help="If selected, variational dequantization is used for encoding categorical data.", action="store_true") 18 | parser.add_argument("--encoding_variational" + postfix, help="If selected, the encoder distribution is joint over categorical variables.", action="store_true") 19 | 20 | # Flow parameters 21 | parser.add_argument("--encoding_num_flows" + postfix, help="Number of flows used in the embedding layer.", type=int, default=0) 22 | parser.add_argument("--encoding_hidden_layers" + postfix, help="Number of hidden layers of flows used in the parallel embedding layer.", type=int, default=2) 23 | parser.add_argument("--encoding_hidden_size" + postfix, help="Hidden size of flows used in the parallel embedding layer.", type=int, default=128) 24 | parser.add_argument("--encoding_num_mixtures" + postfix, help="Number of mixtures used in the coupling layers (if applicable).", type=int, default=8) 25 | 26 | # Decoder parameters 27 | parser.add_argument("--encoding_use_decoder" + postfix, help="If selected, we use a decoder instead of calculating the likelihood by inverting all flows.", action="store_true") 28 | parser.add_argument("--encoding_dec_num_layers" + postfix, help="Number of hidden layers used in the decoder of the parallel embedding layer.", type=int, default=1) 29 | parser.add_argument("--encoding_dec_hidden_size" + postfix, help="Hidden size used in the decoder of the parallel embedding layer.", type=int, default=64) 30 | 31 | 32 | def encoding_args_to_params(args, postfix=""): 33 | params = { 34 | "use_dequantization": getattr(args, "encoding_dequantization" + postfix), 35 | "use_variational": getattr(args, "encoding_variational" + postfix), 36 | "use_decoder": getattr(args, "encoding_use_decoder" + postfix), 37 | "num_dimensions": getattr(args, "encoding_dim" + postfix), 38 | "flow_config": { 39 | "num_flows": getattr(args, "encoding_num_flows" + postfix), 40 | "hidden_layers": getattr(args, "encoding_hidden_layers" + postfix), 41 | "hidden_size": getattr(args, "encoding_hidden_size" + postfix) 42 | }, 43 | "decoder_config": { 44 | "num_layers": getattr(args, "encoding_dec_num_layers" + postfix), 45 | "hidden_size": getattr(args, "encoding_dec_hidden_size" + postfix) 46 | } 47 | } 48 | return params 49 | 50 | 51 | def create_encoding(encoding_params, dataset_class, vocab=None, vocab_size=-1, category_prior=None): 52 | assert not (vocab is None and vocab_size <= 0), "[!] ERROR: When creating the encoding, either a torchtext vocabulary or the vocabulary size needs to be passed." 53 | use_dequantization = encoding_params.pop("use_dequantization") 54 | use_variational = encoding_params.pop("use_variational") 55 | 56 | 57 | if use_dequantization and "model_func" not in encoding_params["flow_config"]: 58 | print("[#] WARNING: For using variational dequantization as encoding scheme, a model function needs to be specified" + \ 59 | " in the encoding parameters, key \"flow_config\" which was missing here. Will deactivate dequantization...") 60 | use_dequantization = False 61 | 62 | if use_dequantization: 63 | encoding_flow = VariationalDequantization 64 | elif use_variational: 65 | encoding_flow = VariationalCategoricalEncoding 66 | else: 67 | encoding_flow = LinearCategoricalEncoding 68 | 69 | print('HALLLO WAT GAAN WE GEBRUIKEN', encoding_flow) 70 | 71 | return encoding_flow(dataset_class=dataset_class, 72 | vocab=vocab, 73 | vocab_size=vocab_size, 74 | category_prior=category_prior, 75 | **encoding_params) -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/categorical_encoding/variational_dequantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.append("../") 5 | 6 | from ..general.mutils import get_param_val 7 | from ..flows.flow_layer import FlowLayer 8 | from ..flows.sigmoid_layer import SigmoidFlow 9 | from ..flows.activation_normalization import ActNormFlow 10 | from ..flows.coupling_layer import CouplingLayer 11 | from ..categorical_encoding.decoder import create_embed_layer 12 | 13 | 14 | class VariationalDequantization(FlowLayer): 15 | """ 16 | Flow layer to encode discrete variables using variational dequantization. 17 | """ 18 | 19 | 20 | def __init__(self, flow_config, 21 | vocab=None, vocab_size=-1, 22 | default_embed_layer_dims=128, 23 | **kwargs): 24 | super().__init__() 25 | self.embed_layer, self.vocab_size = create_embed_layer(vocab, vocab_size, default_embed_layer_dims) 26 | self.flow_layers = _create_flows(flow_config, self.embed_layer.weight.shape[1]) 27 | self.sigmoid_flow = SigmoidFlow(reverse=True) 28 | 29 | 30 | def forward(self, z, ldj=None, reverse=False, **kwargs): 31 | batch_size, seq_length = z.size(0), z.size(1) 32 | 33 | if ldj is None: 34 | ldj = z.new_zeros(z.size(0), dtype=torch.float32) 35 | 36 | if not reverse: 37 | # Sample from noise distribution, modeled by the normalizing flow 38 | rand_inp = torch.rand_like(z, dtype=torch.float32).unsqueeze(dim=-1) # Output range [0,1] 39 | rand_inp, ldj = self.sigmoid_flow(rand_inp, ldj=ldj, reverse=False) # Output range [-inf,inf] 40 | rand_inp, ldj = self._flow_forward(rand_inp, z, ldj, **kwargs) # Output range [-inf,inf] 41 | rand_inp, ldj = self.sigmoid_flow(rand_inp, ldj=ldj, reverse=True) # Output range [0,1] 42 | # Checking that noise is indeed in the range [0,1]. Any value outside indicates a numerical issue in the dequantization flow 43 | assert (rand_inp<0.0).sum() == 0 and (rand_inp>1.0).sum() == 0, "ERROR: Variational Dequantization output is out of bounds.\n" + \ 44 | str(torch.where(rand_inp<0.0)) + "\n" + \ 45 | str(torch.where(rand_inp>1.0)) 46 | # Adding the noise to the discrete values 47 | z_out = z.to(torch.float32).unsqueeze(dim=-1) + rand_inp 48 | assert torch.isnan(z_out).sum() == 0, "ERROR: Found NaN values in variational dequantization.\n" + \ 49 | "NaN z_out: " + str(torch.isnan(z_out).sum().item()) + "\n" + \ 50 | "NaN rand_inp: " + str(torch.isnan(rand_inp).sum().item()) + "\n" + \ 51 | "NaN ldj: " + str(torch.isnan(ldj).sum().item()) 52 | else: 53 | # Inverting the flow is done by finding the next whole integer for each continuous value 54 | z_out = torch.floor(z).clamp(min=0, max=self.vocab_size-1) 55 | z_out = z_out.long().squeeze(dim=-1) 56 | 57 | return z_out, ldj 58 | 59 | 60 | def _flow_forward(self, rand_inp, z, ldj, **kwargs): 61 | # Adding discrete values to flow transformation input by an embedding layer 62 | embed_features = self.embed_layer(z) 63 | for flow in self.flow_layers: 64 | rand_inp, ldj = flow(rand_inp, ldj, ext_input=embed_features, reverse=False, **kwargs) 65 | return rand_inp, ldj 66 | 67 | 68 | def info(self): 69 | s = "Variational Dequantization with %i flows.\n" % (len(self.flow_layers)) 70 | s += "\n".join(["-> [%i] " % (flow_index+1) + flow.info() for flow_index, flow in enumerate(self.flow_layers)]) 71 | return s 72 | 73 | 74 | def _create_flows(config, embed_dims): 75 | num_flows = get_param_val(config, "num_flows", 4) 76 | model_func = get_param_val(config, "model_func", allow_default=False) 77 | block_type = get_param_val(config, "block_type", None) 78 | 79 | def _create_block(flow_index): 80 | # For variational dequantization we apply a combination of activation normalization and coupling layers. 81 | # Invertible convolutions are not useful here as our dimensionality is 1 anyways 82 | mask = CouplingLayer.create_chess_mask() 83 | if flow_index % 2 == 0: 84 | mask = 1 - mask 85 | return [ 86 | ActNormFlow(c_in=1, data_init=False), 87 | CouplingLayer(c_in=1, 88 | mask=mask, 89 | model_func=model_func, 90 | block_type=block_type) 91 | ] 92 | 93 | flow_layers = [] 94 | for flow_index in range(num_flows): 95 | flow_layers += _create_block(flow_index) 96 | 97 | return nn.ModuleList(flow_layers) 98 | 99 | 100 | if __name__ == '__main__': 101 | ## Example code for using variational dequantization 102 | torch.manual_seed(42) 103 | batch_size, seq_len = 3, 6 104 | vocab_size = 4 105 | hidden_size, embed_layer_dims = 128, 128 106 | 107 | class ExampleNetwork(nn.Module): 108 | def __init__(self, c_out): 109 | super().__init__() 110 | self.inp_layer = nn.Linear(1, hidden_size) 111 | self.main_net = nn.Sequential( 112 | nn.Linear(hidden_size + embed_layer_dims, hidden_size), 113 | nn.ReLU(), 114 | nn.Linear(hidden_size, c_out) 115 | ) 116 | 117 | def forward(self, x, ext_input, **kwargs): 118 | inp = self.inp_layer(x) 119 | out = self.main_net(torch.cat([inp, ext_input], dim=-1)) 120 | return out 121 | 122 | model_func = lambda c_out : ExampleNetwork(c_out) 123 | 124 | flow_config = { 125 | "num_flows": 2, 126 | "model_func": model_func, 127 | "block_type": "Linear" 128 | } 129 | vardeq_flow = VariationalDequantization(vocab_size=vocab_size, 130 | flow_config=flow_config, 131 | embed_layer_dims=embed_layer_dims) 132 | 133 | z = torch.randint(high=vocab_size, size=(batch_size, seq_len), dtype=torch.long) 134 | z_cont, _ = vardeq_flow(z, reverse=False) 135 | z_rec, _ = vardeq_flow(z_cont, reverse=True) 136 | 137 | print("-"*90) 138 | print(vardeq_flow.info()) 139 | print("-"*90) 140 | print("Z\n", z) 141 | print("Z reconstructed\n", z_rec) 142 | print("Z continuous\n", z_cont) 143 | assert (z_rec == z).all() 144 | -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/autoregressive_coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import sys 6 | import math 7 | # sys.path.append("../../") 8 | from survae.transforms import Bijection 9 | 10 | from ..general.mutils import get_device, create_channel_mask 11 | # from ..flows.flow_layer import FlowLayer 12 | from ..flows.mixture_cdf_layer import MixtureCDFCoupling 13 | 14 | 15 | class AutoregressiveMixtureCDFCoupling(Bijection): 16 | 17 | def __init__(self, c_in, model_func, block_type=None, num_mixtures=10): 18 | super().__init__() 19 | self.c_in = c_in 20 | self.num_mixtures = num_mixtures 21 | self.block_type = block_type 22 | self.scaling_factor = nn.Parameter(torch.zeros(self.c_in)) 23 | self.mixture_scaling_factor = nn.Parameter(torch.zeros(self.c_in, self.num_mixtures)) 24 | self.nn = model_func(c_out=c_in*(2 + 3 * self.num_mixtures)) 25 | 26 | def forward(self, z, reverse=False, **kwargs): 27 | ldj = z.new_zeros(z.size(0),) 28 | 29 | kwargs['length'] = z.new_ones(size=(z.size(0),)).int() * z.size(2) 30 | 31 | z = z.permute(0, 2, 1) 32 | 33 | if not reverse: 34 | nn_out = self.nn(x=z, **kwargs) 35 | 36 | t, log_s, log_pi, mixt_t, mixt_log_s = MixtureCDFCoupling.get_mixt_params(nn_out, mask=None, 37 | num_mixtures=self.num_mixtures, 38 | scaling_factor=self.scaling_factor, 39 | mixture_scaling_factor=self.mixture_scaling_factor) 40 | 41 | z = z.double() 42 | z_out, ldj_mixt = MixtureCDFCoupling.run_with_params(z, t, log_s, log_pi, mixt_t, mixt_log_s, reverse=reverse) 43 | else: 44 | with torch.no_grad(): 45 | B, L, D = z.shape 46 | param_num = 2 + self.num_mixtures * 3 47 | x = torch.zeros_like(z) 48 | for l in range(L): 49 | print(f'{l+1}/{L}', end='\r') 50 | for d in range(D): 51 | nn_out = self.nn(x=x, **kwargs) 52 | nn_out = nn_out.reshape(nn_out.shape[:-1] + (nn_out.shape[-1]//param_num, param_num))[:,l,d,:] 53 | t, log_s, log_pi, mixt_t, mixt_log_s = MixtureCDFCoupling.get_mixt_params(nn_out, mask=None, 54 | num_mixtures=self.num_mixtures, 55 | scaling_factor=self.scaling_factor, 56 | mixture_scaling_factor=self.mixture_scaling_factor) 57 | z = z.double() 58 | z_out = z[:,l,:] 59 | z_out, _ = MixtureCDFCoupling.run_with_params(z_out, t, log_s, log_pi, mixt_t, mixt_log_s, reverse=reverse) 60 | x[:,l,d] = z_out[:,d].float() 61 | return x.permute(0, 2, 1) 62 | 63 | ldj = ldj + ldj_mixt.float() 64 | z_out = z_out.float() 65 | if "channel_padding_mask" in kwargs and kwargs["channel_padding_mask"] is not None: 66 | z_out = z_out * kwargs["channel_padding_mask"] 67 | 68 | z_out = z_out.permute(0, 2, 1) 69 | 70 | return z_out, ldj 71 | 72 | def inverse(self, z): 73 | return self.forward(z, reverse=True) 74 | 75 | def info(self): 76 | s = "Autoregressive Mixture CDF Coupling Layer - Input size %i" % (self.c_in) 77 | if self.block_type is not None: 78 | s += ", block type %s" % (self.block_type) 79 | return s 80 | 81 | 82 | if __name__ == "__main__": 83 | torch.manual_seed(42) 84 | np.random.seed(42) 85 | 86 | batch_size, seq_len, c_in = 1, 3, 3 87 | hidden_size = 8 88 | _inp = torch.randn(batch_size, seq_len, c_in) 89 | lengths = torch.LongTensor([seq_len]*batch_size) 90 | channel_padding_mask = create_channel_mask(length=lengths, max_len=seq_len) 91 | time_embed = nn.Linear(2*seq_len, 2) 92 | 93 | module = AutoregressiveMixtureCDFCoupling1D(c_in=c_in, hidden_size=hidden_size, num_mixtures=4, 94 | time_embed=time_embed, autoreg_hidden=True) 95 | 96 | orig_out, _ = module(z=_inp, length=lengths, channel_padding_mask=channel_padding_mask) 97 | print("Out", orig_out) 98 | 99 | _inp[0,1,1] = 10 100 | alt_out, _ = module(z=_inp, length=lengths, channel_padding_mask=channel_padding_mask) 101 | print("Out diff", (orig_out - alt_out).abs()) 102 | -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/autoregressive_coupling2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import sys 6 | import math 7 | # sys.path.append("../../") 8 | from survae.transforms import Bijection 9 | 10 | from ..general.mutils import get_device, create_channel_mask 11 | # from ..flows.flow_layer import FlowLayer 12 | from ..flows.mixture_cdf_layer import MixtureCDFCoupling 13 | 14 | 15 | class CouplingMixtureCDFCoupling(Bijection): 16 | 17 | def __init__(self, c_in, c_out, model_func, block_type=None, num_mixtures=10): 18 | super().__init__() 19 | self.c_in = c_in 20 | self.c_out = c_out 21 | self.num_mixtures = num_mixtures 22 | self.block_type = block_type 23 | self.scaling_factor = nn.Parameter(torch.zeros(self.c_out)) 24 | self.mixture_scaling_factor = nn.Parameter(torch.zeros(self.c_out, self.num_mixtures)) 25 | self.nn = model_func(c_in=c_in, c_out=c_out*(2 + 3 * self.num_mixtures)) 26 | 27 | def forward(self, z, reverse=False, **kwargs): 28 | ldj = z.new_zeros(z.size(0),) 29 | 30 | kwargs['length'] = z.new_ones(size=(z.size(0),)).int() * z.size(2) 31 | 32 | z = z.permute(0, 2, 1) 33 | 34 | id, z = torch.split(z, (self.c_in, self.c_out), dim=-1) 35 | 36 | if not reverse: 37 | nn_out = self.nn(x=id, **kwargs) 38 | 39 | t, log_s, log_pi, mixt_t, mixt_log_s = MixtureCDFCoupling.get_mixt_params(nn_out, mask=None, 40 | num_mixtures=self.num_mixtures, 41 | scaling_factor=self.scaling_factor, 42 | mixture_scaling_factor=self.mixture_scaling_factor) 43 | 44 | z = z.double() 45 | z_out, ldj_mixt = MixtureCDFCoupling.run_with_params(z, t, log_s, log_pi, mixt_t, mixt_log_s, reverse=reverse) 46 | else: 47 | with torch.no_grad(): 48 | nn_out = self.nn(x=id, **kwargs) 49 | t, log_s, log_pi, mixt_t, mixt_log_s = MixtureCDFCoupling.get_mixt_params(nn_out, mask=None, 50 | num_mixtures=self.num_mixtures, 51 | scaling_factor=self.scaling_factor, 52 | mixture_scaling_factor=self.mixture_scaling_factor) 53 | z = z.double() 54 | z_out, ldj_mixt = MixtureCDFCoupling.run_with_params(z, t, log_s, log_pi, mixt_t, mixt_log_s, reverse=reverse) 55 | 56 | ldj = ldj + ldj_mixt.float() 57 | z_out = z_out.float() 58 | if "channel_padding_mask" in kwargs and kwargs["channel_padding_mask"] is not None: 59 | z_out = z_out * kwargs["channel_padding_mask"] 60 | 61 | z_out = torch.cat([id, z_out], dim=-1) 62 | z_out = z_out.permute(0, 2, 1) 63 | 64 | if reverse: 65 | return z_out 66 | 67 | return z_out, ldj 68 | 69 | def inverse(self, z): 70 | return self.forward(z, reverse=True) 71 | 72 | def info(self): 73 | s = "Autoregressive Mixture CDF Coupling Layer - Input size %i" % (self.c_in) 74 | if self.block_type is not None: 75 | s += ", block type %s" % (self.block_type) 76 | return s 77 | 78 | 79 | if __name__ == "__main__": 80 | torch.manual_seed(42) 81 | np.random.seed(42) 82 | 83 | batch_size, seq_len, c_in = 1, 3, 3 84 | hidden_size = 8 85 | _inp = torch.randn(batch_size, seq_len, c_in) 86 | lengths = torch.LongTensor([seq_len]*batch_size) 87 | channel_padding_mask = create_channel_mask(length=lengths, max_len=seq_len) 88 | time_embed = nn.Linear(2*seq_len, 2) 89 | 90 | module = AutoregressiveMixtureCDFCoupling1D(c_in=c_in, hidden_size=hidden_size, num_mixtures=4, 91 | time_embed=time_embed, autoreg_hidden=True) 92 | 93 | orig_out, _ = module(z=_inp, length=lengths, channel_padding_mask=channel_padding_mask) 94 | print("Out", orig_out) 95 | 96 | _inp[0,1,1] = 10 97 | alt_out, _ = module(z=_inp, length=lengths, channel_padding_mask=channel_padding_mask) 98 | print("Out diff", (orig_out - alt_out).abs()) 99 | -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/coupling_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import sys 5 | sys.path.append("../../") 6 | from ..flows.flow_layer import FlowLayer 7 | from ..networks.help_layers import run_sequential_with_mask 8 | 9 | 10 | class CouplingLayer(FlowLayer): 11 | 12 | def __init__(self, c_in, mask, 13 | model_func, 14 | block_type=None, 15 | c_out=-1, 16 | **kwargs): 17 | super().__init__() 18 | self.c_in = c_in 19 | self.c_out = c_out if c_out > 0 else 2 * c_in 20 | self.register_buffer('mask', mask) 21 | self.block_type = block_type 22 | 23 | # Scaling factor 24 | self.scaling_factor = nn.Parameter(torch.zeros(c_in)) 25 | self.nn = model_func(c_out=self.c_out) 26 | 27 | 28 | def run_network(self, x, length=None, **kwargs): 29 | if isinstance(self.nn, nn.Sequential): 30 | nn_out = run_sequential_with_mask(self.nn, x, 31 | length=length, 32 | **kwargs) 33 | else: 34 | nn_out = self.nn(x, length=length, 35 | **kwargs) 36 | 37 | if "channel_padding_mask" in kwargs and kwargs["channel_padding_mask"] is not None: 38 | nn_out = nn_out * kwargs["channel_padding_mask"] 39 | return nn_out 40 | 41 | 42 | def forward(self, z, ldj=None, reverse=False, channel_padding_mask=None, **kwargs): 43 | if ldj is None: 44 | ldj = z.new_zeros(z.size(0),) 45 | if channel_padding_mask is None: 46 | channel_padding_mask = torch.ones_like(z) 47 | 48 | mask = self._prepare_mask(self.mask, z) 49 | z_in = z * mask 50 | 51 | nn_out = self.run_network(x=z_in, **kwargs) 52 | 53 | nn_out = nn_out.view(nn_out.shape[:-1] + (nn_out.shape[-1]//2, 2)) 54 | s, t = nn_out[...,0], nn_out[...,1] 55 | 56 | scaling_fac = self.scaling_factor.exp().view([1, 1, s.size(-1)]) 57 | s = torch.tanh(s / scaling_fac.clamp(min=1.0)) * scaling_fac 58 | 59 | s = s * (1 - mask) 60 | t = t * (1 - mask) 61 | 62 | z, layer_ldj = CouplingLayer.run_with_params(z, s, t, reverse=reverse) 63 | ldj = ldj + layer_ldj 64 | 65 | return z, ldj # , detail_dict 66 | 67 | def _prepare_mask(self, mask, z): 68 | # Mask input so that we only use the un-masked regions as input 69 | mask = self.mask.unsqueeze(dim=0) if len(z.shape) > len(self.mask.shape) else self.mask 70 | if mask.size(1) < z.size(1) and mask.size(1) > 1: 71 | mask = mask.repeat(1, int(math.ceil(z.size(1)/mask.size(1))), 1).contiguous() 72 | if mask.size(1) > z.size(1): 73 | mask = mask[:,:z.size(1)] 74 | return mask 75 | 76 | @staticmethod 77 | def get_coup_params(nn_out, mask, scaling_factor=None): 78 | nn_out = nn_out.view(nn_out.shape[:-1] + (nn_out.shape[-1]//2, 2)) 79 | s, t = nn_out[...,0], nn_out[...,1] 80 | if scaling_factor is not None: 81 | scaling_fac = scaling_factor.exp().view([1, 1, s.size(-1)]) 82 | s = torch.tanh(s / scaling_fac.clamp(min=1.0)) * scaling_fac 83 | 84 | s = s * (1 - mask) 85 | t = t * (1 - mask) 86 | return s, t 87 | 88 | @staticmethod 89 | def run_with_params(orig_z, s, t, reverse=False): 90 | if not reverse: 91 | scale = torch.exp(s) 92 | z_out = (orig_z + t) * scale 93 | ldj = s.sum(dim=[1,2]) 94 | else: 95 | inv_scale = torch.exp(-1 * s) 96 | z_out = orig_z * inv_scale - t 97 | ldj = -s.sum(dim=[1,2]) 98 | return z_out, ldj 99 | 100 | 101 | @staticmethod 102 | def create_channel_mask(c_in, ratio=0.5, mask_floor=True): 103 | """ 104 | Ratio: number of channels that are alternated/for which we predict parameters 105 | """ 106 | if mask_floor: 107 | c_masked = int(math.floor(c_in * ratio)) 108 | else: 109 | c_masked = int(math.ceil(c_in * ratio)) 110 | c_unmasked = c_in - c_masked 111 | mask = torch.cat([torch.ones(1, c_masked), torch.zeros(1, c_unmasked)], dim=1) 112 | return mask 113 | 114 | 115 | @staticmethod 116 | def create_chess_mask(seq_len=2): 117 | assert seq_len > 1 118 | seq_unmask = int(seq_len // 2) 119 | seq_mask = seq_len - seq_unmask 120 | mask = torch.cat([torch.ones(seq_mask, 1), torch.zeros(seq_unmask, 1)], dim=1).view(-1, 1) 121 | return mask 122 | 123 | 124 | def info(self): 125 | is_channel_mask = (self.mask.size(0) == 1) 126 | info_str = "Coupling Layer - Input size %i" % (self.c_in) 127 | if self.block_type is not None: 128 | info_str += ", block type %s" % (self.block_type) 129 | info_str += ", mask ratio %.2f, %s mask" % ((1-self.mask).mean().item(), "channel" if is_channel_mask else "chess") 130 | return info_str -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/discrete_coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append("../../") 6 | 7 | from general.mutils import one_hot 8 | from layers.flows.coupling_layer import CouplingLayer 9 | 10 | 11 | class DiscreteCouplingLayer(CouplingLayer): 12 | 13 | def __init__(self, c_in, mask, 14 | model_func, 15 | block_type=None, 16 | temp=0.1, 17 | **kwargs): 18 | super().__init__(c_in=c_in, 19 | mask=mask, 20 | model_func=model_func, 21 | block_type=block_type, 22 | c_out=c_in) 23 | self.temp = temp 24 | del self.scaling_factor 25 | 26 | 27 | def run_network(self, x, length=None, **kwargs): 28 | if isinstance(self.nn, nn.Sequential): 29 | nn_out = run_sequential_with_mask(self.nn, x, 30 | length=length, 31 | **kwargs) 32 | else: 33 | nn_out = self.nn(x, length=length, 34 | **kwargs) 35 | 36 | if "channel_padding_mask" in kwargs and kwargs["channel_padding_mask"] is not None: 37 | nn_out = nn_out * kwargs["channel_padding_mask"] 38 | return nn_out 39 | 40 | 41 | def forward(self, z, ldj=None, reverse=False, channel_padding_mask=None, **kwargs): 42 | if ldj is None: 43 | ldj = z.new_zeros(z.size(0),) 44 | if channel_padding_mask is None: 45 | channel_padding_mask = torch.ones_like(z) 46 | 47 | mask = self._prepare_mask(self.mask, z) 48 | z_in = z * mask 49 | nn_out = self.run_network(x=z_in, **kwargs) 50 | shift_one_hot = one_hot_argmax(logits=nn_out, temperature=self.temp) 51 | default_shift = torch.zeros_like(shift_one_hot) 52 | default_shift[...,0] = 1 # Does not shift output 53 | shift_one_hot = shift_one_hot * (1 - mask) + default_shift * mask 54 | 55 | z = one_hot_add(z, shift_one_hot, reverse=reverse) 56 | return z, ldj 57 | 58 | 59 | def info(self): 60 | info_str = "Discrete Coupling Layer - Input size %i" % (self.c_in) 61 | if self.block_type is not None: 62 | info_str += ", block type %s" % (self.block_type) 63 | info_str += ", temperature %.1f" % (self.temp) 64 | return info_str 65 | 66 | 67 | def one_hot_argmax(logits, temperature=1.0): 68 | probs = F.softmax(logits, dim=-1) 69 | one_hot_argmax = one_hot(probs.argmax(dim=-1), num_classes=probs.size(-1)) 70 | one_hot_approx = (one_hot_argmax - probs).detach() + probs 71 | return one_hot_approx 72 | 73 | 74 | def one_hot_add(inp_one_hot, shift_one_hot, reverse=False): 75 | num_categ = inp_one_hot.size(-1) 76 | roll_matrix = torch.stack([torch.roll(input=shift_one_hot, shifts=i, dims=-1) for i in range(num_categ)], dim=-2) 77 | if reverse: 78 | roll_matrix = torch.transpose(roll_matrix, dim0=-2, dim1=-1) 79 | inp_one_hot = inp_one_hot.unsqueeze(dim=-2) 80 | out_one_hot = torch.matmul(inp_one_hot, roll_matrix) 81 | out_one_hot = out_one_hot.squeeze(dim=-2) 82 | return out_one_hot -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/flow_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FlowLayer(nn.Module): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | 11 | def forward(self, z, ldj=None, reverse=False, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | def reverse(self, z, ldj=None, **kwargs): 16 | return self.forward(z, ldj, reverse=True, **kwargs) 17 | 18 | 19 | def need_data_init(self): 20 | # This function indicates whether a specific flow needs a data-dependent initialization 21 | # or not. For instance, activation normalization requires such a initialization 22 | return False 23 | 24 | 25 | def data_init_forward(self, input_data, **kwargs): 26 | # Only necessary if need_data_init is True. Contains processing of data initialization 27 | raise NotImplementedError 28 | 29 | 30 | def info(self): 31 | # Function to retrieve small summary/info string 32 | raise NotImplementedError -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/flow_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class FlowModel(nn.Module): 7 | 8 | 9 | def __init__(self, layers=None, name="Flow model"): 10 | super().__init__() 11 | 12 | self.flow_layers = nn.ModuleList() 13 | self.name = name 14 | 15 | if layers is not None: 16 | self.add_layers(layers) 17 | 18 | 19 | def add_layers(self, layers): 20 | for l in layers: 21 | self.flow_layers.append(l) 22 | self.print_overview() 23 | 24 | 25 | def forward(self, z, ldj=None, reverse=False, get_ldj_per_layer=False, **kwargs): 26 | if ldj is None: 27 | ldj = z.new_zeros(z.size(0), dtype=torch.float32) 28 | 29 | ldj_per_layer = [] 30 | for layer_index, layer in (enumerate(self.flow_layers) if not reverse else reversed(list(enumerate(self.flow_layers)))): 31 | 32 | layer_res = layer(z, reverse=reverse, get_ldj_per_layer=get_ldj_per_layer, **kwargs) 33 | 34 | if len(layer_res) == 2: 35 | z, layer_ldj = layer_res 36 | detailed_layer_ldj = layer_ldj 37 | elif len(layer_res) == 3: 38 | z, layer_ldj, detailed_layer_ldj = layer_res 39 | else: 40 | print("[!] ERROR: Got more return values than expected: %i" % (len(layer_res))) 41 | 42 | assert torch.isnan(z).sum() == 0, "[!] ERROR: Found NaN latent values. Layer (%i):\n%s" % (layer_index + 1, layer.info()) 43 | 44 | ldj = ldj + layer_ldj 45 | if isinstance(detailed_layer_ldj, list): 46 | ldj_per_layer += detailed_layer_ldj 47 | else: 48 | ldj_per_layer.append(detailed_layer_ldj) 49 | 50 | if get_ldj_per_layer: 51 | return z, ldj, ldj_per_layer 52 | else: 53 | return z, ldj 54 | 55 | 56 | def reverse(self, z): 57 | return self.forward(z, reverse) 58 | 59 | 60 | def test_reversibility(self, z, **kwargs): 61 | test_failed = False 62 | for layer_index, layer in enumerate(self.flow_layers): 63 | z_layer, ldj_layer = layer(z, reverse=False, **kwargs) 64 | z_reconst, ldj_reconst = layer(z_layer, reverse=True, **kwargs) 65 | 66 | ldj_diff = (ldj_layer + ldj_reconst).abs().sum() 67 | z_diff = (z_layer - z_reconst).abs().sum() 68 | 69 | if z_diff != 0 or ldj_diff != 0: 70 | print("-"*100) 71 | print("[!] WARNING: Reversibility check failed for layer index %i" % layer_index) 72 | print(layer.info()) 73 | print("-"*100) 74 | test_failed = True 75 | 76 | print("+"*100) 77 | print("Reversibility test %s (tested %i layers)" % ("failed" if test_failed else "succeeded", len(self.flow_layers))) 78 | print("+"*100) 79 | 80 | 81 | def get_inner_activations(self, z, reverse=False, return_names=False, **kwargs): 82 | out_per_layer = [z.detach()] 83 | layer_names = [] 84 | for layer_index, layer in enumerate((self.flow_layers if not reverse else reversed(self.flow_layers))): 85 | 86 | z = layer(z, reverse=reverse, **kwargs)[0] 87 | out_per_layer.append(z.detach()) 88 | layer_names.append(layer.__class__.__name__) 89 | 90 | if not return_names: 91 | return out_per_layer 92 | else: 93 | return out_per_layer, return_names 94 | 95 | 96 | def initialize_data_dependent(self, batch_list): 97 | # Batch list needs to consist of tuples: (z, kwargs) 98 | with torch.no_grad(): 99 | for layer_index, layer in enumerate(self.flow_layers): 100 | print("Processing layer %i..." % (layer_index+1), end="\r") 101 | batch_list = FlowModel.run_data_init_layer(batch_list, layer) 102 | 103 | 104 | @staticmethod 105 | def run_data_init_layer(batch_list, layer): 106 | if layer.need_data_init(): 107 | stacked_kwargs = {key: [b[1][key] for b in batch_list] for key in batch_list[0][1].keys()} 108 | for key in stacked_kwargs.keys(): 109 | if isinstance(stacked_kwargs[key][0], torch.Tensor): 110 | stacked_kwargs[key] = torch.cat(stacked_kwargs[key], dim=0) 111 | else: 112 | stacked_kwargs[key] = stacked_kwargs[key][0] 113 | if not (isinstance(batch_list[0][0], tuple) or isinstance(batch_list[0][0], list)): 114 | input_data = torch.cat([z for z, _ in batch_list], dim=0) 115 | layer.data_init_forward(input_data, **stacked_kwargs) 116 | else: 117 | input_data = [torch.cat([z[i] for z, _ in batch_list], dim=0) for i in range(len(batch_list[0][0]))] 118 | layer.data_init_forward(*input_data, **stacked_kwargs) 119 | out_list = [] 120 | for z, kwargs in batch_list: 121 | if isinstance(z, tuple) or isinstance(z, list): 122 | z = layer(*z, reverse=False, **kwargs) 123 | out_list.append([e.detach() for e in z[:-1] if isinstance(e, torch.Tensor)]) 124 | if len(z) == 4 and isinstance(z[-1], dict): 125 | kwargs.update(z[-1]) 126 | out_list[-1] = out_list[-1][:-1] 127 | else: 128 | z = layer(z, reverse=False, **kwargs)[0] 129 | out_list.append(z.detach()) 130 | batch_list = [(out_list[i], batch_list[i][1]) for i in range(len(batch_list))] 131 | return batch_list 132 | 133 | 134 | def need_data_init(self): 135 | return any([flow.need_data_init() for flow in self.flow_layers]) 136 | 137 | 138 | def print_overview(self): 139 | # Retrieve layer descriptions for all flows 140 | layer_descp = list() 141 | for layer_index, layer in enumerate(self.flow_layers): 142 | layer_descp.append("(%2i) %s" % (layer_index+1, layer.info())) 143 | num_tokens = max([20] + [len(s) for s in "\n".join(layer_descp).split("\n")]) 144 | # Print out info in a nicer format 145 | print("="*num_tokens) 146 | print("%s with %i flows" % (self.name, len(self.flow_layers))) 147 | print("-"*num_tokens) 148 | print("\n".join(layer_descp)) 149 | print("="*num_tokens) -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/flows/sigmoid_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import sys 6 | sys.path.append("../../") 7 | 8 | from layers.flows.flow_layer import FlowLayer 9 | 10 | 11 | 12 | class SigmoidFlow(FlowLayer): 13 | """ 14 | Applies a sigmoid on an output 15 | """ 16 | 17 | 18 | def __init__(self, reverse=False): 19 | super().__init__() 20 | self.sigmoid = nn.Sigmoid() 21 | self.reverse_layer = reverse 22 | 23 | 24 | def forward(self, z, ldj=None, reverse=False, sum_ldj=True, **kwargs): 25 | if ldj is None: 26 | ldj = z.new_zeros(z.size(0),) 27 | 28 | alpha = 1e-5 29 | reverse = (self.reverse_layer != reverse) # XOR over reverse parameters 30 | 31 | if not reverse: 32 | layer_ldj = -z - 2 * F.softplus(-z) 33 | z = torch.sigmoid(z) 34 | else: 35 | z = z*(1-alpha) + alpha*0.5 # Remove boundaries of 0 and 1 (which would result in minus infinity and inifinity) 36 | layer_ldj = (-torch.log(z) - torch.log(1-z) + math.log(1 - alpha)) 37 | z = torch.log(z) - torch.log(1-z) 38 | 39 | assert torch.isnan(z).sum() == 0, "[!] ERROR: z contains NaN values." 40 | assert torch.isnan(layer_ldj).sum() == 0, "[!] ERROR: ldj contains NaN values." 41 | 42 | if sum_ldj: 43 | ldj = ldj + layer_ldj.view(z.size(0), -1).sum(dim=1) 44 | else: 45 | ldj = layer_ldj 46 | 47 | return z, ldj 48 | 49 | 50 | def info(self): 51 | return "Sigmoid Flow" -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/general/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/semi-discrete-flow/005b4e87ba44d86cc5ebe6cd843b6e168ac51320/charlm/model/CategoricalNF/general/README.md -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/general/radam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation of RAdam is taken from the official repository: 3 | https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py 4 | """ 5 | 6 | 7 | import math 8 | import torch 9 | from torch.optim.optimizer import Optimizer, required 10 | 11 | class RAdam(Optimizer): 12 | 13 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | self.buffer = [[None, None, None] for ind in range(10)] 16 | super(RAdam, self).__init__(params, defaults) 17 | 18 | def __setstate__(self, state): 19 | super(RAdam, self).__setstate__(state) 20 | 21 | def step(self, closure=None): 22 | 23 | loss = None 24 | if closure is not None: 25 | loss = closure() 26 | 27 | for group in self.param_groups: 28 | 29 | for p in group['params']: 30 | if p.grad is None: 31 | continue 32 | grad = p.grad.data.float() 33 | if grad.is_sparse: 34 | raise RuntimeError('RAdam does not support sparse gradients') 35 | 36 | p_data_fp32 = p.data.float() 37 | 38 | state = self.state[p] 39 | 40 | if len(state) == 0: 41 | state['step'] = 0 42 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 43 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 44 | else: 45 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 46 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 47 | 48 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 49 | beta1, beta2 = group['betas'] 50 | 51 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 52 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 53 | 54 | state['step'] += 1 55 | buffered = self.buffer[int(state['step'] % 10)] 56 | if state['step'] == buffered[0]: 57 | N_sma, step_size = buffered[1], buffered[2] 58 | else: 59 | buffered[0] = state['step'] 60 | beta2_t = beta2 ** state['step'] 61 | N_sma_max = 2 / (1 - beta2) - 1 62 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 63 | buffered[1] = N_sma 64 | 65 | # more conservative since it's an approximated value 66 | if N_sma >= 5: 67 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 68 | else: 69 | step_size = 1.0 / (1 - beta1 ** state['step']) 70 | buffered[2] = step_size 71 | 72 | if group['weight_decay'] != 0: 73 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 74 | 75 | # more conservative since it's an approximated value 76 | if N_sma >= 5: 77 | denom = exp_avg_sq.sqrt().add_(group['eps']) 78 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 79 | else: 80 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 81 | 82 | p.data.copy_(p_data_fp32) 83 | 84 | return loss 85 | 86 | 87 | class AdamW(Optimizer): 88 | 89 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 90 | defaults = dict(lr=lr, betas=betas, eps=eps, 91 | weight_decay=weight_decay, warmup = warmup) 92 | super(AdamW, self).__init__(params, defaults) 93 | 94 | def __setstate__(self, state): 95 | super(AdamW, self).__setstate__(state) 96 | 97 | def step(self, closure=None): 98 | loss = None 99 | if closure is not None: 100 | loss = closure() 101 | 102 | for group in self.param_groups: 103 | 104 | for p in group['params']: 105 | if p.grad is None: 106 | continue 107 | grad = p.grad.data.float() 108 | if grad.is_sparse: 109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 110 | 111 | p_data_fp32 = p.data.float() 112 | 113 | state = self.state[p] 114 | 115 | if len(state) == 0: 116 | state['step'] = 0 117 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 118 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 119 | else: 120 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 121 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 122 | 123 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 124 | beta1, beta2 = group['betas'] 125 | 126 | state['step'] += 1 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | denom = exp_avg_sq.sqrt().add_(group['eps']) 132 | bias_correction1 = 1 - beta1 ** state['step'] 133 | bias_correction2 = 1 - beta2 ** state['step'] 134 | 135 | if group['warmup'] > state['step']: 136 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 137 | else: 138 | scheduled_lr = group['lr'] 139 | 140 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 141 | 142 | if group['weight_decay'] != 0: 143 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 144 | 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | 147 | p.data.copy_(p_data_fp32) 148 | 149 | return loss -------------------------------------------------------------------------------- /charlm/model/CategoricalNF/networks/help_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | sys.path.append("../../") 7 | 8 | 9 | 10 | class PositionalEmbedding(nn.Module): 11 | 12 | def __init__(self, d_model, max_seq_len): 13 | super(PositionalEmbedding, self).__init__() 14 | 15 | pos_embed_dim = int(d_model//2) 16 | self.pos_emb = nn.Parameter(torch.randn(1, max_seq_len, pos_embed_dim), requires_grad=True) 17 | self.comb_layer = nn.Sequential( 18 | nn.Linear(d_model + pos_embed_dim, d_model), 19 | nn.ELU(), 20 | nn.LayerNorm(d_model) 21 | ) 22 | 23 | def forward(self, x): 24 | x = torch.cat([x, self.pos_emb.expand(x.size(0),-1,-1)], dim=-1) 25 | out = self.comb_layer(x) 26 | return out 27 | 28 | class IndependentLinear(nn.Module): 29 | 30 | def __init__(self, D, hidden_dim, c_out): 31 | super().__init__() 32 | self.layer1 = nn.Linear(D*hidden_dim, D*hidden_dim) 33 | self.act_fn = nn.GELU() 34 | self.layer2 = nn.Linear(D*hidden_dim, D*c_out) 35 | 36 | mask_layer1 = torch.zeros_like(self.layer1.weight.data) 37 | for d in range(D): 38 | mask_layer1[d*hidden_dim:(d+1)*hidden_dim, d*hidden_dim:(d+1)*hidden_dim] = 1.0 39 | self.register_buffer("mask_layer1", mask_layer1) 40 | 41 | mask_layer2 = torch.zeros_like(self.layer2.weight.data) 42 | for d in range(D): 43 | mask_layer2[d,d*hidden_dim:(d+1)*hidden_dim] = 1.0 44 | self.register_buffer("mask_layer2", mask_layer2) 45 | 46 | def forward(self, x): 47 | self.layer1.weight.data.mul_(self.mask_layer1) 48 | self.layer2.weight.data.mul_(self.mask_layer2) 49 | 50 | x = self.layer1(x) 51 | x = self.act_fn(x) 52 | x = self.layer2(x) 53 | 54 | return x 55 | 56 | 57 | class SimpleLinearLayer(nn.Module): 58 | 59 | def __init__(self, c_in, c_out, data_init=False): 60 | super().__init__() 61 | self.layer = nn.Linear(c_in, c_out) 62 | if data_init: 63 | scale_dims = int(c_out//2) 64 | self.layer.weight.data[scale_dims:,:] = 0 65 | self.layer.weight.data = self.layer.weight.data * 4 / np.sqrt(c_out/2) 66 | self.layer.bias.data.zero_() 67 | 68 | def forward(self, x, **kwargs): 69 | return self.layer(x) 70 | 71 | def initialize_zeros(self): 72 | self.layer.weight.data.zero_() 73 | self.layer.bias.data.zero_() 74 | 75 | 76 | class LinearNet(nn.Module): 77 | 78 | def __init__(self, c_in, c_out, num_layers, hidden_size, ext_input_dims=0, zero_init=False): 79 | super().__init__() 80 | self.inp_layer = nn.Sequential( 81 | nn.Linear(c_in, hidden_size), 82 | nn.GELU() 83 | ) 84 | self.main_net = [] 85 | for i in range(num_layers): 86 | self.main_net += [ 87 | nn.Linear(hidden_size if i>0 else hidden_size + ext_input_dims, 88 | hidden_size), 89 | nn.GELU() 90 | ] 91 | self.main_net += [ 92 | nn.Linear(hidden_size, c_out) 93 | ] 94 | self.main_net = nn.Sequential(*self.main_net) 95 | if zero_init: 96 | self.main_net[-1].weight.data.zero_() 97 | self.main_net[-1].bias.data.zero_() 98 | 99 | def forward(self, x, ext_input=None, **kwargs): 100 | x_feat = self.inp_layer(x) 101 | if ext_input is not None: 102 | x_feat = torch.cat([x_feat, ext_input], dim=-1) 103 | out = self.main_net(x_feat) 104 | return out 105 | 106 | def set_bias(self, bias): 107 | self.main_net[-1].bias.data = bias 108 | 109 | 110 | 111 | def run_sequential_with_mask(net, x, length=None, channel_padding_mask=None, src_key_padding_mask=None, length_one_hot=None, time_embed=None, gt=None, importance_weight=1, detail_out=False, **kwargs): 112 | dict_detail_out = dict() 113 | if channel_padding_mask is None: 114 | nn_out = net(x) 115 | else: 116 | x = x * channel_padding_mask 117 | for l in net: 118 | x = l(x) 119 | nn_out = x * channel_padding_mask # Making sure to zero out the outputs for all padding symbols 120 | 121 | if not detail_out: 122 | return nn_out 123 | else: 124 | return nn_out, dict_detail_out 125 | 126 | 127 | def run_padded_LSTM(x, lstm_cell, length, input_memory=None, return_final_states=False): 128 | if length is not None and (length != x.size(1)).sum() > 0: 129 | # Sort input elements for efficient LSTM application 130 | sorted_lengths, perm_index = length.sort(0, descending=True) 131 | x = x[perm_index] 132 | 133 | packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, sorted_lengths, batch_first=True) 134 | packed_outputs, _ = lstm_cell(packed_input, input_memory) 135 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) 136 | 137 | # Redo sort 138 | _, unsort_indices = perm_index.sort(0, descending=False) 139 | outputs = outputs[unsort_indices] 140 | else: 141 | outputs, _ = lstm_cell(x, input_memory) 142 | return outputs 143 | 144 | 145 | if __name__ == "__main__": 146 | pass -------------------------------------------------------------------------------- /charlm/model/ar/encoder_context.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from survae.nn.layers.autoregressive import AutoregressiveShift 4 | 5 | 6 | class IdxContextNet(nn.Sequential): 7 | def __init__(self, num_classes, context_size, num_layers, hidden_size, dropout, maxlen=400): 8 | super(IdxContextNet, self).__init__( 9 | SqueezeLayer(), # (B,1,L) -> (B,L) 10 | nn.Embedding(num_classes, hidden_size), # (B,L,H) 11 | PermuteLayer((1, 0, 2)), # (B,L,H) -> (L,B,H) 12 | LayerLSTM( 13 | hidden_size, 14 | hidden_size, 15 | num_layers=num_layers, 16 | dropout=dropout, 17 | bidirectional=True, 18 | ), # (L,B,H) -> (L,B,2*H) 19 | # Transformer(hidden_size, num_layers=num_layers, dropout=dropout, maxlen=maxlen), 20 | nn.Linear(hidden_size * 2, context_size), # (L,B,2*H) -> (L,B,P) 21 | # AutoregressiveShift(context_size), 22 | PermuteLayer((1, 2, 0)), # (L,B,P) -> (B,P,L) 23 | ) 24 | 25 | 26 | class Transformer(nn.Module): 27 | def __init__(self, hidden_size, num_layers, dropout, maxlen): 28 | super().__init__() 29 | self.embed_positions = nn.Embedding(maxlen, hidden_size) 30 | self.layers = nn.ModuleList( 31 | [ 32 | nn.TransformerEncoderLayer(hidden_size, nhead=8, dropout=dropout, norm_first=True) 33 | for _ in range(num_layers) 34 | ] 35 | ) 36 | 37 | def forward(self, x): 38 | seqlen = x.shape[0] 39 | positions = self.embed_positions(torch.arange(seqlen).to(x.device)).reshape(seqlen, 1, -1) 40 | x = x + positions 41 | for mod in self.layers: 42 | x = mod(x) 43 | return x 44 | 45 | 46 | class LayerLSTM(nn.LSTM): 47 | def forward(self, x): 48 | output, _ = super(LayerLSTM, self).forward(x) # output, (c_n, h_n) 49 | return output 50 | 51 | 52 | class SqueezeLayer(nn.Module): 53 | def forward(self, x): 54 | return x.squeeze(1) 55 | 56 | 57 | class PermuteLayer(nn.Module): 58 | def __init__(self, order): 59 | super().__init__() 60 | self.order = order 61 | 62 | def forward(self, x): 63 | return x.permute(*self.order) 64 | -------------------------------------------------------------------------------- /charlm/model/ar/encoder_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from survae.utils import sum_except_batch 4 | from survae.transforms.bijections.functional import splines 5 | from survae.nn.layers.autoregressive import AutoregressiveShift 6 | from ..transforms.autoregressive.conditional import ConditionalAutoregressiveBijection 7 | from ..transforms.autoregressive.utils import InvertSequentialCL 8 | 9 | 10 | class ConditionalSplineAutoregressive1d(ConditionalAutoregressiveBijection): 11 | def __init__( 12 | self, c, num_layers, hidden_size, dropout, num_bins, context_size, unconstrained 13 | ): 14 | self.unconstrained = unconstrained 15 | self.num_bins = num_bins 16 | scheme = InvertSequentialCL(order="cl") 17 | lstm = ConditionalAutoregressiveLSTM( 18 | C=c, 19 | P=self._num_params(), 20 | num_layers=num_layers, 21 | hidden_size=hidden_size, 22 | dropout=dropout, 23 | context_size=context_size, 24 | ) 25 | super(ConditionalSplineAutoregressive1d, self).__init__( 26 | ar_net=lstm, scheme=scheme 27 | ) 28 | self.register_buffer("constant", torch.log(torch.exp(torch.ones(1)) - 1)) 29 | self.autoregressive_net = self.ar_net # For backwards compatability 30 | 31 | def _num_params(self): 32 | # return 3 * self.num_bins + 1 33 | return 3 * self.num_bins - 1 34 | 35 | def _forward(self, x, params): 36 | unnormalized_widths = params[..., : self.num_bins] 37 | unnormalized_heights = params[..., self.num_bins : 2 * self.num_bins] 38 | unnormalized_derivatives = params[..., 2 * self.num_bins :] + self.constant 39 | if self.unconstrained: 40 | z, ldj_elementwise = splines.unconstrained_rational_quadratic_spline( 41 | x, 42 | unnormalized_widths=unnormalized_widths, 43 | unnormalized_heights=unnormalized_heights, 44 | unnormalized_derivatives=unnormalized_derivatives, 45 | inverse=False, 46 | ) 47 | else: 48 | z, ldj_elementwise = splines.rational_quadratic_spline( 49 | x, 50 | unnormalized_widths=unnormalized_widths, 51 | unnormalized_heights=unnormalized_heights, 52 | unnormalized_derivatives=unnormalized_derivatives, 53 | inverse=False, 54 | ) 55 | 56 | ldj = sum_except_batch(ldj_elementwise) 57 | return z, ldj 58 | 59 | def _element_inverse(self, z, element_params): 60 | unnormalized_widths = element_params[..., : self.num_bins] 61 | unnormalized_heights = element_params[..., self.num_bins : 2 * self.num_bins] 62 | unnormalized_derivatives = ( 63 | element_params[..., 2 * self.num_bins :] + self.constant 64 | ) 65 | if self.unconstrained: 66 | x, _ = splines.unconstrained_rational_quadratic_spline( 67 | z, 68 | unnormalized_widths=unnormalized_widths, 69 | unnormalized_heights=unnormalized_heights, 70 | unnormalized_derivatives=unnormalized_derivatives, 71 | inverse=True, 72 | ) 73 | else: 74 | x, _ = splines.rational_quadratic_spline( 75 | z, 76 | unnormalized_widths=unnormalized_widths, 77 | unnormalized_heights=unnormalized_heights, 78 | unnormalized_derivatives=unnormalized_derivatives, 79 | inverse=True, 80 | ) 81 | return x 82 | 83 | 84 | class ConditionalAutoregressiveLSTM(nn.Module): 85 | def __init__(self, C, P, num_layers, hidden_size, dropout, context_size): 86 | super(ConditionalAutoregressiveLSTM, self).__init__() 87 | 88 | self.l_in = PermuteLayer((2, 0, 1)) # (B,C,L) -> (L,B,C) 89 | self.lstm = ConditionalLayerLSTM( 90 | C + context_size, hidden_size, num_layers=num_layers, dropout=dropout 91 | ) # (L,B,C) -> (L,B,H) 92 | self.l_out = nn.Sequential( 93 | nn.Linear(hidden_size, P * C), # (L,B,H) -> (L,B,P*C) 94 | AutoregressiveShift(P * C), 95 | ReshapeLayer(C, P), # (L,B,P*C) -> (L,B,C,P) 96 | PermuteLayer((1, 2, 0, 3)), # (L,B,C,P) -> (B,C,L,P) 97 | ) 98 | 99 | def forward(self, x, context): 100 | x = self.l_in(x) 101 | context = self.l_in(context) 102 | 103 | x = self.lstm(x, context=context) 104 | return self.l_out(x) 105 | 106 | 107 | class ConditionalLayerLSTM(nn.LSTM): 108 | def forward(self, x, context): 109 | output, _ = super(ConditionalLayerLSTM, self).forward( 110 | torch.cat([x, context], dim=-1) 111 | ) # output, (c_n, h_n) 112 | return output 113 | 114 | 115 | class ReshapeLayer(nn.Module): 116 | def __init__(self, C, P): 117 | super().__init__() 118 | self.C = C 119 | self.P = P 120 | 121 | def forward(self, x): 122 | return x.reshape(*x.shape[0:2], self.C, self.P) 123 | 124 | 125 | class PermuteLayer(nn.Module): 126 | def __init__(self, order): 127 | super().__init__() 128 | self.order = order 129 | 130 | def forward(self, x): 131 | return x.permute(*self.order) 132 | -------------------------------------------------------------------------------- /charlm/model/ar/flow.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from survae.flows import Flow, ConditionalInverseFlow 4 | from survae.distributions import StandardNormal, ConditionalNormal 5 | from survae.transforms import ActNormBijection1d, PermuteAxes, Reshape, Shuffle, Conv1x1, Reverse 6 | from ..transforms import BinaryProductArgmaxSurjection 7 | from ..distributions import StandardGumbel, ConvNormal1d, BinaryEncoder 8 | from .encoder_context import IdxContextNet 9 | from .encoder_transforms import ConditionalSplineAutoregressive1d 10 | 11 | from ..CategoricalNF.flows.autoregressive_coupling import \ 12 | AutoregressiveMixtureCDFCoupling 13 | from ..CategoricalNF.networks.autoregressive_layers import \ 14 | AutoregressiveLSTMModel 15 | 16 | def ar_func(c_in, c_out, hidden, num_layers, max_seq_len, input_dp_rate): 17 | return AutoregressiveLSTMModel( 18 | c_in=c_in, 19 | c_out=c_out, 20 | max_seq_len=max_seq_len, 21 | num_layers=num_layers, 22 | hidden_size=hidden, 23 | dp_rate=0, 24 | input_dp_rate=input_dp_rate) 25 | 26 | 27 | class ArgmaxARFlow(Flow): 28 | 29 | def __init__(self, data_shape, num_classes, 30 | num_steps, actnorm, perm_channel, perm_length, base_dist, 31 | encoder_steps, encoder_bins, context_size, 32 | lstm_layers, lstm_size, lstm_dropout, 33 | context_lstm_layers, context_lstm_size, input_dp_rate): 34 | 35 | transforms = [] 36 | C, L = data_shape 37 | K = BinaryProductArgmaxSurjection.classes2dims(num_classes) 38 | 39 | # Encoder context 40 | context_net = IdxContextNet(num_classes=num_classes, 41 | context_size=context_size, 42 | num_layers=context_lstm_layers, 43 | hidden_size=context_lstm_size, 44 | dropout=lstm_dropout) 45 | 46 | # Encoder base 47 | encoder_shape = (C*K, L) 48 | encoder_base = ConditionalNormal(nn.Conv1d(context_size, 2*C*K, kernel_size=1, padding=0), split_dim=1) 49 | 50 | # Encoder transforms 51 | encoder_transforms = [] 52 | for step in range(encoder_steps): 53 | if step > 0: 54 | if actnorm: encoder_transforms.append(ActNormBijection1d(encoder_shape[0])) 55 | if perm_length == 'reverse': encoder_transforms.append(Reverse(encoder_shape[1], dim=2)) 56 | if perm_channel == 'conv': encoder_transforms.append(Conv1x1(encoder_shape[0], slogdet_cpu=False)) 57 | elif perm_channel == 'shuffle': encoder_transforms.append(Shuffle(encoder_shape[0])) 58 | 59 | encoder_transforms.append(ConditionalSplineAutoregressive1d(c=encoder_shape[0], 60 | num_layers=lstm_layers, 61 | hidden_size=lstm_size, 62 | dropout=lstm_dropout, 63 | num_bins=encoder_bins, 64 | context_size=context_size, 65 | unconstrained=True)) 66 | encoder_transforms.append(Reshape((C*K,L), (C,K,L))) # (B,C*K,L) -> (B,C,K,L) 67 | encoder_transforms.append(PermuteAxes([0,1,3,2])) # (B,C,K,L) -> (B,C,L,K) 68 | 69 | # Encoder 70 | encoder = BinaryEncoder(ConditionalInverseFlow(base_dist=encoder_base, 71 | transforms=encoder_transforms, 72 | context_init=context_net), dims=K) 73 | transforms.append(BinaryProductArgmaxSurjection(encoder, num_classes)) 74 | 75 | # Reshape 76 | transforms.append(PermuteAxes([0,1,3,2])) # (B,C,L,K) -> (B,C,K,L) 77 | transforms.append(Reshape((C,K,L), (C*K,L))) # (B,C,K,L) -> (B,C*K,L) 78 | current_shape = (C*K,L) 79 | 80 | # Coupling blocks 81 | for step in range(num_steps): 82 | if step > 0: 83 | if actnorm: transforms.append(ActNormBijection1d(current_shape[0])) 84 | if perm_length == 'reverse': transforms.append(Reverse(current_shape[1], dim=2)) 85 | if perm_channel == 'conv': transforms.append(Conv1x1(current_shape[0], slogdet_cpu=False)) 86 | elif perm_channel == 'shuffle': transforms.append(Shuffle(current_shape[0])) 87 | 88 | def model_func(c_out): 89 | return ar_func( 90 | c_in=current_shape[0], 91 | c_out=c_out, 92 | hidden=lstm_size, 93 | num_layers=lstm_layers, 94 | max_seq_len=L, 95 | input_dp_rate=input_dp_rate) 96 | 97 | transforms.append( 98 | AutoregressiveMixtureCDFCoupling( 99 | c_in=current_shape[0], 100 | model_func=model_func, 101 | block_type="LSTM model", 102 | num_mixtures=32) 103 | ) 104 | 105 | if base_dist == 'conv_gauss': base_dist = ConvNormal1d(current_shape) 106 | elif base_dist == 'gauss': base_dist = StandardNormal(current_shape) 107 | elif base_dist == 'gumbel': base_dist = StandardGumbel(current_shape) 108 | super(ArgmaxARFlow, self).__init__(base_dist=base_dist, 109 | transforms=transforms) 110 | -------------------------------------------------------------------------------- /charlm/model/ar/vorflow.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from survae.flows import Flow, ConditionalInverseFlow 4 | from survae.distributions import StandardNormal, ConditionalNormal 5 | from survae.transforms import ( 6 | ActNormBijection1d, 7 | PermuteAxes, 8 | Reshape, 9 | Shuffle, 10 | Conv1x1, 11 | Reverse, 12 | ) 13 | 14 | from ..distributions import StandardGumbel, ConvNormal1d 15 | from .encoder_context import IdxContextNet 16 | from .encoder_transforms import ConditionalSplineAutoregressive1d 17 | from ..transforms.voronoi_surjection import VoronoiSurjection 18 | 19 | from ..CategoricalNF.flows.autoregressive_coupling import ( 20 | AutoregressiveMixtureCDFCoupling, 21 | ) 22 | from ..CategoricalNF.networks.autoregressive_layers import AutoregressiveLSTMModel 23 | 24 | from voronoi import VoronoiTransform 25 | 26 | 27 | def ar_func(c_in, c_out, hidden, num_layers, max_seq_len, input_dp_rate): 28 | return AutoregressiveLSTMModel( 29 | c_in=c_in, 30 | c_out=c_out, 31 | max_seq_len=max_seq_len, 32 | num_layers=num_layers, 33 | hidden_size=hidden, 34 | dp_rate=0, 35 | input_dp_rate=input_dp_rate, 36 | ) 37 | 38 | 39 | class VoronoiARFlow(Flow): 40 | def __init__( 41 | self, 42 | data_shape, 43 | num_classes, 44 | embedding_dim, 45 | num_steps, 46 | actnorm, 47 | perm_channel, 48 | perm_length, 49 | base_dist, 50 | encoder_steps, 51 | encoder_bins, 52 | context_size, 53 | lstm_layers, 54 | lstm_size, 55 | lstm_dropout, 56 | context_lstm_layers, 57 | context_lstm_size, 58 | input_dp_rate, 59 | ): 60 | 61 | transforms = [] 62 | C, L = data_shape 63 | K = embedding_dim 64 | 65 | # Encoder context 66 | context_net = IdxContextNet( 67 | num_classes=num_classes, 68 | context_size=context_size, 69 | num_layers=context_lstm_layers, 70 | hidden_size=context_lstm_size, 71 | dropout=lstm_dropout, 72 | maxlen=L, 73 | ) 74 | 75 | # Encoder base 76 | encoder_shape = (C * K, L) 77 | encoder_base = ConditionalNormal( 78 | nn.Conv1d(context_size, 2 * C * K, kernel_size=1, padding=0), split_dim=1 79 | ) 80 | 81 | # Encoder transforms 82 | encoder_transforms = [] 83 | for step in range(encoder_steps): 84 | if step > 0: 85 | if actnorm: 86 | encoder_transforms.append(ActNormBijection1d(encoder_shape[0])) 87 | if perm_length == "reverse": 88 | encoder_transforms.append(Reverse(encoder_shape[1], dim=2)) 89 | if perm_channel == "conv": 90 | encoder_transforms.append( 91 | Conv1x1(encoder_shape[0], slogdet_cpu=False) 92 | ) 93 | elif perm_channel == "shuffle": 94 | encoder_transforms.append(Shuffle(encoder_shape[0])) 95 | 96 | encoder_transforms.append( 97 | ConditionalSplineAutoregressive1d( 98 | c=encoder_shape[0], 99 | num_layers=lstm_layers, 100 | hidden_size=lstm_size, 101 | dropout=lstm_dropout, 102 | num_bins=encoder_bins, 103 | context_size=context_size, 104 | unconstrained=True, 105 | ) 106 | ) 107 | encoder_transforms.append( 108 | Reshape((C * K, L), (C, K, L)) 109 | ) # (B,C*K,L) -> (B,C,K,L) 110 | encoder_transforms.append(PermuteAxes([0, 1, 3, 2])) # (B,C,K,L) -> (B,C,L,K) 111 | 112 | # Encoder 113 | voronoi_transform = VoronoiTransform( 114 | num_discrete_variables=1, 115 | num_classes=num_classes, 116 | embedding_dim=embedding_dim, 117 | ) 118 | transforms.append( 119 | VoronoiSurjection( 120 | noise_dist=ConditionalInverseFlow( 121 | base_dist=encoder_base, 122 | transforms=encoder_transforms, 123 | context_init=context_net, 124 | ), 125 | voronoi_transform=voronoi_transform, 126 | num_classes=num_classes, 127 | embedding_dim=K, 128 | ) 129 | ) 130 | 131 | # Reshape 132 | transforms.append(PermuteAxes([0, 1, 3, 2])) # (B,C,L,K) -> (B,C,K,L) 133 | transforms.append(Reshape((C, K, L), (C * K, L))) # (B,C,K,L) -> (B,C*K,L) 134 | current_shape = (C * K, L) 135 | 136 | # Coupling blocks 137 | for step in range(num_steps): 138 | if step > 0: 139 | if actnorm: 140 | transforms.append(ActNormBijection1d(current_shape[0])) 141 | if perm_length == "reverse": 142 | transforms.append(Reverse(current_shape[1], dim=2)) 143 | if perm_channel == "conv": 144 | transforms.append(Conv1x1(current_shape[0], slogdet_cpu=False)) 145 | elif perm_channel == "shuffle": 146 | transforms.append(Shuffle(current_shape[0])) 147 | 148 | def model_func(c_out): 149 | return ar_func( 150 | c_in=current_shape[0], 151 | c_out=c_out, 152 | hidden=lstm_size, 153 | num_layers=lstm_layers, 154 | max_seq_len=L, 155 | input_dp_rate=input_dp_rate, 156 | ) 157 | 158 | transforms.append( 159 | AutoregressiveMixtureCDFCoupling( 160 | c_in=current_shape[0], 161 | model_func=model_func, 162 | block_type="LSTM model", 163 | num_mixtures=27, 164 | ) 165 | ) 166 | 167 | if base_dist == "conv_gauss": 168 | base_dist = ConvNormal1d(current_shape) 169 | elif base_dist == "gauss": 170 | base_dist = StandardNormal(current_shape) 171 | elif base_dist == "gumbel": 172 | base_dist = StandardGumbel(current_shape) 173 | super().__init__(base_dist=base_dist, transforms=transforms) 174 | self.voronoi_transform = voronoi_transform 175 | -------------------------------------------------------------------------------- /charlm/model/coupling/cond_ar_affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.utils import sum_except_batch 3 | from ..transforms.autoregressive.conditional import ConditionalAffineAutoregressiveBijection 4 | from ..transforms.autoregressive.utils import InvertSequentialCL 5 | 6 | 7 | class CondAffineAR(ConditionalAffineAutoregressiveBijection): 8 | 9 | def __init__(self, ar_net): 10 | scheme = InvertSequentialCL(order='cl') 11 | super(CondAffineAR, self).__init__(ar_net, scheme=scheme) 12 | 13 | def _forward(self, x, params): 14 | unconstrained_scale, shift = self._split_params(params) 15 | log_scale = 2. * torch.tanh(unconstrained_scale / 2.) 16 | z = shift + torch.exp(log_scale) * x 17 | ldj = sum_except_batch(log_scale) 18 | return z, ldj 19 | 20 | def _element_inverse(self, z, element_params): 21 | unconstrained_scale, shift = self._split_params(element_params) 22 | log_scale = 2. * torch.tanh(unconstrained_scale / 2.) 23 | x = (z - shift) * torch.exp(-log_scale) 24 | return x 25 | -------------------------------------------------------------------------------- /charlm/model/coupling/cond_ar_spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from survae.utils import sum_except_batch 4 | from survae.transforms.bijections.functional import splines 5 | from ..transforms.autoregressive.conditional import ConditionalAutoregressiveBijection 6 | from ..transforms.autoregressive.utils import InvertSequentialCL 7 | 8 | 9 | class CondSplineAR(ConditionalAutoregressiveBijection): 10 | 11 | def __init__(self, ar_net, num_bins, unconstrained): 12 | self.unconstrained = unconstrained 13 | self.num_bins = num_bins 14 | scheme = InvertSequentialCL(order='cl') 15 | super(CondSplineAR, self).__init__(ar_net=ar_net, scheme=scheme) 16 | self.register_buffer('constant', torch.log(torch.exp(torch.ones(1)) - 1)) 17 | 18 | def _num_params(self): 19 | return 3 * self.num_bins + 1 20 | 21 | def _forward(self, x, params): 22 | unnormalized_widths = params[..., :self.num_bins] 23 | unnormalized_heights = params[..., self.num_bins:2*self.num_bins] 24 | unnormalized_derivatives = params[..., 2*self.num_bins:] + self.constant 25 | if self.unconstrained: 26 | z, ldj_elementwise = splines.unconstrained_rational_quadratic_spline( 27 | x, 28 | unnormalized_widths=unnormalized_widths, 29 | unnormalized_heights=unnormalized_heights, 30 | unnormalized_derivatives=unnormalized_derivatives, 31 | inverse=False) 32 | else: 33 | z, ldj_elementwise = splines.rational_quadratic_spline( 34 | x, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=False) 39 | 40 | ldj = sum_except_batch(ldj_elementwise) 41 | return z, ldj 42 | 43 | def _element_inverse(self, z, element_params): 44 | unnormalized_widths = element_params[..., :self.num_bins] 45 | unnormalized_heights = element_params[..., self.num_bins:2*self.num_bins] 46 | unnormalized_derivatives = element_params[..., 2*self.num_bins:] + self.constant 47 | if self.unconstrained: 48 | x, _ = splines.unconstrained_rational_quadratic_spline( 49 | z, 50 | unnormalized_widths=unnormalized_widths, 51 | unnormalized_heights=unnormalized_heights, 52 | unnormalized_derivatives=unnormalized_derivatives, 53 | inverse=True) 54 | else: 55 | x, _ = splines.rational_quadratic_spline( 56 | z, 57 | unnormalized_widths=unnormalized_widths, 58 | unnormalized_heights=unnormalized_heights, 59 | unnormalized_derivatives=unnormalized_derivatives, 60 | inverse=True) 61 | return x 62 | -------------------------------------------------------------------------------- /charlm/model/coupling/masked_linear.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class MaskedLinear(nn.Linear): 8 | 9 | def __init__(self, in_dim, out_dim, data_dim, causal, bias=True): 10 | super(MaskedLinear, self).__init__(in_features=in_dim, out_features=out_dim, bias=bias) 11 | self.register_buffer('mask', self.create_mask(in_dim, out_dim, data_dim, causal)) 12 | self.data_dim = data_dim 13 | self.causal = causal 14 | 15 | @staticmethod 16 | def create_mask(in_dim, out_dim, data_dim, causal): 17 | base_mask = torch.ones([data_dim,data_dim]) 18 | if causal: base_mask = base_mask.tril(-1) 19 | else: base_mask = base_mask.tril(0) 20 | rep_out, rep_in = math.ceil(out_dim / data_dim), math.ceil(in_dim / data_dim) 21 | return base_mask.repeat(rep_out, rep_in)[0:out_dim, 0:in_dim] 22 | 23 | def forward(self, x): 24 | self.weight.data *= self.mask 25 | return super(MaskedLinear, self).forward(x) 26 | -------------------------------------------------------------------------------- /charlm/model/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .gumbel import StandardGumbel 2 | from .conv_normal1d import ConvNormal1d 3 | from .binary_encoder import BinaryEncoder 4 | -------------------------------------------------------------------------------- /charlm/model/distributions/binary_encoder.py: -------------------------------------------------------------------------------- 1 | from survae.distributions import ConditionalDistribution 2 | from survae.transforms import Softplus 3 | from ..transforms.utils import integer_to_base 4 | 5 | 6 | class BinaryEncoder(ConditionalDistribution): 7 | '''An encoder for BinaryProductArgmaxSurjection.''' 8 | 9 | def __init__(self, noise_dist, dims): 10 | super(BinaryEncoder, self).__init__() 11 | self.noise_dist = noise_dist 12 | self.dims = dims 13 | self.softplus = Softplus() 14 | 15 | def sample_with_log_prob(self, context): 16 | # Example: context.shape = (B, C, H, W) with values in {0,1,...,K-1} 17 | # Sample z.shape = (B, C, H, W, K) 18 | 19 | binary = integer_to_base(context, base=2, dims=self.dims) 20 | sign = binary * 2 - 1 21 | 22 | u, log_pu = self.noise_dist.sample_with_log_prob(context=context) 23 | u_positive, ldj = self.softplus(u) 24 | 25 | log_pu_positive = log_pu - ldj 26 | z = u_positive * sign 27 | 28 | log_pz = log_pu_positive 29 | return z, log_pz 30 | -------------------------------------------------------------------------------- /charlm/model/distributions/conv_normal1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.distributions import DiagonalNormal 3 | 4 | 5 | class ConvNormal1d(DiagonalNormal): 6 | def __init__(self, shape): 7 | super(DiagonalNormal, self).__init__() 8 | assert len(shape) == 2 9 | self.shape = torch.Size(shape) 10 | self.loc = torch.nn.Parameter(torch.zeros(1, shape[0], 1)) 11 | self.log_scale = torch.nn.Parameter(torch.zeros(1, shape[0], 1)) 12 | -------------------------------------------------------------------------------- /charlm/model/distributions/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.distributions import Distribution 3 | from survae.utils import sum_except_batch 4 | 5 | 6 | class StandardGumbel(Distribution): 7 | """A standard Gumbel distribution.""" 8 | 9 | def __init__(self, shape): 10 | super(StandardGumbel, self).__init__() 11 | self.shape = torch.Size(shape) 12 | self.register_buffer('buffer', torch.zeros(1)) 13 | 14 | def log_prob(self, x): 15 | return sum_except_batch(- x - (-x).exp()) 16 | 17 | def sample(self, num_samples): 18 | u = torch.rand(num_samples, *self.shape, device=self.buffer.device, dtype=self.buffer.dtype) 19 | eps = torch.finfo(u.dtype).tiny # 1.18e-38 for float32 20 | return -torch.log(-torch.log(u + eps) + eps) 21 | -------------------------------------------------------------------------------- /charlm/model/model.py: -------------------------------------------------------------------------------- 1 | from .ar.flow import ArgmaxARFlow 2 | from .coupling.flow import ArgmaxCouplingFlow 3 | from .ar.vorflow import VoronoiARFlow 4 | from .coupling.vorflow import VoronoiCouplingFlow 5 | 6 | 7 | def get_model(cfg, args, data_shape, num_classes): 8 | 9 | if cfg.dequantization == "argmax": 10 | 11 | if args.model == "ar": 12 | return ArgmaxARFlow( 13 | data_shape=data_shape, 14 | num_classes=num_classes, 15 | num_steps=args.num_steps, 16 | actnorm=args.actnorm, 17 | perm_channel=args.perm_channel, 18 | perm_length=args.perm_length, 19 | base_dist=args.base_dist, 20 | encoder_steps=args.encoder_steps, 21 | encoder_bins=args.encoder_bins, 22 | context_size=args.context_size, 23 | lstm_layers=args.lstm_layers, 24 | lstm_size=args.lstm_size, 25 | lstm_dropout=args.lstm_dropout, 26 | context_lstm_layers=args.context_lstm_layers, 27 | context_lstm_size=args.context_lstm_size, 28 | input_dp_rate=args.input_dp_rate, 29 | ) 30 | 31 | elif args.model == "coupling": 32 | return ArgmaxCouplingFlow( 33 | data_shape=data_shape, 34 | num_classes=num_classes, 35 | num_steps=args.num_steps, 36 | actnorm=args.actnorm, 37 | num_mixtures=args.num_mixtures, 38 | perm_channel=args.perm_channel, 39 | perm_length=args.perm_length, 40 | base_dist=args.base_dist, 41 | encoder_steps=args.encoder_steps, 42 | encoder_bins=args.encoder_bins, 43 | encoder_ff_size=args.encoder_ff_size, 44 | context_size=args.context_size, 45 | context_ff_layers=args.context_ff_layers, 46 | context_ff_size=args.context_ff_size, 47 | context_dropout=args.context_dropout, 48 | lstm_layers=args.lstm_layers, 49 | lstm_size=args.lstm_size, 50 | lstm_dropout=args.lstm_dropout, 51 | input_dp_rate=args.input_dp_rate, 52 | ) 53 | 54 | elif cfg.dequantization == "voronoi": 55 | 56 | if args.model == "ar": 57 | return VoronoiARFlow( 58 | data_shape=data_shape, 59 | num_classes=num_classes, 60 | embedding_dim=cfg.embedding_dim, 61 | num_steps=args.num_steps, 62 | actnorm=args.actnorm, 63 | perm_channel=args.perm_channel, 64 | perm_length=args.perm_length, 65 | base_dist=args.base_dist, 66 | encoder_steps=args.encoder_steps, 67 | encoder_bins=args.encoder_bins, 68 | context_size=args.context_size, 69 | lstm_layers=args.lstm_layers, 70 | lstm_size=args.lstm_size, 71 | lstm_dropout=args.lstm_dropout, 72 | context_lstm_layers=args.context_lstm_layers, 73 | context_lstm_size=args.context_lstm_size, 74 | input_dp_rate=args.input_dp_rate, 75 | ) 76 | 77 | elif args.model == "coupling": 78 | return VoronoiCouplingFlow( 79 | data_shape=data_shape, 80 | num_classes=num_classes, 81 | embedding_dim=cfg.embedding_dim, 82 | num_steps=args.num_steps, 83 | actnorm=args.actnorm, 84 | num_mixtures=args.num_mixtures, 85 | perm_channel=args.perm_channel, 86 | perm_length=args.perm_length, 87 | base_dist=args.base_dist, 88 | encoder_steps=args.encoder_steps, 89 | encoder_bins=args.encoder_bins, 90 | encoder_ff_size=args.encoder_ff_size, 91 | context_size=args.context_size, 92 | context_ff_layers=args.context_ff_layers, 93 | context_ff_size=args.context_ff_size, 94 | context_dropout=args.context_dropout, 95 | lstm_layers=args.lstm_layers, 96 | lstm_size=args.lstm_size, 97 | lstm_dropout=args.lstm_dropout, 98 | input_dp_rate=args.input_dp_rate, 99 | ) 100 | -------------------------------------------------------------------------------- /charlm/model/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .squeeze1d import Squeeze1d 2 | 3 | from .argmax_product import BinaryProductArgmaxSurjection 4 | from .utils import integer_to_base, base_to_integer 5 | -------------------------------------------------------------------------------- /charlm/model/transforms/argmax_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from survae.distributions import ConditionalDistribution 5 | from survae.transforms.surjections import Surjection 6 | from .utils import integer_to_base, base_to_integer 7 | 8 | 9 | class BinaryProductArgmaxSurjection(Surjection): 10 | ''' 11 | A generative argmax surjection using a Cartesian product of binary spaces. Argmax is performed over the final dimension. 12 | Args: 13 | encoder: ConditionalDistribution, a distribution q(z|x) with support over z s.t. x=argmax z. 14 | Example: 15 | Input tensor x of shape (B, D, L) with discrete values {0,1,...,C-1}: 16 | encoder should be a distribution of shape (B, D, L, D), where D=ceil(log2(C)). 17 | When e.g. C=27, we have D=5, such that 2**5=32 classes are represented. 18 | ''' 19 | stochastic_forward = True 20 | 21 | def __init__(self, encoder, num_classes): 22 | super(BinaryProductArgmaxSurjection, self).__init__() 23 | assert isinstance(encoder, ConditionalDistribution) 24 | self.encoder = encoder 25 | self.num_classes = num_classes 26 | self.dims = self.classes2dims(num_classes) 27 | 28 | @staticmethod 29 | def classes2dims(num_classes): 30 | return int(np.ceil(np.log2(num_classes))) 31 | 32 | def idx2base(self, idx_tensor): 33 | return integer_to_base(idx_tensor, base=2, dims=self.dims) 34 | 35 | def base2idx(self, base_tensor): 36 | return base_to_integer(base_tensor, base=2) 37 | 38 | def forward(self, x): 39 | z, log_qz = self.encoder.sample_with_log_prob(context=x) 40 | ldj = -log_qz 41 | return z, ldj 42 | 43 | def inverse(self, z): 44 | binary = torch.gt(z, 0.0).long() 45 | idx = self.base2idx(binary) 46 | return idx 47 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/__init__.py: -------------------------------------------------------------------------------- 1 | from .ar import * 2 | 3 | from .ar_linear import * 4 | from .ar_splines import * 5 | from .ar_mixtures import * 6 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.transforms.bijections import Bijection 3 | from survae.utils import sum_except_batch 4 | 5 | 6 | class AutoregressiveBijection(Bijection): 7 | """ 8 | Autoregressive bijection. 9 | Transforms each input variable with an invertible elementwise bijection, 10 | conditioned on the previous elements. 11 | 12 | NOTE: Calculating the inverse transform is D times slower than calculating the 13 | forward transform, where D is the dimensionality of the input to the transform. 14 | 15 | Args: 16 | ar_net: nn.Module, an autoregressive network such that `params = ar_net(x)`. 17 | scheme: An inversion scheme. E.g. RasterScan from utils. 18 | """ 19 | def __init__(self, ar_net, scheme): 20 | super(AutoregressiveBijection, self).__init__() 21 | self.ar_net = ar_net 22 | self.scheme = scheme 23 | self.scheme.setup(ar_net=self.ar_net, 24 | element_inverse_fn=self._element_inverse) 25 | 26 | def forward(self, x): 27 | params = self.ar_net(x) 28 | z, ldj = self._forward(x, params) 29 | return z, ldj 30 | 31 | def inverse(self, z): 32 | return self.scheme.inverse(z=z) 33 | 34 | def _num_params(self): 35 | raise NotImplementedError() 36 | 37 | def _forward(self, x, params): 38 | raise NotImplementedError() 39 | 40 | def _element_inverse(self, z, element_params): 41 | raise NotImplementedError() 42 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/ar_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.utils import sum_except_batch 3 | from .ar import AutoregressiveBijection 4 | 5 | 6 | class AdditiveAutoregressiveBijection(AutoregressiveBijection): 7 | '''Additive autoregressive bijection.''' 8 | 9 | def _num_params(self): 10 | return 1 11 | 12 | def _forward(self, x, params): 13 | return x + params, x.new_zeros(x.shape[0]) 14 | 15 | def _element_inverse(self, z, element_params): 16 | return z - element_params 17 | 18 | 19 | class AffineAutoregressiveBijection(AutoregressiveBijection): 20 | '''Affine autoregressive bijection.''' 21 | 22 | def _num_params(self): 23 | return 2 24 | 25 | def _forward(self, x, params): 26 | assert params.shape[-1] == self._num_params() 27 | log_scale, shift = self._split_params(params) 28 | scale = torch.exp(log_scale) 29 | z = scale * x + shift 30 | ldj = sum_except_batch(log_scale) 31 | return z, ldj 32 | 33 | def _element_inverse(self, z, element_params): 34 | assert element_params.shape[-1] == self._num_params() 35 | log_scale, shift = self._split_params(element_params) 36 | scale = torch.exp(log_scale) 37 | x = (z - shift) / scale 38 | return x 39 | 40 | def _split_params(self, params): 41 | unconstrained_scale = params[..., 0] 42 | shift = params[..., 1] 43 | return unconstrained_scale, shift 44 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/ar_mixtures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.utils import sum_except_batch 3 | from survae.transforms.bijections.functional.mixtures import gaussian_mixture_transform, logistic_mixture_transform, censored_logistic_mixture_transform 4 | from survae.transforms.bijections.functional.mixtures import get_mixture_params 5 | from .ar import AutoregressiveBijection 6 | 7 | 8 | class GaussianMixtureAutoregressiveBijection(AutoregressiveBijection): 9 | 10 | def __init__(self, ar_net, scheme, num_mixtures): 11 | super(GaussianMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 12 | self.num_mixtures = num_mixtures 13 | self.set_bisection_params() 14 | 15 | def set_bisection_params(self, eps=1e-10, max_iters=100): 16 | self.max_iters = max_iters 17 | self.eps = eps 18 | 19 | def _num_params(self): 20 | return 3 * self.num_mixtures 21 | 22 | def _elementwise(self, inputs, params, inverse): 23 | assert params.shape[-1] == self._num_params() 24 | 25 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 26 | 27 | x = gaussian_mixture_transform(inputs=inputs, 28 | logit_weights=logit_weights, 29 | means=means, 30 | log_scales=log_scales, 31 | eps=self.eps, 32 | max_iters=self.max_iters, 33 | inverse=inverse) 34 | 35 | if inverse: 36 | return x 37 | else: 38 | z, ldj_elementwise = x 39 | ldj = sum_except_batch(ldj_elementwise) 40 | return z, ldj 41 | 42 | def _forward(self, x, params): 43 | return self._elementwise(x, params, inverse=False) 44 | 45 | def _element_inverse(self, z, element_params): 46 | return self._elementwise(z, element_params, inverse=True) 47 | 48 | 49 | class LogisticMixtureAutoregressiveBijection(AutoregressiveBijection): 50 | 51 | def __init__(self, ar_net, scheme, num_mixtures): 52 | super(LogisticMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 53 | self.num_mixtures = num_mixtures 54 | self.set_bisection_params() 55 | 56 | def set_bisection_params(self, eps=1e-10, max_iters=100): 57 | self.max_iters = max_iters 58 | self.eps = eps 59 | 60 | def _num_params(self): 61 | return 3 * self.num_mixtures 62 | 63 | def _elementwise(self, inputs, params, inverse): 64 | assert params.shape[-1] == self._num_params() 65 | 66 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 67 | 68 | x = logistic_mixture_transform(inputs=inputs, 69 | logit_weights=logit_weights, 70 | means=means, 71 | log_scales=log_scales, 72 | eps=self.eps, 73 | max_iters=self.max_iters, 74 | inverse=inverse) 75 | 76 | if inverse: 77 | return x 78 | else: 79 | z, ldj_elementwise = x 80 | ldj = sum_except_batch(ldj_elementwise) 81 | return z, ldj 82 | 83 | def _forward(self, x, params): 84 | return self._elementwise(x, params, inverse=False) 85 | 86 | def _element_inverse(self, z, element_params): 87 | return self._elementwise(z, element_params, inverse=True) 88 | 89 | 90 | class CensoredLogisticMixtureAutoregressiveBijection(AutoregressiveBijection): 91 | 92 | def __init__(self, ar_net, scheme, num_mixtures, num_bins): 93 | super(CensoredLogisticMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 94 | self.num_mixtures = num_mixtures 95 | self.num_bins = num_bins 96 | self.set_bisection_params() 97 | 98 | def set_bisection_params(self, eps=1e-10, max_iters=100): 99 | self.max_iters = max_iters 100 | self.eps = eps 101 | 102 | def _num_params(self): 103 | return 3 * self.num_mixtures 104 | 105 | def _elementwise(self, inputs, params, inverse): 106 | assert params.shape[-1] == self._num_params() 107 | 108 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 109 | 110 | x = censored_logistic_mixture_transform(inputs=inputs, 111 | logit_weights=logit_weights, 112 | means=means, 113 | log_scales=log_scales, 114 | num_bins=self.num_bins, 115 | eps=self.eps, 116 | max_iters=self.max_iters, 117 | inverse=inverse) 118 | 119 | if inverse: 120 | return x 121 | else: 122 | z, ldj_elementwise = x 123 | ldj = sum_except_batch(ldj_elementwise) 124 | return z, ldj 125 | 126 | def _forward(self, x, params): 127 | return self._elementwise(x, params, inverse=False) 128 | 129 | def _element_inverse(self, z, element_params): 130 | return self._elementwise(z, element_params, inverse=True) 131 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/conditional/__init__.py: -------------------------------------------------------------------------------- 1 | from .ar import * 2 | 3 | from .ar_linear import * 4 | from .ar_splines import * 5 | from .ar_mixtures import * 6 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/conditional/ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.transforms.bijections.conditional import ConditionalBijection 3 | 4 | 5 | class ConditionalAutoregressiveBijection(ConditionalBijection): 6 | """ 7 | Autoregressive bijection. 8 | Transforms each input variable with an invertible elementwise bijection, 9 | conditioned on the previous elements. 10 | 11 | NOTE: Calculating the inverse transform is D times slower than calculating the 12 | forward transform, where D is the dimensionality of the input to the transform. 13 | 14 | Args: 15 | ar_net: nn.Module, an autoregressive network such that `params = ar_net(x)`. 16 | scheme: An inversion scheme. E.g. RasterScan from utils. 17 | """ 18 | def __init__(self, ar_net, scheme): 19 | super(ConditionalAutoregressiveBijection, self).__init__() 20 | self.ar_net = ar_net 21 | self.scheme = scheme 22 | self.scheme.setup(ar_net=self.ar_net, 23 | element_inverse_fn=self._element_inverse) 24 | 25 | def forward(self, x, context): 26 | params = self.ar_net(x, context=context) 27 | z, ldj = self._forward(x, params) 28 | return z, ldj 29 | 30 | def inverse(self, z, context): 31 | return self.scheme.inverse(z=z, context=context) 32 | 33 | def _output_dim_multiplier(self): 34 | raise NotImplementedError() 35 | 36 | def _forward(self, x, params): 37 | raise NotImplementedError() 38 | 39 | def _element_inverse(self, z, element_params): 40 | raise NotImplementedError() 41 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/conditional/ar_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.utils import sum_except_batch 3 | from .ar import ConditionalAutoregressiveBijection 4 | 5 | 6 | class ConditionalAdditiveAutoregressiveBijection(ConditionalAutoregressiveBijection): 7 | '''Additive autoregressive bijection.''' 8 | 9 | def _num_params(self): 10 | return 1 11 | 12 | def _forward(self, x, params): 13 | return x + params, x.new_zeros(x.shape[0]) 14 | 15 | def _element_inverse(self, z, element_params): 16 | return z - element_params 17 | 18 | 19 | class ConditionalAffineAutoregressiveBijection(ConditionalAutoregressiveBijection): 20 | '''Affine autoregressive bijection.''' 21 | 22 | def _num_params(self): 23 | return 2 24 | 25 | def _forward(self, x, params): 26 | assert params.shape[-1] == self._num_params() 27 | log_scale, shift = self._split_params(params) 28 | scale = torch.exp(log_scale) 29 | z = scale * x + shift 30 | ldj = sum_except_batch(log_scale) 31 | return z, ldj 32 | 33 | def _element_inverse(self, z, element_params): 34 | assert element_params.shape[-1] == self._num_params() 35 | log_scale, shift = self._split_params(element_params) 36 | scale = torch.exp(log_scale) 37 | x = (z - shift) / scale 38 | return x 39 | 40 | def _split_params(self, params): 41 | unconstrained_scale = params[..., 0] 42 | shift = params[..., 1] 43 | return unconstrained_scale, shift 44 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/conditional/ar_mixtures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.utils import sum_except_batch 3 | from survae.transforms.bijections.functional.mixtures import gaussian_mixture_transform, logistic_mixture_transform, censored_logistic_mixture_transform 4 | from survae.transforms.bijections.functional.mixtures import get_mixture_params 5 | from .ar import ConditionalAutoregressiveBijection 6 | 7 | 8 | class ConditionalGaussianMixtureAutoregressiveBijection(ConditionalAutoregressiveBijection): 9 | 10 | def __init__(self, ar_net, scheme, num_mixtures): 11 | super(ConditionalGaussianMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 12 | self.num_mixtures = num_mixtures 13 | self.set_bisection_params() 14 | 15 | def set_bisection_params(self, eps=1e-10, max_iters=100): 16 | self.max_iters = max_iters 17 | self.eps = eps 18 | 19 | def _num_params(self): 20 | return 3 * self.num_mixtures 21 | 22 | def _elementwise(self, inputs, params, inverse): 23 | assert params.shape[-1] == self._num_params() 24 | 25 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 26 | 27 | x = gaussian_mixture_transform(inputs=inputs, 28 | logit_weights=logit_weights, 29 | means=means, 30 | log_scales=log_scales, 31 | eps=self.eps, 32 | max_iters=self.max_iters, 33 | inverse=inverse) 34 | 35 | if inverse: 36 | return x 37 | else: 38 | z, ldj_elementwise = x 39 | ldj = sum_except_batch(ldj_elementwise) 40 | return z, ldj 41 | 42 | def _forward(self, x, params): 43 | return self._elementwise(x, params, inverse=False) 44 | 45 | def _element_inverse(self, z, element_params): 46 | return self._elementwise(z, element_params, inverse=True) 47 | 48 | 49 | class ConditionalLogisticMixtureAutoregressiveBijection(ConditionalAutoregressiveBijection): 50 | 51 | def __init__(self, ar_net, scheme, num_mixtures): 52 | super(ConditionalLogisticMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 53 | self.num_mixtures = num_mixtures 54 | self.set_bisection_params() 55 | 56 | def set_bisection_params(self, eps=1e-10, max_iters=100): 57 | self.max_iters = max_iters 58 | self.eps = eps 59 | 60 | def _num_params(self): 61 | return 3 * self.num_mixtures 62 | 63 | def _elementwise(self, inputs, params, inverse): 64 | assert params.shape[-1] == self._num_params() 65 | 66 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 67 | 68 | x = logistic_mixture_transform(inputs=inputs, 69 | logit_weights=logit_weights, 70 | means=means, 71 | log_scales=log_scales, 72 | eps=self.eps, 73 | max_iters=self.max_iters, 74 | inverse=inverse) 75 | 76 | if inverse: 77 | return x 78 | else: 79 | z, ldj_elementwise = x 80 | ldj = sum_except_batch(ldj_elementwise) 81 | return z, ldj 82 | 83 | def _forward(self, x, params): 84 | return self._elementwise(x, params, inverse=False) 85 | 86 | def _element_inverse(self, z, element_params): 87 | return self._elementwise(z, element_params, inverse=True) 88 | 89 | 90 | class ConditionalCensoredLogisticMixtureAutoregressiveBijection(ConditionalAutoregressiveBijection): 91 | 92 | def __init__(self, ar_net, scheme, num_mixtures, num_bins): 93 | super(ConditionalCensoredLogisticMixtureAutoregressiveBijection, self).__init__(ar_net=ar_net, scheme=scheme) 94 | self.num_mixtures = num_mixtures 95 | self.num_bins = num_bins 96 | self.set_bisection_params() 97 | 98 | def set_bisection_params(self, eps=1e-10, max_iters=100): 99 | self.max_iters = max_iters 100 | self.eps = eps 101 | 102 | def _num_params(self): 103 | return 3 * self.num_mixtures 104 | 105 | def _elementwise(self, inputs, params, inverse): 106 | assert params.shape[-1] == self._num_params() 107 | 108 | logit_weights, means, log_scales = get_mixture_params(params, num_mixtures=self.num_mixtures) 109 | 110 | x = censored_logistic_mixture_transform(inputs=inputs, 111 | logit_weights=logit_weights, 112 | means=means, 113 | log_scales=log_scales, 114 | num_bins=self.num_bins, 115 | eps=self.eps, 116 | max_iters=self.max_iters, 117 | inverse=inverse) 118 | 119 | if inverse: 120 | return x 121 | else: 122 | z, ldj_elementwise = x 123 | ldj = sum_except_batch(ldj_elementwise) 124 | return z, ldj 125 | 126 | def _forward(self, x, params): 127 | return self._elementwise(x, params, inverse=False) 128 | 129 | def _element_inverse(self, z, element_params): 130 | return self._elementwise(z, element_params, inverse=True) 131 | -------------------------------------------------------------------------------- /charlm/model/transforms/autoregressive/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class InvertRasterScan(): 5 | ''' 6 | Invert autoregressive bijection in 7 | raster scan order (pixelwise, width first, then height). 8 | Data is assumed to be images of shape (C, H, W). 9 | Args: 10 | order (str): The order in which to invert. Choices: `{'cwh', 'wh'}`. 11 | ''' 12 | 13 | def __init__(self, order='cwh'): 14 | assert order in {'cwh', 'wh'} 15 | self.order = order 16 | self.ready = False 17 | 18 | def setup(self, ar_net, element_inverse_fn): 19 | self.ar_net = ar_net 20 | self.element_inverse_fn = element_inverse_fn 21 | self.ready = True 22 | 23 | def inverse(self, z, **kwargs): 24 | assert self.ready, 'Run scheme.setup(...) before scheme.inverse(...).' 25 | with torch.no_grad(): 26 | if self.order == 'cwh': x = self._inverse_cwh(z, **kwargs) 27 | if self.order == 'wh': x = self._inverse_wh(z, **kwargs) 28 | return x 29 | 30 | def _inverse_cwh(self, z, **kwargs): 31 | _, C, H, W = z.shape 32 | x = torch.zeros_like(z) 33 | for h in range(H): 34 | for w in range(W): 35 | for c in range(C): 36 | element_params = self.ar_net(x, **kwargs) 37 | x[:,c,h,w] = self.element_inverse_fn(z[:,c,h,w], element_params[:,c,h,w]) 38 | return x 39 | 40 | def _inverse_wh(self, z, **kwargs): 41 | _, C, H, W = z.shape 42 | x = torch.zeros_like(z) 43 | for h in range(H): 44 | for w in range(W): 45 | element_params = self.ar_net(x, **kwargs) 46 | x[:,:,h,w] = self.element_inverse_fn(z[:,:,h,w], element_params[:,:,h,w]) 47 | return x 48 | 49 | 50 | class InvertSequentialCL(): 51 | ''' 52 | Invert autoregressive bijection in sequential order. 53 | Data is assumed to be audio / time series of shape (C, L). 54 | Args: 55 | shape (Iterable): The data shape, e.g. (2,1024). 56 | order (str): The order in which to invert. Choices: `{'cl', 'l'}`. 57 | ''' 58 | 59 | def __init__(self, order='cl'): 60 | assert order in {'cl', 'l'} 61 | self.order = order 62 | self.ready = False 63 | 64 | def setup(self, ar_net, element_inverse_fn): 65 | self.ar_net = ar_net 66 | self.element_inverse_fn = element_inverse_fn 67 | self.ready = True 68 | 69 | def inverse(self, z, **kwargs): 70 | assert self.ready, 'Run scheme.setup(...) before scheme.invert(...).' 71 | with torch.no_grad(): 72 | if self.order == 'cl': x = self._inverse_cl(z, **kwargs) 73 | if self.order == 'l': x = self._inverse_l(z, **kwargs) 74 | return x 75 | 76 | def _inverse_cl(self, z, **kwargs): 77 | _, C, L = z.shape 78 | x = torch.zeros_like(z) 79 | for l in range(L): 80 | for c in range(C): 81 | element_params = self.ar_net(x, **kwargs) 82 | x[:,c,l] = self.element_inverse_fn(z[:,c,l], element_params[:,c,l]) 83 | return x 84 | 85 | def _inverse_l(self, z, **kwargs): 86 | _, C, L = z.shape 87 | x = torch.zeros_like(z) 88 | for l in range(L): 89 | element_params = self.ar_net(x, **kwargs) 90 | x[:,:,l] = self.element_inverse_fn(z[:,:,l], element_params[:,:,l]) 91 | return x 92 | 93 | 94 | class InvertSequential(): 95 | ''' 96 | Invert autoregressive bijection in sequential order. 97 | Data is assumed to be time series of shape (L,). 98 | Args: 99 | shape (Iterable): The data shape, e.g. (1024,). 100 | order (str): The order in which to invert. Choices: `{'l'}`. 101 | ''' 102 | 103 | def __init__(self, order='l'): 104 | assert order in {'l'} 105 | self.order = order 106 | self.ready = False 107 | 108 | def setup(self, ar_net, element_inverse_fn): 109 | self.ar_net = ar_net 110 | self.element_inverse_fn = element_inverse_fn 111 | self.ready = True 112 | 113 | def inverse(self, z, **kwargs): 114 | assert self.ready, 'Run scheme.setup(...) before scheme.inverse(...).' 115 | with torch.no_grad(): 116 | _, L = self.shape 117 | x = torch.zeros_like(z) 118 | for l in range(L): 119 | element_params = self.ar_net(x, **kwargs) 120 | x[:,l] = self.element_inverse_fn(z[:,l], element_params[:,l]) 121 | return x 122 | -------------------------------------------------------------------------------- /charlm/model/transforms/squeeze1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from survae.transforms.bijections import Bijection 3 | 4 | 5 | class Squeeze1d(Bijection): 6 | """ 7 | A bijection defined for sequential data that trades spatial dimensions for 8 | channel dimensions, i.e. "squeezes" the inputs along the channel dimensions. 9 | Introduced in the RealNVP paper [1]. 10 | Args: 11 | factor: int, the factor to squeeze by (default=2). 12 | ordered: bool, if True, squeezing happens sequencewise. 13 | if False, squeezing happens channelwise. 14 | For more details, see example (default=False). 15 | Source implementation: 16 | Based on `squeeze_nxn`, `squeeze_2x2`, `squeeze_2x2_ordered`, `unsqueeze_2x2` in: 17 | https://github.com/laurent-dinh/models/blob/master/real_nvp/real_nvp_utils.py 18 | References: 19 | [1] Density estimation using Real NVP, 20 | Dinh et al., 2017, https://arxiv.org/abs/1605.08803 21 | """ 22 | 23 | def __init__(self, factor=2, ordered=False): 24 | super(Squeeze1d, self).__init__() 25 | assert isinstance(factor, int) 26 | assert factor > 1 27 | self.factor = factor 28 | self.ordered = ordered 29 | 30 | def _squeeze(self, x): 31 | assert len(x.shape) == 3, 'Dimension should be 3, but was {}'.format(len(x.shape)) 32 | batch_size, c, l = x.shape 33 | assert l % self.factor == 0, 'l = {} not multiplicative of {}'.format(l, self.factor) 34 | t = x.view(batch_size, c, l // self.factor, self.factor) 35 | if not self.ordered: 36 | t = t.permute(0, 1, 3, 2).contiguous() 37 | else: 38 | t = t.permute(0, 3, 1, 2).contiguous() 39 | z = t.view(batch_size, c * self.factor, l // self.factor) 40 | return z 41 | 42 | def _unsqueeze(self, z): 43 | assert len(z.shape) == 3, 'Dimension should be 3, but was {}'.format(len(z.shape)) 44 | batch_size, c, l = z.shape 45 | assert c % self.factor == 0, 'c = {} not multiplicative of {}'.format(c, self.factor) 46 | if not self.ordered: 47 | t = z.view(batch_size, c // self.factor, self.factor, l) 48 | t = t.permute(0, 1, 3, 2).contiguous() 49 | else: 50 | t = z.view(batch_size, self.factor, c // self.factor, l) 51 | t = t.permute(0, 2, 3, 1).contiguous() 52 | x = t.view(batch_size, c // self.factor, l * self.factor) 53 | return x 54 | 55 | def forward(self, x): 56 | z = self._squeeze(x) 57 | ldj = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype) 58 | return z, ldj 59 | 60 | def inverse(self, z): 61 | x = self._unsqueeze(z) 62 | return x 63 | -------------------------------------------------------------------------------- /charlm/model/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def integer_to_base(idx_tensor, base, dims): 5 | ''' 6 | Encodes index tensor to a Cartesian product representation. 7 | Args: 8 | idx_tensor (LongTensor): An index tensor, shape (...), to be encoded. 9 | base (int): The base to use for encoding. 10 | dims (int): The number of dimensions to use for encoding. 11 | Returns: 12 | LongTensor: The encoded tensor, shape (..., dims). 13 | ''' 14 | powers = base ** torch.arange(dims - 1, -1, -1, device=idx_tensor.device) 15 | floored = idx_tensor[..., None] // powers 16 | remainder = floored % base 17 | 18 | base_tensor = remainder 19 | return base_tensor 20 | 21 | 22 | def base_to_integer(base_tensor, base): 23 | ''' 24 | Decodes Cartesian product representation to an index tensor. 25 | Args: 26 | base_tensor (LongTensor): The encoded tensor, shape (..., dims). 27 | base (int): The base used in the encoding. 28 | Returns: 29 | LongTensor: The index tensor, shape (...). 30 | ''' 31 | dims = base_tensor.shape[-1] 32 | powers = base ** torch.arange(dims - 1, -1, -1, device=base_tensor.device) 33 | powers = powers[(None,) * (base_tensor.dim()-1)] 34 | 35 | idx_tensor = (base_tensor * powers).sum(-1) 36 | return idx_tensor 37 | -------------------------------------------------------------------------------- /charlm/model/transforms/voronoi_surjection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from survae.transforms import Softplus 5 | from survae.transforms.surjections import Surjection 6 | 7 | class VoronoiSurjection(Surjection): 8 | """ 9 | A generative argmax surjection using a Cartesian product of binary spaces. Argmax is performed over the final dimension. 10 | Args: 11 | encoder: ConditionalDistribution, a distribution q(z|x) with support over z s.t. x=argmax z. 12 | Example: 13 | Input tensor x of shape (B, D, L) with discrete values {0,1,...,C-1}: 14 | encoder should be a distribution of shape (B, D, L, D), where D=ceil(log2(C)). 15 | When e.g. C=27, we have D=5, such that 2**5=32 classes are represented. 16 | """ 17 | 18 | stochastic_forward = True 19 | 20 | def __init__(self, noise_dist, voronoi_transform, num_classes, embedding_dim): 21 | super().__init__() 22 | 23 | self.noise_dist = noise_dist 24 | self.voronoi_transform = voronoi_transform 25 | self.num_classes = num_classes 26 | self.embedding_dim = embedding_dim 27 | self.softplus = Softplus() 28 | 29 | def forward(self, x): 30 | B, L = x.shape[0], x.shape[2] 31 | N, K, D = 1, self.num_classes, self.embedding_dim 32 | 33 | u, log_pu = self.noise_dist.sample_with_log_prob(context=x) 34 | u = u.reshape(B * L, 1, D) 35 | 36 | mask = F.one_hot(x, self.num_classes).bool().to(x.device) # (B, 1, L, K) 37 | mask = mask.reshape(B * L, 1, K) 38 | 39 | # Center the flow at the Voronoi cell. 40 | points = self.voronoi_transform.anchor_pts.reshape(1, N, K, D) 41 | x_k = torch.masked_select(points, mask.reshape(-1, N, K, 1)).reshape(-1, N, D) 42 | z = u + x_k 43 | 44 | # Transform into the target Voronoi cell. 45 | z, ldj = self.voronoi_transform.map_onto_cell(z, mask=mask) 46 | 47 | z = z.reshape(B, 1, L, D) 48 | ldj = ldj.reshape(B, -1) 49 | 50 | log_qz = log_pu - ldj.sum(1) 51 | 52 | ldj = -log_qz 53 | return z, ldj 54 | 55 | def inverse(self, z): 56 | raise NotImplementedError -------------------------------------------------------------------------------- /charlm/optim/base.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | optim_choices = {'sgd', 'adam', 'adamax'} 4 | 5 | 6 | def add_optim_args(parser): 7 | 8 | # Model params 9 | parser.add_argument('--optimizer', type=str, default='adam', choices=optim_choices) 10 | parser.add_argument('--lr', type=float, default=1e-3) 11 | parser.add_argument('--warmup', type=int, default=None) 12 | parser.add_argument('--momentum', type=float, default=0.9) 13 | parser.add_argument('--momentum_sqr', type=float, default=0.999) 14 | 15 | 16 | def get_optim_id(args): 17 | return 'base' 18 | 19 | 20 | def get_optim(args, model): 21 | assert args.optimizer in optim_choices 22 | 23 | if args.optimizer == 'sgd': 24 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 25 | elif args.optimizer == 'adam': 26 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) 27 | elif args.optimizer == 'adamax': 28 | optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) 29 | 30 | scheduler_iter = None 31 | scheduler_epoch = None 32 | 33 | return optimizer, scheduler_iter, scheduler_epoch 34 | -------------------------------------------------------------------------------- /charlm/optim/expdecay.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.optim.lr_scheduler import ExponentialLR 3 | from survae.optim.schedulers import LinearWarmupScheduler 4 | 5 | optim_choices = {'sgd', 'adam', 'adamax'} 6 | 7 | def get_optim(args, model): 8 | assert args.optimizer in optim_choices 9 | 10 | if args.optimizer == 'sgd': 11 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 12 | elif args.optimizer == 'adam': 13 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) 14 | elif args.optimizer == 'adamax': 15 | optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) 16 | 17 | if args.warmup > 0: 18 | scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup) 19 | else: 20 | scheduler_iter = None 21 | 22 | scheduler_epoch = ExponentialLR(optimizer, gamma=args.gamma) 23 | 24 | 25 | return optimizer, scheduler_iter, scheduler_epoch 26 | -------------------------------------------------------------------------------- /charlm/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_seeds(seed, cuda_deterministic=False): 7 | if seed is not None: 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | if torch.cuda.is_available(): 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | if cuda_deterministic: 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | -------------------------------------------------------------------------------- /configs/charlm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | seed: 0 5 | 6 | dataset: text8 7 | model: ar 8 | dequantization: voronoi # {voronoi, argmax} 9 | embedding_dim: 5 10 | logfreq: 20 11 | 12 | # The rest are fixed. These arguments are taken from argmax flows and are dependent. 13 | 14 | text8: 15 | ar: 16 | dataset: ${dataset} 17 | model: ${model} 18 | num_steps: 1 19 | actnorm: False 20 | perm_channel: none 21 | perm_length: reverse 22 | base_dist: conv_gauss 23 | encoder_steps: 0 24 | encoder_bins: 5 25 | context_size: 128 26 | context_lstm_layers: 1 27 | context_lstm_size: 512 28 | lstm_layers: 2 29 | lstm_size: 2048 30 | lstm_dropout: 0.0 31 | input_dp_rate: 0.25 32 | 33 | epochs: 40 34 | disable_ema_epochs: 20 35 | batch_size: 64 36 | test_batch_size: 256 37 | eval_every: 1 38 | check_every: 10 39 | 40 | optimizer: adam 41 | lr: 1e-3 42 | warmup: 0 43 | momentum: 0.9 44 | momentum_sqr: 0.999 45 | gamma: 0.995 46 | num_mixtures: 27 47 | 48 | coupling: 49 | dataset: ${dataset} 50 | model: ${model} 51 | num_steps: 8 52 | actnorm: False 53 | perm_channel: conv 54 | perm_length: reverse 55 | base_dist: conv_gauss 56 | encoder_steps: 0 57 | encoder_bins: 0 58 | encoder_ff_size: 1024 59 | context_size: 128 60 | context_ff_layers: 1 61 | context_ff_size: 512 62 | context_dropout: 0.0 63 | lstm_layers: 2 64 | lstm_size: 512 65 | lstm_dropout: 0.0 66 | input_dp_rate: 0.05 67 | 68 | epochs: 40 69 | disable_ema_epochs: 20 70 | batch_size: 16 71 | test_batch_size: 256 72 | eval_every: 1 73 | check_every: 1 74 | 75 | optimizer: adamax 76 | lr: 1e-3 77 | warmup: 1000 78 | momentum: 0.9 79 | momentum_sqr: 0.999 80 | gamma: 0.995 81 | num_mixtures: 8 82 | 83 | enwik8: 84 | ar: 85 | dataset: ${dataset} 86 | model: ${model} 87 | num_steps: 1 88 | actnorm: False 89 | perm_channel: none 90 | perm_length: reverse 91 | base_dist: conv_gauss 92 | encoder_steps: 0 93 | encoder_bins: 5 94 | context_size: 128 95 | context_lstm_layers: 1 96 | context_lstm_size: 512 97 | lstm_layers: 2 98 | lstm_size: 2048 99 | lstm_dropout: 0.0 100 | input_dp_rate: 0.25 101 | 102 | epochs: 40 103 | disable_ema_epochs: 20 104 | batch_size: 64 105 | test_batch_size: 256 106 | eval_every: 1 107 | check_every: 10 108 | 109 | optimizer: adam 110 | lr: 1e-3 111 | warmup: 0 112 | momentum: 0.9 113 | momentum_sqr: 0.999 114 | gamma: 0.995 115 | num_mixtures: 27 116 | 117 | coupling: 118 | dataset: ${dataset} 119 | model: ${model} 120 | num_steps: 8 121 | actnorm: False 122 | perm_channel: conv 123 | perm_length: reverse 124 | base_dist: conv_gauss 125 | encoder_steps: 0 126 | encoder_bins: 0 127 | encoder_ff_size: 1024 128 | context_size: 128 129 | context_ff_layers: 1 130 | context_ff_size: 512 131 | context_dropout: 0.0 132 | lstm_layers: 2 133 | lstm_size: 768 134 | lstm_dropout: 0.0 135 | input_dp_rate: 0.1 136 | 137 | epochs: 20 138 | disable_ema_epochs: 10 139 | batch_size: 32 140 | test_batch_size: 256 141 | eval_every: 1 142 | check_every: 1 143 | 144 | optimizer: adamax 145 | lr: 1e-3 146 | warmup: 1000 147 | momentum: 0.9 148 | momentum_sqr: 0.999 149 | gamma: 0.95 150 | num_mixtures: 8 151 | 152 | hydra: 153 | run: 154 | dir: ./exp_local/charlm/${now:%Y.%m.%d}/${now:%H%M%S} 155 | sweep: 156 | dir: ./exp/charlm/${now:%Y.%m.%d}/${now:%H%M%S} 157 | subdir: ${hydra.job.num} 158 | launcher: 159 | max_num_timeout: 100000 160 | timeout_min: 4319 161 | partition: learnfair 162 | mem_gb: 64 163 | cpus_per_task: 10 164 | gpus_per_node: 1 165 | constraint: volta32gb 166 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | seed: 0 5 | 6 | hydra: 7 | run: 8 | dir: ./exp_local/default/${now:%Y.%m.%d}/${now:%H%M%S} 9 | sweep: 10 | dir: ./exp/default/${now:%Y.%m.%d}/${now:%H%M%S} 11 | subdir: ${hydra.job.num} 12 | launcher: 13 | max_num_timeout: 100000 14 | timeout_min: 10 15 | partition: learnfair 16 | mem_gb: 4 17 | gpus_per_node: 1 -------------------------------------------------------------------------------- /configs/discrete2d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | seed: 0 5 | 6 | # dataset should be one of {center3, cluster3, diagonal4, cluster10, cross10, diagonal28, diamond28, discrete_8gaussians} 7 | dataset: discrete_8gaussians 8 | vardeq: True 9 | num_flows: 16 10 | num_dequant_flows: 4 11 | hdims: [256, 256] 12 | flatten: False 13 | dequantization: voronoi 14 | 15 | embedding_dim: 2 16 | actfn: relu 17 | 18 | cond_embed_dim: 8 19 | 20 | block_transform: affine 21 | num_mixtures: 32 22 | 23 | # base should be one of {gaussian, resampled} 24 | base: resampled 25 | resampled: 26 | actfn: relu 27 | hdims: [256, 256] 28 | 29 | iterations: 100000 30 | batch_size: 200 31 | test_batch_size: 200 32 | lr: 1e-3 33 | 34 | num_test_samples: 10 35 | 36 | 37 | hydra: 38 | run: 39 | dir: ./exp_local/discrete2d/${dataset}/${now:%Y.%m.%d}/${now:%H%M%S} 40 | sweep: 41 | dir: ./exp/discrete2d/${dataset}/${now:%Y.%m.%d}/${now:%H%M%S} 42 | subdir: ${hydra.job.num} 43 | launcher: 44 | max_num_timeout: 100000 45 | timeout_min: 1440 46 | partition: learnfair 47 | mem_gb: 64 48 | gpus_per_node: 1 -------------------------------------------------------------------------------- /configs/disjoint2d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | seed: 0 5 | 6 | n_mixtures: 8 7 | cond_embed_dim: 4 8 | 9 | dataset: 8gaussians #{8gaussians,swissroll,circles,rings,moons,pinwheel,2spirals,checkerboard} 10 | 11 | block_type: coupling 12 | idim: 64 13 | depth: 4 14 | actnorm: True 15 | zero_init: False 16 | lazy_init: True 17 | skip_transform: False 18 | prior_std: 0.2 19 | 20 | # block type specific defaults. 21 | cnf: 22 | nblocks: [4] 23 | actfn: softplus 24 | fast_adjoint: False 25 | 26 | resflow: 27 | nblocks: [8] 28 | sn_coeff: 0.98 29 | 30 | coupling: 31 | nblocks: [8, 8] 32 | actfn: relu 33 | mixlogcdf: False 34 | num_mixtures: 8 35 | 36 | iterations: 20000 37 | batchsize: 64 38 | lr: 1e-3 39 | 40 | logfreq: 200 41 | vizfreq: 2000 42 | 43 | hydra: 44 | run: 45 | dir: ./exp_local/disjoint2d/${dataset}/${block_type}/${now:%Y.%m.%d}/${now:%H%M%S} 46 | sweep: 47 | dir: ./exp/disjoint2d/${dataset}/${block_type}/${now:%Y.%m.%d}/${now:%H%M%S} 48 | subdir: ${hydra.job.num} 49 | launcher: 50 | max_num_timeout: 100000 51 | timeout_min: 4319 52 | partition: learnfair 53 | mem_gb: 16 54 | cpus_per_task: 10 55 | gpus_per_node: 1 56 | -------------------------------------------------------------------------------- /configs/disjoint_uci.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | sweep: 5 | n_sample: 75 6 | 7 | seed: 0 8 | 9 | n_mixtures: 8 10 | cond_embed_dim: 8 11 | 12 | user: ${oc.env:USER} 13 | dataset: miniboone #{miniboone,gas,power,hepmass,bsds300} 14 | 15 | block_type: coupling 16 | idim: 64 17 | depth: 4 18 | lazy_init: True 19 | skip_transform: False 20 | prior_std: 0.2 21 | zero_init: False 22 | 23 | nblocks: [16] 24 | actfn: relu 25 | 26 | iterations: 1000000 27 | batchsize: 64 28 | eval_batchsize: 128 29 | lr: 1e-3 30 | 31 | logfreq: 200 32 | evalfreq: 2000 33 | 34 | use_wandb: False 35 | wandb: 36 | save_dir: /checkpoint/${user}/ 37 | project: "voronoi-disjoint-uci" 38 | group: ${dataset} 39 | entity: ${user} 40 | 41 | hydra: 42 | run: 43 | dir: ./exp_local/disjoint_uci/${dataset}/${block_type}/${now:%Y.%m.%d}/${now:%H%M%S} 44 | sweep: 45 | dir: ./exp/disjoint_uci/${dataset}/${block_type}/${now:%Y.%m.%d}/${now:%H%M%S} 46 | subdir: ${hydra.job.num} 47 | launcher: 48 | max_num_timeout: 100000 49 | timeout_min: 4319 50 | partition: learnfair 51 | mem_gb: 16 52 | cpus_per_task: 10 53 | gpus_per_node: 1 54 | -------------------------------------------------------------------------------- /configs/itemsets.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | seed: 0 5 | n_resumes: 5 6 | 7 | dataset: retail # {retail,accidents} 8 | model: cnf # {dpp,cnf,coupling} 9 | dequantization: voronoi # {voronoi,argmax,simplex} 10 | 11 | embedding_dim: 6 12 | num_flows: 4 13 | num_layers: 2 14 | actfn: gelu 15 | 16 | iterations: 10000 17 | batch_size: 128 18 | lr: 1e-3 19 | wd: 1e-6 20 | 21 | ema_eval: True 22 | eval_batch_size: 256 23 | num_eval_samples: 40 24 | 25 | skip_eval: False 26 | logfreq: 10 27 | evalfreq: 200 28 | 29 | hydra: 30 | run: 31 | dir: ./exp_local/itemsets/${dataset}/${model}/${dequantization}/${now:%Y.%m.%d}/${now:%H%M%S} 32 | sweep: 33 | dir: ./exp/itemsets/${dataset}/${model}/${dequantization}/${now:%Y.%m.%d}/${now:%H%M%S} 34 | subdir: ${hydra.job.num} 35 | launcher: 36 | max_num_timeout: 100000 37 | timeout_min: 4319 38 | partition: learnfair 39 | mem_gb: 64 40 | cpus_per_task: 10 41 | gpus_per_node: 1 42 | -------------------------------------------------------------------------------- /configs/uci_categorical.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/launcher: submitit_slurm 3 | 4 | sweep: 5 | n_sample: 100 6 | n_seed_per_sample: 1 7 | 8 | seed: 0 9 | n_resumes: 5 10 | 11 | # dataset should be one of {mushroom,nursery,connect4,uscensus90,pokerhand,forests} 12 | dataset: mushroom 13 | flow_type: coupling # {coupling,autoreg} 14 | num_flows: 16 15 | num_dequant_flows: 4 16 | hdims: [256, 256] 17 | actfn: relu 18 | dequantization: voronoi 19 | use_dequant_flow: True 20 | cond_embed_dim: 8 21 | arch: mlp 22 | num_transformer_layers: 2 23 | transformer_d_model: 256 24 | transformer_dropout: 0.0 25 | 26 | block_transform: affine 27 | num_mixtures: 8 28 | 29 | use_contextnet: True 30 | 31 | skip_eval: False 32 | 33 | embedding_dim: 3 34 | share_embeddings: False 35 | use_logit_transform: False 36 | learn_box_constraints: True 37 | 38 | # base should be one of {gaussian, resampled} 39 | base: resampled 40 | resampled: 41 | actfn: relu 42 | hdims: [256, 256] 43 | 44 | iterations: 200000 45 | batch_size: 64 46 | lr: 0.0005 47 | weight_decay: 0 48 | warmup: 0 49 | 50 | ema_eval: True 51 | eval_batch_size: 32 52 | num_eval_samples: 100 53 | 54 | logfreq: 100 55 | evalfreq: 1000 56 | 57 | hydra: 58 | run: 59 | dir: ./exp_local/uci_categorical/${dataset}/${dequantization}/${now:%Y.%m.%d}/${now:%H%M%S} 60 | sweep: 61 | dir: ./exp/uci_categorical/${dataset}/${dequantization}/${now:%Y.%m.%d}/${now:%H%M%S} 62 | subdir: ${hydra.job.num} 63 | launcher: 64 | max_num_timeout: 100000 65 | timeout_min: 4319 66 | partition: learnfair 67 | mem_gb: 64 68 | cpus_per_task: 10 69 | gpus_per_node: 1 70 | -------------------------------------------------------------------------------- /data/download_itemsets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import os 10 | import wget 11 | import subprocess 12 | import numpy as np 13 | from collections import OrderedDict 14 | 15 | 16 | def prune_itemset(fname, *, vocab_size, freq_threshold, subsample_more=False): 17 | with open(fname) as f: 18 | lines = f.readlines() 19 | 20 | rows = OrderedDict() 21 | for row_id, line in enumerate(lines): 22 | if row_id not in rows.keys(): 23 | rows[row_id] = [] 24 | for item_id in map(int, (line.rstrip().split(" "))): 25 | rows[row_id].append(item_id) 26 | 27 | # remove rows with insufficient number of items. 28 | for row_id in list(rows.keys()): 29 | if len(rows[row_id]) < vocab_size: 30 | del rows[row_id] 31 | 32 | # subsample each row to a fixed number of items. 33 | next_row_id = np.max(list(rows.keys())) + 1 34 | for row_id in list(rows.keys()): 35 | 36 | if subsample_more: 37 | 38 | if len(rows[row_id]) > vocab_size * 4: 39 | for _ in range(5): 40 | subsample = list( 41 | np.random.choice(rows[row_id], size=vocab_size, replace=False) 42 | ) 43 | rows[next_row_id] = subsample 44 | next_row_id += 1 45 | 46 | if len(rows[row_id]) > vocab_size * 2: 47 | for _ in range(4): 48 | subsample = list( 49 | np.random.choice(rows[row_id], size=vocab_size, replace=False) 50 | ) 51 | rows[next_row_id] = subsample 52 | next_row_id += 1 53 | 54 | if len(rows[row_id]) > vocab_size: 55 | subsample = list( 56 | np.random.choice(rows[row_id], size=vocab_size, replace=False) 57 | ) 58 | rows[next_row_id] = subsample 59 | next_row_id += 1 60 | del rows[row_id] 61 | 62 | # create items list 63 | items = OrderedDict() 64 | for row_id in list(rows.keys()): 65 | for item_id in rows[row_id]: 66 | if item_id not in items.keys(): 67 | items[item_id] = [] 68 | items[item_id].append(row_id) 69 | 70 | # remove items that don't occur frequently. 71 | for item_id in list(items.keys()): 72 | if len(items[item_id]) < freq_threshold: 73 | for row_id in items[item_id]: 74 | if row_id in rows.keys(): 75 | del rows[row_id] 76 | del items[item_id] 77 | 78 | return rows, items 79 | 80 | 81 | def split_and_save(dirname, rows, items): 82 | D = np.array(list(rows.values()), dtype=int) 83 | total_size = D.shape[0] 84 | train_size = int(np.floor(total_size * 0.8)) 85 | val_size = int(np.floor(total_size * 0.1)) 86 | test_size = int(total_size - train_size - val_size) 87 | print(f"Splitting into {train_size} train, {val_size} val, and {test_size} test") 88 | perm = np.random.permutation(total_size) 89 | train_set = D[perm[:train_size]] 90 | val_set = D[perm[train_size : train_size + val_size]] 91 | test_set = D[perm[train_size + val_size :]] 92 | assert len(np.unique(train_set)) == len( 93 | items 94 | ), "train set does not contain all items" 95 | np.savetxt( 96 | os.path.join(dirname, "train.data"), 97 | train_set.astype(int), 98 | fmt="%i", 99 | delimiter=",", 100 | ) 101 | np.savetxt( 102 | os.path.join(dirname, "val.data"), val_set.astype(int), fmt="%i", delimiter="," 103 | ) 104 | np.savetxt( 105 | os.path.join(dirname, "test.data"), 106 | test_set.astype(int), 107 | fmt="%i", 108 | delimiter=",", 109 | ) 110 | 111 | 112 | if __name__ == "__main__": 113 | os.makedirs("retail", exist_ok=True) 114 | os.makedirs("accidents", exist_ok=True) 115 | 116 | print("Downloading datasets.") 117 | wget.download( 118 | "http://fimi.uantwerpen.be/data/retail.dat", out="retail", 119 | ) 120 | wget.download( 121 | "http://fimi.uantwerpen.be/data/accidents.dat", out="accidents", 122 | ) 123 | 124 | np.random.seed(123) 125 | rows, items = prune_itemset( 126 | "retail/retail.dat", vocab_size=4, freq_threshold=300, subsample_more=True 127 | ) 128 | print(len(rows), len(items)) 129 | split_and_save("retail", rows, items) 130 | 131 | np.random.seed(234) 132 | rows, items = prune_itemset( 133 | "accidents/accidents.dat", 134 | vocab_size=4, 135 | freq_threshold=100, 136 | subsample_more=False, 137 | ) 138 | print(len(rows), len(items)) 139 | split_and_save("accidents", rows, items) 140 | -------------------------------------------------------------------------------- /datasets/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, George Papamakarios 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of anybody else. 27 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | root = 'data/' 2 | 3 | from .power import POWER 4 | from .gas import GAS 5 | from .hepmass import HEPMASS 6 | from .miniboone import MINIBOONE 7 | from .bsds300 import BSDS300 8 | -------------------------------------------------------------------------------- /datasets/bsds300.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | import datasets 5 | 6 | 7 | class BSDS300: 8 | """ 9 | A dataset of patches from BSDS300. 10 | """ 11 | 12 | class Data: 13 | """ 14 | Constructs the dataset. 15 | """ 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data[:] 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | # load dataset 25 | f = h5py.File(datasets.root + 'BSDS300/BSDS300.hdf5', 'r') 26 | 27 | self.trn = self.Data(f['train']) 28 | self.val = self.Data(f['validation']) 29 | self.tst = self.Data(f['test']) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 33 | 34 | f.close() 35 | -------------------------------------------------------------------------------- /datasets/gas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import datasets 5 | 6 | 7 | class GAS: 8 | 9 | class Data: 10 | 11 | def __init__(self, data): 12 | 13 | self.x = data.astype(np.float32) 14 | self.N = self.x.shape[0] 15 | 16 | def __init__(self): 17 | 18 | file = datasets.root + 'gas/ethylene_CO.pickle' 19 | trn, val, tst = load_data_and_clean_and_split(file) 20 | 21 | self.trn = self.Data(trn) 22 | self.val = self.Data(val) 23 | self.tst = self.Data(tst) 24 | 25 | self.n_dims = self.trn.x.shape[1] 26 | 27 | 28 | def load_data(file): 29 | 30 | data = pd.read_pickle(file) 31 | # data = pd.read_pickle(file).sample(frac=0.25) 32 | # data.to_pickle(file) 33 | data.drop("Meth", axis=1, inplace=True) 34 | data.drop("Eth", axis=1, inplace=True) 35 | data.drop("Time", axis=1, inplace=True) 36 | return data 37 | 38 | 39 | def get_correlation_numbers(data): 40 | C = data.corr() 41 | A = C > 0.98 42 | B = A.values.sum(axis=1) 43 | return B 44 | 45 | 46 | def load_data_and_clean(file): 47 | 48 | data = load_data(file) 49 | B = get_correlation_numbers(data) 50 | 51 | while np.any(B > 1): 52 | col_to_remove = np.where(B > 1)[0][0] 53 | col_name = data.columns[col_to_remove] 54 | data.drop(col_name, axis=1, inplace=True) 55 | B = get_correlation_numbers(data) 56 | # print(data.corr()) 57 | data = (data - data.mean()) / data.std() 58 | 59 | return data 60 | 61 | 62 | def load_data_and_clean_and_split(file): 63 | 64 | data = load_data_and_clean(file).values 65 | N_test = int(0.1 * data.shape[0]) 66 | data_test = data[-N_test:] 67 | data_train = data[0:-N_test] 68 | N_validate = int(0.1 * data_train.shape[0]) 69 | data_validate = data_train[-N_validate:] 70 | data_train = data_train[0:-N_validate] 71 | 72 | return data_train, data_validate, data_test 73 | -------------------------------------------------------------------------------- /datasets/hepmass.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | from os.path import join 5 | 6 | import datasets 7 | 8 | 9 | class HEPMASS: 10 | """ 11 | The HEPMASS data set. 12 | http://archive.ics.uci.edu/ml/datasets/HEPMASS 13 | """ 14 | 15 | class Data: 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data.astype(np.float32) 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | path = datasets.root + 'hepmass/' 25 | trn, val, tst = load_data_no_discrete_normalised_as_array(path) 26 | 27 | self.trn = self.Data(trn) 28 | self.val = self.Data(val) 29 | self.tst = self.Data(tst) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | 33 | 34 | def load_data(path): 35 | 36 | data_train = pd.read_csv(filepath_or_buffer=join(path, "1000_train.csv"), index_col=False) 37 | data_test = pd.read_csv(filepath_or_buffer=join(path, "1000_test.csv"), index_col=False) 38 | 39 | return data_train, data_test 40 | 41 | 42 | def load_data_no_discrete(path): 43 | """ 44 | Loads the positive class examples from the first 10 percent of the dataset. 45 | """ 46 | data_train, data_test = load_data(path) 47 | 48 | # Gets rid of any background noise examples i.e. class label 0. 49 | data_train = data_train[data_train[data_train.columns[0]] == 1] 50 | data_train = data_train.drop(data_train.columns[0], axis=1) 51 | data_test = data_test[data_test[data_test.columns[0]] == 1] 52 | data_test = data_test.drop(data_test.columns[0], axis=1) 53 | # Because the data set is messed up! 54 | data_test = data_test.drop(data_test.columns[-1], axis=1) 55 | 56 | return data_train, data_test 57 | 58 | 59 | def load_data_no_discrete_normalised(path): 60 | 61 | data_train, data_test = load_data_no_discrete(path) 62 | mu = data_train.mean() 63 | s = data_train.std() 64 | data_train = (data_train - mu) / s 65 | data_test = (data_test - mu) / s 66 | 67 | return data_train, data_test 68 | 69 | 70 | def load_data_no_discrete_normalised_as_array(path): 71 | 72 | data_train, data_test = load_data_no_discrete_normalised(path) 73 | data_train, data_test = data_train.values, data_test.values 74 | 75 | i = 0 76 | # Remove any features that have too many re-occurring real values. 77 | features_to_remove = [] 78 | for feature in data_train.T: 79 | c = Counter(feature) 80 | max_count = np.array([v for k, v in sorted(c.items())])[0] 81 | if max_count > 5: 82 | features_to_remove.append(i) 83 | i += 1 84 | data_train = data_train[:, np.array([i for i in range(data_train.shape[1]) if i not in features_to_remove])] 85 | data_test = data_test[:, np.array([i for i in range(data_test.shape[1]) if i not in features_to_remove])] 86 | 87 | N = data_train.shape[0] 88 | N_validate = int(N * 0.1) 89 | data_validate = data_train[-N_validate:] 90 | data_train = data_train[0:-N_validate] 91 | 92 | return data_train, data_validate, data_test 93 | -------------------------------------------------------------------------------- /datasets/miniboone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import datasets 4 | 5 | 6 | class MINIBOONE: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | file = datasets.root + 'miniboone/data.npy' 18 | trn, val, tst = load_data_normalised(file) 19 | 20 | self.trn = self.Data(trn) 21 | self.val = self.Data(val) 22 | self.tst = self.Data(tst) 23 | 24 | self.n_dims = self.trn.x.shape[1] 25 | 26 | 27 | def load_data(root_path): 28 | # NOTE: To remember how the pre-processing was done. 29 | # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) 30 | # print data.head() 31 | # data = data.as_matrix() 32 | # # Remove some random outliers 33 | # indices = (data[:, 0] < -100) 34 | # data = data[~indices] 35 | # 36 | # i = 0 37 | # # Remove any features that have too many re-occuring real values. 38 | # features_to_remove = [] 39 | # for feature in data.T: 40 | # c = Counter(feature) 41 | # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] 42 | # if max_count > 5: 43 | # features_to_remove.append(i) 44 | # i += 1 45 | # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] 46 | # np.save("~/data/miniboone/data.npy", data) 47 | 48 | data = np.load(root_path) 49 | N_test = int(0.1 * data.shape[0]) 50 | data_test = data[-N_test:] 51 | data = data[0:-N_test] 52 | N_validate = int(0.1 * data.shape[0]) 53 | data_validate = data[-N_validate:] 54 | data_train = data[0:-N_validate] 55 | 56 | return data_train, data_validate, data_test 57 | 58 | 59 | def load_data_normalised(root_path): 60 | 61 | data_train, data_validate, data_test = load_data(root_path) 62 | data = np.vstack((data_train, data_validate)) 63 | mu = data.mean(axis=0) 64 | s = data.std(axis=0) 65 | data_train = (data_train - mu) / s 66 | data_validate = (data_validate - mu) / s 67 | data_test = (data_test - mu) / s 68 | 69 | return data_train, data_validate, data_test 70 | -------------------------------------------------------------------------------- /datasets/power.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import datasets 4 | 5 | 6 | class POWER: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | trn, val, tst = load_data_normalised() 18 | 19 | self.trn = self.Data(trn) 20 | self.val = self.Data(val) 21 | self.tst = self.Data(tst) 22 | 23 | self.n_dims = self.trn.x.shape[1] 24 | 25 | 26 | def load_data(): 27 | return np.load(datasets.root + 'power/data.npy') 28 | 29 | 30 | def load_data_split_with_noise(): 31 | 32 | rng = np.random.RandomState(42) 33 | 34 | data = load_data() 35 | rng.shuffle(data) 36 | N = data.shape[0] 37 | 38 | data = np.delete(data, 3, axis=1) 39 | data = np.delete(data, 1, axis=1) 40 | ############################ 41 | # Add noise 42 | ############################ 43 | # global_intensity_noise = 0.1*rng.rand(N, 1) 44 | voltage_noise = 0.01 * rng.rand(N, 1) 45 | # grp_noise = 0.001*rng.rand(N, 1) 46 | gap_noise = 0.001 * rng.rand(N, 1) 47 | sm_noise = rng.rand(N, 3) 48 | time_noise = np.zeros((N, 1)) 49 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) 50 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) 51 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 52 | data = data + noise 53 | 54 | N_test = int(0.1 * data.shape[0]) 55 | data_test = data[-N_test:] 56 | data = data[0:-N_test] 57 | N_validate = int(0.1 * data.shape[0]) 58 | data_validate = data[-N_validate:] 59 | data_train = data[0:-N_validate] 60 | 61 | return data_train, data_validate, data_test 62 | 63 | 64 | def load_data_normalised(): 65 | 66 | data_train, data_validate, data_test = load_data_split_with_noise() 67 | data = np.vstack((data_train, data_validate)) 68 | mu = data.mean(axis=0) 69 | s = data.std(axis=0) 70 | data_train = (data_train - mu) / s 71 | data_validate = (data_validate - mu) / s 72 | data_test = (data_test - mu) / s 73 | 74 | return data_train, data_validate, data_test 75 | -------------------------------------------------------------------------------- /diagnostics/voronoi_plot_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def _adjust_bounds(ax, points): 14 | margin = 0.1 * points.ptp(axis=0) 15 | xy_min = points.min(axis=0) - margin 16 | xy_max = points.max(axis=0) + margin 17 | ax.set_xlim(xy_min[0], xy_max[0]) 18 | ax.set_ylim(xy_min[1], xy_max[1]) 19 | 20 | 21 | def voronoi_plot_2d(vor, ax=None, **kw): 22 | """ 23 | Plot the given Voronoi diagram in 2-D 24 | Parameters 25 | ---------- 26 | vor : scipy.spatial.Voronoi instance 27 | Diagram to plot 28 | ax : matplotlib.axes.Axes instance, optional 29 | Axes to plot on 30 | show_points: bool, optional 31 | Add the Voronoi points to the plot. 32 | show_vertices : bool, optional 33 | Add the Voronoi vertices to the plot. 34 | line_colors : string, optional 35 | Specifies the line color for polygon boundaries 36 | line_width : float, optional 37 | Specifies the line width for polygon boundaries 38 | line_alpha: float, optional 39 | Specifies the line alpha for polygon boundaries 40 | Returns 41 | ------- 42 | fig : matplotlib.figure.Figure instance 43 | Figure for the plot 44 | See Also 45 | -------- 46 | Voronoi 47 | Notes 48 | ----- 49 | Requires Matplotlib. 50 | """ 51 | from matplotlib.collections import LineCollection 52 | 53 | if ax is None: 54 | ax = plt.gca() 55 | 56 | if vor.points.shape[1] != 2: 57 | raise ValueError("Voronoi diagram is not 2-D") 58 | 59 | if kw.get("show_points", True): 60 | ax.plot(vor.points[:, 0], vor.points[:, 1], ".") 61 | if kw.get("show_vertices", True): 62 | ax.plot(vor.vertices[:, 0], vor.vertices[:, 1], "o") 63 | 64 | line_colors = kw.get("line_colors", "k") 65 | line_width = kw.get("line_width", 1.0) 66 | line_alpha = kw.get("line_alpha", 1.0) 67 | 68 | center = vor.points.mean(axis=0) 69 | ptp_bound = vor.points.ptp(axis=0) 70 | 71 | finite_segments = [] 72 | infinite_segments = [] 73 | for pointidx, simplex in zip(vor.ridge_points, vor.ridge_vertices): 74 | simplex = np.asarray(simplex) 75 | if np.all(simplex >= 0): 76 | finite_segments.append(vor.vertices[simplex]) 77 | else: 78 | i = simplex[simplex >= 0][0] # finite end Voronoi vertex 79 | 80 | t = vor.points[pointidx[1]] - vor.points[pointidx[0]] # tangent 81 | t /= np.linalg.norm(t) 82 | n = np.array([-t[1], t[0]]) # normal 83 | 84 | midpoint = vor.points[pointidx].mean(axis=0) 85 | direction = np.sign(np.dot(midpoint - center, n)) * n 86 | far_point = vor.vertices[i] + direction * ptp_bound.max() 87 | 88 | infinite_segments.append([vor.vertices[i], far_point]) 89 | 90 | ax.add_collection( 91 | LineCollection( 92 | finite_segments, 93 | colors=line_colors, 94 | lw=line_width, 95 | alpha=line_alpha, 96 | linestyle="solid", 97 | ) 98 | ) 99 | ax.add_collection( 100 | LineCollection( 101 | infinite_segments, 102 | colors=line_colors, 103 | lw=line_width, 104 | alpha=line_alpha, 105 | linestyle="solid", 106 | ) 107 | ) 108 | 109 | _adjust_bounds(ax, vor.points) 110 | 111 | return ax.figure -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from .act_norm import * 10 | from .autoreg import * 11 | from .cnf import * 12 | from .container import * 13 | from .coupling import * 14 | from .elemwise import * 15 | from .iresblock import * 16 | from .normalization import * 17 | from .softmax import * 18 | from .squeeze import * 19 | from .glow import * 20 | -------------------------------------------------------------------------------- /layers/act_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import Parameter 12 | 13 | __all__ = ["ActNorm1d", "ActNorm2d", "ConditionalAffine1d", "ConditionalAffine2d"] 14 | 15 | 16 | class ActNormNd(nn.Module): 17 | def __init__(self, num_features, eps=1e-12): 18 | super(ActNormNd, self).__init__() 19 | self.num_features = num_features 20 | self.eps = eps 21 | self.weight = Parameter(torch.Tensor(num_features)) 22 | self.bias = Parameter(torch.Tensor(num_features)) 23 | self.register_buffer("initialized", torch.tensor(0)) 24 | 25 | def shape(self, x): 26 | raise NotImplementedError 27 | 28 | def axis(self, x): 29 | return NotImplementedError 30 | 31 | def forward(self, x, *, logp=None, **kwargs): 32 | c = x.shape[self.axis(x)] 33 | 34 | if not self.initialized: 35 | with torch.no_grad(): 36 | # compute batch statistics 37 | x_t = x.transpose(0, self.axis(x)).reshape(c, -1) 38 | batch_mean = torch.mean(x_t, dim=1) 39 | batch_var = torch.var(x_t, dim=1) 40 | 41 | # for numerical issues 42 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) 43 | 44 | self.bias.data.copy_(-batch_mean) 45 | self.weight.data.copy_(-0.5 * torch.log(batch_var)) 46 | self.initialized.fill_(1) 47 | 48 | bias = self.bias.view(*self.shape(x)).expand_as(x) 49 | weight = self.weight.view(*self.shape(x)).expand_as(x) 50 | 51 | y = (x + bias) * torch.exp(weight) 52 | 53 | if logp is None: 54 | return y 55 | else: 56 | return y, logp - self._logdetgrad(x) 57 | 58 | def inverse(self, y, logp=None, **kwargs): 59 | assert self.initialized 60 | bias = self.bias.view(*self.shape(y)).expand_as(y) 61 | weight = self.weight.view(*self.shape(y)).expand_as(y) 62 | 63 | x = y * torch.exp(-weight) - bias 64 | 65 | if logp is None: 66 | return x 67 | else: 68 | return x, logp + self._logdetgrad(x) 69 | 70 | def _logdetgrad(self, x): 71 | weight = self.weight.view(*self.shape(x)).expand(*x.size()) 72 | return weight.reshape(x.shape[0], -1).sum(1, keepdim=True) 73 | 74 | def __repr__(self): 75 | return "{name}({num_features})".format( 76 | name=self.__class__.__name__, **self.__dict__ 77 | ) 78 | 79 | 80 | class ActNorm1d(ActNormNd): 81 | def shape(self, x): 82 | return [1] * (x.ndim - 1) + [-1] 83 | 84 | def axis(self, x): 85 | return x.ndim - 1 86 | 87 | 88 | class ActNorm2d(ActNormNd): 89 | def shape(self, x): 90 | return [1, -1, 1, 1] 91 | 92 | def axis(self, x): 93 | return 1 94 | 95 | 96 | class ConditionalAffineNd(nn.Module): 97 | def __init__(self, num_features, nnet): 98 | super(ConditionalAffineNd, self).__init__() 99 | self.num_features = num_features 100 | self.nnet = nnet 101 | 102 | def shape(self, x): 103 | raise NotImplementedError 104 | 105 | def axis(self, x): 106 | return NotImplementedError 107 | 108 | def _get_params(self, x, cond): 109 | f = self.nnet(cond.reshape(x.shape[0], -1)) 110 | t = f[:, : self.num_features] 111 | s = f[:, self.num_features :] 112 | 113 | s = torch.sigmoid(s) * 0.98 + 0.01 114 | 115 | t = t.reshape(*self.shape(x)).expand_as(x) 116 | s = s.reshape(*self.shape(x)).expand_as(x) 117 | 118 | s = torch.sigmoid(s) * 0.98 + 0.01 119 | 120 | return t, s 121 | 122 | def forward(self, x, *, logp=None, cond=None, **kwargs): 123 | assert cond is not None, "This module only works when cond is provided." 124 | 125 | # Ehhh... 126 | if cond.ndim == 4: 127 | cond = cond[:, :, 0, 0] 128 | 129 | t, s = self._get_params(x, cond) 130 | y = x * s + t * (1 - s) 131 | 132 | if logp is None: 133 | return y 134 | else: 135 | logpy = logp - self._logdetgrad(s) 136 | return y, logpy 137 | 138 | def inverse(self, y, logp=None, cond=None, **kwargs): 139 | if cond.ndim == 4: 140 | cond = cond[:, :, 0, 0] 141 | 142 | t, s = self._get_params(y, cond) 143 | x = (y - t * (1 - s)) / s 144 | 145 | if logp is None: 146 | return x 147 | else: 148 | return x, logp + self._logdetgrad(s) 149 | 150 | def _logdetgrad(self, s): 151 | log_s = torch.log(s) 152 | return log_s.reshape(log_s.shape[0], -1).sum(1, keepdim=True) 153 | 154 | def __repr__(self): 155 | return "{name}({num_features})".format( 156 | name=self.__class__.__name__, **self.__dict__ 157 | ) 158 | 159 | 160 | class ConditionalAffine1d(ConditionalAffineNd): 161 | def shape(self, x): 162 | return [x.shape[0]] + [1] * (x.ndim - 2) + [-1] 163 | 164 | def axis(self, x): 165 | return x.ndim - 1 166 | 167 | 168 | class ConditionalAffine2d(ConditionalAffineNd): 169 | def shape(self, x): 170 | return [x.shape[0], -1, 1, 1] 171 | 172 | def axis(self, x): 173 | return 1 -------------------------------------------------------------------------------- /layers/autoreg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | __all__ = ["AffineAutoregressive"] 13 | 14 | 15 | class AffineAutoregressive(nn.Module): 16 | def __init__(self, nnet, split_dim=-1, cond_dim=-1): 17 | nn.Module.__init__(self) 18 | self.nnet = nnet 19 | self.split_dim = ( 20 | split_dim # used for splitting the output into scale and shift. 21 | ) 22 | self.cond_dim = cond_dim 23 | 24 | def func_s_t(self, x, cond=None, **kwargs): 25 | split_dim = self.split_dim % x.ndim 26 | input_shape = x.shape 27 | 28 | if cond is not None: 29 | x = torch.cat([x, cond], dim=self.cond_dim) 30 | 31 | f = self.nnet(x) 32 | 33 | f = f.reshape(*input_shape[:split_dim], 2, *input_shape[split_dim:]) 34 | s, t = f.split((1, 1), dim=split_dim) 35 | s = s.squeeze(split_dim) 36 | t = t.squeeze(split_dim) 37 | 38 | s = torch.sigmoid(s) * 0.98 + 0.01 39 | 40 | return s, t 41 | 42 | def forward(self, x, **kwargs): 43 | logp = kwargs.pop("logp", None) 44 | 45 | s, t = self.func_s_t(x, **kwargs) 46 | y = x * s + t * (1 - s) 47 | 48 | if logp is None: 49 | return y 50 | else: 51 | logpy = logp - self._logdetgrad(s) 52 | return y, logpy 53 | 54 | def inverse(self, y, **kwargs): 55 | raise NotImplementedError 56 | 57 | def _logdetgrad(self, s): 58 | masked_s = safe_log(s) 59 | return masked_s.reshape(s.shape[0], -1).sum(1, keepdim=True) 60 | 61 | def extra_repr(self): 62 | return "split_dim={split_dim}, cond_dim={cond_dim}".format(**self.__dict__) 63 | 64 | 65 | class Reorder(nn.Module): 66 | def __init__(self, dim, perm_dim=1): 67 | super().__init__() 68 | self.dim = dim 69 | self.perm_dim = perm_dim 70 | self.register_buffer("randperm", torch.randperm(dim)) 71 | self.register_buffer("invperm", torch.argsort(self.randperm)) 72 | 73 | def forward(self, x, logp, **kwargs): 74 | y = torch.index_select(x, self.perm_dim, self.randperm) 75 | if logp is None: 76 | return y 77 | else: 78 | return y, logp 79 | 80 | def inverse(self, y, logp, **kwargs): 81 | x = torch.index_select(y, self.perm_dim, self.invperm) 82 | if logp is None: 83 | return x 84 | else: 85 | return x, logp 86 | 87 | def extra_repr(self): 88 | return "dim={dim}, perm_dim={perm_dim}".format(**self.__dict__) 89 | 90 | 91 | def safe_log(x): 92 | return torch.log(x.clamp(min=1e-22)) 93 | -------------------------------------------------------------------------------- /layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from .activations import * 10 | from .lipschitz import * 11 | from .mixed_lipschitz import * 12 | -------------------------------------------------------------------------------- /layers/base/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class Identity(nn.Module): 15 | 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | class FullSort(nn.Module): 21 | 22 | def forward(self, x): 23 | return torch.sort(x, 1)[0] 24 | 25 | 26 | class MaxMin(nn.Module): 27 | 28 | def forward(self, x): 29 | b, d = x.shape 30 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0] 31 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0] 32 | return torch.cat([max_vals, min_vals], 1) 33 | 34 | 35 | class LipschitzCube(nn.Module): 36 | 37 | def forward(self, x): 38 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3 39 | 40 | 41 | class SwishFn(torch.autograd.Function): 42 | 43 | @staticmethod 44 | def forward(ctx, x, beta): 45 | beta_sigm = torch.sigmoid(beta * x) 46 | output = x * beta_sigm 47 | ctx.save_for_backward(x, output, beta) 48 | return output / 1.1 49 | 50 | @staticmethod 51 | def backward(ctx, grad_output): 52 | x, output, beta = ctx.saved_tensors 53 | beta_sigm = output / x 54 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output)) 55 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta) 56 | return grad_x / 1.1, grad_beta / 1.1 57 | 58 | 59 | class Swish(nn.Module): 60 | 61 | def __init__(self): 62 | super(Swish, self).__init__() 63 | self.beta = nn.Parameter(torch.tensor([0.5])) 64 | 65 | def forward(self, x): 66 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1) 67 | 68 | 69 | LipSwish = Swish 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | m = Swish() 75 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True) 76 | yy = m(xx) 77 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta]) 78 | 79 | import matplotlib.pyplot as plt 80 | 81 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func') 82 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv') 83 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1') 84 | plt.legend() 85 | plt.tight_layout() 86 | plt.show() 87 | -------------------------------------------------------------------------------- /layers/base/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from collections import abc as container_abcs 10 | from itertools import repeat 11 | 12 | 13 | def _ntuple(n): 14 | 15 | def parse(x): 16 | if isinstance(x, container_abcs.Iterable): 17 | return x 18 | return tuple(repeat(x, n)) 19 | 20 | return parse 21 | 22 | 23 | _single = _ntuple(1) 24 | _pair = _ntuple(2) 25 | _triple = _ntuple(3) 26 | _quadruple = _ntuple(4) 27 | -------------------------------------------------------------------------------- /layers/cnf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from re import X 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torchdiffeq import odeint_adjoint 14 | 15 | from . import diffeq_layers 16 | from .base.activations import Swish 17 | 18 | 19 | def divergence_bf(f, y, training, **unused_kwargs): 20 | sum_diag = 0.0 21 | for i in range(f.shape[1]): 22 | retain_graph = training or i < (f.shape[1] - 1) 23 | sum_diag += ( 24 | torch.autograd.grad( 25 | f[:, i].sum(), y, create_graph=training, retain_graph=retain_graph 26 | )[0] 27 | .contiguous()[:, i] 28 | .contiguous() 29 | ) 30 | return sum_diag.contiguous() 31 | 32 | 33 | def divergence_approx(f, y, training, e=None, **unused_kwargs): 34 | assert e is not None 35 | dim = f.shape[1] 36 | e_dzdx = torch.autograd.grad(f, y, e, create_graph=training, retain_graph=training)[ 37 | 0 38 | ][:, :dim].contiguous() 39 | e_dzdx_e = e_dzdx * e 40 | approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1) 41 | return approx_tr_dzdx 42 | 43 | 44 | def rms_norm(tensor): 45 | return tensor.pow(2).mean().sqrt() 46 | 47 | 48 | class CNF(nn.Module): 49 | 50 | start_time = 0.0 51 | end_time = 1.0 52 | 53 | def __init__( 54 | self, 55 | func, 56 | divergence_fn=divergence_approx, 57 | rtol=1e-5, 58 | atol=1e-5, 59 | method="dopri5", 60 | fast_adjoint=True, 61 | cond_dim=None, 62 | nonself_connections=False, 63 | ): 64 | super().__init__() 65 | self.func = func 66 | self.divergence_fn = divergence_fn 67 | self.rtol = rtol 68 | self.atol = atol 69 | self.method = method 70 | self.fast_adjoint = fast_adjoint 71 | self.cond_dim = cond_dim 72 | self.nonself_connections = nonself_connections 73 | 74 | self.nfe = 0 75 | 76 | def reset_nfe_ts(self): 77 | self.nfe = 0 78 | 79 | def forward(self, x, *, logp=None, cond=None, **kwargs): 80 | 81 | if cond is not None: 82 | assert self.cond_dim is not None 83 | 84 | e = torch.randn_like(x) 85 | 86 | options = {} 87 | adjoint_kwargs = {} 88 | adjoint_kwargs["adjoint_params"] = self.parameters() 89 | 90 | if logp is None: 91 | logp = torch.zeros(x.shape[0], 1, device=x.device) 92 | 93 | if cond is None: 94 | initial_state = (e, x, logp) 95 | else: 96 | initial_state = (e, x, cond, logp) 97 | 98 | if self.fast_adjoint: 99 | adjoint_kwargs["adjoint_options"] = {"norm": "seminorm"} 100 | 101 | solution = odeint_adjoint( 102 | self.diffeq, 103 | initial_state, 104 | torch.tensor([self.start_time, self.end_time]).to(x), 105 | rtol=self.rtol, 106 | atol=self.atol, 107 | method=self.method, 108 | options=options, 109 | **adjoint_kwargs, 110 | ) 111 | if cond is None: 112 | _, y, logpy = tuple(s[-1] for s in solution) 113 | else: 114 | _, y, _, logpy = tuple(s[-1] for s in solution) 115 | 116 | if logp is None: 117 | return y 118 | else: 119 | return y, logpy 120 | 121 | def diffeq(self, t, state): 122 | self.nfe += 1 123 | 124 | if self.cond_dim is None: 125 | e, x, _ = state 126 | else: 127 | e, x, cond, _ = state 128 | 129 | with torch.enable_grad(): 130 | x = x.clone().requires_grad_(True) 131 | 132 | if self.cond_dim is None: 133 | inputs = x 134 | else: 135 | inputs = torch.cat([x, cond], dim=self.cond_dim) 136 | 137 | dx = self.func(t, inputs) 138 | 139 | if self.nonself_connections: 140 | dx_div = self.func(t, inputs, rm_nonself_grads=True) 141 | else: 142 | dx_div = dx 143 | 144 | # Use brute force trace for testing if 2D. 145 | dim = np.prod(x.shape[1:]) 146 | if not self.training and dim <= 2: 147 | div = divergence_bf(dx_div, x, self.training) 148 | else: 149 | div = self.divergence_fn(dx_div, x, self.training, e=e) 150 | 151 | if not self.training: 152 | dx = dx.detach() 153 | div = div.detach() 154 | 155 | if self.cond_dim is None: 156 | return torch.zeros_like(e), dx, -div.reshape(-1, 1) 157 | else: 158 | return torch.zeros_like(e), dx, torch.zeros_like(cond), -div.reshape(-1, 1) 159 | 160 | def extra_repr(self): 161 | return f"method={self.method}, cond_dim={self.cond_dim}, rtol={self.rtol}, atol={self.atol}, fast_adjoint={self.fast_adjoint}" 162 | 163 | 164 | def cnf_block_fn( 165 | i, 166 | input_size, 167 | fc=False, 168 | idim=64, 169 | zero_init=False, 170 | depth=4, 171 | actfn="softplus", 172 | cond_embed_dim=0, 173 | **kwargs, 174 | ): 175 | del i 176 | 177 | actfns = { 178 | "softplus": nn.Softplus, 179 | "swish": Swish, 180 | } 181 | 182 | layer_fn = diffeq_layers.ConcatLinear if fc else diffeq_layers.ConcatConv2d 183 | 184 | dim = np.prod(input_size) if fc else input_size[0] 185 | 186 | if depth > 1: 187 | in_dims = [dim + cond_embed_dim] + [idim] * (depth - 1) 188 | out_dims = [idim] * (depth - 1) + [dim] 189 | layers = [] 190 | for d_in, d_out in zip(in_dims, out_dims): 191 | layers.append(layer_fn(d_in, d_out)) 192 | layers.append(actfns[actfn]()) 193 | layers = layers[:-1] # remove last actfn. 194 | else: 195 | layers = [layer_fn(dim + cond_embed_dim, dim)] 196 | 197 | if zero_init: 198 | layers[-1]._layer.weight.data.fill_(0) 199 | 200 | net = diffeq_layers.SequentialDiffEq(*layers) 201 | 202 | return CNF(net, **kwargs) 203 | -------------------------------------------------------------------------------- /layers/container.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | 12 | class SequentialFlow(nn.Module): 13 | """A generalized nn.Sequential container for normalizing flows. 14 | """ 15 | 16 | def __init__(self, layersList): 17 | super(SequentialFlow, self).__init__() 18 | self.chain = nn.ModuleList(layersList) 19 | 20 | def forward(self, x, **kwargs): 21 | logp = kwargs.pop("logp", None) 22 | if logp is None: 23 | for i in range(len(self.chain)): 24 | x = self.chain[i](x, **kwargs) 25 | return x 26 | else: 27 | for i in range(len(self.chain)): 28 | x, logp = self.chain[i](x, logp=logp, **kwargs) 29 | return x, logp 30 | 31 | def inverse(self, y, **kwargs): 32 | logp = kwargs.pop("logp", None) 33 | if logp is None: 34 | for i in range(len(self.chain) - 1, -1, -1): 35 | y = self.chain[i].inverse(y, **kwargs) 36 | return y 37 | else: 38 | for i in range(len(self.chain) - 1, -1, -1): 39 | y, logp = self.chain[i].inverse(y, logp=logp, **kwargs) 40 | return y, logp 41 | 42 | 43 | class Inverse(nn.Module): 44 | 45 | def __init__(self, flow): 46 | super(Inverse, self).__init__() 47 | self.flow = flow 48 | 49 | def forward(self, x, **kwargs): 50 | return self.flow.inverse(x, **kwargs) 51 | 52 | def inverse(self, y, **kwargs): 53 | return self.flow.forward(y, **kwargs) 54 | 55 | 56 | class Lambda(nn.Module): 57 | 58 | def __init__(self, forward_fn, inverse_fn): 59 | super(Lambda, self).__init__() 60 | self.forward_fn = forward_fn 61 | self.inverse_fn = inverse_fn 62 | 63 | def forward(self, x, logp=None): 64 | y = self.forward_fn(x) 65 | if logp is None: 66 | return y 67 | else: 68 | return y, logp 69 | 70 | def inverse(self, y, logp=None): 71 | x = self.inverse_fn(y) 72 | if logp is None: 73 | return x 74 | else: 75 | return x, logp 76 | -------------------------------------------------------------------------------- /layers/diffeq_layers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from .container import * 10 | from .resnet import * 11 | from .normalization import * 12 | from .basic import * 13 | from .wrappers import * 14 | -------------------------------------------------------------------------------- /layers/diffeq_layers/container.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .wrappers import diffeq_wrapper 13 | 14 | 15 | class SequentialDiffEq(nn.Module): 16 | """A container for a sequential chain of layers. Supports both regular and diffeq layers. 17 | """ 18 | 19 | def __init__(self, *layers): 20 | super(SequentialDiffEq, self).__init__() 21 | self.layers = nn.ModuleList([diffeq_wrapper(layer) for layer in layers]) 22 | 23 | def forward(self, t, x): 24 | for layer in self.layers: 25 | x = layer(t, x) 26 | return x 27 | 28 | 29 | class MixtureODELayer(nn.Module): 30 | """Produces a mixture of experts where output = sigma(t) * f(t, x). 31 | Time-dependent weights sigma(t) help learn to blend the experts without resorting to a highly stiff f. 32 | Supports both regular and diffeq experts. 33 | """ 34 | 35 | def __init__(self, experts): 36 | super(MixtureODELayer, self).__init__() 37 | assert len(experts) > 1 38 | wrapped_experts = [diffeq_wrapper(ex) for ex in experts] 39 | self.experts = nn.ModuleList(wrapped_experts) 40 | self.mixture_weights = nn.Linear(1, len(self.experts)) 41 | 42 | def forward(self, t, y): 43 | dys = [] 44 | for f in self.experts: 45 | dys.append(f(t, y)) 46 | dys = torch.stack(dys, 0) 47 | weights = self.mixture_weights(t).view(-1, *([1] * (dys.ndimension() - 1))) 48 | 49 | dy = torch.sum(dys * weights, dim=0, keepdim=False) 50 | return dy 51 | -------------------------------------------------------------------------------- /layers/diffeq_layers/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | from . import basic 12 | from . import container 13 | 14 | NGROUPS = 16 15 | 16 | 17 | class ResNet(container.SequentialDiffEq): 18 | def __init__(self, dim, intermediate_dim, n_resblocks, conv_block=None): 19 | super(ResNet, self).__init__() 20 | 21 | if conv_block is None: 22 | conv_block = basic.ConcatCoordConv2d 23 | 24 | self.dim = dim 25 | self.intermediate_dim = intermediate_dim 26 | self.n_resblocks = n_resblocks 27 | 28 | layers = [] 29 | layers.append(conv_block(dim, intermediate_dim, ksize=3, stride=1, padding=1, bias=False)) 30 | for _ in range(n_resblocks): 31 | layers.append(BasicBlock(intermediate_dim, conv_block)) 32 | layers.append(nn.GroupNorm(NGROUPS, intermediate_dim, eps=1e-4)) 33 | layers.append(nn.ReLU(inplace=True)) 34 | layers.append(conv_block(intermediate_dim, dim, ksize=1, bias=False)) 35 | 36 | super(ResNet, self).__init__(*layers) 37 | 38 | def __repr__(self): 39 | return ( 40 | '{name}({dim}, intermediate_dim={intermediate_dim}, n_resblocks={n_resblocks})'.format( 41 | name=self.__class__.__name__, **self.__dict__ 42 | ) 43 | ) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, dim, conv_block=None): 50 | super(BasicBlock, self).__init__() 51 | 52 | if conv_block is None: 53 | conv_block = basic.ConcatCoordConv2d 54 | 55 | self.norm1 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 56 | self.relu1 = nn.ReLU(inplace=True) 57 | self.conv1 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 58 | self.norm2 = nn.GroupNorm(NGROUPS, dim, eps=1e-4) 59 | self.relu2 = nn.ReLU(inplace=True) 60 | self.conv2 = conv_block(dim, dim, ksize=3, stride=1, padding=1, bias=False) 61 | 62 | def forward(self, t, x): 63 | residual = x 64 | 65 | out = self.norm1(x) 66 | out = self.relu1(out) 67 | out = self.conv1(t, out) 68 | 69 | out = self.norm2(out) 70 | out = self.relu2(out) 71 | out = self.conv2(t, out) 72 | 73 | out += residual 74 | 75 | return out 76 | -------------------------------------------------------------------------------- /layers/diffeq_layers/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from inspect import signature 10 | import torch.nn as nn 11 | 12 | __all__ = ["diffeq_wrapper", "reshape_wrapper"] 13 | 14 | 15 | class DiffEqWrapper(nn.Module): 16 | def __init__(self, module): 17 | super(DiffEqWrapper, self).__init__() 18 | self.module = module 19 | 20 | def forward(self, t, y): 21 | if len(signature(self.module.forward).parameters) == 1: 22 | return self.module(y) 23 | elif len(signature(self.module.forward).parameters) == 2: 24 | return self.module(t, y) 25 | else: 26 | raise ValueError("Differential equation needs to either take (t, y) or (y,) as input.") 27 | 28 | def __repr__(self): 29 | return self.module.__repr__() 30 | 31 | 32 | def diffeq_wrapper(layer): 33 | return DiffEqWrapper(layer) 34 | 35 | 36 | class ReshapeDiffEq(nn.Module): 37 | def __init__(self, input_shape, net): 38 | super(ReshapeDiffEq, self).__init__() 39 | assert len(signature(net.forward).parameters) == 2, "use diffeq_wrapper before reshape_wrapper." 40 | self.input_shape = input_shape 41 | self.net = net 42 | 43 | def forward(self, t, x): 44 | batchsize = x.shape[0] 45 | x = x.view(batchsize, *self.input_shape) 46 | return self.net(t, x).view(batchsize, -1) 47 | 48 | def __repr__(self): 49 | return self.diffeq.__repr__() 50 | 51 | 52 | def reshape_wrapper(input_shape, layer): 53 | return ReshapeDiffEq(input_shape, layer) 54 | -------------------------------------------------------------------------------- /layers/elemwise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class ZeroMeanTransform(nn.Module): 16 | 17 | def __init__(self): 18 | nn.Module.__init__(self) 19 | 20 | def forward(self, x, logp=None): 21 | x = x - .5 22 | if logp is None: 23 | return x 24 | return x, logp 25 | 26 | def inverse(self, y, logp=None): 27 | y = y + .5 28 | if logp is None: 29 | return y 30 | return y, logp 31 | 32 | 33 | class Normalize(nn.Module): 34 | 35 | def __init__(self, mean, std): 36 | nn.Module.__init__(self) 37 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32)) 38 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32)) 39 | 40 | def forward(self, x, logp=None): 41 | y = x.clone() 42 | c = len(self.mean) 43 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 44 | if logp is None: 45 | return y 46 | else: 47 | return y, logp - self._logdetgrad(x) 48 | 49 | def inverse(self, y, logp=None): 50 | x = y.clone() 51 | c = len(self.mean) 52 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None]) 53 | if logp is None: 54 | return x 55 | else: 56 | return x, logp + self._logdetgrad(x) 57 | 58 | def _logdetgrad(self, x): 59 | logdetgrad = ( 60 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3]) 61 | ) 62 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True) 63 | 64 | 65 | class LogitTransform(nn.Module): 66 | """ 67 | The proprocessing step used in Real NVP: 68 | y = (sigmoid(x) - a) / (1 - 2a) 69 | x = logit(a + (1 - 2a)*y) 70 | """ 71 | 72 | def __init__(self, alpha=1e-6): 73 | nn.Module.__init__(self) 74 | self.alpha = alpha 75 | 76 | def forward(self, x, logp=None): 77 | s = self.alpha + (1 - 2 * self.alpha) * x 78 | y = safe_log(s) - safe_log(1 - s) 79 | if logp is None: 80 | return y 81 | return y, logp - self._logdetgrad(x) 82 | 83 | def inverse(self, y, logp=None): 84 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 85 | if logp is None: 86 | return x 87 | return x, logp + self._logdetgrad(x) 88 | 89 | def _logdetgrad(self, x): 90 | s = self.alpha + (1 - 2 * self.alpha) * x 91 | logdetgrad = -safe_log(s - s * s) + math.log(1 - 2 * self.alpha) 92 | logdetgrad = logdetgrad.view(x.size(0), -1).sum(1, keepdim=True) 93 | return logdetgrad 94 | 95 | def __repr__(self): 96 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__)) 97 | 98 | 99 | def safe_log(x): 100 | return torch.log(x.clamp(min=1e-22)) 101 | 102 | 103 | class Softplus(nn.Module): 104 | def __init__(self, eps=1e-7): 105 | super().__init__() 106 | self.eps = eps 107 | 108 | def forward(self, x, logp=None): 109 | ''' 110 | z = softplus(x) = log(1+exp(z)) 111 | ldj = log(dsoftplus(x)/dx) = log(1/(1+exp(-x))) = log(sigmoid(x)) 112 | ''' 113 | z = F.softplus(x) 114 | ldj = F.logsigmoid(x).reshape(x.shape[0], -1).sum(1, keepdim=True) 115 | 116 | if logp is None: 117 | return z 118 | else: 119 | ldj = F.logsigmoid(x).reshape(x.shape[0], -1).sum(1, keepdim=True) 120 | return z, logp - ldj 121 | 122 | def inverse(self, z, logp=None): 123 | '''x = softplus_inv(z) = log(exp(z)-1) = z + log(1-exp(-z))''' 124 | zc = z.clamp(self.eps) 125 | x = z + torch.log1p(-torch.exp(-zc)) 126 | 127 | if logp is None: 128 | return x 129 | else: 130 | ldj = -F.logsigmoid(x).reshape(x.shape[0], -1).sum(1, keepdim=True) 131 | return x, logp - ldj -------------------------------------------------------------------------------- /layers/glow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class InvertibleLinear(nn.Module): 16 | 17 | def __init__(self, dim): 18 | super(InvertibleLinear, self).__init__() 19 | self.dim = dim 20 | w_init = torch.randn(dim, dim) 21 | w_init = np.linalg.qr(w_init.numpy())[0].astype(np.float32) 22 | self.weight = nn.Parameter(torch.from_numpy(w_init)) 23 | 24 | def forward(self, x, **kwargs): 25 | logp = kwargs.pop("logp", None) 26 | input_mask = kwargs.pop("input_mask", None) 27 | 28 | # Only apply to those with full input masks. 29 | if input_mask is None: 30 | input_mask = torch.as_tensor(True).expand_as(x) 31 | input_mask = input_mask.all(dim=-1, keepdim=True) 32 | 33 | y = F.linear(x, self.weight) 34 | y = y * input_mask + x * ~input_mask 35 | if logp is None: 36 | return y 37 | else: 38 | return y, logp - self._logdetgrad(x, input_mask) 39 | 40 | def inverse(self, y, **kwargs): 41 | logp = kwargs.pop("logp", None) 42 | input_mask = kwargs.pop("input_mask", None) 43 | input_mask = input_mask.all(dim=-1, keepdim=True) 44 | x = F.linear(y, self.weight.double().inverse().float()) 45 | x = x * input_mask + y * ~input_mask 46 | if logp is None: 47 | return x 48 | else: 49 | return x, logp + self._logdetgrad(x, input_mask) 50 | 51 | def _logdetgrad(self, x, input_mask): 52 | nreps = input_mask.reshape(input_mask.shape[0], -1).sum(1, keepdim=True) 53 | return torch.slogdet(self.weight)[1] * nreps 54 | 55 | def extra_repr(self): 56 | return 'dim={}'.format(self.dim) 57 | 58 | 59 | class InvertibleConv2d(nn.Module): 60 | 61 | def __init__(self, dim): 62 | super(InvertibleConv2d, self).__init__() 63 | self.dim = dim 64 | w_init = torch.randn(dim, dim) 65 | w_init = np.linalg.qr(w_init.numpy())[0].astype(np.float32) 66 | self.weight = nn.Parameter(torch.from_numpy(w_init)) 67 | 68 | def forward(self, x, **kwargs): 69 | logp = kwargs.pop("logp", None) 70 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1)) 71 | if logp is None: 72 | return y 73 | else: 74 | return y, logp - self._logdetgrad * x.shape[2] * x.shape[3] 75 | 76 | def inverse(self, y, **kwargs): 77 | logp = kwargs.pop("logp", None) 78 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1)) 79 | if logp is None: 80 | return x 81 | else: 82 | return x, logp + self._logdetgrad * x.shape[2] * x.shape[3] 83 | 84 | @property 85 | def _logdetgrad(self): 86 | return torch.slogdet(self.weight)[1] 87 | 88 | def extra_repr(self): 89 | return 'dim={}'.format(self.dim) 90 | -------------------------------------------------------------------------------- /layers/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import Parameter 12 | 13 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 14 | 15 | 16 | class MovingBatchNormNd(nn.Module): 17 | 18 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 19 | super(MovingBatchNormNd, self).__init__() 20 | self.num_features = num_features 21 | self.affine = affine 22 | self.eps = eps 23 | self.decay = decay 24 | self.bn_lag = bn_lag 25 | self.register_buffer('step', torch.zeros(1)) 26 | if self.affine: 27 | self.bias = Parameter(torch.Tensor(num_features)) 28 | else: 29 | self.register_parameter('bias', None) 30 | self.register_buffer('running_mean', torch.zeros(num_features)) 31 | self.reset_parameters() 32 | 33 | def shape(self, x): 34 | raise NotImplementedError 35 | 36 | def reset_parameters(self): 37 | self.running_mean.zero_() 38 | if self.affine: 39 | self.bias.data.zero_() 40 | 41 | def forward(self, x, logp=None): 42 | c = x.size(1) 43 | used_mean = self.running_mean.clone().detach() 44 | 45 | if self.training: 46 | # compute batch statistics 47 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 48 | batch_mean = torch.mean(x_t, dim=1) 49 | 50 | # moving average 51 | if self.bn_lag > 0: 52 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 53 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 54 | 55 | # update running estimates 56 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 57 | self.step += 1 58 | 59 | # perform normalization 60 | used_mean = used_mean.view(*self.shape(x)).expand_as(x) 61 | 62 | y = x - used_mean 63 | 64 | if self.affine: 65 | bias = self.bias.view(*self.shape(x)).expand_as(x) 66 | y = y + bias 67 | 68 | if logp is None: 69 | return y 70 | else: 71 | return y, logp 72 | 73 | def inverse(self, y, logp=None): 74 | used_mean = self.running_mean 75 | 76 | if self.affine: 77 | bias = self.bias.view(*self.shape(y)).expand_as(y) 78 | y = y - bias 79 | 80 | used_mean = used_mean.view(*self.shape(y)).expand_as(y) 81 | x = y + used_mean 82 | 83 | if logp is None: 84 | return x 85 | else: 86 | return x, logp 87 | 88 | def __repr__(self): 89 | return ( 90 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 91 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 92 | ) 93 | 94 | 95 | class MovingBatchNorm1d(MovingBatchNormNd): 96 | 97 | def shape(self, x): 98 | return [1] * (x.ndim - 1) + [-1] 99 | 100 | 101 | class MovingBatchNorm2d(MovingBatchNormNd): 102 | 103 | def shape(self, x): 104 | return [1, -1, 1, 1] 105 | -------------------------------------------------------------------------------- /layers/squeeze.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | __all__ = ['SqueezeLayer'] 13 | 14 | 15 | class SqueezeLayer(nn.Module): 16 | 17 | def __init__(self, downscale_factor): 18 | super(SqueezeLayer, self).__init__() 19 | self.downscale_factor = downscale_factor 20 | 21 | def forward(self, x, logp=None, **kwargs): 22 | squeeze_x = squeeze(x, self.downscale_factor) 23 | if logp is None: 24 | return squeeze_x 25 | else: 26 | return squeeze_x, logp 27 | 28 | def inverse(self, y, logp=None, **kwargs): 29 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 30 | if logp is None: 31 | return unsqueeze_y 32 | else: 33 | return unsqueeze_y, logp 34 | 35 | 36 | def unsqueeze(input, upscale_factor=2): 37 | return torch.pixel_shuffle(input, upscale_factor) 38 | 39 | 40 | def squeeze(input, downscale_factor=2): 41 | ''' 42 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 43 | ''' 44 | batch_size, in_channels, in_height, in_width = input.shape 45 | out_channels = in_channels * (downscale_factor**2) 46 | 47 | out_height = in_height // downscale_factor 48 | out_width = in_width // downscale_factor 49 | 50 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor) 51 | 52 | output = input_view.permute(0, 1, 3, 5, 2, 4) 53 | return output.reshape(batch_size, out_channels, out_height, out_width) 54 | -------------------------------------------------------------------------------- /toy_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | import numpy as np 10 | import sklearn 11 | import sklearn.datasets 12 | from sklearn.utils import shuffle as util_shuffle 13 | 14 | 15 | # Dataset iterator 16 | def inf_train_gen(data, batch_size=200): 17 | 18 | if data == "swissroll": 19 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 20 | data = data.astype("float32")[:, [0, 2]] 21 | data /= 5 22 | return data 23 | 24 | elif data == "circles": 25 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 26 | data = data.astype("float32") 27 | data *= 3 28 | return data 29 | 30 | elif data == "rings": 31 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 32 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 33 | 34 | # so as not to have the first point = last point, we set endpoint=False 35 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 36 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 37 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 38 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 39 | 40 | circ4_x = np.cos(linspace4) 41 | circ4_y = np.sin(linspace4) 42 | circ3_x = np.cos(linspace4) * 0.75 43 | circ3_y = np.sin(linspace3) * 0.75 44 | circ2_x = np.cos(linspace2) * 0.5 45 | circ2_y = np.sin(linspace2) * 0.5 46 | circ1_x = np.cos(linspace1) * 0.25 47 | circ1_y = np.sin(linspace1) * 0.25 48 | 49 | X = np.vstack([ 50 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 51 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 52 | ]).T * 3.0 53 | X = util_shuffle(X) 54 | 55 | # Add noise 56 | X = X + np.random.normal(scale=0.08, size=X.shape) 57 | 58 | return X.astype("float32") 59 | 60 | elif data == "moons": 61 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 62 | data = data.astype("float32") 63 | data = data * 2 + np.array([-1, -0.2]) 64 | return data 65 | 66 | elif data == "8gaussians": 67 | scale = 4. 68 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 69 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 70 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 71 | centers = [(scale * x, scale * y) for x, y in centers] 72 | 73 | dataset = [] 74 | for i in range(batch_size): 75 | point = np.random.randn(2) * 0.5 76 | idx = np.random.randint(8) 77 | center = centers[idx] 78 | point[0] += center[0] 79 | point[1] += center[1] 80 | dataset.append(point) 81 | dataset = np.array(dataset, dtype="float32") 82 | dataset /= 1.414 83 | return dataset 84 | 85 | elif data == "pinwheel": 86 | radial_std = 0.3 87 | tangential_std = 0.1 88 | num_classes = 5 89 | num_per_class = batch_size // 5 90 | rate = 0.25 91 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 92 | 93 | features = np.random.randn(num_classes*num_per_class, 2) \ 94 | * np.array([radial_std, tangential_std]) 95 | features[:, 0] += 1. 96 | labels = np.repeat(np.arange(num_classes), num_per_class) 97 | 98 | angles = rads[labels] + rate * np.exp(features[:, 0]) 99 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 100 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 101 | 102 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 103 | 104 | elif data == "2spirals": 105 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 106 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 107 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 108 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 109 | x += np.random.randn(*x.shape) * 0.1 110 | return x 111 | 112 | elif data == "checkerboard": 113 | x1 = np.random.rand(batch_size) * 4 - 2 114 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 115 | x2 = x2_ + (np.floor(x1) % 2) 116 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 117 | 118 | elif data == "line": 119 | x = np.random.rand(batch_size) * 5 - 2.5 120 | y = x 121 | return np.stack((x, y), 1) 122 | elif data == "cos": 123 | x = np.random.rand(batch_size) * 5 - 2.5 124 | y = np.sin(x) * 2.5 125 | return np.stack((x, y), 1) 126 | else: 127 | return inf_train_gen("8gaussians", batch_size) 128 | --------------------------------------------------------------------------------