├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── __init__.py ├── dataset.py ├── extract_code.py ├── models ├── __init__.py ├── fista_pixelsnail.py ├── model_utils.py ├── pixelsnail.py ├── quantizers.py └── vqvae.py ├── mt_sample.py ├── requirements.txt ├── sample.py ├── scheduler.py ├── scripts ├── __init__.py ├── calculate_jpg_psnr.py ├── calculate_model_psnr.py ├── calculate_model_psnr.sh ├── compression_psnr_graph.py ├── decompression_graph_psnr.py ├── extract_dataset_unlearned_encodings.py ├── hyperparameter_alpha_search.py ├── merge_model_psnrs_to_csv.py ├── spwan.py └── visualize_encodings.py ├── train_fista_pixelsnail.py ├── train_pixelsnail.py ├── train_vqvae.py └── utils ├── __init__.py ├── pyfista.py ├── pyfista_test.py ├── pyomp.py └── util_funcs.py /.gitignore: -------------------------------------------------------------------------------- 1 | sample/ 2 | checkpoint/ 3 | test_samples/ 4 | __pycharm__/ 5 | lyssa/ 6 | Lyssandra/ 7 | .idea/ 8 | .DS_Store 9 | __pycache__/ 10 | runs/ 11 | sampled_images/ 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | ============================================================================== 25 | PixelSNAIL 26 | ============================================================================== 27 | 28 | MIT License 29 | 30 | Copyright (c) 2019 Xi Chen 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all 40 | copies or substantial portions of the Software. 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | SOFTWARE. 49 | 50 | ============================================================================== 51 | Learning rate scheduler and VQ-VAE and SparseVQVAE 52 | ============================================================================== 53 | 54 | Apache License, Version 2.0 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ 55 | 56 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 57 | 58 | 1. Definitions. 59 | 60 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 61 | 62 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 63 | 64 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 65 | 66 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 67 | 68 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 69 | 70 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 71 | 72 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 73 | 74 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 75 | 76 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 77 | 78 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 79 | 80 | 2. Grant of Copyright License. 81 | 82 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 83 | 84 | 3. Grant of Patent License. 85 | 86 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 87 | 88 | 4. Redistribution. 89 | 90 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 91 | 92 | You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 93 | 94 | 5. Submission of Contributions. 95 | 96 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 97 | 98 | 6. Trademarks. 99 | 100 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 101 | 102 | 7. Disclaimer of Warranty. 103 | 104 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 105 | 106 | 8. Limitation of Liability. 107 | 108 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 109 | 110 | 9. Accepting Warranty or Additional Liability. 111 | 112 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 113 | 114 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseVQVAE: Sparse Dictionary based Vector Quantized Variational AutoEncoder 2 | Experimental implementation for a sparse-dictionary based version of the VQ-VAE2 paper 3 | (see: [Generating Diverse High-Fidelity Images with VQ-VAE-2](https://arxiv.org/abs/1906.00446)) 4 | 5 | This repository builds over PyTorch. 6 | 7 | ## Authors 8 | Yiftach Ginger ([iftachg](https://github.com/iftachg)), Or Perel ([orperel](https://github.com/orperel)), Roee Litman ([rlit](https://github.com/rlit)) 9 | 10 | ## Introduction 11 | 12 | VQ-VAE is a promising direction for image synthesis, that is completely separate from the GAN line of works. 13 | The main idea of this codebase is to create a generalized VQ-VAE, 14 | by replacing the hard selection in the heart of the method to a softer selection by sparse coding. 15 | This stems from the observation that hard selection is in essence the “sparsest code”, 16 | i.e. single non-zero element (or- one hot vector). 17 | 18 | In this generalized implementation, we allow the VAE to code each patch with a small set sparse dictionary atoms, 19 | rather than a single code as done in the original work. 20 | We therefore build over the VQVAE2 paper: 21 | 1. We perform sparse dictionary learning, to generate a set of atoms best describing the data. 22 | 2. During training, new images are encoded per patch, where each patch is encoded by a small set of atoms. 23 | 3. We then decode the image back from sparse-codes to pixel space using a learned encoder. 24 | 25 | During inference time images may get compressed by employing both encoder & decoder. 26 | Alternatively, new images can be synthesized by randomizing sparse codes and employing only the decoder. 27 | 28 | We summarize the main contributions of this repository as follows: 29 | 1. Sparse dictionary over Pytorch: 30 | - Sparse dictionary is learned via [Task-Driven Dictionary Learning][1], implemented to be compatible with PyTorch's auto-differentiation. 31 | - Fast parallel implementations of the [FISTA][2] and [OMP][3] sparse-coding algorithms. 32 | 2. A complete sparse-dictionary empowered VQ-VAE2 implementation, including training & evaluation code. 33 | 34 | [1]: https://arxiv.org/abs/1009.5358 35 | [2]: https://people.rennes.inria.fr/Cedric.Herzet/Cedric.Herzet/Sparse_Seminar/Entrees/2012/11/12_A_Fast_Iterative_Shrinkage-Thresholding_Algorithmfor_Linear_Inverse_Problems_(A._Beck,_M._Teboulle)_files/Breck_2009.pdf 36 | [3]: http://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf 37 | 38 | ## Dictionary Learning 39 | 40 | This sparse coding problem involves integer programming over a non-covex L0 norm, and therefore is NP-hard. 41 | In practice, the solution is approximated using pursuit algorithms, where the atoms "compete" over which get to describe the input signal. 42 | Generally speaking, there are two flavours of pursuit algorithms: greedy and convex-relaxation. 43 | We provide one example from each family 44 | 45 | 46 | #### OMP 47 | 48 | This method approximates the exact L0 norm solution in a greedy manner, selecting the next atom with the smallest (angular) residual w.r.t. the current code. 49 | The benefit here is that we have a guaranteed number of `K` non-zer0 elements after `K` iterations. 50 | On the other hand, the selection process makes the process itself less suitable for differentiable programming (aka back-prop). 51 | 52 | #### FISTA 53 | 54 | Here, the L0 is relaxed to its nearest convex counterpart, the L1 norm which is treated as an additive penalty. 55 | The resulting LASSO problem is a convex one, and has several efficient methods to solve efficiently. 56 | The iterative nature of this methos allows unrolling its structure and approximating it using a neural net (see [LISTA](http://yann.lecun.com/exdb/publis/pdf/gregor-icml-10.pdf)) 57 | The drawback here is that the resulting code can have arbitrary number of non-zero elements after a fixed number of iterations. 58 | 59 | #### Task-Driven Dictionary Learning 60 | 61 | Without going into too many details, this paper proposes a way to calculate the derivative of the spase coding problem with respect to the dictionary. 62 | This is opens the way for a bi-level optimisation procedure, where we optimize the result of an optimization process. 63 | Using this method we can create a dictionary optimized for any task, specifically the one our vq-vae is meant to solve. 64 | 65 | ## Applications 66 | 67 | #### Compression 68 | #### Synthesis 69 | 70 | ## Limitations 71 | 72 | This reporistory contains research materials of an unpublished work. 73 | Training + Inference code based on FISTA and OMP over PyTorch is fully functional for compression use cases. 74 | PixelSnail synthesis functionality is partially supported. 75 | 76 | ## Installation 77 | 78 | Code assumes Linux environment (tested on Ubuntu 16). 79 | 80 | ### Prerequisites 81 | 82 | * python >= 3.6 83 | * pytorch >= 1.4 84 | * cuda 10 or higher (recommended) 85 | 86 | ### Environment Setup 87 | 88 | After cloning this repository: 89 | > cd Sparse_VAE 90 | 91 | > pip install -r requirements.txt 92 | 93 | ## Project Structure 94 | 95 | FISTAFunction and OMPFunction are fast GPU implementations of both algorithms over PyTorch. 96 | Practitioners are welcome to incorporate these functions into their repositories under the license terms of this repository. 97 | 98 | The rest of the project structure can be briefly described as such: 99 | 100 | * checkpoint/ 101 | * MODELS SAVED HERE for vanilla and pixelsnail (both), as well as args used to generate them. 102 | * models/ 103 | * fista_pixelsnail - the implementation of the modifier pixelsnail based on FISTA 104 | * model_utils - contains functions for genrating VQVAE objects and loading datasets (CIFAR, imagenet..). All files are downloaded relative to the project path. 105 | * pixelsnail - the original pixelsnail model. fista_pixelsnail overrides this model and adds additional heads. The vanilla model generates this model twice (top and bottom) 106 | * quantizers - Contains only the stuff that generates quantized codes: FISTA, OMP and Vanilla VQVAE quantization. 107 | * vqvae - composed of Encoder / Decoder. Of interest here: we can change the stride to achieve different effects (0 - for decompression ; 1- for vqvae vanilla ; 2 - for compression). The stride should change for both Encoder / Decoder 108 | * scripts/ 109 | * calculate_jpg_psnr - a standalone script, accepts a dataset (hardcoded cifar) and runs compression for multiple quality levels. Outputs the psnr.. 110 | * calculate_model_psnr - similar to the above, only this one receives a model as input and prints it’s compression psnr. Note we have our own manual calculation of PSNR here. FISTA converges for multiple images at the same time, so the slowest image in the batch determines the bottleneck speed. If we run with batch size 1 we’re faster and more accurate. 111 | * extract_dataset_unlearned_encodings - skip that (was used for experiments on alpha). 112 | * graph_psnr - takes the PSNR tables we’ve created and generates plots. 113 | * hyperparameter_alpha_search - convergence of alpha related to amount of nonzeros - calculated twice for random data and second time for the script we’ve just skipped. Most probably we shouldn’t be touching this script.. 114 | * visualize_encodings - a visualization script Yiftach have created for himself. Here we take a model and a dataset, run the model over the dataset and save the output image, to test it’s still valid. If all goes well we shouldn’t be using this file.. 115 | * utils/ 116 | * pyfista is implemented here, both. dictionary learning and 117 | * pyfista_test - generates fake data to train sparse coding.. We don’t do hyperparams search anymore so we have no additiona uses for this file. 118 | * pyomp - Holds implementation of forwards for OMP for a single sample at a time (TODO: implement batch OMP if we want). 119 | * util_funcs - lots of helper functions are stored here. Argument parsers are handled here, as well as seeding and experiments setup (general stuff like assigning an experiment name..) 120 | * dataset - all definitions for used datasets. These are definitions for datasets but there is nothing to configure here. 121 | * extract_code - main for extract_code (2nd step in the algorithm training..) 122 | * mt_sample - multi threaded sampling.. Currently broken. 123 | * sample - receives a PixelSnail and starts generating images.. 124 | * scheduler - Scheduling definitions for number of schedulers, when to save a checkpoint file.. etc.. 125 | * train_fista_pixelsnail / train_pixelsnail / train_vqvae - all neural net trainers we support.. 126 | * scheduler - Scheduling definitions for number of schedulers, when to save a checkpoint file.. Etc 127 | 128 | ## Usage 129 | 130 | 1. Training the Sparse-VQVAE encoder-decoder: 131 | 132 | Currently this codebase supports CIFAR10, CIFAR100, and ImageNet. 133 | 134 | * Train with original VQVAE: 135 | ``` 136 | train_vqvae.py --experiment_name="experiment_vq" --selection_fn=vanilla 137 | ``` 138 | 139 | * Train with FISTA sparse-coding: 140 | 141 | ``` 142 | train_vqvae.py --experiment_name="experiment_fista" --selection_fn=fista 143 | ``` 144 | 145 | * Train with OMP sparse-coding: 146 | 147 | ``` 148 | train_vqvae.py --experiment_name="experiment_omp" --selection_fn=fista --num_strides=2 149 | ``` 150 | 151 | 152 | For synthesis, additional steps are required: 153 | 154 | 2. Extract codes for stage 2 training 155 | 156 | > python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH] 157 | 158 | 159 | 3. Stage 2 (PixelSNAIL) 160 | 161 | > python train_pixelsnail.py [LMDB NAME] 162 | 163 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amzn/sparse-vqvae/33ca864b6a20c644c3c825ce958fd20a99349dda/__init__.py -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import datasets 8 | import lmdb 9 | 10 | 11 | CodeRow = namedtuple('CodeRow', ['top', 'bottom', 'filename']) 12 | 13 | 14 | class NamedDataset(Dataset): 15 | 16 | def __init__(self, dataset): 17 | self.dataset = dataset 18 | 19 | def __getitem__(self, index): 20 | return list(self.dataset[index]) + [index] 21 | 22 | def __len__(self): 23 | return len(self.dataset) 24 | 25 | 26 | class ImageFileDataset(datasets.ImageFolder): 27 | def __getitem__(self, index): 28 | sample, target = super().__getitem__(index) 29 | path, _ = self.samples[index] 30 | dirs, filename = os.path.split(path) 31 | _, class_name = os.path.split(dirs) 32 | filename = os.path.join(class_name, filename) 33 | 34 | return sample, target, filename 35 | 36 | 37 | class LMDBDataset(Dataset): 38 | def __init__(self, path, architecture): 39 | if architecture == 'vqvae' or architecture == 'vqvae2': 40 | self.architecture = architecture 41 | else: 42 | raise ValueError('Valid architectures are vqvae and vqvae2. Got: {}'.format(architecture)) 43 | 44 | 45 | self.env = lmdb.open( 46 | path, 47 | max_readers=32, 48 | readonly=True, 49 | lock=False, 50 | readahead=False, 51 | meminit=False, 52 | ) 53 | 54 | if not self.env: 55 | raise IOError('Cannot open lmdb dataset', path) 56 | 57 | with self.env.begin(write=False) as txn: 58 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 59 | 60 | def __len__(self): 61 | return self.length 62 | 63 | def __getitem__(self, index): 64 | with self.env.begin(write=False) as txn: 65 | key = str(index).encode('utf-8') 66 | 67 | row = pickle.loads(txn.get(key)) 68 | 69 | if self.architecture == 'vqvae': 70 | return torch.from_numpy(row.bottom), torch.from_numpy(row.bottom), row.filename 71 | elif self.architecture == 'vqvae2': 72 | return torch.from_numpy(row.top), torch.from_numpy(row.bottom), row.filename -------------------------------------------------------------------------------- /extract_code.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | import torch 5 | from torch.utils.data import DataLoader 6 | # from torchvision import transforms 7 | import lmdb 8 | from tqdm import tqdm 9 | # from torchvision import datasets 10 | from dataset import CodeRow, NamedDataset 11 | # from models.vqvae import VQVAE 12 | # import torch.nn as nn 13 | from utils import util_funcs 14 | from models.model_utils import get_model, get_dataset 15 | # from torchvision import datasets, transforms, utils 16 | # import joblib 17 | 18 | 19 | def extract(lmdb_env, loader, model, device, phase='train'): 20 | index = 0 21 | 22 | with lmdb_env.begin(write=True) as txn: 23 | pbar = tqdm(loader, desc='Extracting for {} phase'.format(phase)) 24 | 25 | for img, _, filename in pbar: 26 | img = img.to(device) 27 | 28 | # Quantize the image and output the atom ids of the patches 29 | quant, _, id, _, _, _, _, _, _ = model.encode(img) 30 | id = id.detach().cpu().numpy() 31 | 32 | # Dump every patch separately 33 | for file, bottom in zip(filename, id): 34 | row = CodeRow(top=None, bottom=bottom, filename=file) 35 | txn.put(str(index).encode('utf-8'), pickle.dumps(row)) 36 | index += 1 37 | pbar.set_postfix({'Inserted': index}) 38 | 39 | txn.put('length'.encode('utf-8'), str(index).encode('utf-8')) 40 | 41 | 42 | def create_extraction_run(size, device, dataset, data_path, num_workers, num_embeddings, architecture, ckpt_epoch, neighborhood, selection_fn, embed_dim, **kwargs): 43 | train_dataset, test_dataset = get_dataset(dataset, data_path, size) 44 | 45 | print('Creating named datasets') 46 | # We don't really use the "Named" part, but I'm keeping it to stay close to the original code repository 47 | train_named_dataset = NamedDataset(train_dataset) 48 | test_named_dataset = NamedDataset(test_dataset) 49 | 50 | print('creating data loaders') 51 | train_loader = DataLoader(train_named_dataset, batch_size=kwargs['vae_batch'], shuffle=False, num_workers=num_workers) 52 | test_loader = DataLoader(test_named_dataset, batch_size=kwargs['vae_batch'], shuffle=False, num_workers=num_workers) 53 | 54 | # This is still the VQ-VAE experiment name and path 55 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) 56 | checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) 57 | checkpoint_path = f'checkpoint/{checkpoint_name}' 58 | 59 | print('Loading model') 60 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) 61 | model.load_state_dict(torch.load(checkpoint_path), strict=False) 62 | model = model.to(device) 63 | model.eval() 64 | 65 | 66 | print('Creating LMDB DBs') 67 | map_size = 100 * 1024 * 1024 * 1024 # This would be the maximum size of the databases 68 | db_name = checkpoint_name[:-3] + '_dataset[{}]'.format(dataset) # This comprises of the experiment name and the epoch the codes are taken from 69 | train_env = lmdb.open(os.path.join('codes', 'train_codes', db_name), map_size=map_size) # Will save the encodings of train samples 70 | test_env = lmdb.open(os.path.join('codes', 'test_codes', db_name), map_size=map_size) # Will save the encodings of test samples 71 | 72 | print('Extracting') 73 | if architecture == 'vqvae': 74 | extract(train_env, train_loader, model, device, 'train') 75 | extract(test_env, test_loader, model, device, 'test') 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser = util_funcs.base_parser(parser) 81 | parser = util_funcs.vqvae_parser(parser) 82 | parser = util_funcs.code_extraction_parser(parser) 83 | args = parser.parse_args() 84 | 85 | print(args) 86 | 87 | util_funcs.seed_generators(args.seed) 88 | 89 | create_extraction_run(**vars(args)) 90 | # create_extraction_run(args.size, args.device, args.dataset, args.data_path, args.num_workers, args.num_embed, args.architecture, args.ckpt_epoch, args.neighborhood, args.selection_fn) 91 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amzn/sparse-vqvae/33ca864b6a20c644c3c825ce958fd20a99349dda/models/__init__.py -------------------------------------------------------------------------------- /models/fista_pixelsnail.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xi Chen 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch 7 | 8 | # from math import sqrt 9 | # from functools import partial, lru_cache 10 | 11 | # import numpy as np 12 | # import torch 13 | # from torch import nn 14 | # from torch.nn import functional as F 15 | from models.pixelsnail import * 16 | # import joblib 17 | 18 | 19 | class FistaPixelSNAIL(PixelSNAIL): 20 | def __init__( 21 | self, 22 | shape, 23 | n_class, 24 | channel, 25 | kernel_size, 26 | n_block, 27 | n_res_block, 28 | res_channel, 29 | attention=True, 30 | dropout=0.1, 31 | n_cond_res_block=0, 32 | cond_res_channel=0, 33 | cond_res_kernel=3, 34 | n_out_res_block=0, 35 | ): 36 | super().__init__(shape, 37 | n_class, 38 | channel, 39 | kernel_size, 40 | n_block, 41 | n_res_block, 42 | res_channel, 43 | attention=attention, 44 | dropout=dropout, 45 | n_cond_res_block=n_cond_res_block, 46 | cond_res_channel=cond_res_channel, 47 | cond_res_kernel=cond_res_kernel, 48 | n_out_res_block=n_out_res_block) 49 | 50 | self.eps = np.finfo(float).eps * 10 51 | 52 | # Override base PixelSnail out module 53 | out = [] 54 | 55 | for i in range(n_out_res_block): 56 | out.append(GatedResBlock(channel, res_channel, 1)) 57 | 58 | self.out = nn.Sequential(*out) 59 | 60 | # Declare network heads 61 | self.sampling_head = nn.Sequential(*[nn.ELU(inplace=False), WNConv2d(channel, n_class, 1)]) 62 | self.nonzeros_head = nn.Sequential(*[nn.ELU(inplace=False), WNConv2d(channel, n_class, 1)]) 63 | 64 | # Create a joined tensor for the mu and sigma for the reparametrization trick 65 | self.reparamaterization_head = nn.Sequential(*[nn.ELU(inplace=False), WNConv2d(channel, n_class*n_class+n_class, 1)]) 66 | 67 | @staticmethod 68 | def reparameterize(mu, sigma): 69 | """ 70 | Reparameterization trick. 71 | 72 | Dimension notation: 73 | B: Batch 74 | S: Sparse code 75 | W: Width 76 | H: Height 77 | 78 | :param mu: Float Tensor. (B, W, H, S). Predicted median for coefficient per atom per patch 79 | :param sigma: Float Tensor. (B, W, H, S, S). Predicted covariance for coefficient per atom per patch 80 | :return: Float Tensor. (B, W, H, S). Sampled coefficients per atom per patch. 81 | """ 82 | 83 | # We change the view of the parameters otherwise having unique dimensions for the Width and Height 84 | # would interfere with the product operation 85 | mu_view = mu.contiguous().view([-1, mu.size()[-1]]) 86 | sigma_view = sigma.contiguous().view([-1, sigma.size()[-2], sigma.size()[-1]]) 87 | theta_0 = torch.randn(mu_view.size()).to(mu_view.device) 88 | 89 | theta_1 = (theta_0 - mu_view).unsqueeze(1).bmm(sigma_view) 90 | theta_1 = theta_1.view(mu.size()) 91 | return theta_1 92 | 93 | def prepare_inputs(self, Z): 94 | """ 95 | Extracts ground-truth from given sparse code. Useful only in training 96 | 97 | Dimension notation: 98 | B: Batch 99 | S: Sparse code 100 | W: Width 101 | H: Height 102 | 103 | :param Z: Float Tensor. (B, S, W, H). Given sparse code for every patch 104 | :return: 105 | 1. used_atoms_mask, Bool Tensor. (B, S, W, H). Map of atoms which are non-zero 106 | 2. gt_num_nonzeros, Long Tensor. (B, W, H). Number of non-zero atoms for each patch 107 | """ 108 | used_atoms_mask = torch.abs(Z) > self.eps 109 | gt_num_nonzeros = used_atoms_mask.sum(1) # sum over the sparse code dimension 110 | 111 | return used_atoms_mask, gt_num_nonzeros 112 | 113 | def forward(self, Z, used_atoms_mask, gt_num_nonzeros): 114 | """ 115 | Dimension notation: 116 | B: Batch 117 | S: Sparse code 118 | W: Width 119 | H: Height 120 | 121 | :param Z: Float Tensor. (B, S, W, H). Map of sparse code patches 122 | :param used_atoms_mask: Bool Tensor. (B, S, W, H). Map of atoms which are non-zero 123 | :param gt_num_nonzeros: Long Tensor. (B, W, H). Number of non-zero atoms for each patch 124 | """ 125 | 126 | batch, n_class, height, width = Z.size() 127 | assert n_class == self.n_class 128 | 129 | horizontal = shift_down(self.horizontal(Z)) 130 | vertical = shift_right(self.vertical(Z)) 131 | out = horizontal + vertical 132 | 133 | background = self.background[:, :, :height, :].expand(batch, 2, height, width) 134 | 135 | for block in self.blocks: 136 | out = block(out, background) 137 | 138 | out = self.out(out) 139 | sampled_atoms = self.sampling_head(out) # Output a probability for each atom to be used or not 140 | sampled_num_nonzeros = self.nonzeros_head(out) # Classify how many non-zero atoms there are 141 | 142 | # Binary mask of all selected atoms, per patch 143 | expanded_matrix_nonzero_inds_mask = used_atoms_mask.unsqueeze(dim=1).repeat(1, self.n_class, 1, 1, 1) 144 | 145 | # Output the mu and sigma from which we sample the coefficients of non-zero atoms 146 | reparametarization_results = self.reparamaterization_head(out) 147 | mu = reparametarization_results[:, :self.n_class, :, :] 148 | sigma_vector = reparametarization_results[:, self.n_class:, :, :] 149 | sigma_matrix = sigma_vector.view([sigma_vector.size()[0], self.n_class, self.n_class, sigma_vector.size()[2], sigma_vector.size()[3]]) 150 | 151 | # (B, S, W, H) -> (B, W, H, S): Push the coefficient data dimension to last for the reparameterization trick 152 | permuted_mu = mu.permute(0, 2, 3, 1) 153 | # (B, S, S, W, H) -> (B, W, H, S, S): Push the coefficient data dimension to last for the reparameterization trick 154 | permuted_sigma_matrix = sigma_matrix.permute(0, 3, 4, 1, 2) 155 | 156 | permuted_coefficients = self.reparameterize(permuted_mu, permuted_sigma_matrix) 157 | 158 | # Un-permute the results: (B, W, H, S) -> (B, S, W, H) 159 | coefficients = permuted_coefficients.permute(0, 3, 1, 2) 160 | 161 | # Apply used_atoms_mask 162 | masked_coefficients = coefficients * used_atoms_mask 163 | 164 | return sampled_atoms, sampled_num_nonzeros, masked_coefficients 165 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import torch 3 | import torch.nn as nn 4 | from models.vqvae import VQVAE 5 | from torchvision import datasets, transforms, utils 6 | 7 | 8 | def get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=True, **kwargs): 9 | """ 10 | Creates a VQVAE object. 11 | 12 | :param architecture: Has to be 'vqvae'. 13 | :param num_embeddings: Int. Number of dictioanry atoms 14 | :param device: String. 'cpu', 'cuda' or 'cuda:device_number' 15 | :param neighborhood: Int. Not used. 16 | :param selection_fn: String. 'vanilla' or 'fista' 17 | :param embed_dim: Int. Size of latent space. 18 | :param parallel: Bool. Use DataParallel or not. 19 | 20 | :return: VQVAE model or DataParallel(VQVAE model) 21 | """ 22 | if architecture == 'vqvae': 23 | model = VQVAE(n_embed=num_embeddings, neighborhood=neighborhood, selection_fn=selection_fn, embed_dim=embed_dim, **kwargs).to(device) 24 | else: 25 | raise ValueError('Valid architectures are vqvae. Got: {}'.format(architecture)) 26 | 27 | if parallel and device != 'cpu': 28 | model = nn.DataParallel(model) 29 | 30 | return model 31 | 32 | 33 | def get_dataset(dataset, data_path, size, download=False): 34 | """ 35 | Loads a dataset 36 | 37 | :param dataset: String. Name of dataset. Currently supports ['test', 'cifar10', 'cifar100', 'imagenet']. 38 | Note that 'test' loads the 'cifar10' 39 | :param data_path: String. Path to directory where the dataset is / will be saved. 40 | :param size: Int. Resize image to this size. 41 | :param download: Bool. True to download dataset IF needed. False will save time. 42 | :return: Dataset object. 43 | """ 44 | train_transform = transforms.Compose( 45 | [ 46 | transforms.Resize(size), 47 | transforms.CenterCrop(size), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 50 | ] 51 | ) 52 | 53 | test_transform = transforms.Compose( 54 | [ 55 | transforms.Resize(size), 56 | transforms.CenterCrop(size), 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 59 | ] 60 | ) 61 | 62 | if dataset == 'test': 63 | train_dataset = datasets.CIFAR10(root=os.path.join(data_path, dataset), download=download, 64 | transform=train_transform) 65 | test_dataset = datasets.CIFAR10(root=os.path.join(data_path, dataset), download=download, 66 | transform=test_transform, train=False) 67 | elif dataset == 'cifar10': 68 | train_dataset = datasets.CIFAR10(root=os.path.join(data_path, dataset), download=download, 69 | transform=train_transform) 70 | test_dataset = datasets.CIFAR10(root=os.path.join(data_path, dataset), download=download, 71 | transform=test_transform, train=False) 72 | elif dataset == 'cifar100': 73 | train_dataset = datasets.CIFAR100(root=os.path.join(data_path, dataset), download=download, 74 | transform=train_transform) 75 | test_dataset = datasets.CIFAR100(root=os.path.join(data_path, dataset), download=download, 76 | transform=test_transform, train=False) 77 | elif dataset == 'imagenet': 78 | train_dataset = datasets.ImageNet(root=os.path.join(data_path, dataset), download=download, 79 | transform=train_transform, split='train') 80 | test_dataset = datasets.ImageNet(root=os.path.join(data_path, dataset), download=False, # Currently the ImageNet valiadation set is inaccessible 81 | transform=test_transform, split='val') 82 | else: 83 | raise ValueError('Valid datasets are cifar10, cifar100 and imagenet. Got: {}'.format(dataset)) 84 | 85 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /models/pixelsnail.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xi Chen 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch 7 | 8 | from math import sqrt 9 | from functools import partial, lru_cache 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | 17 | def wn_linear(in_dim, out_dim): 18 | return nn.utils.weight_norm(nn.Linear(in_dim, out_dim)) 19 | 20 | 21 | class WNConv2d(nn.Module): 22 | def __init__( 23 | self, 24 | in_channel, 25 | out_channel, 26 | kernel_size, 27 | stride=1, 28 | padding=0, 29 | bias=True, 30 | activation=None, 31 | ): 32 | super().__init__() 33 | 34 | self.conv = nn.utils.weight_norm( 35 | nn.Conv2d( 36 | in_channel, 37 | out_channel, 38 | kernel_size, 39 | stride=stride, 40 | padding=padding, 41 | bias=bias, 42 | ) 43 | ) 44 | 45 | self.out_channel = out_channel 46 | 47 | if isinstance(kernel_size, int): 48 | kernel_size = [kernel_size, kernel_size] 49 | 50 | self.kernel_size = kernel_size 51 | 52 | self.activation = activation 53 | 54 | def forward(self, input): 55 | out = self.conv(input) 56 | 57 | if self.activation is not None: 58 | out = self.activation(out) 59 | # print(out.size()) 60 | return out 61 | 62 | 63 | def shift_down(input, size=1): 64 | """ 65 | Shifts given tensor down by `size` rows: 66 | 1. Pad the top of the tensor with `size` rows of zeros 67 | 2. Remove the bottom `size` rows 68 | :param input: Tensor. Tensor to be shifted 69 | :param size: Int. Number of rows to shift 70 | :return: Tensor. Shifted input 71 | """ 72 | return F.pad(input, [0, 0, size, 0])[:, :, : input.shape[2], :] 73 | 74 | 75 | def shift_right(input, size=1): 76 | """ 77 | Shifts given tensor right by `size` columns: 78 | 1. Pad the left of the tensor with `size` columns of zeros 79 | 2. Remove the rightmost `size` columns 80 | :param input: Tensor. Tensor to be shifted 81 | :param size: Int. Number of columns to shift 82 | :return: Tensor. Shifted input 83 | """ 84 | return F.pad(input, [size, 0, 0, 0])[:, :, :, : input.shape[3]] 85 | 86 | 87 | class CausalConv2d(nn.Module): 88 | def __init__( 89 | self, 90 | in_channel, 91 | out_channel, 92 | kernel_size, 93 | stride=1, 94 | padding='downright', 95 | activation=None, 96 | ): 97 | super().__init__() 98 | 99 | if isinstance(kernel_size, int): 100 | kernel_size = [kernel_size] * 2 101 | 102 | self.kernel_size = kernel_size 103 | 104 | if padding == 'downright': 105 | pad = [kernel_size[1] - 1, 0, kernel_size[0] - 1, 0] 106 | 107 | elif padding == 'down' or padding == 'causal': 108 | pad = kernel_size[1] // 2 109 | 110 | pad = [pad, pad, kernel_size[0] - 1, 0] 111 | 112 | self.causal = 0 113 | if padding == 'causal': 114 | self.causal = kernel_size[1] // 2 115 | 116 | self.pad = nn.ZeroPad2d(pad) 117 | 118 | self.conv = WNConv2d( 119 | in_channel, 120 | out_channel, 121 | kernel_size, 122 | stride=stride, 123 | padding=0, 124 | activation=activation, 125 | ) 126 | 127 | def forward(self, input): 128 | out = self.pad(input) 129 | 130 | if self.causal > 0: 131 | self.conv.conv.weight_v.data[:, :, -1, self.causal :].zero_() 132 | 133 | out = self.conv(out) 134 | 135 | return out 136 | 137 | 138 | class GatedResBlock(nn.Module): 139 | def __init__( 140 | self, 141 | in_channel, 142 | channel, 143 | kernel_size, 144 | conv='wnconv2d', 145 | activation=nn.ELU, 146 | dropout=0.1, 147 | auxiliary_channel=0, 148 | condition_dim=0, 149 | ): 150 | super().__init__() 151 | 152 | if conv == 'wnconv2d': 153 | conv_module = partial(WNConv2d, padding=kernel_size // 2) 154 | 155 | elif conv == 'causal_downright': 156 | conv_module = partial(CausalConv2d, padding='downright') 157 | 158 | elif conv == 'causal': 159 | conv_module = partial(CausalConv2d, padding='causal') 160 | 161 | self.activation = activation(inplace=True) 162 | self.conv1 = conv_module(in_channel, channel, kernel_size) 163 | 164 | if auxiliary_channel > 0: 165 | self.aux_conv = WNConv2d(auxiliary_channel, channel, 1) 166 | 167 | self.dropout = nn.Dropout(dropout) 168 | 169 | self.conv2 = conv_module(channel, in_channel * 2, kernel_size) 170 | 171 | if condition_dim > 0: 172 | # self.condition = nn.Linear(condition_dim, in_channel * 2, bias=False) 173 | self.condition = WNConv2d(condition_dim, in_channel * 2, 1, bias=False) 174 | 175 | self.gate = nn.GLU(1) 176 | 177 | def forward(self, input, aux_input=None, condition=None): 178 | out = self.conv1(self.activation(input)) 179 | 180 | if aux_input is not None: 181 | out = out + self.aux_conv(self.activation(aux_input)) 182 | 183 | out = self.activation(out) 184 | out = self.dropout(out) 185 | out = self.conv2(out) 186 | 187 | if condition is not None: 188 | condition = self.condition(condition) 189 | out += condition 190 | # out = out + condition.view(condition.shape[0], 1, 1, condition.shape[1]) 191 | 192 | out = self.gate(out) 193 | out += input 194 | 195 | return out 196 | 197 | 198 | @lru_cache(maxsize=64) 199 | def causal_mask(size): 200 | shape = [size, size] 201 | mask = np.triu(np.ones(shape), k=1).astype(np.uint8).T 202 | start_mask = np.ones(size).astype(np.float32) 203 | start_mask[0] = 0 204 | 205 | return ( 206 | torch.from_numpy(mask).unsqueeze(0), 207 | torch.from_numpy(start_mask).unsqueeze(1), 208 | ) 209 | 210 | 211 | class CausalAttention(nn.Module): 212 | def __init__(self, query_channel, key_channel, channel, n_head=8, dropout=0.1): 213 | super().__init__() 214 | 215 | self.query = wn_linear(query_channel, channel) 216 | self.key = wn_linear(key_channel, channel) 217 | self.value = wn_linear(key_channel, channel) 218 | 219 | self.dim_head = channel // n_head 220 | self.n_head = n_head 221 | 222 | self.dropout = nn.Dropout(dropout) 223 | 224 | def forward(self, query, key): 225 | batch, _, height, width = key.shape 226 | 227 | def reshape(input): 228 | return input.view(batch, -1, self.n_head, self.dim_head).transpose(1, 2) 229 | 230 | query_flat = query.view(batch, query.shape[1], -1).transpose(1, 2) 231 | key_flat = key.view(batch, key.shape[1], -1).transpose(1, 2) 232 | query = reshape(self.query(query_flat)) 233 | key = reshape(self.key(key_flat)).transpose(2, 3) 234 | value = reshape(self.value(key_flat)) 235 | 236 | attn = torch.matmul(query, key) / sqrt(self.dim_head) 237 | mask, start_mask = causal_mask(height * width) 238 | mask = mask.type_as(query) 239 | start_mask = start_mask.type_as(query) 240 | attn = attn.masked_fill(mask == 0, -1e4) 241 | attn = torch.softmax(attn, 3) * start_mask 242 | attn = self.dropout(attn) 243 | 244 | out = attn @ value 245 | out = out.transpose(1, 2).reshape( 246 | batch, height, width, self.dim_head * self.n_head 247 | ) 248 | out = out.permute(0, 3, 1, 2) 249 | 250 | return out 251 | 252 | 253 | class PixelBlock(nn.Module): 254 | def __init__( 255 | self, 256 | in_channel, 257 | channel, 258 | kernel_size, 259 | n_res_block, 260 | attention=True, 261 | dropout=0.1, 262 | condition_dim=0, 263 | ): 264 | super().__init__() 265 | 266 | resblocks = [] 267 | for i in range(n_res_block): 268 | resblocks.append( 269 | GatedResBlock( 270 | in_channel, 271 | channel, 272 | kernel_size, 273 | conv='causal', 274 | dropout=dropout, 275 | condition_dim=condition_dim, 276 | ) 277 | ) 278 | 279 | self.resblocks = nn.ModuleList(resblocks) 280 | 281 | self.attention = attention 282 | 283 | if attention: 284 | self.key_resblock = GatedResBlock( 285 | in_channel * 2 + 2, in_channel, 1, dropout=dropout 286 | ) 287 | self.query_resblock = GatedResBlock( 288 | in_channel + 2, in_channel, 1, dropout=dropout 289 | ) 290 | 291 | self.causal_attention = CausalAttention( 292 | in_channel + 2, in_channel * 2 + 2, in_channel // 2, dropout=dropout 293 | ) 294 | 295 | self.out_resblock = GatedResBlock( 296 | in_channel, 297 | in_channel, 298 | 1, 299 | auxiliary_channel=in_channel // 2, 300 | dropout=dropout, 301 | ) 302 | 303 | else: 304 | self.out = WNConv2d(in_channel + 2, in_channel, 1) 305 | 306 | def forward(self, input, background, condition=None): 307 | out = input 308 | 309 | for resblock in self.resblocks: 310 | out = resblock(out, condition=condition) 311 | 312 | if self.attention: 313 | key_cat = torch.cat([input, out, background], 1) 314 | key = self.key_resblock(key_cat) 315 | query_cat = torch.cat([out, background], 1) 316 | query = self.query_resblock(query_cat) 317 | attn_out = self.causal_attention(query, key) 318 | out = self.out_resblock(out, attn_out) 319 | 320 | else: 321 | bg_cat = torch.cat([out, background], 1) 322 | out = self.out(bg_cat) 323 | 324 | return out 325 | 326 | 327 | class CondResNet(nn.Module): 328 | def __init__(self, in_channel, channel, kernel_size, n_res_block): 329 | super().__init__() 330 | 331 | blocks = [WNConv2d(in_channel, channel, kernel_size, padding=kernel_size // 2)] 332 | 333 | for i in range(n_res_block): 334 | blocks.append(GatedResBlock(channel, channel, kernel_size)) 335 | 336 | self.blocks = nn.Sequential(*blocks) 337 | 338 | def forward(self, input): 339 | return self.blocks(input) 340 | 341 | 342 | class PixelSNAIL(nn.Module): 343 | def __init__( 344 | self, 345 | shape, 346 | n_class, 347 | channel, 348 | kernel_size, 349 | n_block, 350 | n_res_block, 351 | res_channel, 352 | attention=True, 353 | dropout=0.1, 354 | n_cond_res_block=0, 355 | cond_res_channel=0, 356 | cond_res_kernel=3, 357 | n_out_res_block=0, 358 | ): 359 | super().__init__() 360 | 361 | height, width = shape 362 | 363 | self.n_class = n_class 364 | 365 | if kernel_size % 2 == 0: 366 | kernel = kernel_size + 1 367 | 368 | else: 369 | kernel = kernel_size 370 | 371 | self.horizontal = CausalConv2d( 372 | n_class, channel, [kernel // 2, kernel], padding='down' 373 | ) 374 | self.vertical = CausalConv2d( 375 | n_class, channel, [(kernel + 1) // 2, kernel // 2], padding='downright' 376 | ) 377 | 378 | coord_x = (torch.arange(height).float() - height / 2) / height 379 | coord_x = coord_x.view(1, 1, height, 1).expand(1, 1, height, width) 380 | coord_y = (torch.arange(width).float() - width / 2) / width 381 | coord_y = coord_y.view(1, 1, 1, width).expand(1, 1, height, width) 382 | self.register_buffer('background', torch.cat([coord_x, coord_y], 1)) 383 | 384 | self.blocks = nn.ModuleList() 385 | 386 | for i in range(n_block): 387 | self.blocks.append( 388 | PixelBlock( 389 | channel, 390 | res_channel, 391 | kernel_size, 392 | n_res_block, 393 | attention=attention, 394 | dropout=dropout, 395 | condition_dim=cond_res_channel, 396 | ) 397 | ) 398 | 399 | if n_cond_res_block > 0: 400 | self.cond_resnet = CondResNet( 401 | n_class, cond_res_channel, cond_res_kernel, n_cond_res_block 402 | ) 403 | 404 | out = [] 405 | 406 | for i in range(n_out_res_block): 407 | out.append(GatedResBlock(channel, res_channel, 1)) 408 | 409 | out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)]) 410 | 411 | self.out = nn.Sequential(*out) 412 | 413 | def ids2vectors(self, ids): 414 | """ 415 | Transforms given tensor of ids to tensor of one-hot encodings 416 | """ 417 | vectors = ( 418 | F.one_hot(ids.clone(), self.n_class).permute(0, 3, 1, 2).type_as(self.background) 419 | ) 420 | 421 | return vectors 422 | 423 | def forward(self, input, condition=None, cache=None): 424 | if cache is None: 425 | cache = {} 426 | 427 | batch, height, width = input.shape 428 | input=self.ids2vectors(input) 429 | horizontal = shift_down(self.horizontal(input)) 430 | vertical = shift_right(self.vertical(input)) 431 | out = horizontal + vertical 432 | 433 | background = self.background[:, :, :height, :].expand(batch, 2, height, width) 434 | 435 | ## Following code is only relevant when doing VQ-VAE cascades 436 | # if condition is not None: 437 | # if 'condition' in cache: 438 | # condition = cache['condition'] 439 | # condition = condition[:, :, :height, :] 440 | # 441 | # else: 442 | # condition = ( 443 | # F.one_hot(condition, self.n_class) 444 | # .permute(0, 3, 1, 2) 445 | # .type_as(self.background) 446 | # ) 447 | # condition = self.cond_resnet(condition) 448 | # condition = F.interpolate(condition, scale_factor=2) 449 | # cache['condition'] = condition.detach().clone() 450 | # condition = condition[:, :, :height, :] 451 | 452 | for block in self.blocks: 453 | out = block(out, background, condition=condition) 454 | 455 | out = self.out(out) 456 | 457 | return out, cache 458 | -------------------------------------------------------------------------------- /models/quantizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | # from torch.nn.functional import normalize 5 | 6 | from utils.util_funcs import has_value_and_true 7 | from utils.pyfista import FistaFunction 8 | from utils.pyomp import Batch_OMP 9 | 10 | import numpy as np 11 | 12 | 13 | class FistaQuantize(nn.Module): 14 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, neighborhood=2, num_workers=4, alpha=0.1, **kwargs): 15 | """ 16 | 17 | :param dim: Int. Size of latent space. 18 | :param n_embed: Int. Size of dictionary. 19 | :param alpha: Float. Fista shrinkage value. 20 | """ 21 | super().__init__() 22 | 23 | self.num_workers=num_workers 24 | self.dim = dim 25 | self.n_embed = n_embed 26 | self.decay = decay 27 | self.eps = eps 28 | self.neighborhood = neighborhood 29 | self.alpha = alpha 30 | 31 | # Dictionary tensor 32 | self.dictionary = torch.nn.Parameter(torch.randn(dim, n_embed, requires_grad=True)) 33 | 34 | self.normalize_dict = has_value_and_true(kwargs, 'normalize_dict') # Normalize dictionary flag 35 | self.normalize_z = has_value_and_true(kwargs, 'normalize_z') # Normalize sparse code flag 36 | self.normalize_x = has_value_and_true(kwargs, 'normalize_x') # Normalize quantization input flag 37 | self.is_enforce_sparsity = has_value_and_true(kwargs, 'is_enforce_sparsity') # Flag to enforce sparsity by selecting top K sparse code values 38 | self.is_quantize_coefs = has_value_and_true(kwargs, 'is_quantize_coefs') # Flag to quantize sparse code for compression 39 | self.backward_dict = has_value_and_true(kwargs, 'backward_dict') # Flag to backprop with respect to the dictionary 40 | self.sample_gradients = kwargs['sample_gradients'] # Number of gradients to use when backproping through the dictionary 41 | self.use_backwards_simd = kwargs['use_backwards_simd'] # Use matrix backprop instead of loop backprop 42 | 43 | self.normalization_flag = self.normalize_dict 44 | self.debug = False 45 | 46 | def forward(self, input): 47 | """ 48 | Denote dimensions: 49 | B - number of patches 50 | d - size of latent space 51 | D - size of dictionary 52 | 53 | :param input: Signal to get the sparse code for 54 | :return: quantized input 55 | """ 56 | 57 | permuted_input = input.permute(0, 2, 3, 1) 58 | flatten = permuted_input.reshape(-1, self.dim) # Shape: Bxd 59 | 60 | if self.normalize_x: 61 | flatten = F.normalize(flatten, p=2, dim=1) 62 | 63 | if self.normalization_flag and self.normalize_dict: 64 | with torch.no_grad(): # Cannot directly change a module Parameter outside of no_grad 65 | self.dictionary.data = self.dictionary.data.__div__(torch.norm(self.dictionary.data,p=2,dim=0)) # Shape: dXD 66 | 67 | if not self.training: 68 | self.normalization_flag = False 69 | 70 | sparse_code, num_fista_steps = FistaFunction.apply(flatten.t(), self.dictionary, self.alpha, 0.01, -1, False, 71 | self.sample_gradients, self.use_backwards_simd) 72 | if self.debug: 73 | # We print this to understand what the range of the sparse code is to know when we quantize it in test time 74 | print('Sparse code value range: min: {} | max: {}'.format(sparse_code.min(), sparse_code.max())) 75 | 76 | if self.training and self.normalize_z and sparse_code.abs().max() > np.finfo(float).eps * 10: #TODO: Decide if we continue skipping or find another solution like reduce alpha and run again 77 | with torch.no_grad(): # We can normalize the coefs with learning through them but we choose not to for symmetry with embeddings 78 | sparse_code.data = sparse_code.data.__div__(torch.norm(sparse_code.data,p=2,dim=0)) # Shape: BXD 79 | 80 | if self.debug: 81 | print('Sparse code average L0 norm: {}'.format(sparse_code.norm(0, 0).mean())) 82 | 83 | # This is only used when calculating PSNR to control over the compression rate 84 | if not self.training and self.is_enforce_sparsity: 85 | sparse_code = self.enforce_sparsity(sparse_code) 86 | 87 | # Quantize sparse code to only use a set number of bits 88 | if self.is_quantize_coefs: 89 | # print('Quantizing sparse code coefficients') 90 | sparse_code = self.hardcode_quantize(sparse_code) 91 | 92 | # Apply sparse code to input to get quantization of it 93 | quantize = sparse_code.t().float().mm(self.dictionary.t()).to(flatten.device) 94 | quantize = quantize.view(*permuted_input.shape) 95 | 96 | # We reshape the sparse code as well to conform to patches 97 | reshapes = list(permuted_input.shape) 98 | reshapes[-1] = sparse_code.size()[0] 99 | ids = sparse_code.t().view(*reshapes) 100 | 101 | if self.backward_dict: 102 | quantization_diff_for_encoder = (quantize.detach() - permuted_input).pow(2).mean() 103 | quantization_diff_for_dictionary = (quantize - permuted_input.detach()).pow(2).mean() 104 | quantize = permuted_input + (quantize - permuted_input).detach() 105 | else: 106 | # If we don't want to backprop through the dictionary we simply duplicate the quantization_diff_for_encoder 107 | # and detach to prevent double backprop 108 | quantization_diff_for_encoder = (quantize.detach() - permuted_input).pow(2).mean() 109 | quantization_diff_for_dictionary = (quantize.detach() - permuted_input).pow(2).mean().detach() 110 | quantize = permuted_input + (quantize - permuted_input).detach() 111 | 112 | # Reporting like this and not in a single object because PyTorch throws a fit 113 | norm_0 = sparse_code.norm(0, 0) 114 | num_quantization_steps = num_fista_steps.detach() 115 | mean_D = self.dictionary.abs().mean().detach() 116 | mean_Z = sparse_code.abs().mean().detach() 117 | norm_Z = norm_0.mean().detach() 118 | topk_num = max(1, int(len(norm_0)*0.01)) 119 | 120 | top_percentile = norm_0.topk(topk_num).values.min().detach() 121 | num_zeros = (norm_0==0).int().sum().float().detach() 122 | if self.debug: 123 | print('num zero: {}'.format(num_zeros)) 124 | 125 | return quantize.permute(0, 3, 1, 2), [quantization_diff_for_encoder, quantization_diff_for_dictionary], ids.permute(0, 3, 1, 2), num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros 126 | 127 | def enforce_sparsity(self, sparse_code, sparsity_size=10): 128 | """ 129 | This function is used to enforce a certain sparsity on the input sparse code by keeping only the top 130 | sparsity_size values non-zero. 131 | :param sparse_code: Tensor. Sparse code that we want to enforce sparsity on 132 | :param sparsity_size: Int. Hard limit on the non-zeros we allow in the sparse code 133 | :return: Tensor. The sparse_code with only the top sparsity_size values not zeroed out. 134 | """ 135 | tmp_coefs1 = torch.zeros(sparse_code.size()).to(sparse_code.device) 136 | tmp_coefs2 = torch.zeros(sparse_code.size()).to(sparse_code.device) 137 | tops = torch.topk(sparse_code.abs(), sparsity_size, 0) 138 | torch.gather(sparse_code.detach(), 0, tops[1], out=tmp_coefs1) 139 | tmp_coefs2.scatter_(0, tops[1], tmp_coefs1) 140 | sparse_code = tmp_coefs2 141 | return sparse_code 142 | 143 | def hardcode_quantize(self, sparse_code, min_val=-0.55, max_val=0.55, bits=8): 144 | """ 145 | Clamps the sparse code in the range (min_val, max_val) and quantizes to limited number of bits. 146 | :param sparse_code: Tensor. Sparse code to be quantized. 147 | :param min_val: Float. Lower range boundary. 148 | :param max_val: Float. Upper range boundary. 149 | :param bits: Int. Number of bits to qunatize to. 150 | :return: Tensor. Quantized sparse code. 151 | """ 152 | sparse_code=sparse_code.clamp(min_val, max_val) 153 | sparse_code -= min_val 154 | sparse_code /= (max_val - min_val) 155 | sparse_code *= 2**bits 156 | sparse_code = sparse_code.round() 157 | sparse_code /= 2**bits 158 | sparse_code *= (max_val-min_val) 159 | sparse_code += min_val 160 | return sparse_code 161 | 162 | def embed_code(self, sparse_code): 163 | """ 164 | Performs a de-quantization operation for the given sparse code 165 | :param sparse_code: Tensor. Sparse code we desire to use. Dimensions: (Batch, Sparse code, Width, Height) 166 | :return: Tensor. Linear combination of the dictionary based on given sparse code 167 | """ 168 | 169 | # Transform to (Batch, Width, Height, Sparse code), then flatten the first three dimensions 170 | permuted_sparse_code = sparse_code.permute(0, 2, 3, 1) 171 | aligned_sparse_code = permuted_sparse_code.contiguous().view(-1, sparse_code.size()[1]) 172 | 173 | # Project dictionary on sparse code to create latent image 174 | result = aligned_sparse_code.mm(self.dictionary.t()).t() 175 | 176 | # Reshape latent image to (Batch, Width, Height, Latent) 177 | reshaped_result = result.view(-1, sparse_code.size()[0], sparse_code.size()[2], sparse_code.size()[3]) 178 | permuted_results = reshaped_result.permute(1, 2, 3, 0) 179 | 180 | return permuted_results 181 | 182 | 183 | class OMPQuantize(nn.Module): 184 | def __init__(self, dim, n_embed, num_nonzero=1, eps=1e-9, num_workers=4, **kwargs): 185 | super().__init__() 186 | 187 | self.num_workers = num_workers 188 | self.dim = dim 189 | self.n_embed = n_embed 190 | self.num_nonzero = num_nonzero 191 | self.eps = eps 192 | 193 | # Dictionary tensor 194 | self.dictionary = torch.nn.Parameter(torch.randn(dim, n_embed, requires_grad=True)) 195 | self.normalize_dict = has_value_and_true(kwargs, 'normalize_dict') 196 | 197 | self.normalize_dict = has_value_and_true(kwargs, 'normalize_dict') # Normalize dictionary flag 198 | self.is_quantize_coefs = has_value_and_true(kwargs, 'is_quantize_coefs') # Flag to quantize sparse code for compression 199 | self.normalize_x = has_value_and_true(kwargs, 'normalize_x') # Normalize quantization input flag 200 | self.backward_dict = has_value_and_true(kwargs, 'backward_dict') # Flag to backprop with respect to the dictionary 201 | 202 | self._quantize_bits = 8 203 | self._quantize_max_val = 0.55 204 | 205 | def forward(self, _input): 206 | 207 | permuted_input = _input.permute(0, 2, 3, 1) 208 | flatten = permuted_input.reshape(-1, self.dim) # Shape: Bxd 209 | 210 | if self.normalize_x: 211 | flatten = F.normalize(flatten, p=2, dim=1) 212 | 213 | if self.normalize_dict: 214 | with torch.no_grad(): # Cannot directly change a module Parameter outside of no_grad 215 | self.dictionary.data = self.dictionary.data.__div__(torch.norm(self.dictionary.data,p=2,dim=0)) # Shape: dXD 216 | 217 | with torch.no_grad(): # OMP selection process 218 | sparse_code = Batch_OMP(flatten.t(), self.dictionary, self.num_nonzero, tolerance=self.eps) 219 | 220 | # Quantize sparse code to only use a predefined number of bits 221 | if self.is_quantize_coefs: 222 | # print('Quantizing sparse code coefficients') 223 | sparse_code = self.hardcode_quantize(sparse_code) 224 | 225 | # Apply sparse code to input to get quantization of it 226 | quantize = sparse_code.t().float().mm(self.dictionary.t()).to(_input.device) 227 | quantize = quantize.view(*permuted_input.shape) 228 | 229 | # We reshape the sparse code as well to conform to patches 230 | reshapes = list(permuted_input.shape) 231 | reshapes[-1] = sparse_code.size()[0] 232 | ids = sparse_code.t().view(*reshapes) 233 | 234 | if self.backward_dict: 235 | quantization_diff_for_encoder = (quantize.detach() - permuted_input).pow(2).mean() 236 | quantization_diff_for_dictionary = (quantize - permuted_input.detach()).pow(2).mean() 237 | quantize = permuted_input + (quantize - permuted_input).detach() 238 | else: 239 | # If we don't want to backprop through the dictionary we simply duplicate the quantization_diff_for_encoder 240 | # and detach to prevent double backprop 241 | quantization_diff_for_encoder = (quantize.detach() - permuted_input).pow(2).mean() 242 | quantization_diff_for_dictionary = (quantize.detach() - permuted_input).pow(2).mean().detach() 243 | quantize = permuted_input + (quantize - permuted_input).detach() 244 | 245 | # Reporting like this and not in a single object because PyTorch throws a fit 246 | norm_0 = sparse_code.norm(0, 0) 247 | num_quantization_steps = torch.tensor(self.num_nonzero, dtype=torch.float).to(self.dictionary.device).detach() 248 | mean_D = self.dictionary.abs().mean().detach() 249 | mean_Z = sparse_code.abs().mean().detach() 250 | norm_Z = norm_0.mean().detach() 251 | topk_num = max(1, int(len(norm_0)*0.01)) 252 | top_percentile = norm_0.topk(topk_num).values.min().detach() 253 | num_zeros = (norm_0 == 0).int().sum().float().detach() 254 | 255 | return quantize.permute(0, 3, 1, 2), [quantization_diff_for_encoder, quantization_diff_for_dictionary], ids.permute(0, 3, 1, 2), num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros 256 | 257 | def embed_code(self, sparse_code): 258 | """ 259 | Performs a de-quantization operation for the given sparse code 260 | :param sparse_code: Tensor. Sparse code we desire to use. Dimensions: (Batch, Sparse code, Width, Height) 261 | :return: Tensor. Linear combination of the dictionary based on given sparse code 262 | """ 263 | 264 | # Transform to (Batch, Width, Height, Sparse code), then flatten the first three dimensions 265 | permuted_sparse_code = sparse_code.permute(0, 2, 3, 1) 266 | aligned_sparse_code = permuted_sparse_code.contiguous().view(-1, sparse_code.size()[1]) 267 | 268 | # Project dictionary on sparse code to create latent image 269 | result = aligned_sparse_code.mm(self.dictionary.t()).t() 270 | 271 | # Reshape latent image to (Batch, Width, Height, Latent) 272 | reshaped_result = result.view(-1, sparse_code.size()[0], sparse_code.size()[2], sparse_code.size()[3]) 273 | permuted_results = reshaped_result.permute(1, 2, 3, 0) 274 | 275 | return permuted_results 276 | 277 | def hardcode_quantize(self, sparse_code): 278 | """ 279 | Clamps the sparse code in the range (min_val, max_val) and quantizes to limited number of bits. 280 | :param sparse_code: Tensor. Sparse code to be quantized. 281 | :return: Tensor. Quantized sparse code. 282 | """ 283 | min_val = -self._quantize_max_val 284 | max_val = self._quantize_max_val 285 | sparse_code = sparse_code.clamp(min_val, max_val) 286 | sparse_code -= min_val 287 | sparse_code /= (max_val - min_val) 288 | sparse_code *= 2 ** self._quantize_bits 289 | sparse_code = sparse_code.round() 290 | sparse_code /= 2 ** self._quantize_bits 291 | sparse_code *= (max_val-min_val) 292 | sparse_code += min_val 293 | return sparse_code 294 | 295 | 296 | class VanillaQuantize(nn.Module): 297 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, **kwargs): 298 | """ 299 | :param dim: Int. Size of latent space. 300 | :param n_embed: Int. Size of dictionary. 301 | """ 302 | super().__init__() 303 | 304 | self.dim = dim 305 | self.n_embed = n_embed 306 | self.decay = decay 307 | self.eps = eps 308 | 309 | dictionary = torch.randn(dim, n_embed) 310 | self.register_buffer('dictionary', dictionary) 311 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 312 | self.register_buffer('dictionary_avg', dictionary.clone()) 313 | self.normalize_dict = has_value_and_true(kwargs, 'normalize_dict') 314 | self.normalize_x = has_value_and_true(kwargs, 'normalize_x') 315 | 316 | self.normalization_flag = self.normalize_dict 317 | 318 | def forward(self, _input): 319 | """ 320 | Denote dimensions: 321 | B - number of patches 322 | d - size of latent space 323 | D - size of dictionary 324 | 325 | :param input: Signal to get the sparse code for 326 | :return: quantized input 327 | """ 328 | 329 | if self.normalization_flag and self.normalize_dict: # We don't need the no_grad operation as dictionary if a buffer 330 | self.dictionary = self.dictionary.div(torch.norm(self.dictionary, p=2, dim=0).expand_as(self.dictionary)) 331 | 332 | if not self.training: 333 | self.normalization_flag = False 334 | 335 | permuted_input = _input.permute(0, 2, 3, 1) 336 | flatten = permuted_input.reshape(-1, self.dim) # Shape: Bxd 337 | 338 | if self.normalize_x: 339 | flatten = F.normalize(flatten, p=2, dim=1) 340 | 341 | dist = ( 342 | flatten.pow(2).sum(1, keepdim=True) 343 | - 2 * flatten @ self.dictionary 344 | + self.dictionary.pow(2).sum(0, keepdim=True) 345 | ) 346 | _, embed_ind = (-dist).max(1) 347 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) 348 | embed_ind = embed_ind.view(*permuted_input.shape[:-1]) 349 | quantize = self.embed_code(embed_ind) 350 | 351 | # This is the EMA dictionary-optimization step 352 | if self.training: 353 | self.cluster_size.data.mul_(self.decay).add_( 354 | 1 - self.decay, embed_onehot.sum(0) 355 | ) 356 | embed_sum = flatten.transpose(0, 1) @ embed_onehot 357 | self.dictionary_avg.data.mul_(self.decay).add_(1 - self.decay, embed_sum) 358 | n = self.cluster_size.sum() 359 | cluster_size = ( 360 | (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n 361 | ) 362 | embed_normalized = self.dictionary_avg / cluster_size.unsqueeze(0) 363 | self.dictionary.data.copy_(embed_normalized) 364 | 365 | diff = (quantize.detach() - permuted_input).pow(2).mean() 366 | quantize = permuted_input + (quantize - permuted_input).detach() 367 | 368 | # Reporting like this and not in a single object because PyTorch throws a fit 369 | num_quantization_steps = torch.ones(1).to(self.dictionary.device).detach() 370 | mean_D = self.dictionary.abs().mean().detach() 371 | mean_Z = torch.ones(1).to(self.dictionary.device).detach() 372 | norm_Z = torch.ones(1).to(self.dictionary.device).detach() 373 | top_percentile = torch.ones(1).to(self.dictionary.device).detach() 374 | num_zeros = torch.zeros(1).to(self.dictionary.device).detach() 375 | 376 | return quantize.permute(0, 3, 1, 2), [diff, torch.zeros(0).to(self.dictionary.device)], embed_ind, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros 377 | 378 | def embed_code(self, embed_id): 379 | """ 380 | Performs a lookup operation for the given id's atom 381 | :param embed_id: Int. Dictionary atom to look for 382 | :return: Tensor. Atom looked for 383 | """ 384 | return F.embedding(embed_id, self.dictionary.transpose(0, 1)) -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from models.quantizers import VanillaQuantize, OMPQuantize, FistaQuantize 5 | from utils.util_funcs import has_value_and_true 6 | 7 | # Copyright 2018 The Sonnet Authors. All Rights Reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | # ============================================================================ 21 | 22 | 23 | # Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch 24 | 25 | class ResBlock(nn.Module): 26 | def __init__(self, in_channel, channel): 27 | super().__init__() 28 | 29 | self.conv = nn.Sequential( 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(in_channel, channel, 3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(channel, in_channel, 1), 34 | ) 35 | 36 | def forward(self, input): 37 | out = self.conv(input) 38 | out += input 39 | 40 | return out 41 | 42 | 43 | class Encoder(nn.Module): 44 | def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, num_strides=1): 45 | super().__init__() 46 | 47 | blocks = [ 48 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), 49 | nn.ReLU(inplace=True)] 50 | 51 | # Each stride reduces the encoded matrix size by 4 52 | for i in range(num_strides-1): 53 | blocks += [ 54 | nn.Conv2d(channel // 2, channel // 2, 4, stride=2, padding=1), 55 | nn.ReLU(inplace=True) 56 | ] 57 | 58 | # Block to remove to recreate old decompression results from December 59 | blocks += [ 60 | nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(channel, channel, 3, padding=1), 63 | ] 64 | 65 | # Use the following block instead of previous block 66 | # when you want to recreate old decompression results from December 67 | # blocks += [ 68 | # # nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), 69 | # # nn.ReLU(inplace=True), 70 | # nn.Conv2d(channel//2, channel, 3, padding=1), 71 | # ] 72 | 73 | for i in range(n_res_block): 74 | blocks.append(ResBlock(channel, n_res_channel)) 75 | 76 | blocks.append(nn.ReLU(inplace=True)) 77 | 78 | self.blocks = nn.Sequential(*blocks) 79 | 80 | def forward(self, input): 81 | return self.blocks(input) 82 | 83 | 84 | class Decoder(nn.Module): 85 | def __init__( 86 | self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, num_strides=1 87 | ): 88 | super().__init__() 89 | 90 | self.channel = channel 91 | self.in_channel = in_channel 92 | self.out_channel = out_channel 93 | 94 | blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] 95 | 96 | for i in range(n_res_block): 97 | blocks.append(ResBlock(channel, n_res_channel)) 98 | 99 | blocks.append(nn.ReLU(inplace=True)) 100 | 101 | # Block to remove to recreate old decompression results from December 102 | blocks.extend( 103 | [ 104 | nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), 105 | nn.ReLU(inplace=True) 106 | ] 107 | ) 108 | 109 | for i in range(num_strides-1): 110 | blocks.extend( [ 111 | nn.ConvTranspose2d(channel // 2, channel // 2, 4, stride=2, padding=1), 112 | nn.ReLU(inplace=True) 113 | ] 114 | ) 115 | 116 | blocks.extend( 117 | [ 118 | nn.ConvTranspose2d( 119 | channel // 2, out_channel, 4, stride=2, padding=1 120 | ), 121 | ] 122 | ) 123 | 124 | # Use the following block instead of previous block 125 | # when you want to recreate old decompression results from December 126 | # blocks.append( 127 | # nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) 128 | # ) 129 | 130 | self.blocks = nn.Sequential(*blocks) 131 | 132 | def forward(self, input): 133 | return self.blocks(input) 134 | 135 | 136 | class VQVAE(nn.Module): 137 | def __init__( 138 | self, 139 | in_channel=3, 140 | channel=128, 141 | n_res_block=2, 142 | n_res_channel=32, 143 | embed_dim=64, 144 | n_embed=512, 145 | decay=0.99, 146 | num_nonzero=1, 147 | neighborhood=1, 148 | selection_fn='omp', 149 | alpha=0.1, 150 | num_strides=1, 151 | **kwargs 152 | ): 153 | super().__init__() 154 | 155 | self.embed_dim = embed_dim 156 | self.channel = channel 157 | self.alpha = alpha 158 | 159 | self.enc = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=2, num_strides=num_strides) 160 | self.quantize_conv = nn.Conv2d(channel, embed_dim, 1) 161 | 162 | if selection_fn == 'omp': 163 | print('Using OMP selection function') 164 | self.quantize = OMPQuantize(embed_dim, n_embed, num_nonzero=num_nonzero, neighborhood=neighborhood, **kwargs) 165 | elif selection_fn == 'fista': 166 | print('Using fista selection function') 167 | self.quantize = FistaQuantize(embed_dim, n_embed, decay=decay, alpha=self.alpha, **kwargs) 168 | elif selection_fn == 'vanilla': 169 | print('Using vanilla selection function') 170 | self.quantize = VanillaQuantize(embed_dim, n_embed, decay=decay, **kwargs) 171 | else: 172 | raise ValueError('Got an illegal selection function: {}'.format(selection_fn)) 173 | 174 | self.dec = Decoder( 175 | embed_dim, 176 | in_channel, 177 | channel, 178 | n_res_block, 179 | n_res_channel, 180 | stride=2, 181 | num_strides=num_strides 182 | ) 183 | 184 | def forward(self, input): 185 | quant, diff, _, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = self.encode(input) 186 | dec = self.decode(quant) 187 | return dec, diff, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros 188 | 189 | def encode(self, input): 190 | enc = self.enc(input) 191 | 192 | quant = self.quantize_conv(enc) 193 | quant, diff, id, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = self.quantize(quant) 194 | 195 | return quant, diff, id, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros 196 | 197 | def decode(self, quant): 198 | dec = self.dec(quant) 199 | 200 | return dec 201 | 202 | def decode_code(self, code): 203 | """ 204 | Given an atom id map - look up the atoms and decode the map 205 | :param code: Tensor. Matrix of dictionary atom ids. 206 | For Sparse code the dimensions should be: (Batch, Sparse code, Width, Height) 207 | 208 | :return: Tensor. The result of decoding the given id map 209 | """ 210 | quant = self.quantize.embed_code(code) 211 | quant = quant.permute(0, 3, 1, 2) 212 | 213 | dec = self.decode(quant) 214 | 215 | return dec 216 | 217 | def to(self, *args, **kwargs): 218 | self = super().to(*args, **kwargs) 219 | self.quantize = self.quantize.to(*args, **kwargs) 220 | return self -------------------------------------------------------------------------------- /mt_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import threading 6 | import logging 7 | import time 8 | 9 | import torch 10 | from torchvision.utils import save_image 11 | from tqdm import tqdm 12 | 13 | from models.vqvae import VQVAE2, VQVAE 14 | from models.pixelsnail import PixelSNAIL 15 | from models.model_utils import get_model 16 | 17 | from utils import util_funcs 18 | 19 | 20 | @torch.no_grad() 21 | def sample_model(thread_id, model, device, batch, size, temperature, condition=None): 22 | row = torch.zeros(batch, *size, dtype=torch.int64).to(device) 23 | cache = {} 24 | 25 | # for i in range(size[0]): 26 | for i in tqdm(range(size[0]), desc='Thread {}, sampling rows'.format(thread_id)): 27 | for j in range(size[1]): 28 | out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache) 29 | prob = torch.softmax(out[:, :, i, j] / temperature, 1) 30 | sample = torch.multinomial(prob, 1).squeeze(-1) 31 | row[:, i, j] = sample 32 | 33 | return row 34 | 35 | 36 | def load_model(model, checkpoint, device, architecture=None, num_embeddings=None, neighborhood=None, selection_fn=None, 37 | **kwargs): 38 | ckpt = torch.load(os.path.join('checkpoint', checkpoint)) 39 | 40 | if 'args' in ckpt: 41 | args = ckpt['args'] 42 | 43 | if model == 'vqvae': 44 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, **kwargs) 45 | 46 | elif model == 'vqvae2': 47 | model = VQVAE2() 48 | 49 | elif model == 'pixelsnail_top': 50 | model = PixelSNAIL( 51 | [args.size//8, args.size//8], 52 | 512, 53 | args.channel, 54 | 5, 55 | 4, 56 | args.n_res_block, 57 | args.n_res_channel, 58 | dropout=args.dropout, 59 | n_out_res_block=args.n_out_res_block, 60 | ) 61 | 62 | elif model == 'pixelsnail_bottom': 63 | model = PixelSNAIL( 64 | [args.size//4, args.size//4], 65 | 512, 66 | args.channel, 67 | 5, 68 | 4, 69 | args.n_res_block, 70 | args.n_res_channel, 71 | attention=False, 72 | dropout=args.dropout, 73 | n_cond_res_block=args.n_cond_res_block, 74 | cond_res_channel=args.n_res_channel, 75 | ) 76 | 77 | if 'model' in ckpt: 78 | ckpt = ckpt['model'] 79 | 80 | model.load_state_dict(ckpt, strict=False) 81 | model = model.to(device) 82 | model.eval() 83 | 84 | return model 85 | 86 | 87 | def sample_from_range(thread_ind, min_ind, max_ind, sampled_directory, device, temp, batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, selection_fn, size, dataset, **kwargs): 88 | logging.info("Sampling thread {}: starting with range [{},{}) on device {}".format(thread_ind, min_ind, max_ind, device)) 89 | pixelsnail_checkpoint_name, vqvae_checkpoint_name = get_checkpoint_names(architecture, ckpt_epoch, dataset, hier, 90 | kwargs, neighborhood, num_embeddings, 91 | pixelsnail_ckpt_epoch, selection_fn, size) 92 | 93 | model_bottom, model_vqvae = load_models(architecture, device, kwargs, neighborhood, num_embeddings, pixelsnail_checkpoint_name, selection_fn, vqvae_checkpoint_name) 94 | # print('Sampling in range {}-{}'.format(min_ind, max_ind)) 95 | # for sample_ind in tqdm(range(min_ind, max_ind), 'Sampling image for: {}'.format(pixelsnail_checkpoint_name)): 96 | for sample_ind in tqdm(range(min_ind, max_ind), 'Sampling image for: {} in range [{},{})'.format(pixelsnail_checkpoint_name, min_ind, max_ind)): 97 | logging.info('Thread {}, sample ind {}'.format(thread_ind, sample_ind)) 98 | bottom_sample = sample_model(thread_ind, model_bottom, device, batch, [size//4, size//4], temp, condition=None) 99 | 100 | decoded_sample = model_vqvae._modules['module'].decode_code(bottom_sample) 101 | decoded_sample = decoded_sample.clamp(-1, 1) 102 | 103 | filename = 'sampled_{}.png'.format(sample_ind) 104 | target_path = os.path.join(sampled_directory, filename) 105 | save_image(decoded_sample, target_path, normalize=True, range=(-1, 1)) 106 | 107 | 108 | def load_models(architecture, device, kwargs, neighborhood, num_embeddings, pixelsnail_checkpoint_name, selection_fn, 109 | vqvae_checkpoint_name): 110 | model_vqvae = load_model('vqvae', vqvae_checkpoint_name, device, architecture, num_embeddings, neighborhood, 111 | selection_fn, **kwargs) 112 | model_bottom = load_model('pixelsnail_bottom', pixelsnail_checkpoint_name, device, **kwargs) 113 | return model_bottom, model_vqvae 114 | 115 | 116 | def get_checkpoint_names(architecture, ckpt_epoch, dataset, hier, kwargs, neighborhood, num_embeddings, 117 | pixelsnail_ckpt_epoch, selection_fn, size): 118 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, 119 | **kwargs) 120 | vqvae_checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) 121 | pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt' 122 | return pixelsnail_checkpoint_name, vqvae_checkpoint_name 123 | 124 | 125 | def create_run(device, temp, batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, 126 | selection_fn, dataset, num_threads, size, **kwargs): 127 | pixelsnail_checkpoint_name, _ = get_checkpoint_names(architecture, ckpt_epoch, dataset, hier, 128 | kwargs, neighborhood, num_embeddings, 129 | pixelsnail_ckpt_epoch, selection_fn, size) 130 | sampled_directory = os.path.join('sampled_images', pixelsnail_checkpoint_name).replace('.pt', '') 131 | if os.path.exists(sampled_directory): 132 | shutil.rmtree(sampled_directory) 133 | os.mkdir(sampled_directory) 134 | 135 | logging.basicConfig(format=format, level=logging.INFO, datefmt="%H:%M:%S") 136 | threads = list() 137 | num_samples = 50000 138 | min_ind = 0 139 | step_size = num_samples // num_threads 140 | for thread_index in range(num_threads): 141 | max_ind = min_ind + step_size 142 | x = threading.Thread(target=sample_from_range, args=(thread_index, min_ind, max_ind, sampled_directory, device + ':{}'.format(thread_index), temp, batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, selection_fn, size, dataset), kwargs=kwargs) 143 | threads.append(x) 144 | x.start() 145 | min_ind = max_ind 146 | if thread_index + 1 == num_threads: 147 | min_ind = max(min_ind, num_samples - step_size) 148 | 149 | for thread_index, thread in enumerate(threads): 150 | thread.join() 151 | 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser = util_funcs.base_parser(parser) 157 | parser = util_funcs.vqvae_parser(parser) 158 | parser = util_funcs.code_extraction_parser(parser) 159 | parser = util_funcs.pixelsnail_parser(parser) 160 | parser = util_funcs.sampling_parser(parser) 161 | args = parser.parse_args() 162 | 163 | print(args) 164 | 165 | create_run(**vars(args)) 166 | 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.4.0 2 | tensorboardX==1.9 3 | tqdm 4 | sklearn -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import torch 6 | from torchvision.utils import save_image 7 | from tqdm import tqdm 8 | 9 | from models.vqvae import VQVAE2, VQVAE 10 | from models.pixelsnail import PixelSNAIL 11 | from models.model_utils import get_model 12 | 13 | from utils import util_funcs 14 | 15 | @torch.no_grad() 16 | def sample_model(model, device, batch, size, temperature, condition=None): 17 | row = torch.zeros(batch, *size, dtype=torch.int64).to(device) 18 | cache = {} 19 | 20 | for i in tqdm(range(size[0]), desc='Sampling rows'): 21 | for j in range(size[1]): 22 | out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache) 23 | prob = torch.softmax(out[:, :, i, j] / temperature, 1) 24 | sample = torch.multinomial(prob, 1).squeeze(-1) 25 | row[:, i, j] = sample.detach() 26 | 27 | return row 28 | 29 | 30 | def load_model(model, checkpoint, device, architecture=None, num_embeddings=None, neighborhood=None, selection_fn=None, size=256, **kwargs): 31 | ckpt = torch.load(os.path.join('checkpoint', checkpoint)) 32 | 33 | if 'args' in ckpt: 34 | args = ckpt['args'] 35 | 36 | if model == 'vqvae': 37 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, **kwargs) 38 | 39 | elif model == 'vqvae2': 40 | model = VQVAE2() 41 | 42 | elif model == 'pixelsnail_top': 43 | model = PixelSNAIL( 44 | [size//8, size//8], 45 | 512, 46 | args.channel, 47 | 5, 48 | 4, 49 | args.n_res_block, 50 | args.n_res_channel, 51 | dropout=args.dropout, 52 | n_out_res_block=args.n_out_res_block, 53 | ) 54 | 55 | elif model == 'pixelsnail_bottom': 56 | model = PixelSNAIL( 57 | [size//4, size//4], 58 | 512, 59 | args.channel, 60 | 5, 61 | 4, 62 | args.n_res_block, 63 | args.n_res_channel, 64 | attention=False, 65 | dropout=args.dropout, 66 | n_cond_res_block=args.n_cond_res_block, 67 | cond_res_channel=args.n_res_channel, 68 | ) 69 | 70 | if 'model' in ckpt: 71 | ckpt = ckpt['model'] 72 | 73 | model.load_state_dict(ckpt, strict=False) 74 | model = model.to(device) 75 | model.eval() 76 | 77 | return model 78 | 79 | 80 | def create_run(device, temp, pixelsnail_batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, selection_fn, dataset, size, **kwargs): 81 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) 82 | vqvae_checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) 83 | 84 | # pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt' 85 | pixelsnail_checkpoint_name = 'pixelsnail_vqvae_imagenet_num_embeddings[512]_neighborhood[1]_selectionFN[vanilla]_size[128]_bottom_420.pt' 86 | 87 | # model_vqvae = load_model('vqvae', vqvae_checkpoint_name, device, architecture, num_embeddings, neighborhood, selection_fn, **kwargs) 88 | # model_top = load_model('pixelsnail_top', args.top, device) 89 | model_bottom = load_model('pixelsnail_bottom', pixelsnail_checkpoint_name, device, size=size, **kwargs) 90 | 91 | num_samples = 50000 92 | sampled_directory = os.path.join('sampled_images', pixelsnail_checkpoint_name).replace('.pt', '') 93 | if os.path.exists(sampled_directory): 94 | shutil.rmtree(sampled_directory) 95 | os.mkdir(sampled_directory) 96 | 97 | for sample_ind in tqdm(range(num_samples), 'Sampling image for: {}'.format(pixelsnail_checkpoint_name)): 98 | 99 | # top_sample = sample_model(model_top, device, args.batch, [32, 32], args.temp) 100 | bottom_sample = sample_model( 101 | model_bottom, device, pixelsnail_batch, [size // 4, size // 4], temp, condition=None 102 | # model_bottom, device, args.batch, [64, 64], args.temp, condition=top_sample 103 | ) 104 | 105 | # decoded_sample = model_vqvae._modules['module'].decode_code(bottom_sample) 106 | # decoded_sample = model_vqvae.decode_code(top_sample, bottom_sample) 107 | # decoded_sample = decoded_sample.clamp(-1, 1) 108 | 109 | # filename = 'sampled_{}.png'.format(sample_ind) 110 | # target_path = os.path.join(sampled_directory, filename) 111 | save_image(decoded_sample, target_path, normalize=True, range=(-1, 1)) 112 | 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser = util_funcs.base_parser(parser) 118 | parser = util_funcs.vqvae_parser(parser) 119 | parser = util_funcs.code_extraction_parser(parser) 120 | parser = util_funcs.pixelsnail_parser(parser) 121 | parser = util_funcs.sampling_parser(parser) 122 | args = parser.parse_args() 123 | 124 | print(args) 125 | 126 | create_run(**vars(args)) 127 | 128 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, floor, sin 2 | 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | class CosineLR(lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, lr_min, lr_max, step_size): 8 | self.lr_min = lr_min 9 | self.lr_max = lr_max 10 | self.step_size = step_size 11 | self.iteration = 0 12 | 13 | super().__init__(optimizer, -1) 14 | 15 | def get_lr(self): 16 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 17 | 1 + cos(self.iteration / self.step_size * pi) 18 | ) 19 | self.iteration += 1 20 | 21 | if self.iteration == self.step_size: 22 | self.iteration = 0 23 | 24 | return [lr for base_lr in self.base_lrs] 25 | 26 | 27 | class PowerLR(lr_scheduler._LRScheduler): 28 | def __init__(self, optimizer, lr_min, lr_max, warmup): 29 | self.lr_min = lr_min 30 | self.lr_max = lr_max 31 | self.warmup = warmup 32 | self.iteration = 0 33 | 34 | super().__init__(optimizer, -1) 35 | 36 | def get_lr(self): 37 | if self.iteration < self.warmup: 38 | lr = ( 39 | self.lr_min + (self.lr_max - self.lr_min) / self.warmup * self.iteration 40 | ) 41 | 42 | else: 43 | lr = self.lr_max * (self.iteration - self.warmup + 1) ** -0.5 44 | 45 | self.iteration += 1 46 | 47 | return [lr for base_lr in self.base_lrs] 48 | 49 | 50 | class SineLR(lr_scheduler._LRScheduler): 51 | def __init__(self, optimizer, lr_min, lr_max, step_size): 52 | self.lr_min = lr_min 53 | self.lr_max = lr_max 54 | self.step_size = step_size 55 | self.iteration = 0 56 | 57 | super().__init__(optimizer, -1) 58 | 59 | def get_lr(self): 60 | lr = self.lr_min + (self.lr_max - self.lr_min) * sin( 61 | self.iteration / self.step_size * pi 62 | ) 63 | self.iteration += 1 64 | 65 | if self.iteration == self.step_size: 66 | self.iteration = 0 67 | 68 | return [lr for base_lr in self.base_lrs] 69 | 70 | 71 | class LinearLR(lr_scheduler._LRScheduler): 72 | def __init__(self, optimizer, lr_min, lr_max, warmup, step_size): 73 | self.lr_min = lr_min 74 | self.lr_max = lr_max 75 | self.step_size = step_size 76 | self.warmup = warmup 77 | self.iteration = 0 78 | 79 | super().__init__(optimizer, -1) 80 | 81 | def get_lr(self): 82 | if self.iteration < self.warmup: 83 | lr = self.lr_max 84 | 85 | else: 86 | lr = self.lr_max + (self.iteration - self.warmup) * ( 87 | self.lr_min - self.lr_max 88 | ) / (self.step_size - self.warmup) 89 | self.iteration += 1 90 | 91 | if self.iteration == self.step_size: 92 | self.iteration = 0 93 | 94 | return [lr for base_lr in self.base_lrs] 95 | 96 | 97 | class CLR(lr_scheduler._LRScheduler): 98 | def __init__(self, optimizer, lr_min, lr_max, step_size): 99 | self.epoch = 0 100 | self.lr_min = lr_min 101 | self.lr_max = lr_max 102 | self.current_lr = lr_min 103 | self.step_size = step_size 104 | 105 | super().__init__(optimizer, -1) 106 | 107 | def get_lr(self): 108 | cycle = floor(1 + self.epoch / (2 * self.step_size)) 109 | x = abs(self.epoch / self.step_size - 2 * cycle + 1) 110 | lr = self.lr_min + (self.lr_max - self.lr_min) * max(0, 1 - x) 111 | self.current_lr = lr 112 | 113 | self.epoch += 1 114 | 115 | return [lr for base_lr in self.base_lrs] 116 | 117 | 118 | class Warmup(lr_scheduler._LRScheduler): 119 | def __init__(self, optimizer, model_dim, factor=1, warmup=16000): 120 | self.optimizer = optimizer 121 | self.model_dim = model_dim 122 | self.factor = factor 123 | self.warmup = warmup 124 | self.iteration = 0 125 | 126 | super().__init__(optimizer, -1) 127 | 128 | def get_lr(self): 129 | self.iteration += 1 130 | lr = ( 131 | self.factor 132 | * self.model_dim ** (-0.5) 133 | * min(self.iteration ** (-0.5), self.iteration * self.warmup ** (-1.5)) 134 | ) 135 | 136 | return [lr for base_lr in self.base_lrs] 137 | 138 | 139 | # Copyright 2019 fastai 140 | 141 | # Licensed under the Apache License, Version 2.0 (the "License"); 142 | # you may not use this file except in compliance with the License. 143 | # You may obtain a copy of the License at 144 | 145 | # http://www.apache.org/licenses/LICENSE-2.0 146 | 147 | # Unless required by applicable law or agreed to in writing, software 148 | # distributed under the License is distributed on an "AS IS" BASIS, 149 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 150 | # See the License for the specific language governing permissions and 151 | # limitations under the License. 152 | 153 | 154 | # Borrowed from https://github.com/fastai/fastai and changed to make it runs like PyTorch lr scheduler 155 | 156 | 157 | class CycleAnnealScheduler: 158 | def __init__( 159 | self, optimizer, lr_max, lr_divider, cut_point, step_size, momentum=None 160 | ): 161 | self.lr_max = lr_max 162 | self.lr_divider = lr_divider 163 | self.cut_point = step_size // cut_point 164 | self.step_size = step_size 165 | self.iteration = 0 166 | self.cycle_step = int(step_size * (1 - cut_point / 100) / 2) 167 | self.momentum = momentum 168 | self.optimizer = optimizer 169 | 170 | def get_lr(self): 171 | if self.iteration > 2 * self.cycle_step: 172 | cut = (self.iteration - 2 * self.cycle_step) / ( 173 | self.step_size - 2 * self.cycle_step 174 | ) 175 | lr = self.lr_max * (1 + (cut * (1 - 100) / 100)) / self.lr_divider 176 | 177 | elif self.iteration > self.cycle_step: 178 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 179 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 180 | 181 | else: 182 | cut = self.iteration / self.cycle_step 183 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 184 | 185 | return lr 186 | 187 | def get_momentum(self): 188 | if self.iteration > 2 * self.cycle_step: 189 | momentum = self.momentum[0] 190 | 191 | elif self.iteration > self.cycle_step: 192 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 193 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 194 | 195 | else: 196 | cut = self.iteration / self.cycle_step 197 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 198 | 199 | return momentum 200 | 201 | def step(self): 202 | lr = self.get_lr() 203 | 204 | if self.momentum is not None: 205 | momentum = self.get_momentum() 206 | 207 | self.iteration += 1 208 | 209 | if self.iteration == self.step_size: 210 | self.iteration = 0 211 | 212 | for group in self.optimizer.param_groups: 213 | group['lr'] = lr 214 | 215 | if self.momentum is not None: 216 | group['betas'] = (momentum, group['betas'][1]) 217 | 218 | return lr 219 | 220 | 221 | def anneal_linear(start, end, proportion): 222 | return start + proportion * (end - start) 223 | 224 | 225 | def anneal_cos(start, end, proportion): 226 | cos_val = cos(pi * proportion) + 1 227 | 228 | return end + (start - end) / 2 * cos_val 229 | 230 | 231 | class Phase: 232 | def __init__(self, start, end, n_iter, anneal_fn): 233 | self.start, self.end = start, end 234 | self.n_iter = n_iter 235 | self.anneal_fn = anneal_fn 236 | self.n = 0 237 | 238 | def step(self): 239 | self.n += 1 240 | 241 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 242 | 243 | def reset(self): 244 | self.n = 0 245 | 246 | @property 247 | def is_done(self): 248 | return self.n >= self.n_iter 249 | 250 | 251 | class CycleScheduler: 252 | def __init__( 253 | self, 254 | optimizer, 255 | lr_max, 256 | n_iter, 257 | momentum=(0.95, 0.85), 258 | divider=25, 259 | warmup_proportion=0.3, 260 | phase=('linear', 'cos'), 261 | ): 262 | self.optimizer = optimizer 263 | 264 | phase1 = int(n_iter * warmup_proportion) 265 | phase2 = n_iter - phase1 266 | lr_min = lr_max / divider 267 | 268 | phase_map = {'linear': anneal_linear, 'cos': anneal_cos} 269 | 270 | self.lr_phase = [ 271 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]), 272 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]), 273 | ] 274 | 275 | self.momentum = momentum 276 | 277 | if momentum is not None: 278 | mom1, mom2 = momentum 279 | self.momentum_phase = [ 280 | Phase(mom1, mom2, phase1, phase_map[phase[0]]), 281 | Phase(mom2, mom1, phase2, phase_map[phase[1]]), 282 | ] 283 | 284 | else: 285 | self.momentum_phase = [] 286 | 287 | self.phase = 0 288 | 289 | def step(self): 290 | lr = self.lr_phase[self.phase].step() 291 | 292 | if self.momentum is not None: 293 | momentum = self.momentum_phase[self.phase].step() 294 | 295 | else: 296 | momentum = None 297 | 298 | for group in self.optimizer.param_groups: 299 | group['lr'] = lr 300 | 301 | if self.momentum is not None: 302 | if 'betas' in group: 303 | group['betas'] = (momentum, group['betas'][1]) 304 | 305 | else: 306 | group['momentum'] = momentum 307 | 308 | if self.lr_phase[self.phase].is_done: 309 | self.phase += 1 310 | 311 | if self.phase >= len(self.lr_phase): 312 | for phase in self.lr_phase: 313 | phase.reset() 314 | 315 | for phase in self.momentum_phase: 316 | phase.reset() 317 | 318 | self.phase = 0 319 | 320 | return lr, momentum 321 | 322 | 323 | class LRFinder(lr_scheduler._LRScheduler): 324 | def __init__(self, optimizer, lr_min, lr_max, step_size, linear=False): 325 | ratio = lr_max / lr_min 326 | self.linear = linear 327 | self.lr_min = lr_min 328 | self.lr_mult = (ratio / step_size) if linear else ratio ** (1 / step_size) 329 | self.iteration = 0 330 | self.lrs = [] 331 | self.losses = [] 332 | 333 | super().__init__(optimizer, -1) 334 | 335 | def get_lr(self): 336 | lr = ( 337 | self.lr_mult * self.iteration 338 | if self.linear 339 | else self.lr_mult ** self.iteration 340 | ) 341 | lr = self.lr_min + lr if self.linear else self.lr_min * lr 342 | 343 | self.iteration += 1 344 | self.lrs.append(lr) 345 | 346 | return [lr for base_lr in self.base_lrs] 347 | 348 | def record(self, loss): 349 | self.losses.append(loss) 350 | 351 | def save(self, filename): 352 | with open(filename, 'w') as f: 353 | for lr, loss in zip(self.lrs, self.losses): 354 | f.write('{},{}\n'.format(lr, loss)) 355 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amzn/sparse-vqvae/33ca864b6a20c644c3c825ce958fd20a99349dda/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/calculate_jpg_psnr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculates PSNR for JPG with respect to a dataset. 3 | Note: Currently hard-coded for Cifar10 4 | """ 5 | import sys 6 | import os 7 | sys.path.append(os.path.abspath('..')) 8 | import argparse 9 | import numpy as np 10 | import torch 11 | import cv2 12 | from collections import namedtuple 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | from utils import util_funcs 17 | from models.model_utils import get_dataset 18 | 19 | 20 | result_tuple = namedtuple('JPEG_res', ['quality', 'ratio', 'psnr', ]) 21 | 22 | 23 | def tensor2img(tensor): 24 | ndarr = tensor.clone().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 25 | im = Image.fromarray(ndarr) 26 | return im 27 | 28 | 29 | def get_PSNR(quality, dataset, size): 30 | total_psnr = 0 31 | total_steps = 0 32 | MAX_i = 255 33 | psnr_term = 20 * np.log10(np.ones(1) * MAX_i) 34 | compression_ratio = 0 35 | 36 | _tqdm = tqdm(dataset, desc=f'Quality: {quality}') 37 | for batch in _tqdm: 38 | # Load data 39 | image = (batch[0] + 1) * MAX_i / 2 # from [-1,1] to [0, 255] 40 | assert image.max() <= MAX_i and image.min() >= 0, f'bad image valuse, in [{image.max()}, {image.min()}]' 41 | if image.shape[0] == 3: 42 | image = image.transpose(0, 2) 43 | rgb_image = np.array(image, dtype=np.uint8) 44 | open_cv_image = rgb_image[:, :, ::-1].copy() 45 | 46 | # Translate to and back from jpg 47 | jpg_str = cv2.imencode('.jpg', open_cv_image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tostring() 48 | np_arr = np.fromstring(jpg_str, np.uint8) 49 | decoded_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 50 | 51 | # Calculate encoded length 52 | raw_size = size * size * 3 53 | # bmp_size = len(cv2.imencode('.bmp', open_cv_image)[1]) 54 | jpg_size = len(jpg_str) 55 | 56 | # Calculate compression ratio and PSNR 57 | compression_ratio += raw_size / float(jpg_size) 58 | 59 | mse = np.mean(np.square(np.array(decoded_img) - rgb_image)) 60 | psnr = psnr_term - 10 * np.log10(mse) 61 | total_psnr += psnr 62 | total_steps += 1.0 63 | 64 | if total_steps % 1000 == 0: 65 | _tqdm.set_postfix({'psnr': np.round(total_psnr/total_steps, 2), 'ratio': compression_ratio/total_steps }) 66 | 67 | print('Calculating for JPG with quality measure: {}'.format(quality)) 68 | print('PSNR: {}'.format(total_psnr / total_steps)) 69 | print('compression_ratio: {}'.format(compression_ratio / total_steps)) 70 | 71 | return result_tuple( 72 | quality=quality, 73 | psnr=total_psnr / total_steps, 74 | ratio=compression_ratio / total_steps, 75 | ) 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser = util_funcs.base_parser(parser) 81 | parser = util_funcs.vqvae_parser(parser) 82 | parser = util_funcs.code_extraction_parser(parser) 83 | args = parser.parse_args() 84 | 85 | print('setting up datasets') 86 | _, test_dataset = get_dataset(args.dataset, args.data_path, args.size) 87 | 88 | all_res = list() 89 | for quality in range(0, 110, 10): 90 | res = get_PSNR(quality, test_dataset, args.size) 91 | all_res.append(res) 92 | 93 | [print(f'{r.quality}, {r.psnr}, {r.ratio}') for r in all_res] 94 | with open('/tmp/jpeg_res.csv', 'w') as fp: 95 | [fp.write(f'{r.quality}, {r.psnr}, {r.ratio} \n') for r in all_res] 96 | -------------------------------------------------------------------------------- /scripts/calculate_model_psnr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculates PSNR for a given VQ-VAE model with respect to a dataset. 3 | Accepts same arguments as train_vqvae.py 4 | """ 5 | import sys 6 | import os 7 | sys.path.append(os.path.abspath('..')) 8 | 9 | import json 10 | import argparse 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torchvision import utils 14 | from torch import nn 15 | 16 | from tqdm import tqdm 17 | from datetime import datetime 18 | from PIL import Image 19 | 20 | from utils import util_funcs 21 | from models.model_utils import get_dataset, get_model 22 | 23 | 24 | _NOW = datetime.now().strftime('%Y_%m_%d__%H_%M') 25 | 26 | 27 | def _save_tensors(exp_name, image_in, image_out, sample_size=5): 28 | 29 | _save_root = os.path.join(args.sample_save_path, exp_name, '_save_tensors') 30 | os.makedirs(_save_root, exist_ok=True) 31 | 32 | utils.save_image( 33 | torch.cat([image_in[:sample_size], image_out[:sample_size]], 0), 34 | os.path.join(_save_root, f'sample_{_NOW}.png'), 35 | nrow=sample_size, 36 | normalize=True, 37 | range=(-1, 1), 38 | ) 39 | 40 | 41 | def tensor2img(tensor): 42 | ndarr = tensor.clone().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 43 | im = Image.fromarray(ndarr) 44 | return im 45 | 46 | 47 | def _validate_args(checkpoint_path): 48 | args_file = os.path.dirname(checkpoint_path) + '/args.txt' 49 | if not os.path.isfile(args_file): 50 | print(f'cannot locate "args.txt" in {args_file}') 51 | return 52 | 53 | with open(args_file) as fp: 54 | kw_strs = [line.split(' : ') for line in fp.readlines()] 55 | kw_dict = {l[0]: l[1].strip(' \n') for l in kw_strs} 56 | 57 | msg = '' 58 | FIELDS_TO_SKIP = {'today_str', 'num_workers', 'seed', 'ckpt_epoch', 'experiment_name', ''} 59 | for k, v in kw_dict.items(): 60 | if k in FIELDS_TO_SKIP: 61 | continue 62 | elif not hasattr(args, k): 63 | msg += f'"{k}" is in the args.txt file but not such parameter exits in input \n' 64 | elif v != str(getattr(args, k)): 65 | msg += f'"{k}" has input value of "{getattr(args, k)}" but in checkpoint it has value "{v}"\n' 66 | 67 | if msg != '': 68 | print(f'{30*"="}\nparameter inconsistency found') 69 | print(msg[:-1]) 70 | print(f'{30*"="}') 71 | 72 | 73 | def get_PSNR(size, device, dataset, data_path, num_workers, num_embeddings, architecture, ckpt_epoch, neighborhood, selection_fn, embed_dim, **kwargs): 74 | print('setting up dataset') 75 | _, test_dataset = get_dataset(dataset, data_path, size) 76 | 77 | print('creating data loaders') 78 | test_loader = DataLoader(test_dataset, batch_size=kwargs['vae_batch'], shuffle=True, num_workers=num_workers) 79 | 80 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs) 81 | checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) 82 | checkpoint_path = f"{kwargs['checkpoint_path']}/{checkpoint_name}" 83 | _validate_args(checkpoint_path) 84 | 85 | print('Calculating PSNR for: {}'.format(checkpoint_name)) 86 | 87 | print('Loading model') 88 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) 89 | model.load_state_dict(torch.load(os.path.join('..', checkpoint_path), map_location='cuda:0'), strict=False) 90 | model = model.to(device) 91 | model.eval() 92 | 93 | mse = nn.MSELoss() 94 | MAX_i = 255 95 | to_MAX = lambda t: (t+1) * MAX_i / 2 96 | psnr_term = 20 * torch.log10(torch.ones(1, 1)*MAX_i) 97 | 98 | # calculate PSNR over test set 99 | sparsity = 0 100 | top_percentiles = 0 101 | num_zeros = 0 102 | psnrs = 0 103 | done = 0 104 | 105 | for batch in tqdm(test_loader, desc='Calculating PSNR'): 106 | with torch.no_grad(): 107 | img = batch[0].to(device) 108 | out, _, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = model(img) 109 | 110 | if psnrs == 0 and os.path.isdir(args.sample_save_path): # save the first test batch 111 | _save_tensors(experiment_name, img, out, ) 112 | 113 | cur_psnr = psnr_term.item() - 10 * torch.log10(mse(to_MAX(out), to_MAX(img))) 114 | 115 | # Gather data 116 | psnrs += cur_psnr 117 | sparsity += norm_Z.mean() 118 | top_percentiles += top_percentile.mean() 119 | num_zeros += num_zeros.mean() 120 | done += 1 121 | 122 | # Dump results 123 | print('sparsity: {}'.format(sparsity)) 124 | print('done: {}'.format(done)) 125 | print('(sparsity/float(done)).item(): {}'.format((sparsity/float(done)).item())) 126 | avg_psnr = (psnrs/float(done)).item() 127 | avg_top_percentiles = (top_percentiles/float(done)).item() 128 | avg_num_zeros = (num_zeros/float(done)).item() 129 | avg_spasity = (sparsity/float(done)).item() 130 | 131 | print('#'*30) 132 | print('Experiment name: {}'.format(experiment_name)) 133 | print('Epoch name: {}'.format(ckpt_epoch)) 134 | print('avg_psnr: {}'.format(avg_psnr)) 135 | print('is_quantize_coefs: {}'.format(kwargs['is_quantize_coefs'])) 136 | 137 | # dump params and stats into a JSON file 138 | _save_root = os.path.join(args.sample_save_path, experiment_name) 139 | os.makedirs(_save_root, exist_ok=True) 140 | with open(f'{_save_root}/eval_result_{_NOW}.json', 'w') as fp: 141 | json.dump(dict( 142 | # --- params 143 | dataset=args.dataset, 144 | selection_fn=args.selection_fn, 145 | embed_dim=args.embed_dim, 146 | num_atoms=args.num_embeddings, 147 | image_size=args.size, 148 | batch_size=args.vae_batch, 149 | num_strides=args.num_strides, 150 | num_nonzero=args.num_nonzero, 151 | normalize_x=args.normalize_x, 152 | normalize_d=args.normalize_dict, 153 | epoch=args.ckpt_epoch, 154 | # --- stats 155 | psnr=avg_psnr, 156 | compression=None, # todo - calculate this 157 | atom_bits=None, # todo - calculate this 158 | non_zero_mean=avg_spasity, 159 | non_zero_99pct=avg_top_percentiles, 160 | # tmp=args.tmp, 161 | ), fp, indent=1) 162 | 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser() 166 | parser = util_funcs.base_parser(parser) 167 | parser = util_funcs.vqvae_parser(parser) 168 | parser = util_funcs.code_extraction_parser(parser) 169 | parser.add_argument('--sample_save_path', default='.', 170 | type=str, help='a csv file to append the result to') 171 | args = parser.parse_args() 172 | util_funcs.seed_generators(args.seed) 173 | get_PSNR(**vars(args)) 174 | -------------------------------------------------------------------------------- /scripts/calculate_model_psnr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate pytorch_p36 4 | 5 | ROOT_CHEKCPOINT=/hiero_efs/HieroExperiments/Sparse_VAE/checkpoint 6 | # DATE_CHEKCPOINT=2020_10_28 7 | DATE_CHEKCPOINT=2020_11_04 8 | 9 | 10 | SEED="1" 11 | STRIDES=("1" "2") 12 | SELECT_FNS=("van" "omp") 13 | NUM_NONZERO=("1" "2" "4") 14 | 15 | function run_psnr_eval () { 16 | # parameters 17 | # 1 - experiment name prefix 18 | # 2 - selection function 19 | # 3 - stride 20 | # 4 - num_nonzeros 21 | # 5 - seed 22 | # 6 - X normalization 23 | # 7 - D normalization 24 | 25 | experiment_name="$DATE_CHEKCPOINT/$1" 26 | # --is_quantize_coefs -stride=2 -sel=omp -k=4 27 | declare ARGS="" 28 | declare ARGS=$ARGS"-n=$experiment_name " 29 | declare ARGS=$ARGS"-sel=$2 " 30 | declare ARGS=$ARGS"-stride=$3 " 31 | declare ARGS=$ARGS"-k=$4 " 32 | declare ARGS=$ARGS"--seed=$5 " 33 | if [ "$6" == "False" ]; then 34 | declare ARGS=$ARGS"--no_normalize_x " 35 | fi 36 | if [ "$7" == "False" ]; then 37 | declare ARGS=$ARGS"--no_normalize_dict " 38 | fi 39 | if [ "$2" == "omp" ]; then 40 | declare ARGS=$ARGS"--is_quantize_coefs " 41 | fi 42 | 43 | echo "#################" #| tee "$3"/run_params.txt -a 44 | echo "command args = $ARGS" #| tee "$3"/run_params.txt -a 45 | echo "#################" #| tee "$3"/run_params.txt -a 46 | 47 | python scripts/calculate_model_psnr.py $ARGS 48 | } 49 | 50 | 51 | 52 | # for JZ_NAME in "${!JZ_MODELS[@]}" 53 | for SELECT_FN in ${SELECT_FNS[@]}; do 54 | 55 | echo 56 | echo '**********************' 57 | echo '****' $SELECT_FN '****' 58 | echo '**********************' 59 | echo 60 | 61 | for STRIDE in ${STRIDES[@]}; do 62 | 63 | EXP_NAME_PREFIX="${SELECT_FN}_s${STRIDE}_${SEED}" 64 | case $SELECT_FN in 65 | "van") 66 | run_psnr_eval $EXP_NAME_PREFIX "vanilla" $STRIDE "1" $SEED False False 67 | run_psnr_eval $EXP_NAME_PREFIX"_nrm" "vanilla" $STRIDE "1" $SEED True True 68 | run_psnr_eval $EXP_NAME_PREFIX"_nrmD" "vanilla" $STRIDE "1" $SEED False True 69 | ;; 70 | "omp") 71 | for K in ${NUM_NONZERO[@]}; do 72 | run_psnr_eval $EXP_NAME_PREFIX"_k"$K "omp" $STRIDE $K $SEED True True 73 | run_psnr_eval $EXP_NAME_PREFIX"_k"$K"_nrmD" "omp" $STRIDE $K $SEED False True 74 | done 75 | ;; 76 | *) 77 | echo "bad SELECT_FN '${SELECT_FN}'" 78 | return 1 79 | ;; 80 | esac 81 | done 82 | done 83 | -------------------------------------------------------------------------------- /scripts/compression_psnr_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate PSNR graph 3 | """ 4 | import pandas as pd 5 | from matplotlib import pyplot as plt 6 | from matplotlib.cm import rainbow 7 | import numpy as np 8 | 9 | # Load data 10 | data = pd.read_csv('../psnr_results/psnr_compression_results.csv') 11 | 12 | # Generate color map 13 | colormap = rainbow(np.linspace(0, 1, 8)) 14 | fig = plt.figure() 15 | 16 | # Select fista data 17 | fista_none = data[(data['Selection function']=='fista') & (data['Test Stride Steps']==1) & (data['Normalized Dictionary'] == 'No') & (data['Normalized X'] == 'No')][['PSNR', 'Compression', 'Precision']] 18 | fista_X = data[(data['Selection function']=='fista') & (data['Test Stride Steps']==1) & (data['Normalized Dictionary'] == 'No') & (data['Normalized X'] == 'Yes')][['PSNR', 'Compression', 'Precision']] 19 | fista_D = data[(data['Selection function']=='fista') & (data['Test Stride Steps']==1) & (data['Normalized Dictionary'] == 'Yes') & (data['Normalized X'] == 'No')][['PSNR', 'Compression', 'Precision']] 20 | fista_both = data[(data['Selection function']=='fista') & (data['Test Stride Steps']==1) & (data['Normalized Dictionary'] == 'Yes') & (data['Normalized X'] == 'Yes')][['PSNR', 'Compression', 'Precision']] 21 | 22 | 23 | 24 | 25 | # Select vanilla data 26 | vanilla_stride1 = data[(data['Selection function']=='vanilla') & (data['Test Stride Steps']==1)][['PSNR', 'Compression']] 27 | vanilla_stride2 = data[(data['Selection function']=='vanilla') & (data['Test Stride Steps']==2)][['PSNR', 'Compression']] 28 | 29 | # Select jpg data 30 | jpg = data[data['Selection function']=='jpg'][['PSNR', 'Compression']] 31 | 32 | 33 | # Plot jpg data and set ax object 34 | ax = jpg.plot.scatter(x='Compression', y='PSNR', marker='*', c='m') 35 | 36 | # Plot vanilla data 37 | vanilla_stride1.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='o', c='r') 38 | vanilla_stride2.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='o', c='g') 39 | 40 | # Plot fista data 41 | fista_none.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='k') 42 | fista_X.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='y') 43 | fista_D.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='c') 44 | fista_both.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='m') 45 | 46 | # Set Figure parametes 47 | plt.xscale('log') 48 | plt.grid(True) 49 | plt.title('Compression-PSNR w.r.t. normalization schemes') 50 | plt.legend(['JPG', 'Vanilla stride 1', 'Vanilla stride 2', 'FISTA stride 1, No normalization', 'FISTA stride 1, X normalization', 'FISTA stride 1, D normalization', 'FISTA stride 1, Both normalization']) 51 | # plt.legend(['FISTA stride 1, fp8', 'FISTA stride 1, fp32', 'FISTA stride 2']) 52 | plt.savefig('normalization_view_psnr_compression_results.png') 53 | 54 | 55 | -------------------------------------------------------------------------------- /scripts/decompression_graph_psnr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate PSNR graph 3 | """ 4 | import pandas as pd 5 | from matplotlib import pyplot as plt 6 | from matplotlib.cm import rainbow 7 | import numpy as np 8 | 9 | # Load data 10 | data = pd.read_csv('../psnr_results/psnr_decompression_data.csv') 11 | 12 | # Generate color map 13 | colormap = rainbow(np.linspace(0, 1, 8)) 14 | fig = plt.figure() 15 | 16 | # Select fista data 17 | fista_stride1 = data[(data['Selection function']=='fista') & (data['Test Stride Steps']=='1')][['PSNR', 'Compression', 'Precision']] 18 | fista_stride1_f32 = fista_stride1[fista_stride1['Precision']=='32'] 19 | fista_stride1_f8 = fista_stride1[fista_stride1['Precision']=='8'] 20 | fista_stride2 = data[(data['Selection function']=='fista') & (data['Test Stride Steps']=='2')][['PSNR', 'Compression']] 21 | 22 | # Select vanilla data 23 | vanilla_stride1 = data[(data['Selection function']=='Vanilla') & (data['Test Stride Steps']=='1')][['PSNR', 'Compression']] 24 | vanilla_stride2 = data[(data['Selection function']=='Vanilla') & (data['Test Stride Steps']=='2')][['PSNR', 'Compression']] 25 | vanilla_stride3 = data[(data['Selection function']=='Vanilla') & (data['Test Stride Steps']=='3')][['PSNR', 'Compression']] 26 | 27 | # Select jpg data 28 | jpg = data[data['Selection function']=='jpg'][['PSNR', 'Compression']] 29 | 30 | 31 | # Plot jpg data and set ax object 32 | ax = jpg.plot.scatter(x='Compression', y='PSNR', marker='*', c='m') 33 | 34 | # Plot vanilla data 35 | vanilla_stride1.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='o', c='r') 36 | vanilla_stride2.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='o', c='g') 37 | vanilla_stride3.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='o', c='b') 38 | 39 | # Plot fista data 40 | fista_stride1_f8.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='k') 41 | fista_stride1_f32.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='y') 42 | fista_stride2.plot.scatter(ax=ax, x='Compression', y='PSNR', marker='s', c='c') 43 | 44 | # Set Figure parametes 45 | plt.xscale('log') 46 | plt.grid(True) 47 | plt.legend(['JPG', 'Vanilla stride 1', 'Vanilla stride 2', 'Vanilla stride 3', 'FISTA stride 1, fp8', 'FISTA stride 1, fp32', 'FISTA stride 2']) 48 | # plt.legend(['FISTA stride 1, fp8', 'FISTA stride 1, fp32', 'FISTA stride 2']) 49 | plt.savefig('full_psnr_decompression_results.png') 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/extract_dataset_unlearned_encodings.py: -------------------------------------------------------------------------------- 1 | from models.vqvae import Encoder 2 | from models.model_utils import get_dataset 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | import os 9 | print(os.path.abspath(".")) 10 | 11 | train_dataset, _ = get_dataset('cifar10', '../../../data', 32) 12 | train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8) 13 | 14 | 15 | i = 0 16 | for batch in tqdm(train_loader, desc='Extracting unlearned encodings'): 17 | encoder = Encoder(in_channel=3, channel=128, n_res_block=2, n_res_channel=32, stride=2) 18 | quantize_conv = nn.Conv2d(128, 64, 1) 19 | enc = encoder(batch[0]) 20 | quant = quantize_conv(enc).permute(0, 2, 3, 1) 21 | np.save('unlearned_encodings/unlearned_cifar10_{}'.format(i), quant.detach().numpy()) 22 | 23 | del encoder 24 | del quantize_conv 25 | i += 1 26 | a = 5 -------------------------------------------------------------------------------- /scripts/hyperparameter_alpha_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to explore the effect of different alpha values on the sparse code 3 | """ 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from utils.pyfista_test import create_normalized_noised_inputs, load_real_inputs 9 | import pandas as pd 10 | 11 | from utils.pyfista import FISTA 12 | 13 | 14 | def score_fn(Z): 15 | """ 16 | We evaluate the sparse code's L0 norm, specifically the average L0 norm, the standard deviation and the top 99% value. 17 | :param Z: Sparse code to evaluate. 18 | :return: Float, Float, Float. L0 norm mean, L0 norm std, top 99% L0 norm value. 19 | """ 20 | norm = torch.norm(Z, p=0, dim=0) 21 | return norm.mean(), norm.std(), norm.topk(int(len(norm)*0.01)).values.min() 22 | 23 | 24 | def histogram_fn(run_name, name, Z, alpha): 25 | """ 26 | Create and save an historgam of the sparse code values for given alpha 27 | """ 28 | print('Doing histogram with experiment {} and alpha {}'.format(name, alpha)) 29 | plt.hist(Z) 30 | plt.savefig('histograms/{}_{}_{}.png'.format(run_name, name, alpha)) 31 | plt.close() 32 | 33 | 34 | def create_alpha_graph_random_data(alphas, just_plot=False, sparsity_num=10): 35 | """ 36 | Create an alpha/L0 norm graph on random data. 37 | Data is generated by sampling a random dictionary and creating the data and sampling sparsity_num atoms for create 38 | a linear combination of dictionary atoms. 39 | :param alphas: tuple of (start, end, denominator). The graph will be over a the range(start / denominator, end / denominator). 40 | :param just_plot: Bool. Use cached results if True, else recalculate results. 41 | :param sparsity_num: Sparsity value used to create the random data. 42 | """ 43 | filename = 'alpha_search_random_data_sparsity_{}_alpahs-{},results'.format(sparsity_num, alphas) 44 | 45 | experiments = { 46 | 'no_normalization': create_normalized_noised_inputs(False, False, sparsity_num), 47 | 'd_normalization': create_normalized_noised_inputs(True, False, sparsity_num), 48 | 'x_normalization': create_normalized_noised_inputs(False, True, sparsity_num), 49 | 'full_normalization': create_normalized_noised_inputs(True, True, sparsity_num), 50 | } 51 | 52 | create_alpha_graph(experiments, filename, alphas, just_plot) 53 | 54 | 55 | def create_alpha_graph_real_data(just_plot=False, alphas=(0.01, 2, 0.01)): 56 | """ 57 | Create an alpha/L0 norm graph on real data and random dictionary. 58 | Data is generated by running an untrained encoder on real data. Results are calculated with respect to a random dictionary. 59 | :param just_plot: Bool. Use cached results if True, else recalculate results. 60 | :param alphas: tuple of (start, end, denominator). The graph will be over a the range(start / denominator, end / denominator). 61 | """ 62 | filename = 'alpha_search_real_data_alphas-{}_results'.format(alphas) 63 | experiments = { 64 | 'no_normalization': load_real_inputs(False, False), 65 | 'd_normalization': load_real_inputs(True, False), 66 | 'x_normalization': load_real_inputs(False, True), 67 | 'full_normalization': load_real_inputs(True, True), 68 | } 69 | 70 | create_alpha_graph(experiments, filename, alphas, just_plot=just_plot) 71 | 72 | 73 | def create_alpha_graph(experiments, filename, alphas, just_plot=False): 74 | """ 75 | Save experiment results to CSV file and graph 76 | :param experiments: List of (data, dictionary) for experiments ('no_normalization', 'd_normalization', 'x_normalization', 'full_normalization') 77 | :param filename: filename for the CSV and graph 78 | :param alphas: tuple of (start, end, denominator). The graph will be over a the range(start / denominator, end / denominator). 79 | :param just_plot: just_plot: Bool. Use cached results if True, else recalculate results. 80 | """ 81 | if not just_plot: 82 | with open('{}.csv'.format(filename), 'w') as f: 83 | title_row = 'Alpha,' \ 84 | 'no_normalization_score,no_normalization_std,no_normalization_percentile,' \ 85 | 'd_normalization_score,d_normalization_std,d_normalization_percentile,' \ 86 | 'x_normalization_score,x_normalization_std,x_normalization_percentile,' \ 87 | 'full_normalization_score,full_normalization_std,full_normalization_percentile\n' 88 | f.write(title_row) 89 | 90 | alphas = list([float(a) / alphas[2] for a in range(alphas[0], alphas[1])]) 91 | for alpha in alphas: 92 | process_alpha(alpha, experiments, filename) 93 | 94 | create_result_graph(filename) 95 | 96 | 97 | def process_alpha(alpha, experiments, filename): 98 | """ 99 | Save experiment results to CSV file 100 | :param alpha: Value of alpha to run for 101 | :param experiments: List of (data, dictionary) for experiments ('no_normalization', 'd_normalization', 'x_normalization', 'full_normalization') 102 | :param filename: filename for the CSV 103 | """ 104 | 105 | # Run experiments 106 | d_normalization_Z, full_normalization_Z, no_normalization_Z, x_normalization_Z = run_experiments(alpha, experiments) 107 | 108 | # Extract statistics. TODO: Turn output and input to dictionary 109 | d_normalization_percentile, d_normalization_score, d_normalization_std, \ 110 | full_normalization_percentile, full_normalization_score, full_normalization_std, \ 111 | no_normalization_percentile, no_normalization_score, no_normalization_std, \ 112 | x_normalization_percentile, x_normalization_score, x_normalization_std = \ 113 | calculate_statistics(d_normalization_Z, full_normalization_Z, no_normalization_Z, x_normalization_Z) 114 | 115 | # Create histograms 116 | histogram_fn(filename, 'no_normalization_Z', no_normalization_Z.view(-1), alpha) 117 | histogram_fn(filename, 'd_normalization_Z', d_normalization_Z.view(-1), alpha) 118 | histogram_fn(filename, 'x_normalization_Z', x_normalization_Z.view(-1), alpha) 119 | histogram_fn(filename, 'full_normalization_Z', full_normalization_Z.view(-1), alpha) 120 | 121 | # Create final row 122 | results_row = '{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(alpha, 123 | no_normalization_score, no_normalization_std, 124 | no_normalization_percentile, 125 | d_normalization_score, d_normalization_std, 126 | d_normalization_percentile, 127 | x_normalization_score, x_normalization_std, 128 | x_normalization_percentile, 129 | full_normalization_score, 130 | full_normalization_std, 131 | full_normalization_percentile, 132 | ) 133 | 134 | with open('{}.csv'.format(filename), 'a') as f: 135 | f.write(results_row) 136 | 137 | 138 | def create_result_graph(filename): 139 | """ 140 | Save experiment results to graph 141 | :param filename: filename to load the data from 142 | """ 143 | 144 | data = pd.read_csv('{}.csv'.format(filename), sep=',') 145 | # plt.errorbar(data['Alpha'], data['no_normalization_score'], yerr=data['no_normalization_std']) 146 | plt.errorbar(data['Alpha'], data['d_normalization_score'], color='r') 147 | plt.errorbar(data['Alpha'], data['d_normalization_percentile'], color='r', linestyle='--') 148 | plt.errorbar(data['Alpha'], data['x_normalization_score'], color='g') 149 | plt.errorbar(data['Alpha'], data['x_normalization_percentile'], color='g', linestyle='--') 150 | plt.errorbar(data['Alpha'], data['full_normalization_score'], color='b') 151 | plt.errorbar(data['Alpha'], data['full_normalization_percentile'], color='b', linestyle='--') 152 | plt.legend( 153 | [ 154 | 'D normalization', 155 | 'D 99% percentile', 156 | 'X normalization', 157 | 'X 99% percentile', 158 | 'Full normalization', 159 | 'Full 99% percentile' 160 | ] 161 | ) 162 | plt.yscale('log') 163 | plt.xlabel('Alpha') 164 | plt.ylabel('Log norm0') 165 | plt.title('Alpha search graph for {}'.format(filename)) 166 | plt.savefig('{}_scores_with_errorbars.png'.format(filename)) 167 | 168 | 169 | def calculate_statistics(d_normalization_Z, full_normalization_Z, no_normalization_Z, x_normalization_Z): 170 | """ 171 | Calculates statistics for given Sparse Code coefficients 172 | :param d_normalization_Z: d_normalization experiment sparse code coefficients 173 | :param full_normalization_Z: full_normalization experiment sparse code coefficients 174 | :param no_normalization_Z: no_normalization experiment sparse code coefficients 175 | :param x_normalization_Z: x_normalization experiment sparse code coefficients 176 | :return: For each experiment: Value of the 99% percentile L0 norm, mean L0 norm, standard deviation of L0 norm 177 | """ 178 | no_normalization_score, no_normalization_std, no_normalization_percentile = score_fn(no_normalization_Z) 179 | d_normalization_score, d_normalization_std, d_normalization_percentile = score_fn(d_normalization_Z) 180 | x_normalization_score, x_normalization_std, x_normalization_percentile = score_fn(x_normalization_Z) 181 | full_normalization_score, full_normalization_std, full_normalization_percentile = score_fn(full_normalization_Z) 182 | return \ 183 | d_normalization_percentile, d_normalization_score, d_normalization_std, \ 184 | full_normalization_percentile, full_normalization_score, full_normalization_std, \ 185 | no_normalization_percentile, no_normalization_score, no_normalization_std, \ 186 | x_normalization_percentile, x_normalization_score, x_normalization_std 187 | 188 | 189 | def run_experiments(alpha, experiments): 190 | """ 191 | Runs FISTA for given experimental data 192 | :param alpha: value of data to use for FISTA 193 | :param experiments: Dictionary with fista input and dictionary with names of the experiments. 194 | expects experiments: no_normalization, d_normalization, x_normalization, full_normalization 195 | 196 | :return: Sparse code coefficients found for the experiments in the order: 197 | d_normalization, full_normalization, no_normalization, x_normalization 198 | """ 199 | no_normalization_Z, _ = FISTA(experiments['no_normalization'][0], experiments['no_normalization'][1], alpha, 0.01) 200 | d_normalization_Z, _ = FISTA(experiments['d_normalization'][0], experiments['d_normalization'][1], alpha, 0.01) 201 | x_normalization_Z, _ = FISTA(experiments['x_normalization'][0], experiments['x_normalization'][1], alpha, 0.01) 202 | full_normalization_Z, _ = FISTA(experiments['full_normalization'][0], experiments['full_normalization'][1], alpha, 0.01) 203 | return d_normalization_Z, full_normalization_Z, no_normalization_Z, x_normalization_Z 204 | 205 | 206 | if __name__ == '__main__': 207 | # create_alpha_graph_random_data(just_plot=False, sparsity_num=5, alphas=(100, 300, 1000)) 208 | create_alpha_graph_real_data(just_plot=False, alphas=(100, 300, 1000)) 209 | -------------------------------------------------------------------------------- /scripts/merge_model_psnrs_to_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | 6 | osp = os.path 7 | 8 | 9 | def main(): 10 | eval_root = './eval' 11 | eval_folder = osp.join(eval_root, '2020_10_28') 12 | # eval_folder = osp.join(eval_root, '2020_11_04') 13 | 14 | data = dict() 15 | for p in os.walk(eval_folder): 16 | pth, dirs, files = p 17 | 18 | folder = osp.basename(pth) 19 | if len(files) == 0: 20 | continue 21 | if not any([f.endswith('json') for f in files]): 22 | continue 23 | 24 | print(f'parsing "{folder}"') 25 | with open(osp.join(pth, sorted(files)[-1])) as fp: 26 | stats = json.load(fp) 27 | 28 | columns = list(stats.keys()) 29 | data[folder] = stats 30 | 31 | print(f'saving csv') 32 | df = pd.DataFrame.from_dict(data, orient='index', columns=columns) 33 | df.to_csv(osp.join(eval_folder, 'summary.csv')) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() -------------------------------------------------------------------------------- /scripts/spwan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import boto3 3 | from time import sleep 4 | from typing import Iterable 5 | from itertools import product 6 | 7 | 8 | KEY_FILE = os.path.expanduser("KEY.PEM") 9 | INSTANCE_NAME_PREFIX = 'EC2_PREFIX' 10 | USE_USER_DATA = False 11 | 12 | 13 | def _get_runs_dict(): 14 | non_zero_vals = (1, 2, 4,) 15 | it = product( 16 | ('vanilla', 'omp',), # sel/selection_fn 17 | (0, ), # seed 18 | (1, 2, 3,), # stride/num_strides 19 | ) 20 | _runs = dict() 21 | for sel, seed, stride in it: 22 | _name_base = f'{sel[:3]}_s{stride}_{seed}' 23 | 24 | if sel == 'vanilla': 25 | _name = _name_base + '_nrm' 26 | _runs[_name] = f'-n={_name} --seed={seed} -sel={sel} -stride={stride}' 27 | _name = _name_base + '_nrmD' 28 | _runs[_name] = f'-n={_name} --seed={seed} -sel={sel} -stride={stride} --no_normalize_x' 29 | _name = _name_base 30 | _runs[_name] = f'-n={_name} --seed={seed} -sel={sel} -stride={stride} --no_normalize_dict --no_normalize_x' 31 | elif sel == 'omp': 32 | for k in non_zero_vals: 33 | _name = _name_base + f'_k{k}' 34 | _runs[_name] = f'-n={_name} --seed={seed} -sel={sel} -stride={stride} -k={k}' 35 | _name = _name_base + f'_k{k}' + '_nrmD' 36 | _runs[_name] = f'-n={_name} --seed={seed} -sel={sel} -stride={stride} -k={k} --no_normalize_x' 37 | 38 | return _runs 39 | 40 | 41 | ec2 = boto3.resource('ec2') 42 | client = boto3.client('ec2') 43 | 44 | 45 | def _image_id_from_name(name): 46 | response = client.describe_images(Owners=['self'], Filters=[ 47 | { 48 | 'Name': 'name', 49 | 'Values': [name] 50 | }, 51 | ]) 52 | images = response['Images'] 53 | assert len(images) 54 | return images[0]['ImageId'] 55 | 56 | 57 | def _create_instance(image_id: str, user_data: str, instance_name_suffix='vae-runner', instance_type='p3.2xlarge', 58 | description='') -> ec2.Instance: 59 | 60 | # create a new EC2 instance 61 | instances = ec2.create_instances( 62 | ImageId=image_id, 63 | UserData=user_data, 64 | InstanceType=instance_type, 65 | # 66 | MinCount=1, 67 | MaxCount=1, 68 | KeyName='KEY', 69 | InstanceInitiatedShutdownBehavior='terminate', 70 | # 71 | EbsOptimized=False, 72 | DryRun=False, 73 | ) 74 | 75 | # set name + description 76 | tags = [{'Key': 'Name', 'Value': INSTANCE_NAME_PREFIX + instance_name_suffix}] 77 | if description != '': 78 | tags.append({'Key': 'Description', 'Value': description}) 79 | instances[0].create_tags(Tags=tags) 80 | 81 | return instances[0] 82 | 83 | 84 | def _get_train_cmd(parameters, use_user_data=USE_USER_DATA): 85 | screen_name = 'train' 86 | train_cmds = [ 87 | f'screen -S {screen_name} -dm', 88 | f'screen -S {screen_name} -X stuff "cd PATH/^M"', 89 | f'screen -S {screen_name} -X stuff "conda activate pytorch_p36 ^M"', 90 | f'screen -S {screen_name} -X stuff "python train_vqvae.py {parameters} ^M"', 91 | f'screen -S {screen_name} -X stuff "sudo shutdown ^M"', 92 | ] 93 | 94 | if use_user_data: 95 | train_cmds.insert(0, '#!/bin/zsh') # add shebang 96 | train_cmds.insert(1, 'su - ubuntu') # user data is executed as root, we prefer to avoid this 97 | train_cmds = ' \n'.join(train_cmds) # concatenate into one string 98 | 99 | return train_cmds 100 | 101 | 102 | def send_cmds_via_ssh(cmds: Iterable, instance, wait_extra_min=0.5, key_file=KEY_FILE): 103 | import paramiko 104 | 105 | print(f'waiting until {instance.id} is running...') 106 | instance.wait_until_running() 107 | 108 | if wait_extra_min > 0: 109 | print(f'waiting another {wait_extra_min} min...') 110 | sleep(wait_extra_min * 60) 111 | else: 112 | print(f'waiting until {instance.id} is "OK" and reachable...') 113 | print('(note, this might take about two minutes)') 114 | waiter = client.get_waiter('instance_status_ok') 115 | waiter.wait(InstanceIds=[instance.id]) 116 | 117 | ssh_client = paramiko.SSHClient() 118 | ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 119 | ssh_client.connect(instance.public_ip_address, username='ubuntu', key_filename=key_file) 120 | for cmd in cmds: 121 | print(f'sending to {instance.id} command "{cmd}"') 122 | stdin, stdout, stderr = ssh_client.exec_command(cmd) 123 | [print(line) for line in stderr.readlines()] 124 | [print(line) for line in stdout.readlines()] 125 | sleep(0.5) 126 | 127 | ssh_client.close() 128 | 129 | 130 | def main(): 131 | 132 | all_runs = _get_runs_dict() 133 | all_instances = dict() 134 | all_cmds = dict() 135 | for run_name, parameters in all_runs.items(): 136 | print(f'spawning run "{run_name}" with config file "{parameters}...') 137 | 138 | train_cmds = _get_train_cmd(parameters, USE_USER_DATA) 139 | 140 | instance = _create_instance(_image_id, 141 | user_data=train_cmds if USE_USER_DATA else '', 142 | instance_name_suffix=f"vae-runner-{run_name}", 143 | description=f'running with "{parameters}"', 144 | instance_type='g4dn.12xlarge', 145 | ) 146 | 147 | all_instances[run_name] = instance 148 | all_cmds[run_name] = train_cmds 149 | 150 | print(f'lunched instance {instance.id} for config file {parameters}') 151 | 152 | if not USE_USER_DATA: 153 | for run_name in all_runs.keys(): 154 | send_cmds_via_ssh(all_cmds[run_name], all_instances[run_name], wait_extra_min=0) 155 | pass 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /scripts/visualize_encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal of this script is to visualize the encodings created using extract_code.py for debugging purposes 3 | """ 4 | import argparse 5 | # import pickle 6 | import os 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | # from torchvision import transforms 11 | # import lmdb 12 | # from tqdm import tqdm 13 | # from torchvision import datasets 14 | # from dataset import CodeRow, NamedDataset 15 | # from models.vqvae import VQVAE 16 | # import torch.nn as nn 17 | from utils import util_funcs 18 | from models.model_utils import get_model, get_dataset 19 | from torchvision import datasets, transforms, utils 20 | # import joblib 21 | from dataset import LMDBDataset 22 | 23 | 24 | def create_run(architecture, dataset, num_embeddings, num_workers, selection_fn, neighborhood, device, size, ckpt_epoch, embed_dim, **kwargs): 25 | global args, scheduler 26 | 27 | print('creating data loaders') 28 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, 29 | selection_fn, size, **kwargs) 30 | checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch) 31 | checkpoint_path = f'checkpoint/{checkpoint_name}' 32 | 33 | test_loader, train_loader = load_datasets(args, experiment_name, num_workers, dataset) 34 | 35 | print('Loading model') 36 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) 37 | model.load_state_dict(torch.load(os.path.join('..', checkpoint_path)), strict=False) 38 | model = model.to(device) 39 | model.eval() 40 | 41 | for batch in train_loader: 42 | print('decoding') 43 | X = model.decode_code(batch[1].to(next(model.parameters()).device)) 44 | 45 | print('decoded') 46 | utils.save_image( 47 | torch.cat([X], 0), 48 | 'X_img.png', 49 | nrow=1, 50 | normalize=True, 51 | range=(-1, 1), 52 | ) 53 | a = 5 54 | exit() 55 | 56 | 57 | def load_datasets(args, experiment_name, num_workers, dataset): 58 | db_name = util_funcs.create_checkpoint_name(experiment_name, args.ckpt_epoch)[:-3] + '_dataset[{}]'.format(dataset) 59 | train_dataset = LMDBDataset(os.path.join('..', 'codes', 'train_codes', db_name), args.architecture) 60 | test_dataset = LMDBDataset(os.path.join('..', 'codes', 'test_codes', db_name), args.architecture) 61 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=True) 62 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=True) 63 | return test_loader, train_loader 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser = util_funcs.base_parser(parser) 69 | parser = util_funcs.vqvae_parser(parser) 70 | parser = util_funcs.code_extraction_parser(parser) 71 | args = parser.parse_args() 72 | 73 | print(args) 74 | 75 | util_funcs.seed_generators(args.seed) 76 | 77 | create_run(**vars(args)) 78 | -------------------------------------------------------------------------------- /train_fista_pixelsnail.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn, optim 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from torchvision import datasets, transforms, utils 9 | 10 | try: 11 | from apex import amp 12 | 13 | except ImportError: 14 | amp = None 15 | 16 | from dataset import LMDBDataset 17 | # from models.pixelsnail import PixelSNAIL 18 | from models.fista_pixelsnail import FistaPixelSNAIL 19 | from scheduler import CycleScheduler 20 | 21 | import argparse 22 | import pickle 23 | import os 24 | #h 25 | import torch 26 | from torch.utils.data import DataLoader 27 | from torchvision import transforms 28 | import lmdb 29 | from tqdm import tqdm 30 | from torchvision import datasets 31 | from dataset import CodeRow, NamedDataset 32 | from models.vqvae import VQVAE 33 | import torch.nn as nn 34 | from utils import util_funcs 35 | from models.model_utils import get_model, get_dataset 36 | from dataset import LMDBDataset 37 | import numpy as np 38 | from tensorboardX import SummaryWriter 39 | import datetime 40 | 41 | 42 | def train(args, epoch, loader, model, optimizer, scheduler, device, writer, experiment_name, vqvae_model): 43 | loader = tqdm(loader, desc='PixelSnail training {}'.format(experiment_name)) 44 | 45 | criterion = nn.CrossEntropyLoss() 46 | multilabel_criterion = nn.BCEWithLogitsLoss() 47 | kl_criterion = nn.KLDivLoss() 48 | 49 | total_coefficients_loss = 0 50 | total_num_nonzeros_loss = 0 51 | total_atom_loss = 0 52 | total_steps = 0 53 | total_loss = 0 54 | for i, (top, bottom, label) in enumerate(loader): 55 | model.zero_grad() 56 | 57 | top = top.to(device) 58 | 59 | if args.hier == 'top': 60 | top = top.to(device) 61 | target = top 62 | reconstruction, num_nonzeros, sigma_matrix, coefficients = model(top) 63 | 64 | elif args.hier == 'bottom': 65 | bottom = bottom.to(device) 66 | target = bottom 67 | 68 | if hasattr(model, 'prepare_inputs'): # False if using DataParallel 69 | used_atoms_mask, gt_num_nonzeros = model.prepare_inputs(bottom) 70 | else: 71 | used_atoms_mask, gt_num_nonzeros = model.module.prepare_inputs(bottom) 72 | 73 | sampled_atoms, sampled_num_nonzeros, coefficients = model(bottom, used_atoms_mask, gt_num_nonzeros) 74 | 75 | if i % 25 == 0: 76 | save_reconstruction(bottom, coefficients, epoch, vqvae_model, i, 'train') 77 | 78 | # Todo: Expose different loss weights as script parameters 79 | atom_loss = multilabel_criterion(sampled_atoms, used_atoms_mask.float()) 80 | num_nonzeros_loss = criterion(sampled_num_nonzeros, gt_num_nonzeros) 81 | coefficients_loss = kl_criterion(coefficients, target) 82 | loss = coefficients_loss 83 | # loss = atom_loss + num_nonzeros_loss + coefficients_loss 84 | 85 | loss.backward() 86 | 87 | if scheduler is not None: 88 | scheduler.step() 89 | optimizer.step() 90 | 91 | # TODO: Plan what we want to log 92 | total_steps += 1 93 | total_coefficients_loss += coefficients_loss.item() 94 | total_num_nonzeros_loss += num_nonzeros_loss.item() 95 | total_atom_loss += atom_loss.item() 96 | total_loss += loss.item() 97 | 98 | lr = optimizer.param_groups[0]['lr'] 99 | 100 | loader.set_postfix( 101 | { 102 | 'Epoch': epoch + 1, 103 | 'Loss': f'{loss.item():.5f}', 104 | 'Coefficients loss': f'{coefficients_loss.item():.5f}', 105 | 'Num nonzeros loss': f'{num_nonzeros_loss.item():.5f}', 106 | 'Atom selection loss': f'{atom_loss.item():.5f}', 107 | 'LR': f'{lr:.5f}' 108 | } 109 | ) 110 | 111 | loader.update(1) 112 | 113 | return total_coefficients_loss / total_steps, total_num_nonzeros_loss / total_steps, total_atom_loss / total_steps, total_loss / total_steps 114 | 115 | 116 | def save_reconstruction(inthing, out, epoch, vqvae_model, i, phase): 117 | X1 = vqvae_model.decode_code(out.to(next(vqvae_model.parameters()).device)) 118 | X2 = vqvae_model.decode_code(inthing.clone().detach().to(next(vqvae_model.parameters()).device)) 119 | utils.save_image( 120 | torch.cat([X1, X2], 0), 121 | 'dumps/fista_pixelsnail_dumps/pixelsnail_reconstrution_epoch[{}]_batch[{}]_phase[{}].png'.format(epoch,i , phase), 122 | nrow=2, 123 | normalize=True, 124 | range=(-1, 1), 125 | ) 126 | 127 | 128 | def test(args, epoch, loader, model, optimizer, scheduler, device, writer, experiment_name, vqvae_model): 129 | loader = tqdm(loader, desc='PixelSnail testing {}'.format(experiment_name)) 130 | model.eval() 131 | criterion = nn.CrossEntropyLoss() 132 | 133 | total_accuracy = 0 134 | total_steps = 0 135 | total_loss = 0 136 | for i, (top, bottom, label) in enumerate(loader): 137 | if args.hier == 'top': 138 | top = top.to(device) 139 | target = top 140 | out, _ = model(top) 141 | 142 | elif args.hier == 'bottom': 143 | bottom = bottom.to(device) 144 | target = bottom 145 | out, _ = model(bottom) 146 | # out, _ = model(bottom, condition=top) 147 | 148 | 149 | if i % 25 == 0: 150 | save_reconstruction(bottom, out, epoch, vqvae_model, i, 'train') 151 | 152 | loss = criterion(out, target) 153 | 154 | _, pred = out.max(1) 155 | correct = (pred == target).float() 156 | accuracy = correct.sum() / target.numel() 157 | total_accuracy += accuracy 158 | total_steps += 1 159 | total_loss += loss.item() 160 | 161 | loader.set_postfix( 162 | { 163 | 'Epoch': epoch + 1, 164 | 'Loss': f'{loss.item():.5f}', 165 | 'Acc': f'{accuracy:.5f}' 166 | } 167 | ) 168 | 169 | loader.update(1) 170 | 171 | return total_accuracy / total_steps, total_loss / total_steps 172 | 173 | 174 | class PixelTransform: 175 | def __init__(self): 176 | pass 177 | 178 | def __call__(self, input): 179 | ar = np.array(input) 180 | 181 | return torch.from_numpy(ar).long() 182 | 183 | 184 | def create_run(architecture, dataset, num_embeddings, num_workers, selection_fn, neighborhood, device, embed_dim, size, **kwargs): 185 | global args, scheduler 186 | 187 | # Get VQVAE experiment name 188 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn=selection_fn, size=size, **kwargs) 189 | 190 | # Prepare logger 191 | writer = SummaryWriter(os.path.join('runs', 'pixelsnail_' + experiment_name + '2', str(datetime.datetime.now()))) 192 | 193 | # Load datasets 194 | test_loader, train_loader = load_datasets(args, experiment_name, num_workers, dataset) 195 | 196 | # Create model and optimizer 197 | model, optimizer = prepare_model_parts(train_loader) 198 | 199 | # Get checkpoint path for underlying VQ-VAE model 200 | checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, kwargs['ckpt_epoch']) 201 | checkpoint_path = f'checkpoint/{checkpoint_name}' 202 | 203 | # Load underlying VQ-VAE model for logging purposes 204 | vqvae_model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) 205 | vqvae_model.load_state_dict(torch.load(os.path.join(checkpoint_path)), strict=False) 206 | vqvae_model = vqvae_model.to(args.device) 207 | vqvae_model.eval() 208 | 209 | # Train model 210 | train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \ 211 | test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses, = \ 212 | run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model) 213 | 214 | return train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \ 215 | test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses 216 | 217 | 218 | def run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model): 219 | train_coefficients_loss = [] 220 | train_num_nonzeros_loss = [] 221 | train_atom_loss = [] 222 | train_losses = [] 223 | test_coefficients_loss = [] 224 | test_num_nonzeros_loss = [] 225 | test_atom_loss = [] 226 | test_losses = [] 227 | for i in range(args.pixelsnail_epoch): 228 | # Train epoch 229 | avg_train_coefficients_loss, avg_train_num_nonzeros_loss, avg_train_atom_loss, avg_train_losses = \ 230 | train(args, i, train_loader, model, optimizer, scheduler, args.device, writer, experiment_name, vqvae_model) 231 | 232 | # Test epoch 233 | avg_test_coefficients_loss, avg_test_num_nonzeros_loss, avg_test_atom_loss, avg_test_losses = \ 234 | test(args, i, train_loader, model, optimizer, scheduler, args.device, writer, experiment_name, vqvae_model) 235 | 236 | # Log train outputs 237 | train_coefficients_loss.append(avg_train_coefficients_loss) 238 | train_num_nonzeros_loss.append(avg_train_num_nonzeros_loss) 239 | train_num_nonzeros_loss.append(avg_train_atom_loss) 240 | train_atom_loss.append(avg_train_losses) 241 | train_losses.append(avg_train_losses) 242 | writer.add_scalar('train/coefficients_loss', avg_train_coefficients_loss) 243 | writer.add_scalar('train/num_nonzeros_loss', avg_train_num_nonzeros_loss) 244 | writer.add_scalar('train/atom_loss', avg_train_atom_loss) 245 | writer.add_scalar('train/loss', avg_train_losses) 246 | 247 | # Log test outputs 248 | test_coefficients_loss.append(avg_test_coefficients_loss) 249 | test_num_nonzeros_loss.append(avg_test_num_nonzeros_loss) 250 | test_num_nonzeros_loss.append(avg_test_atom_loss) 251 | test_atom_loss.append(avg_test_losses) 252 | test_losses.append(avg_train_losses) 253 | writer.add_scalar('test/coefficients_loss', avg_test_coefficients_loss) 254 | writer.add_scalar('test/num_nonzeros_loss', avg_test_num_nonzeros_loss) 255 | writer.add_scalar('test/atom_loss', avg_test_atom_loss) 256 | writer.add_scalar('test/loss', avg_test_losses) 257 | 258 | # Create checkpoint 259 | torch.save( 260 | {'model': model.module.state_dict(), 'args': args}, 261 | f'checkpoint/pixelsnail_{experiment_name}_{args.hier}_{str(i + 1).zfill(3)}.pt', 262 | ) 263 | 264 | return train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \ 265 | test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses 266 | 267 | 268 | def prepare_model_parts(train_loader): 269 | global args, scheduler 270 | 271 | # Load specific checkpoint to continue training 272 | ckpt = {} 273 | if args.pixelsnail_ckpt is not None: 274 | ckpt = torch.load(args.pixelsnail_ckpt) 275 | args = ckpt['args'] 276 | 277 | # Create PixelSnail object 278 | if args.hier == 'top': 279 | model = FistaPixelSNAIL( 280 | [args.size // 8, args.size // 8], 281 | 512, 282 | args.pixelsnail_channel, 283 | 5, 284 | 4, 285 | args.pixelsnail_n_res_block, 286 | args.pixelsnail_n_res_channel, 287 | dropout=args.pixelsnail_dropout, 288 | n_out_res_block=args.pixelsnail_n_out_res_block, 289 | ) 290 | 291 | elif args.hier == 'bottom': 292 | model = FistaPixelSNAIL( 293 | [args.size // 4, args.size // 4], 294 | 512, 295 | args.pixelsnail_channel, 296 | 5, 297 | 4, 298 | args.pixelsnail_n_res_block, 299 | args.pixelsnail_n_res_channel, 300 | attention=False, 301 | dropout=args.pixelsnail_dropout, 302 | n_cond_res_block=args.pixelsnail_n_cond_res_block, 303 | cond_res_channel=args.pixelsnail_n_res_channel, 304 | ) 305 | 306 | # Load saved checkpoint into PixelSnail object 307 | if 'model' in ckpt: 308 | model.load_state_dict(ckpt['model']) 309 | 310 | # Parallelize training 311 | model = nn.DataParallel(model) 312 | 313 | # Move model to proper device 314 | model = model.to(args.device) 315 | 316 | # Create other training objects 317 | optimizer = optim.Adam(model.parameters(), lr=args.pixelsnail_lr) 318 | if amp is not None: 319 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp) 320 | 321 | scheduler = None 322 | if args.pixelsnail_sched == 'cycle': 323 | scheduler = CycleScheduler( 324 | optimizer, args.pixelsnail_lr, n_iter=len(train_loader) * args.pixelsnail_epoch, momentum=None 325 | ) 326 | return model, optimizer 327 | 328 | 329 | def load_datasets(args, experiment_name, num_workers, dataset): 330 | """ 331 | Load LMDB datasets 332 | """ 333 | db_name = util_funcs.create_checkpoint_name(experiment_name, args.ckpt_epoch)[:-3] + '_dataset[{}]'.format(dataset) 334 | 335 | train_dataset = LMDBDataset(os.path.join('codes', 'train_codes', db_name), args.architecture) 336 | test_dataset = LMDBDataset(os.path.join('codes', 'test_codes', db_name ), args.architecture) 337 | 338 | train_loader = DataLoader(train_dataset, batch_size=args.pixelsnail_batch, shuffle=True, num_workers=num_workers) 339 | test_loader = DataLoader(test_dataset, batch_size=args.pixelsnail_batch, shuffle=True, num_workers=num_workers) 340 | return test_loader, train_loader 341 | 342 | 343 | def log_arguments(**arguments): 344 | experiment_name = util_funcs.create_experiment_name(**arguments) 345 | with open(os.path.join('checkpoint', experiment_name + '_args.txt'), 'w') as f: 346 | for key in arguments.keys(): 347 | f.write('{} : {} \n'.format(key, arguments[key])) 348 | 349 | 350 | if __name__ == '__main__': 351 | parser = argparse.ArgumentParser() 352 | parser = util_funcs.base_parser(parser) 353 | parser = util_funcs.vqvae_parser(parser) 354 | parser = util_funcs.code_extraction_parser(parser) 355 | parser = util_funcs.pixelsnail_parser(parser) 356 | args = parser.parse_args() 357 | 358 | print(args) 359 | log_arguments(**vars(args)) 360 | create_run(**vars(args)) 361 | -------------------------------------------------------------------------------- /train_pixelsnail.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn, optim 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from torchvision import datasets, transforms, utils 9 | 10 | try: 11 | from apex import amp 12 | 13 | except ImportError: 14 | amp = None 15 | 16 | from dataset import LMDBDataset 17 | from models.pixelsnail import PixelSNAIL 18 | from scheduler import CycleScheduler 19 | 20 | import argparse 21 | import pickle 22 | import os 23 | #h 24 | import torch 25 | from torch.utils.data import DataLoader 26 | from torchvision import transforms 27 | import lmdb 28 | from tqdm import tqdm 29 | from torchvision import datasets 30 | from dataset import CodeRow, NamedDataset 31 | from models.vqvae import VQVAE 32 | import torch.nn as nn 33 | from utils import util_funcs 34 | from models.model_utils import get_model, get_dataset 35 | from dataset import LMDBDataset 36 | import numpy as np 37 | from tensorboardX import SummaryWriter 38 | import datetime 39 | 40 | 41 | def train(args, epoch, loader, model, optimizer, scheduler, device, writer, experiment_name, vqvae_model): 42 | loader = tqdm(loader, desc='PixelSnail training {}'.format(experiment_name)) 43 | 44 | criterion = nn.CrossEntropyLoss() 45 | 46 | total_accuracy = 0 47 | total_steps = 0 48 | total_loss = 0 49 | for i, (top, bottom, label) in enumerate(loader): 50 | model.zero_grad() 51 | 52 | # Forward 53 | if args.hier == 'top': 54 | top = top.to(device) 55 | target = top 56 | out, _ = model(top) 57 | 58 | elif args.hier == 'bottom': 59 | bottom = bottom.to(device) 60 | target = bottom 61 | out, _ = model(bottom) 62 | 63 | if i % 25 == 0: 64 | save_reconstruction(bottom, out, epoch, vqvae_model, i, 'train') 65 | 66 | loss = criterion(out, target) 67 | loss.backward() 68 | 69 | if scheduler is not None: 70 | scheduler.step() 71 | optimizer.step() 72 | 73 | _, pred = out.max(1) 74 | correct = (pred == target).float() 75 | accuracy = correct.sum() / target.numel() 76 | total_accuracy += accuracy 77 | total_steps += 1 78 | total_loss += loss.item() 79 | 80 | lr = optimizer.param_groups[0]['lr'] 81 | 82 | loader.set_postfix( 83 | { 84 | 'Epoch': epoch + 1, 85 | 'Loss': f'{loss.item():.5f}', 86 | 'Acc': f'{accuracy:.5f}', 87 | 'LR': f'{lr:.5f}' 88 | } 89 | ) 90 | 91 | loader.update(1) 92 | 93 | return total_accuracy / total_steps, total_loss / total_steps 94 | 95 | 96 | def save_reconstruction(inthing, out, epoch, vqvae_model, i, phase): 97 | maxed_out = out.clone().detach().argmax(1).long() 98 | X1 = vqvae_model.decode_code(maxed_out.to(next(vqvae_model.parameters()).device)) 99 | X2 = vqvae_model.decode_code(inthing.clone().detach().to(next(vqvae_model.parameters()).device)) 100 | utils.save_image( 101 | torch.cat([X1, X2], 0), 102 | 'dumps/pixelsnail_dumps/pixelsnail2_reconstrution_epoch[{}]_batch[{}]_phase[{}].png'.format(epoch,i , phase), 103 | nrow=2, 104 | normalize=True, 105 | range=(-1, 1), 106 | ) 107 | 108 | 109 | def test(args, epoch, loader, model, optimizer, scheduler, device, writer, experiment_name, vqvae_model): 110 | loader = tqdm(loader, desc='PixelSnail testing {}'.format(experiment_name)) 111 | model.eval() 112 | criterion = nn.CrossEntropyLoss() 113 | 114 | total_accuracy = 0 115 | total_steps = 0 116 | total_loss = 0 117 | for i, (top, bottom, label) in enumerate(loader): 118 | if args.hier == 'top': 119 | top = top.to(device) 120 | target = top 121 | out, _ = model(top) 122 | 123 | elif args.hier == 'bottom': 124 | bottom = bottom.to(device) 125 | target = bottom 126 | out, _ = model(bottom) 127 | # out, _ = model(bottom, condition=top) 128 | 129 | if i % 25 == 0: 130 | save_reconstruction(bottom, out, epoch, vqvae_model, i, 'train') 131 | 132 | loss = criterion(out, target) 133 | 134 | _, pred = out.max(1) 135 | correct = (pred == target).float() 136 | accuracy = correct.sum() / target.numel() 137 | total_accuracy += accuracy 138 | total_steps += 1 139 | total_loss += loss.item() 140 | 141 | loader.set_postfix( 142 | { 143 | 'Epoch': epoch + 1, 144 | 'Loss': f'{loss.item():.5f}', 145 | 'Acc': f'{accuracy:.5f}' 146 | } 147 | ) 148 | 149 | loader.update(1) 150 | 151 | return total_accuracy / total_steps, total_loss / total_steps 152 | 153 | 154 | class PixelTransform: 155 | def __init__(self): 156 | pass 157 | 158 | def __call__(self, input): 159 | ar = np.array(input) 160 | 161 | return torch.from_numpy(ar).long() 162 | 163 | 164 | def create_run(architecture, dataset, num_embeddings, num_workers, selection_fn, neighborhood, device, embed_dim, size, **kwargs): 165 | global args, scheduler 166 | 167 | # Get VQVAE experiment name 168 | experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn=selection_fn, size=size, **kwargs) 169 | 170 | # Prepare logger 171 | writer = SummaryWriter(os.path.join('runs', 'pixelsnail_' + experiment_name + '2', str(datetime.datetime.now()))) 172 | 173 | # Load datasets 174 | test_loader, train_loader = load_datasets(args, experiment_name, num_workers, dataset) 175 | 176 | # Create model and optimizer 177 | model, optimizer = prepare_model_parts(train_loader) 178 | 179 | # Get checkpoint path for underlying VQ-VAE model 180 | checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, kwargs['ckpt_epoch']) 181 | checkpoint_path = f'checkpoint/{checkpoint_name}' 182 | 183 | # Load underlying VQ-VAE model for logging purposes 184 | vqvae_model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=False, **kwargs) 185 | vqvae_model.load_state_dict(torch.load(os.path.join(checkpoint_path)), strict=False) 186 | vqvae_model = vqvae_model.to(args.device) 187 | vqvae_model.eval() 188 | 189 | test_accs, test_losses, train_accs, train_losses = run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model) 190 | 191 | return train_accs, train_losses, test_accs, test_losses 192 | 193 | 194 | def run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model): 195 | train_accs = [] 196 | train_losses = [] 197 | test_accs = [] 198 | test_losses = [] 199 | for i in range(args.pixelsnail_epoch): 200 | train_avg_acc, train_avg_loss = train(args, i, train_loader, model, optimizer, scheduler, args.device, writer, 201 | experiment_name, vqvae_model) 202 | test_avg_acc, test_avg_loss = test(args, i, test_loader, model, optimizer, scheduler, args.device, writer, 203 | experiment_name, vqvae_model) 204 | 205 | 206 | # Create logs 207 | train_accs.append(train_avg_acc) 208 | train_losses.append(train_avg_loss) 209 | writer.add_scalar('train/accuracy', train_avg_acc) 210 | writer.add_scalar('train/loss', train_avg_loss) 211 | 212 | test_accs.append(test_avg_acc) 213 | test_losses.append(test_avg_loss) 214 | writer.add_scalar('test/accuracy', test_avg_acc) 215 | writer.add_scalar('test/loss', test_avg_loss) 216 | 217 | # Save checkpoint 218 | torch.save( 219 | {'model': model.module.state_dict(), 'args': args}, 220 | f'checkpoint/pixelsnail_{experiment_name}_{args.hier}_{str(i + 1).zfill(3)}.pt', 221 | ) 222 | 223 | return test_accs, test_losses, train_accs, train_losses 224 | 225 | 226 | def prepare_model_parts(train_loader): 227 | global args, scheduler 228 | 229 | # Load specific checkpoint to continue training 230 | ckpt = {} 231 | if args.pixelsnail_ckpt is not None: 232 | ckpt = torch.load(args.pixelsnail_ckpt) 233 | args = ckpt['args'] 234 | 235 | # Create PixelSnail object 236 | if args.hier == 'top': 237 | model = PixelSNAIL( 238 | [args.size // 8, args.size // 8], 239 | 512, 240 | args.pixelsnail_channel, 241 | 5, 242 | 4, 243 | args.pixelsnail_n_res_block, 244 | args.pixelsnail_n_res_channel, 245 | dropout=args.pixelsnail_dropout, 246 | n_out_res_block=args.pixelsnail_n_out_res_block, 247 | ) 248 | 249 | elif args.hier == 'bottom': 250 | model = PixelSNAIL( 251 | [args.size // 4, args.size // 4], 252 | 512, 253 | args.pixelsnail_channel, 254 | 5, 255 | 4, 256 | args.pixelsnail_n_res_block, 257 | args.pixelsnail_n_res_channel, 258 | attention=False, 259 | dropout=args.pixelsnail_dropout, 260 | n_cond_res_block=args.pixelsnail_n_cond_res_block, 261 | cond_res_channel=args.pixelsnail_n_res_channel, 262 | ) 263 | 264 | # Load saved checkpoint into PixelSnail object 265 | if 'model' in ckpt: 266 | model.load_state_dict(ckpt['model']) 267 | 268 | # Parallelize training 269 | model = nn.DataParallel(model) 270 | 271 | # Move model to proper device 272 | model = model.to(args.device) 273 | 274 | # Create other training objects 275 | optimizer = optim.Adam(model.parameters(), lr=args.pixelsnail_lr) 276 | if amp is not None: 277 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp) 278 | 279 | scheduler = None 280 | if args.pixelsnail_sched == 'cycle': 281 | scheduler = CycleScheduler( 282 | optimizer, args.pixelsnail_lr, n_iter=len(train_loader) * args.pixelsnail_epoch, momentum=None 283 | ) 284 | 285 | return model, optimizer 286 | 287 | 288 | def load_datasets(args, experiment_name, num_workers, dataset): 289 | """ 290 | Load LMDB datasets 291 | """ 292 | db_name = util_funcs.create_checkpoint_name(experiment_name, args.ckpt_epoch)[:-3] + '_dataset[{}]'.format(dataset) 293 | 294 | train_dataset = LMDBDataset(os.path.join('codes', 'train_codes', db_name), args.architecture) 295 | test_dataset = LMDBDataset(os.path.join('codes', 'test_codes', db_name ), args.architecture) 296 | 297 | train_loader = DataLoader(train_dataset, batch_size=args.pixelsnail_batch, shuffle=True, num_workers=num_workers) 298 | test_loader = DataLoader(test_dataset, batch_size=args.pixelsnail_batch, shuffle=True, num_workers=num_workers) 299 | return test_loader, train_loader 300 | 301 | 302 | def log_arguments(**arguments): 303 | experiment_name = util_funcs.create_experiment_name(**arguments) 304 | with open(os.path.join('checkpoint', experiment_name + '_args.txt'), 'w') as f: 305 | for key in arguments.keys(): 306 | f.write('{} : {} \n'.format(key, arguments[key])) 307 | 308 | 309 | if __name__ == '__main__': 310 | parser = argparse.ArgumentParser() 311 | parser = util_funcs.base_parser(parser) 312 | parser = util_funcs.vqvae_parser(parser) 313 | parser = util_funcs.code_extraction_parser(parser) 314 | parser = util_funcs.pixelsnail_parser(parser) 315 | args = parser.parse_args() 316 | 317 | print(args) 318 | log_arguments(**vars(args)) 319 | create_run(**vars(args)) 320 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch import nn, optim 6 | from torch.utils.data import DataLoader 7 | from torchvision import utils 8 | from tqdm import tqdm 9 | 10 | from scheduler import CycleScheduler 11 | from tensorboardX import SummaryWriter 12 | 13 | from utils.util_funcs import * 14 | from models.model_utils import get_model, get_dataset 15 | from datetime import datetime 16 | import numpy as np 17 | 18 | 19 | _TODAY = datetime.now().strftime("%Y_%m_%d") 20 | 21 | 22 | def do_epoch(epoch_num, loader, model, writer, experiment_name, device, optimizer=None, scheduler=None, phase='train', 23 | dictionary_loss_weight=0, sampling_iter=50, sample_size=25): 24 | if phase == 'train': 25 | model.train() 26 | else: 27 | model.eval() 28 | 29 | _today = getattr(args, 'today_str', _TODAY) 30 | tqdm_loader = tqdm(loader, desc='{}: {}'.format(phase, experiment_name)) 31 | 32 | criterion = nn.MSELoss() 33 | 34 | latent_loss_weight = 0.25 35 | 36 | mse_sum = 0 37 | mse_n = 0 38 | loss_sum = 0 39 | total_steps = 0 40 | quantization_steps = 0 41 | avg_dictionary_embedding_size = 0 42 | avg_Z_size = 0 43 | avg_norm_Z = 0 44 | avg_top_percentile = 0 45 | avg_num_zeros = 0 46 | 47 | for i, (img, label) in enumerate(tqdm_loader): 48 | model.zero_grad() 49 | 50 | img = img.to(device) 51 | 52 | out, latent_loss, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = model(img) 53 | recon_loss = criterion(out, img) 54 | encoder_latent_loss = latent_loss[0].mean() 55 | dictionary_latent_loss = latent_loss[1].mean() 56 | loss = recon_loss + latent_loss_weight * encoder_latent_loss + dictionary_latent_loss * dictionary_loss_weight 57 | 58 | if phase == 'train': 59 | loss.backward() 60 | 61 | if scheduler is not None: 62 | scheduler.step() 63 | optimizer.step() 64 | 65 | mse_sum += recon_loss.item() 66 | loss_sum += encoder_latent_loss.item() 67 | total_steps += 1 68 | quantization_steps += num_quantization_steps.mean().item() 69 | avg_dictionary_embedding_size += mean_D.mean().item() 70 | avg_Z_size += mean_Z.mean().item() 71 | avg_norm_Z += norm_Z.mean().item() 72 | avg_top_percentile += top_percentile.mean().item() 73 | avg_num_zeros += num_zeros.mean().item() 74 | 75 | lr = optimizer.param_groups[0]['lr'] 76 | 77 | tqdm_loader.set_postfix({ 78 | # 'Experiment': experiment_name, 79 | 'Epoch': epoch_num, 80 | 'mse': recon_loss.item(), 81 | 'latent_loss': encoder_latent_loss.item(), 82 | # 'avg_norm_Z': norm_Z.mean().item(), 83 | }) 84 | tqdm_loader.update(1) 85 | 86 | # writer.add_scalar('{}/epoch_{}/latent_loss'.format(phase, epoch_num), encoder_latent_loss.item(), i) 87 | # writer.add_scalar('{}/epoch_{}/avg_mse'.format(phase, epoch_num), recon_loss.item(), i) 88 | # writer.add_scalar('{}/epoch_{}/norm_Z'.format(phase, epoch_num), norm_Z.mean(), i) 89 | # writer.add_scalar('{}/epoch_{}/top_percentile'.format(phase, epoch_num), top_percentile.mean(), i) 90 | # writer.add_scalar('{}/epoch_{}/num_zeros'.format(phase, epoch_num), num_zeros.mean(), i) 91 | 92 | if i % sampling_iter == 0: 93 | model.eval() 94 | 95 | sample = img[:sample_size] 96 | 97 | with torch.no_grad(): 98 | out = model(sample) 99 | out = out[0] 100 | 101 | sample_path = os.path.join(args.summary_path, _today, experiment_name, 'samples') 102 | if not os.path.exists(sample_path): 103 | os.makedirs(sample_path) 104 | 105 | utils.save_image( 106 | torch.cat([sample, out], 0), os.path.join(sample_path, f'Epoch_{str(epoch_num).zfill(5)}_batch_{str(i).zfill(5)}.png'), 107 | nrow=sample_size, 108 | normalize=True, 109 | range=(-1, 1), 110 | ) 111 | 112 | if phase == 'train': 113 | model.train() 114 | 115 | return mse_sum / total_steps, loss_sum / total_steps, quantization_steps / total_steps, avg_Z_size / total_steps, avg_dictionary_embedding_size / total_steps, avg_norm_Z / total_steps, avg_top_percentile / total_steps, avg_num_zeros / total_steps 116 | 117 | 118 | def create_training_run(size, num_epochs, lr, sched, dataset, architecture, data_path, device, num_embeddings, neighborhood, selection_fn, num_workers, vae_batch, eval_iter, embed_dim, parallelize, download, **kwargs): 119 | experiment_name = create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn=selection_fn, size=size, lr=lr, **kwargs) 120 | _today = getattr(args, 'today_str', _TODAY) 121 | log_arguments(**vars(args)) 122 | writer = SummaryWriter(os.path.join(args.summary_path, _today, experiment_name)) 123 | 124 | print('Loading datasets') 125 | train_dataset, test_dataset = get_dataset(dataset, data_path, size, download) 126 | 127 | print('Creating loaders') 128 | train_loader = DataLoader(train_dataset, batch_size=vae_batch, shuffle=True, num_workers=num_workers) 129 | test_loader = DataLoader(test_dataset, batch_size=vae_batch, shuffle=True, num_workers=num_workers) 130 | 131 | print('Initializing models') 132 | model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallelize, **kwargs) 133 | optimizer = optim.Adam(model.parameters(), lr=lr) 134 | scheduler = None 135 | if sched == 'cycle': 136 | scheduler = CycleScheduler( 137 | optimizer, lr, n_iter=len(train_loader) * num_epochs, momentum=None 138 | ) 139 | 140 | train_mses = [] 141 | train_losses = [] 142 | test_mses = [] 143 | test_losses = [] 144 | 145 | for epoch_ind in range(1, num_epochs+1): 146 | avg_mse, avg_loss, avg_quantization_steps, avg_Z, avg_D, avg_norm_Z, avg_top_percentile, avg_num_zeros = do_epoch(epoch_ind, train_loader, model, writer, experiment_name, device, optimizer, scheduler, dictionary_loss_weight=kwargs['dictionary_loss_weight'], sampling_iter=kwargs['sampling_iter'],sample_size=kwargs['sample_size']) 147 | train_mses.append(avg_mse) 148 | train_losses.append(avg_loss) 149 | writer.add_scalar('train/avg_mse', avg_mse, epoch_ind) 150 | writer.add_scalar('train/avg_loss', avg_loss, epoch_ind) 151 | writer.add_scalar('train/avg_quantization_steps', avg_quantization_steps, epoch_ind) 152 | writer.add_scalar('train/avg_Z', avg_Z, epoch_ind) 153 | writer.add_scalar('train/avg_D', avg_D, epoch_ind) 154 | writer.add_scalar('train/avg_norm_Z', avg_norm_Z, epoch_ind) 155 | writer.add_scalar('train/avg_top_percentile', avg_top_percentile, epoch_ind) 156 | writer.add_scalar('train/avg_num_zeros', avg_num_zeros, epoch_ind) 157 | 158 | if epoch_ind % kwargs['checkpoint_freq'] == 0: 159 | 160 | cp_path = os.path.join(args.checkpoint_path, _today, create_checkpoint_name(experiment_name, epoch_ind)) 161 | os.makedirs(osp.dirname(cp_path), exist_ok=True) 162 | if parallelize: # If using DataParallel we need to access the inner module 163 | torch.save(model.module.state_dict(), cp_path) 164 | else: 165 | torch.save(model.state_dict(), cp_path) 166 | 167 | if epoch_ind % eval_iter == 0: 168 | avg_mse, avg_loss, avg_quantization_steps, avg_Z, avg_D, avg_norm_Z, avg_top_percentile, avg_num_zeros = do_epoch(epoch_ind, test_loader, model, writer, experiment_name, device, optimizer, scheduler, phase='test') 169 | 170 | test_mses.append(avg_mse) 171 | test_losses.append(avg_loss) 172 | writer.add_scalar('test/avg_loss', avg_loss, epoch_ind) 173 | writer.add_scalar('test/avg_mse', avg_mse, epoch_ind) 174 | writer.add_scalar('test/avg_quantization_steps', avg_quantization_steps, epoch_ind) 175 | writer.add_scalar('test/avg_Z', avg_Z, epoch_ind) 176 | writer.add_scalar('test/avg_D', avg_D, epoch_ind) 177 | writer.add_scalar('test/avg_norm_Z', avg_norm_Z, epoch_ind) 178 | writer.add_scalar('test/avg_top_percentile', avg_top_percentile, epoch_ind) 179 | writer.add_scalar('test/avg_num_zeros', avg_num_zeros, epoch_ind) 180 | model.train() 181 | 182 | return train_mses, train_losses, test_mses, test_losses 183 | 184 | 185 | def log_arguments(**arguments): 186 | _today = arguments.get('today_str', _TODAY) 187 | experiment_name = create_experiment_name(**arguments) 188 | cp_path = os.path.join(arguments['checkpoint_path'], _today, experiment_name) 189 | os.makedirs(cp_path, exist_ok=True) 190 | with open(os.path.join(cp_path, 'args.txt'), 'w') as f: 191 | for key in arguments.keys(): 192 | f.write('{} : {} \n'.format(key, arguments[key])) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser = base_parser(parser) 198 | parser = vqvae_parser(parser) 199 | parser = code_extraction_parser(parser) 200 | args = parser.parse_args() 201 | 202 | print(str(args).replace(',', ',\n\t')) #[print(f'\t{k}: {v}') for k,v in arguments.items()] 203 | args.today_str = datetime.now().strftime('%Y_%m_%d') 204 | seed_generators(args.seed) 205 | device = args.device 206 | 207 | create_training_run(**vars(args)) 208 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amzn/sparse-vqvae/33ca864b6a20c644c3c825ce958fd20a99349dda/utils/__init__.py -------------------------------------------------------------------------------- /utils/pyfista_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils to create data for tests on synthetic data 3 | """ 4 | import torch 5 | from utils.pyfista import FistaFunction 6 | import numpy as np 7 | 8 | 9 | def create_random_dictionary(normalize=False): 10 | """ 11 | Creates a random (normal) dictionary. 12 | :param normalize: Bool. Normalize L0 norm of dictionary if True. 13 | :return: Tensor. Created dictionary 14 | """ 15 | dictionary = torch.rand((64, 512)) 16 | if normalize: 17 | dictionary = dictionary.__div__(torch.norm(dictionary, p=2, dim=0)) 18 | 19 | return dictionary 20 | 21 | 22 | def create_normalized_noised_inputs(normalize_dictionary, normalize_x, sparsity_num=10, num_samples=1000): 23 | """ 24 | Creates random data based on random (normal) dictionary. 25 | The data is created as a random linear combination of the dictionary atoms. A random (normal) noise is added to the data. 26 | :param normalize_dictionary: Bool. Underlying dictionary is normalized if True. 27 | :param normalize_x: Bool. Normalize L0 norm of data if True. 28 | :param sparsity_num: Int. Number of dictionary atoms to use in the creation of the data 29 | :param num_samples: Int. Number of samples to generate. 30 | :return: Tensor. Created data 31 | """ 32 | 33 | 34 | dictionary = create_random_dictionary(normalize_dictionary) 35 | X, atoms, coefs = create_sparse_input(dictionary, K=sparsity_num, num_samples=num_samples) 36 | noise = torch.randn(X.size()) 37 | X += noise 38 | 39 | if normalize_x: 40 | X = X.__div__(torch.norm(X, p=2, dim=0)) 41 | 42 | return X, dictionary 43 | 44 | 45 | def load_real_inputs(normalize_dictionary, normalize_x, sample_id=5): 46 | """ 47 | Creates a random dictionary and loads an encoding saved from an untrained encoder. 48 | :param normalize_dictionary: Bool. Underlying dictionary is normalized if True. 49 | :param normalize_x: Bool. Normalize L0 norm of data if True. 50 | :param sample_id: Int. Id of data sample to load 51 | :return: Tensor, Tensor. Created data, created dictionary 52 | """ 53 | 54 | # Create random dictionary 55 | dictionary = create_random_dictionary(normalize_dictionary) 56 | 57 | # Load data 58 | data = torch.from_numpy(np.load('unlearned_encodings/unlearned_cifar10_{}.npy'.format(sample_id))).reshape(-1, 64).t() 59 | inds = list(np.random.choice(list(range(data.size()[1])), 1000, replace=False)) 60 | X = data[:, inds] 61 | 62 | if normalize_x: 63 | X = X.__div__(torch.norm(X, p=2, dim=0)) 64 | 65 | return X, dictionary 66 | 67 | 68 | def create_sparse_input(dictionary, K=1, num_samples=1): 69 | """ 70 | Create sparse data given a dictionary and sparsity value. 71 | :param dictionary: Tensor. Dictionary to base the data on. 72 | :param K: Int. Sparsity value, use this number of atoms to create the data. 73 | :param num_samples: Number of samples to create. 74 | :return: Tensor. Created data 75 | """ 76 | atoms = torch.randint(dictionary.size()[1], (num_samples, K)) 77 | coefs = torch.randn((K, num_samples, 1)) 78 | 79 | X = [] 80 | for sample_ind in range(num_samples): 81 | input = dictionary[:, atoms[sample_ind,:]].mm(coefs[:, sample_ind]) 82 | 83 | X.append(input) 84 | 85 | X = torch.stack(X, 1).squeeze(-1) 86 | 87 | return X, atoms, coefs -------------------------------------------------------------------------------- /utils/pyomp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.decomposition import DictionaryLearning 3 | 4 | 5 | def get_largest_eigenvalue(X): 6 | eigs = torch.eig(X, eigenvectors=False).eigenvalues 7 | max_eign = eigs.max(dim=0) 8 | return max_eign.values[0] 9 | 10 | 11 | def shrink_function(Z, cutoff): 12 | cutted = shrink1(Z, cutoff) 13 | maxed = shrink2(Z, cutted) 14 | signed = shrink3(Z, maxed) 15 | return signed 16 | 17 | 18 | def shrink3(Z, maxed): 19 | signed = maxed * torch.sign(Z) 20 | return signed 21 | 22 | 23 | def shrink2(Z, cutted): 24 | maxed = torch.max(cutted, torch.zeros(Z.size(), dtype=Z.dtype).cuda(3)) 25 | return maxed 26 | 27 | 28 | def shrink1(Z, cutoff): 29 | cutted = torch.abs(Z) - cutoff 30 | return cutted 31 | 32 | 33 | def reconstruction_distance(D, cur_Z, last_Z): 34 | distance = torch.norm(D.mm(last_Z - cur_Z), p=2, dim=0) / torch.norm(D.mm(last_Z), p=2, dim=0) 35 | max_distance = distance.max() 36 | return distance, max_distance 37 | 38 | 39 | def OMP(X, D, K, tolerance, debug=False): 40 | Dt = D.t() 41 | Dpinv = torch.pinverse(D) 42 | r = X 43 | I = [] 44 | stopping = False 45 | last_sparse_code = torch.zeros((D.size()[1], X.size()[1]), dtype=X.dtype)#.cuda(3) 46 | sparse_code = torch.zeros((D.size()[1], X.size()[1]), dtype=X.dtype)#.cuda(3) 47 | 48 | step = 0 49 | while not stopping: 50 | k_hat = torch.argmax(Dt.mm(r), 0) 51 | I.append(k_hat) 52 | sparse_code = Dpinv.mm(X) # Should be: (torch.pinverse(D[:,I])*X).sum(0) 53 | r = X - D.mm(sparse_code) 54 | 55 | distance, max_distance = reconstruction_distance(D, sparse_code, last_sparse_code) 56 | stopping = len(I) >= K or max_distance < tolerance 57 | last_sparse_code = sparse_code 58 | 59 | if debug and step % 1 == 0: 60 | print('Step {}, code improvement: {}, below tolerance: {}'.format(step, max_distance, (distance < tolerance).float().mean().item())) 61 | 62 | step += 1 63 | 64 | return sparse_code 65 | 66 | 67 | def _update_logical(logical, to_add): 68 | running_idx = torch.arange(to_add.shape[0], device=to_add.device) 69 | logical[running_idx, to_add] = 1 70 | 71 | 72 | def Batch_OMP(data, dictionary, max_nonzero, tolerance=1e-7, debug=False): 73 | """ 74 | for details on variable names, see 75 | https://sparse-plex.readthedocs.io/en/latest/book/pursuit/omp/batch_omp.html 76 | or the original paper 77 | http://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf 78 | 79 | NOTE - the implementation below works on transposed versions of the input signal to make the batch size the first 80 | coordinate, which is how pytorch expects the data.. 81 | """ 82 | vector_dim, batch_size = data.size() 83 | dictionary_t = dictionary.t() 84 | G = dictionary_t.mm(dictionary) # this is a Gram matrix 85 | eps = torch.norm(data, dim=0) # the residual, initalized as the L2 norm of the signal 86 | h_bar = dictionary_t.mm(data).t() # initial correlation vector, transposed to make batch_size the first dimension 87 | 88 | # note - below this line we no longer use "data" or "dictionary" 89 | 90 | h = h_bar 91 | x = torch.zeros_like(h_bar) # the resulting sparse code 92 | L = torch.ones(batch_size, 1, 1, device=h.device) # Contains the progressive Cholesky of G in selected indices 93 | I = torch.ones(batch_size, 0, device=h.device).long() 94 | I_logic = torch.zeros_like(h_bar).bool() # used to zero our elements is h before argmax 95 | delta = torch.zeros(batch_size, device=h.device) # to track errors 96 | 97 | k = 0 98 | while k < max_nonzero and eps.max() > tolerance: 99 | k += 1 100 | # use "I_logic" to make sure we do not select same index twice 101 | index = (h*(~I_logic).float()).abs().argmax(dim=1) # todo - can we use "I" rather than "I_logic" 102 | _update_logical(I_logic, index) 103 | batch_idx = torch.arange(batch_size, device=G.device) 104 | expanded_batch_idx = batch_idx.unsqueeze(0).expand(k, batch_size).t() 105 | 106 | if k > 1: # Cholesky update 107 | # Following line is equivalent to: 108 | # G_stack = torch.stack([G[I[i, :], index[i]] for i in range(batch_size)], dim=0).view(batch_size, k-1, 1) 109 | G_stack = G[I[batch_idx, :], index[expanded_batch_idx[...,:-1]]].view(batch_size, k-1, 1) 110 | w = torch.triangular_solve(G_stack, L, upper=False, ).solution.view(-1, 1, k-1) 111 | w_corner = torch.sqrt(1-(w**2).sum(dim=2, keepdim=True)) # <- L corner element: sqrt(1- w.t().mm(w)) 112 | 113 | # do concatenation into the new Cholesky: L <- [[L, 0], [w, w_corner]] 114 | k_zeros = torch.zeros(batch_size, k-1, 1, device=h.device) 115 | L = torch.cat(( 116 | torch.cat((L, k_zeros), dim=2), 117 | torch.cat((w, w_corner), dim=2), 118 | ), dim=1) 119 | 120 | # update non-zero indices 121 | I = torch.cat([I, index.unsqueeze(1)], dim=1) 122 | 123 | # x = solve L 124 | # The following line is equivalent to: 125 | # h_stack = torch.stack([h_bar[i, I[i, :]] for i in range(batch_size)]).unsqueeze(2) 126 | h_stack = h_bar[expanded_batch_idx, I[batch_idx, :]].view(batch_size, k, 1) 127 | x_stack = torch.cholesky_solve(h_stack, L) 128 | 129 | # de-stack x into the non-zero elements 130 | # The following line is equivalent to: 131 | # for i in range(batch_size): 132 | # x[i:i+1, I[i, :]] = x_stack[i, :].t() 133 | x[batch_idx.unsqueeze(1), I[batch_idx]] = x_stack[batch_idx].squeeze(-1) 134 | 135 | # beta = G_I * x_I 136 | # The following line is equivalent to: 137 | # beta = torch.cat([x[i:i+1, I[i, :]].mm(G[I[i, :], :]) for i in range(batch_size)], dim=0) 138 | beta = x[batch_idx.unsqueeze(1), I[batch_idx]].unsqueeze(1).bmm(G[I[batch_idx], :]).squeeze(1) 139 | 140 | h = h_bar - beta 141 | 142 | # update residual 143 | new_delta = (x * beta).sum(dim=1) 144 | eps += delta-new_delta 145 | delta = new_delta 146 | 147 | if debug and k % 1 == 0: 148 | print('Step {}, residual: {:.4f}, below tolerance: {:.4f}'.format(k, eps.max(), (eps < tolerance).float().mean().item())) 149 | 150 | return x.t() # transpose since sparse codes should be used as D * x 151 | 152 | 153 | if __name__ == '__main__': 154 | import time 155 | from tqdm import tqdm 156 | torch.manual_seed(0) 157 | use_gpu = torch.cuda.device_count() > 0 158 | device = 'cuda' if use_gpu else 'cpu' 159 | 160 | num_nonzeros = 4 161 | num_samples = int(1e4) 162 | num_atoms = 512 163 | embedding_size = 64 164 | 165 | Wd = torch.randn(embedding_size, num_atoms) 166 | Wd = torch.nn.functional.normalize(Wd, dim=0).to(device) 167 | 168 | codes = [] 169 | for i in tqdm(range(num_samples), desc='generating codes... '): 170 | tmp = torch.zeros(num_atoms).to(device) 171 | tmp[torch.randperm(num_atoms)[:num_nonzeros]] = 0.5 * torch.rand(num_nonzeros).to(device) + 0.5 172 | codes.append(tmp) 173 | codes = torch.stack(codes, dim=1) 174 | X = Wd.mm(codes) 175 | # X += torch.randn(X.size()) / 100 # add noise 176 | # X = torch.nn.functional.normalize(X, dim=0) # normalize signal 177 | 178 | if use_gpu: # warm start? 179 | print('doing warm start...') 180 | Batch_OMP(X[:, :min(num_nonzeros, 1000)], Wd, num_nonzeros) 181 | 182 | tic = time.time() 183 | Z2 = Batch_OMP(X, Wd, num_nonzeros, debug=True) 184 | Z2_time = time.time() - tic 185 | print(f'Z2, {torch.isclose(codes, Z2, rtol=1e-03, atol=1e-05).float().mean()}, time/sample={1e6*Z2_time/num_samples/num_nonzeros:.4f}usec') 186 | pass 187 | 188 | 189 | -------------------------------------------------------------------------------- /utils/util_funcs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from argparse import ArgumentParser 6 | 7 | 8 | osp = os.path 9 | _e_root = '.' 10 | n_gpu = torch.cuda.device_count() if torch.cuda.is_available() else 0 11 | n_cpu = torch.multiprocessing.cpu_count() 12 | NUM_WORKERS = min(n_cpu, max(8, 12 * n_gpu)) 13 | 14 | data_root = './data' 15 | if osp.isdir('/home/ubuntu/data/imagenet') and osp.isdir('/home/ubuntu/data/cifar10'): 16 | data_root = '/home/ubuntu/data' # if data was copied locally 17 | 18 | 19 | def base_parser(parser: ArgumentParser): 20 | parser.add_argument('--size', type=int, default=64, help='Resize samples to this size') 21 | parser.add_argument('--seed', type=int, default=7, help='Random seed') 22 | parser.add_argument('--num_workers', type=int, default=NUM_WORKERS, help='Number of data loading workers') 23 | parser.add_argument('--device', default='cuda', type=str, help='Device to run on [cpu | cuda | cuda:device_number]') 24 | parser.add_argument('--data_path', default=data_root, type=str, help='Path to dataset') 25 | parser.add_argument('--checkpoint_path', default=osp.join(_e_root, 'checkpoint'), 26 | type=str, help='Path where checkpoints are saved') 27 | parser.add_argument('--summary_path', default=osp.join(_e_root, 'summary'), 28 | type=str, help='Path where saved (tensorboard) summries are written') 29 | parser.add_argument('--dataset', default='cifar10', type=str, 30 | help='Name of dataset to use [cifar10 | cifar100 | imagenet]') 31 | parser.add_argument('-n', '--experiment_name', default='', type=str, help='Name of experiment') 32 | 33 | return parser 34 | 35 | 36 | def vqvae_parser(parser: ArgumentParser): 37 | # vqvae training params 38 | parser.add_argument('--alpha', type=float, default=0.3, help='Fista shrinkage parameter') 39 | parser.add_argument('-k', '--num_nonzero', type=int, default=2, help='OMP maximal number of nonzero') 40 | 41 | parser.add_argument('--num_epochs', type=int, default=200, help='Train VQ-VAE model for this number of epochs') 42 | parser.add_argument('--embed_dim', type=int, default=64, help='Size of the latent space') 43 | parser.add_argument('--sample_gradients', type=int, default=1000, help='Number of patches to backprop against in FISTA. 0 for all') 44 | 45 | parser.add_argument('--vae_batch', type=int, default=128, help='Batch size for the VQ-VAE model') 46 | parser.add_argument('--lr', type=float, default=3e-4, help='VQ-VAE learning rate') 47 | parser.add_argument('--dictionary_loss_weight', type=float, default=0.0, help='Weight of the dictionary loss term.') 48 | parser.add_argument('--sched', type=str, help='Scheduler to use. [cycle | ]. Default empty') 49 | parser.add_argument('--architecture', default='vqvae', type=str, help='Name of architecture to use [vqvae]') 50 | parser.add_argument('--num_embeddings', default=512, type=int, help='Number of embeddings in code book') 51 | parser.add_argument('-sel', '--selection_fn', default='vanilla', type=str, help='Function to select dictionary vectors [omp | vanilla | fista ]') 52 | parser.add_argument('--neighborhood', default=1, type=int, help='Only relevant for the OMP.') 53 | parser.add_argument('--checkpoint_freq', default=5, type=int, help='Checkpoint model every this number of epochs.') 54 | parser.add_argument('--backward_dict', default=1, type=int, help='1 to do backprop w.r.t. the dictionary, 0 otherwise.') 55 | parser.add_argument('-stride', '--num_strides', default=2, type=int, help='Number of stride blocks, every block reduces the size of the quantized image by half.') 56 | parser.add_argument('--use_backwards_simd', default=True, type=bool, help='Flag to use matrix version of the FISTA backprop.') 57 | parser.add_argument('--download', default=False, type=bool, help='Flag to download the dataset if needed.') 58 | 59 | parser.add_argument('--parallelize', dest='parallelize', action='store_true', help='Flag to use DataParallel') 60 | parser.add_argument('--no_parallelize', dest='parallelize', action='store_false', help='Flag not to use DataParallel') 61 | parser.set_defaults(parallelize=True) 62 | 63 | parser.add_argument('--normalize_dict', dest='normalize_dict', action='store_true', help='Flag to normalize dictionary') 64 | parser.add_argument('--no_normalize_dict', dest='normalize_dict', action='store_false', help='Flag not to normalize dictionary') 65 | parser.set_defaults(normalize_dict=True) 66 | 67 | parser.add_argument('--normalize_z', dest='normalize_z', action='store_true', help='Flag to normalize found sparse code') 68 | parser.add_argument('--no_normalize_z', dest='normalize_z', action='store_false', help='Flag not to normalize found sparse code') 69 | parser.set_defaults(normalize_z=False) 70 | 71 | parser.add_argument('--normalize_x', dest='normalize_x', action='store_true', help='Flag to normalize quantization input') 72 | parser.add_argument('--no_normalize_x', dest='normalize_x', action='store_false', help='Flag not to normalize quantization input') 73 | parser.set_defaults(normalize_x=True) 74 | 75 | # Training evaluation and sampling parameters 76 | parser.add_argument('--eval_iter', default=1, type=int, help='Eval every [value] iterations') 77 | parser.add_argument('--sampling_iter', default=25, type=int, help='Sample every [value] batches') 78 | parser.add_argument('--sample_size', default=25, type=int, help='Number of images to sample every [value] batches') 79 | 80 | # Test parameters 81 | parser.add_argument('--is_enforce_sparsity', dest='is_enforce_sparsity', action='store_true', 82 | help='Flag to select only top-K sparse code values, needed only for FISTA') 83 | parser.set_defaults(is_enforce_sparsity=False) 84 | 85 | parser.add_argument('--is_quantize_coefs', dest='is_quantize_coefs', action='store_true', 86 | help='Flag to quantize sparse code coefficients for compression') 87 | parser.set_defaults(is_quantize_coefs=False) 88 | 89 | return parser 90 | 91 | 92 | def code_extraction_parser(parser: ArgumentParser): 93 | parser.add_argument('--ckpt_epoch', type=int, default=200, help='Epoch number of the VQVAE model to load') 94 | 95 | return parser 96 | 97 | 98 | def pixelsnail_parser(parser: ArgumentParser): 99 | parser.add_argument('--pixelsnail_batch', type=int, default=256, help='Size of PixelSnail batch') 100 | parser.add_argument('--pixelsnail_epoch', type=int, default=420, help='Train PixelSnail model for this number of epochs') 101 | parser.add_argument('--hier', type=str, default='bottom', help='Used for cascaded VQ-VAE. ', choices='bottom') # FIXME - Use only `bottom` for now 102 | parser.add_argument('--pixelsnail_lr', type=float, default=3e-4, help='PixelSnail learning rate') 103 | parser.add_argument('--pixelsnail_channel', type=int, default=256, help='Number of channels to expand to in the PixelSnail architecture') 104 | parser.add_argument('--pixelsnail_n_res_block', type=int, default=4, help='Number of residual blocks in PixelSnail') 105 | parser.add_argument('--pixelsnail_n_res_channel', type=int, default=256, help='Number of channels to expand to in residual blocks in PixelSnail') 106 | parser.add_argument('--pixelsnail_n_out_res_block', type=int, default=0, ) 107 | parser.add_argument('--pixelsnail_n_cond_res_block', type=int, default=3) 108 | parser.add_argument('--pixelsnail_dropout', type=float, default=0.1) 109 | parser.add_argument('--amp', type=str, default='O0') 110 | parser.add_argument('--pixelsnail_sched', type=str) 111 | parser.add_argument('--pixelsnail_ckpt', type=str, help='PixelSnail checkpoint to continue training from') 112 | 113 | return parser 114 | 115 | 116 | def sampling_parser(parser: ArgumentParser): 117 | parser.add_argument('--pixelsnail_ckpt_epoch', type=int, default=420, help='PixelSnail epoch to load') 118 | parser.add_argument('--temp', type=float, default=1.0) 119 | parser.add_argument('--num_threads', type=int, default=3, help='Number of threads to multithread the sampling on') 120 | 121 | return parser 122 | 123 | 124 | def create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, lr=None, **kwargs): 125 | additional_remarks = '' 126 | 127 | if kwargs['experiment_name'] == '': 128 | # if 'old_experiment_name_format' in kwargs and kwargs['old_experiment_name_format']: 129 | if hasattr(kwargs, 'experiment_name') and len(kwargs['experiment_name']) > 0: 130 | additional_remarks += '_experiment_name_{}'.format(kwargs['experiment_name']) 131 | 132 | if has_value_and_true(kwargs, 'normalize_dict'): 133 | pass # additional_remarks += '_normalize_dict_{}'.format(True) 134 | else: 135 | additional_remarks += '_normalize_dict_{}'.format(False) 136 | 137 | if has_value_and_true(kwargs, 'normalize_x'): 138 | pass # additional_remarks += '_normalize_x_{}'.format(True) 139 | else: 140 | additional_remarks += '_normalize_x_{}'.format(False) 141 | 142 | # if has_value_and_true(kwargs, 'normalize_z'): 143 | # additional_remarks += '_normalize_z_{}'.format(True) 144 | # else: 145 | # additional_remarks += '_normalize_z_{}'.format(False) 146 | 147 | additional_remarks += '_size_{}'.format(size) 148 | 149 | if hasattr(kwargs, 'ckpt_epoch'): 150 | additional_remarks += '_ckpt_epoch_{}'.format(kwargs['ckpt_epoch']) 151 | 152 | if hasattr(kwargs, 'sample_gradients'): 153 | additional_remarks += '_sample_gradients_{}'.format(kwargs['sample_gradients']) 154 | 155 | if hasattr(kwargs, 'backward_dict') and selection_fn in ('fista', 'omp'): 156 | additional_remarks += '_backward_dict_{}'.format(kwargs['backward_dict']) 157 | 158 | if hasattr(kwargs, 'lr'): 159 | additional_remarks += 'lr_{}'.format(kwargs['lr']) 160 | # else: 161 | # additional_remarks += 'lr_{}'.format(lr) 162 | 163 | alpha = kwargs['alpha'] 164 | experiment_name = f'{architecture}_{dataset}_num_embeddings{num_embeddings}_neighborhood{neighborhood}_selectionFN{selection_fn}_alpha{alpha}' + additional_remarks 165 | else: 166 | experiment_name = kwargs['experiment_name'] 167 | 168 | return experiment_name 169 | 170 | 171 | def create_checkpoint_name(experiment_name, epoch_ind): 172 | return f'{experiment_name}/{str(epoch_ind).zfill(3)}.pt' 173 | 174 | 175 | def has_value_and_true(dictionary, key): 176 | return key in dictionary and dictionary[key] 177 | 178 | 179 | def seed_generators(seed): 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | 184 | 185 | --------------------------------------------------------------------------------