├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── _static │ └── custom.css ├── make.bat ├── requirements.txt └── source │ ├── GATA1_example.rst │ ├── _autosummary │ ├── perturbnet.baselines.knn.rst │ ├── perturbnet.baselines.linear.rst │ ├── perturbnet.baselines.rst │ ├── perturbnet.chemicalvae.chemicalVAE.rst │ ├── perturbnet.chemicalvae.rst │ ├── perturbnet.cinn.FeatureAttr.rst │ ├── perturbnet.cinn.flow.rst │ ├── perturbnet.cinn.flow_generate.rst │ ├── perturbnet.cinn.rst │ ├── perturbnet.genotypevae.genotypeVAE.rst │ ├── perturbnet.genotypevae.rst │ ├── perturbnet.pytorch_scvi.distributions.rst │ ├── perturbnet.pytorch_scvi.rst │ ├── perturbnet.pytorch_scvi.scvi_generate_z.rst │ └── perturbnet.util.rst │ ├── _static │ └── custom.css │ ├── api.rst │ ├── chemical_perturbation.rst │ ├── coding_variant.rst │ ├── conf.py │ ├── feature_attribution.rst │ ├── genetic_perturbation.rst │ ├── index.rst │ ├── tutorials │ ├── GATA1_example.nblink │ ├── GATA1_prediction_analysis.ipynb │ ├── Integrated_gradients_example.ipynb │ ├── Tutorial_PerturbNet_Chemicals.ipynb │ ├── Tutorial_PerturbNet_Genetic.ipynb │ ├── Tutorial_PerturbNet_coding_variants.ipynb │ ├── chemical_perturbation.nblink │ ├── coding_variant.nblink │ ├── feature_attribution.nblink │ └── genetic_perturbation.nblink │ └── usage.rst ├── notebooks ├── Benchmark_Jorge_Example.ipynb ├── Benchmark_LINCS_Example.ipynb ├── Benchmark_Norman_Example.ipynb ├── Benchmark_Sciplex_Example.ipynb ├── Benchmark_Ursu_Example.ipynb ├── GATA1_prediction_analysis.ipynb ├── Integrated_gradients_example.ipynb ├── Tutorial_PerturbNet_Chemicals.ipynb ├── Tutorial_PerturbNet_Genetic.ipynb └── Tutorial_PerturbNet_coding_variants.ipynb ├── perturbnet ├── __init__.py ├── baselines │ ├── __init__.py │ ├── knn.py │ └── linear.py ├── chemicalvae │ ├── README.md │ ├── __init__.py │ └── chemicalVAE.py ├── cinn │ ├── FeatureAttr.py │ ├── __init__.py │ ├── flow.py │ └── flow_generate.py ├── data_vae │ ├── __init__.py │ ├── util.py │ └── vae.py ├── genotypevae │ ├── README.md │ ├── __init__.py │ └── genotypeVAE.py ├── net2net │ ├── __init__.py │ ├── ckpt_util.py │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── coco.py │ │ ├── faces.py │ │ ├── utils.py │ │ └── zcodes.py │ ├── models │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── flows │ │ │ ├── __init__.py │ │ │ ├── flow.py │ │ │ ├── scviflow.py │ │ │ └── util.py │ └── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── util.cpython-37.pyc │ │ └── util.cpython-39.pyc │ │ ├── autoencoder │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── basic.cpython-37.pyc │ │ │ └── basic.cpython-39.pyc │ │ ├── basic.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── loss.py │ │ └── lpips.py │ │ ├── captions │ │ ├── __init__.py │ │ ├── model.py │ │ └── models.py │ │ ├── discriminator │ │ ├── __init__.py │ │ └── model.py │ │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── distributions.cpython-37.pyc │ │ │ └── distributions.cpython-39.pyc │ │ └── distributions.py │ │ ├── facenet │ │ ├── __init__.py │ │ ├── inception_resnet_v1.py │ │ └── model.py │ │ ├── flow │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── base.cpython-37.pyc │ │ │ ├── base.cpython-39.pyc │ │ │ ├── blocks.cpython-37.pyc │ │ │ ├── blocks.cpython-39.pyc │ │ │ ├── flatflow.cpython-37.pyc │ │ │ ├── flatflow.cpython-39.pyc │ │ │ ├── loss.cpython-37.pyc │ │ │ └── loss.cpython-39.pyc │ │ ├── base.py │ │ ├── blocks.py │ │ ├── flatflow.py │ │ └── loss.py │ │ ├── gan │ │ ├── __init__.py │ │ ├── bigbigan.py │ │ └── biggan.py │ │ ├── labels │ │ ├── __init__.py │ │ └── model.py │ │ ├── mlp │ │ ├── __init__.py │ │ └── models.py │ │ ├── sbert │ │ ├── __init__.py │ │ └── model.py │ │ └── util.py ├── pytorch_scvi │ ├── __init__.py │ ├── __pycache__ │ │ ├── distributions.cpython-37.pyc │ │ ├── distributions.cpython-39.pyc │ │ ├── scvi_generate_z.cpython-37.pyc │ │ └── scvi_generate_z.cpython-39.pyc │ ├── distributions.py │ └── scvi_generate_z.py └── util.py ├── readthedocs.yml ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude notebooks 2 | exclude example_data 3 | exclude pretrained_model 4 | include README.md 5 | include LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PerturbNet 2 | 3 | PerturbNet is a deep generative model that can predict the distribution of cell states induced by chemical or genetic perturbation. Currently, you can refer to the preprint [PerturbNet predicts single-cell responses to unseen chemical and genetic perturbations](https://www.biorxiv.org/content/10.1101/2022.07.20.500854v2). We will submit an updated version of the paper soon. 4 | 5 | 6 | 7 | 8 | ## System Requirements and Installation 9 | 10 | The current version of PerturbNet requires Python 3.7. All required dependencies are listed in ```requirements.txt```. We recommend creating a clean Conda environment using the following command: 11 | 12 | ``` 13 | conda create -n "PerturbNet" python=3.7 14 | ``` 15 | After setting up the environment, you can install the package by running: 16 | ``` 17 | conda activate PerturbNet 18 | pip install --upgrade PerturbNet 19 | ``` 20 | We used **cuDNN 8.7.0 (cudnn/11.7-v8.7.0)** and **CUDA 11.7.1** for model training. 21 | 22 | We also provide an updated version that removes the dependency on TensorFlow by using Python 3.10. To install: 23 | ``` 24 | conda create -n "PerturbNet" python=3.10 25 | conda activate PerturbNet 26 | pip install pip install PerturbNet==0.0.3b1 27 | ``` 28 | For reproducibility, we currently recommend using the stable version with Python 3.7. 29 | 30 | ## Core Repository Structure 31 | 32 | [`./perturbnet`](https://github.com/welch-lab/PerturbNet/tree/main/perturbnet) contains the core modules to train and benchmark the PerturbNet framework. 33 | 34 | [`./perturbnet/net2net`](https://github.com/welch-lab/PerturbNet/tree/main/net2net) contains the conditional invertible neural network (cINN) modules in the [GitHub](https://github.com/CompVis/net2net/tree/master/net2net) repository of [Network-to-Network Translation with Conditional Invertible Neural Networks](https://arxiv.org/abs/2005.13580). 35 | 36 | 37 | [`./perturbnet/pytorch_scvi`](https://github.com/welch-lab/PerturbNet/tree/main/pytorch_scvi) contains our adapted modules to decode latent representations to expression profiles based on scVI version 0.7.1. 38 | 39 | 40 | ## Tutorial and Reproducibility 41 | The [`./notebooks`] directory contains Jupyter notebooks demonstrating how to use **PerturbNet** and includes code to reproduce the results: 42 | * [Tutorial on using PerturbNet on chemical perturbations](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Tutorial_PerturbNet_Chemicals.ipynb) 43 | * [Tutorial on using PerturbNet on genetic perturbations](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Tutorial_PerturbNet_Genetic.ipynb) 44 | * [Tutorial on using PerturbNet on coding variants](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Tutorial_PerturbNet_coding_variants.ipynb) 45 | * [Tutorial on using integrated gradients to calculate feature scores for chemicals](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Integrated_gradients_example.ipynb) 46 | * [Benchmark on LINCS-Drug](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Benchmark_LINCS_Example.ipynb) 47 | * [Benchmark on sci-Plex](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Benchmark_Sciplex_Example.ipynb) 48 | * [Benchmark on Norman et al.](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Benchmark_Norman_Example.ipynb) 49 | * [Benchmark on Ursu et al.](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Benchmark_Ursu_Example.ipynb) 50 | * [Benchmark on Jorge et al.](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/Benchmark_Jorge_Example.ipynb) 51 | * [Analysis of predicted novel GATA1 mutations](https://github.com/welch-lab/PerturbNet/blob/main/notebooks/GATA1_prediction_analysis.ipynb) 52 | 53 | The required data, toy examples, and model weights can be downloaded from [Hugging Face](https://huggingface.co/cyclopeta/PerturbNet_reproduce/tree/main). 54 | 55 | 56 | 57 | ## Reference 58 | 59 | Please consider citing 60 | 61 | ``` 62 | @article {Yu2022.07.20.500854, 63 | author = {Yu, Hengshi and Welch, Joshua D}, 64 | title = {PerturbNet predicts single-cell responses to unseen chemical and genetic perturbations}, 65 | elocation-id = {2022.07.20.500854}, 66 | year = {2022}, 67 | doi = {10.1101/2022.07.20.500854}, 68 | publisher = {Cold Spring Harbor Laboratory}, 69 | URL = {https://www.biorxiv.org/content/early/2022/07/22/2022.07.20.500854}, 70 | eprint = {https://www.biorxiv.org/content/early/2022/07/22/2022.07.20.500854.full.pdf}, 71 | journal = {bioRxiv} 72 | } 73 | 74 | ``` 75 | We appreciate your interest in our work. 76 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static /custom.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 90% !important; 3 | } 4 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | setuptools_scm 3 | typing_extensions 4 | importlib_metadata 5 | ipykernel 6 | Sphinx>=4.0 7 | sphinx_rtd_theme>=0.5 8 | sphinx_autodoc_typehints>=1.12.0 9 | sphinx_autodoc_defaultargs==0.1.2 10 | nbsphinx>=0.7 11 | nbsphinx-link 12 | protobuf<=3.20 13 | -------------------------------------------------------------------------------- /docs/source/GATA1_example.rst: -------------------------------------------------------------------------------- 1 | GATA1 Example 2 | ============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorials/GATA1_example -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.baselines.knn.rst: -------------------------------------------------------------------------------- 1 | perturbnet.baselines.knn 2 | ======================== 3 | 4 | .. automodule:: perturbnet.baselines.knn 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | samplefromNeighbors 21 | samplefromNeighborsCINN 22 | samplefromNeighborsGenotype 23 | samplefromNeighborsGenotypeCINN 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.baselines.linear.rst: -------------------------------------------------------------------------------- 1 | perturbnet.baselines.linear 2 | =========================== 3 | 4 | .. automodule:: perturbnet.baselines.linear 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | solve_y_axb 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.baselines.rst: -------------------------------------------------------------------------------- 1 | perturbnet.baselines 2 | ==================== 3 | 4 | .. automodule:: perturbnet.baselines 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | perturbnet.baselines.knn 31 | perturbnet.baselines.linear 32 | 33 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.chemicalvae.chemicalVAE.rst: -------------------------------------------------------------------------------- 1 | perturbnet.chemicalvae.chemicalVAE 2 | ================================== 3 | 4 | .. automodule:: perturbnet.chemicalvae.chemicalVAE 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | vae_loss 17 | 18 | 19 | 20 | 21 | 22 | .. rubric:: Classes 23 | 24 | .. autosummary:: 25 | 26 | ChemicalVAE 27 | ChemicalVAEFineTuneZLZ 28 | ChemicalVAETrain 29 | ConcatDatasetWithIndices 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.chemicalvae.rst: -------------------------------------------------------------------------------- 1 | perturbnet.chemicalvae 2 | ====================== 3 | 4 | .. automodule:: perturbnet.chemicalvae 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | perturbnet.chemicalvae.chemicalVAE 31 | 32 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.cinn.FeatureAttr.rst: -------------------------------------------------------------------------------- 1 | perturbnet.cinn.FeatureAttr 2 | =========================== 3 | 4 | .. automodule:: perturbnet.cinn.FeatureAttr 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | ig_b_score_compute 17 | ig_y_score_compute 18 | plot_molecule_attribtuion_score 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Classes 25 | 26 | .. autosummary:: 27 | 28 | BinaryCellStatesClass 29 | CellDataset 30 | FlowResizeLabelClass 31 | FlowResizeYLabelClass 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.cinn.flow.rst: -------------------------------------------------------------------------------- 1 | perturbnet.cinn.flow 2 | ==================== 3 | 4 | .. automodule:: perturbnet.cinn.flow 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | ConcatDatasetWithIndices 21 | Net2NetFlow_TFVAEFixFlow 22 | Net2NetFlow_TFVAEFlow 23 | Net2NetFlow_TFVAENonStdFlow 24 | Net2NetFlow_TFVAE_Covariate_Flow 25 | Net2NetFlow_scVIChemFlow 26 | Net2NetFlow_scVIChemStdFlow 27 | Net2NetFlow_scVIChemStdStatesFlow 28 | Net2NetFlow_scVIFixFlow 29 | Net2NetFlow_scVIFix_Covariate_Flow 30 | Net2NetFlow_scVIGenoFlow 31 | Net2NetFlow_scVIGenoFlow_GIlayer 32 | Net2NetFlow_scVIGenoPerLibFlow 33 | Net2NetFlow_scVIGenoStatesFlow 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.cinn.flow_generate.rst: -------------------------------------------------------------------------------- 1 | perturbnet.cinn.flow\_generate 2 | ============================== 3 | 4 | .. automodule:: perturbnet.cinn.flow_generate 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | SCVIZ_CheckNet2Net 21 | TFVAEZ_CheckNet2Net 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.cinn.rst: -------------------------------------------------------------------------------- 1 | perturbnet.cinn 2 | =============== 3 | 4 | .. automodule:: perturbnet.cinn 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | perturbnet.cinn.FeatureAttr 31 | perturbnet.cinn.flow 32 | perturbnet.cinn.flow_generate 33 | 34 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.genotypevae.genotypeVAE.rst: -------------------------------------------------------------------------------- 1 | perturbnet.genotypevae.genotypeVAE 2 | ================================== 3 | 4 | .. automodule:: perturbnet.genotypevae.genotypeVAE 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | vae_loss 17 | 18 | 19 | 20 | 21 | 22 | .. rubric:: Classes 23 | 24 | .. autosummary:: 25 | 26 | ConcatDatasetWithIndices 27 | GenotypeVAE 28 | GenotypeVAETrain 29 | GenotypeVAEZLZFineTune 30 | GenotypeVAE_Customize 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.genotypevae.rst: -------------------------------------------------------------------------------- 1 | perturbnet.genotypevae 2 | ====================== 3 | 4 | .. automodule:: perturbnet.genotypevae 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | perturbnet.genotypevae.genotypeVAE 31 | 32 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.pytorch_scvi.distributions.rst: -------------------------------------------------------------------------------- 1 | perturbnet.pytorch\_scvi.distributions 2 | ====================================== 3 | 4 | .. automodule:: perturbnet.pytorch_scvi.distributions 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | log_mixture_nb 17 | log_nb_positive 18 | log_zinb_positive 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Classes 25 | 26 | .. autosummary:: 27 | 28 | NegativeBinomial 29 | NegativeBinomialMixture 30 | ZeroInflatedNegativeBinomial 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.pytorch_scvi.rst: -------------------------------------------------------------------------------- 1 | perturbnet.pytorch\_scvi 2 | ======================== 3 | 4 | .. automodule:: perturbnet.pytorch_scvi 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | perturbnet.pytorch_scvi.distributions 31 | perturbnet.pytorch_scvi.scvi_generate_z 32 | 33 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.pytorch_scvi.scvi_generate_z.rst: -------------------------------------------------------------------------------- 1 | perturbnet.pytorch\_scvi.scvi\_generate\_z 2 | ========================================== 3 | 4 | .. automodule:: perturbnet.pytorch_scvi.scvi_generate_z 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | ConcatDataset 21 | scvi_predictive_z 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/source/_autosummary/perturbnet.util.rst: -------------------------------------------------------------------------------- 1 | perturbnet.util 2 | =============== 3 | 4 | .. automodule:: perturbnet.util 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | Seq_to_Embed_ESM 17 | boxplot_metrics 18 | contourplot_space_mapping 19 | create_train_test_splits_by_key 20 | pad_smile 21 | prepare_embeddings_cinn 22 | smiles_to_hot 23 | umapPlot_latent_check 24 | 25 | 26 | 27 | 28 | 29 | .. rubric:: Classes 30 | 31 | .. autosummary:: 32 | 33 | ConcatDatasetWithIndices 34 | NormalizedRevisionRSquare 35 | Standardize 36 | StandardizeLoad 37 | fidscore 38 | fidscore_scgen_extend 39 | fidscore_scvi_extend 40 | fidscore_vae_extend 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /docs/source/_static /custom.css: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | ============= 3 | 4 | Cellular Representation Networks 5 | --------------------------------- 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | 10 | perturbnet.data_vae.vae.VAE 11 | 12 | Perturbation Representation Networks 13 | ------------------------------------- 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | 18 | perturbnet.chemicalvae.chemicalVAE.ChemicalVAE 19 | perturbnet.genotypevae.genotypeVAE.GenotypeVAE 20 | 21 | cINNs 22 | ------- 23 | .. autosummary:: 24 | :toctree: _autosummary 25 | 26 | perturbnet.cinn.flow.Net2NetFlow_TFVAEFlow 27 | perturbnet.cinn.flow.Net2NetFlow_TFVAE_Covariate_Flow 28 | perturbnet.cinn.flow.Net2NetFlow_scVIGenoFlow 29 | perturbnet.cinn.flow.Net2NetFlow_scVIFixFlow 30 | 31 | Final Generative Models 32 | --------------------------------- 33 | .. autosummary:: 34 | :toctree: _autosummary 35 | 36 | perturbnet.cinn.flow_generate.SCVIZ_CheckNet2Net 37 | perturbnet.cinn.flow_generate.TFVAEZ_CheckNet2Net 38 | 39 | Feature Attribution 40 | -------------------- 41 | .. autosummary:: 42 | :toctree: _autosummary 43 | 44 | 45 | perturbnet.cinn.FeatureAttr 46 | 47 | Tools & Plot 48 | ------------- 49 | .. autosummary:: 50 | :toctree: _autosummary 51 | 52 | 53 | perturbnet.util.create_train_test_splits_by_key 54 | perturbnet.util.prepare_embeddings_cinn 55 | perturbnet.util.smiles_to_hot 56 | perturbnet.util.contourplot_space_mapping 57 | perturbnet.util.Seq_to_Embed_ESM 58 | 59 | -------------------------------------------------------------------------------- /docs/source/chemical_perturbation.rst: -------------------------------------------------------------------------------- 1 | Chemical Perturbation 2 | ===================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorials/chemical_perturbation -------------------------------------------------------------------------------- /docs/source/coding_variant.rst: -------------------------------------------------------------------------------- 1 | Coding Variant 2 | ============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorials/coding_variant -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from datetime import datetime 5 | from pathlib import Path 6 | from sphinx.ext import autosummary 7 | 8 | import matplotlib 9 | matplotlib.use('agg') 10 | 11 | HERE = Path(__file__).parent 12 | sys.path.insert(0, str(HERE.parent)) 13 | sys.path.insert(0, os.path.abspath('../..')) 14 | import perturbnet 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 19 | 20 | project = 'perturbnet' 21 | copyright = '2025, Hengshi Yu, Weizhou Qian, Yuxuan Song, Joshua Welch' 22 | author = 'Hengshi Yu, Weizhou Qian, Yuxuan Song, Joshua Welch' 23 | release = '0.0.3' 24 | 25 | 26 | 27 | # -- General configuration 28 | 29 | extensions = [ 30 | 'sphinx.ext.doctest', 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.autosummary', 33 | 'sphinx.ext.intersphinx', 34 | 'sphinx.ext.napoleon', 35 | 'sphinx.ext.githubpages', 36 | 'nbsphinx', 37 | 'nbsphinx_link' 38 | ] 39 | 40 | autosummary_generate = True 41 | # Napoleon settings 42 | napoleon_google_docstring = False 43 | napoleon_numpy_docstring = True 44 | napoleon_include_init_with_doc = False 45 | napoleon_include_private_with_doc = False 46 | napoleon_include_special_with_doc = True 47 | napoleon_use_admonition_for_examples = False 48 | napoleon_use_admonition_for_notes = False 49 | napoleon_use_admonition_for_references = False 50 | napoleon_use_ivar = False 51 | napoleon_use_param = True 52 | napoleon_use_rtype = False 53 | napoleon_preprocess_types = False 54 | napoleon_type_aliases = None 55 | napoleon_attr_annotations = True 56 | 57 | html_theme = "sphinx_rtd_theme" 58 | html_static_path = ['_static'] 59 | html_css_files = [ 60 | 'custom.css', 61 | ] 62 | 63 | intersphinx_mapping = { 64 | 'python': ('https://docs.python.org/3/', None), 65 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 66 | } 67 | intersphinx_disabled_domains = ['std'] 68 | 69 | templates_path = ['_templates'] 70 | 71 | # -- Options for HTML output 72 | 73 | html_theme = 'sphinx_rtd_theme' 74 | 75 | # -- Options for EPUB output 76 | epub_show_urls = 'footnote' 77 | 78 | suppress_warnings = ["config.cache"] -------------------------------------------------------------------------------- /docs/source/feature_attribution.rst: -------------------------------------------------------------------------------- 1 | Feature Attribution 2 | =================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorials/feature_attribution -------------------------------------------------------------------------------- /docs/source/genetic_perturbation.rst: -------------------------------------------------------------------------------- 1 | Genetic Perturbation 2 | ==================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorials/genetic_perturbation -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | |GitHubStars| |PyPI| |PyPIDownloads| 2 | 3 | PerturbNet 4 | ======================== 5 | 6 | PerturbNet is a deep generative framework designed to model and predict shifts in cell state—defined as changes in overall gene expression—in response to diverse cellular perturbations. PerturbNet consists of three trainable components: a perturbation representation network, a cellular representation network, and a conditional normalizing flow. These components work together to embed perturbations and cell states into a latent spaces and to learn a flexible mapping from perturbation features to gene expression distributions. 7 | 8 | Given a perturbation of interest—such as gene knockdown, gene overexpression, sequence mutation, or drug treatment—PerturbNet predicts the resulting distribution of single-cell gene expression states. Currently, you can refer to the preprint `PerturbNet predicts single-cell responses to unseen chemical and genetic perturbations `_. We will submit an updated version of the paper soon. 9 | 10 | .. toctree:: 11 | :caption: Main 12 | :maxdepth: 1 13 | :hidden: 14 | 15 | usage 16 | api 17 | 18 | .. toctree:: 19 | :caption: Tutorial 20 | :maxdepth: 1 21 | :hidden: 22 | 23 | chemical_perturbation 24 | genetic_perturbation 25 | coding_variant 26 | feature_attribution 27 | GATA1_example 28 | 29 | 30 | 31 | 32 | .. |GitHubStars| image:: https://img.shields.io/github/stars/welch-lab/PerturbNet?logo=GitHub&color=yellow 33 | :target: https://github.com/welch-lab/PerturbNet/stargazers 34 | 35 | .. |PyPI| image:: https://img.shields.io/pypi/v/perturbnet?logo=PyPI 36 | :target: https://pypi.org/project/perturbnet/ 37 | 38 | .. |PyPIDownloads| image:: https://pepy.tech/badge/perturbnet 39 | :target: https://pepy.tech/project/perturbnet 40 | -------------------------------------------------------------------------------- /docs/source/tutorials/GATA1_example.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "./GATA1_prediction_analysis.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorials/chemical_perturbation.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "./Tutorial_PerturbNet_Chemicals.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorials/coding_variant.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "./Tutorial_PerturbNet_coding_variants.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorials/feature_attribution.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "./Integrated_gradients_example.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/tutorials/genetic_perturbation.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "./Tutorial_PerturbNet_Genetic.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | .. _installation: 5 | 6 | Installation 7 | ------------ 8 | 9 | The current version of PerturbNet requires Python 3.7. We recommend creating a clean Conda environment using the following command: 10 | 11 | .. code-block:: console 12 | 13 | $ conda create -n "PerturbNet" python=3.7 14 | 15 | After setting up the environment, you can install the package by running: 16 | 17 | .. code-block:: console 18 | 19 | $ conda activate PerturbNet 20 | $ pip install --upgrade PerturbNet 21 | 22 | 23 | We used cuDNN 8.7.0 (cudnn/11.7-v8.7.0) and CUDA 11.7.1 for model training. 24 | 25 | We also provide an updated version that removes the dependency on TensorFlow by using Python 3.10. To install: 26 | 27 | .. code-block:: console 28 | 29 | $ conda create -n "PerturbNet" python=3.10 30 | $ conda activate PerturbNet 31 | $pip install pip install PerturbNet==0.0.3b1 32 | 33 | 34 | 35 | Data and Model Availability 36 | --------------------------- 37 | 38 | The required data, toy examples, and model weights can be downloaded from `Hugging Face `_. 39 | 40 | 41 | -------------------------------------------------------------------------------- /perturbnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/__init__.py -------------------------------------------------------------------------------- /perturbnet/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/baselines/__init__.py -------------------------------------------------------------------------------- /perturbnet/baselines/knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.decomposition import PCA 6 | 7 | import umap 8 | from plotnine import * 9 | 10 | import matplotlib 11 | matplotlib.use('agg') 12 | import matplotlib.pyplot as plt 13 | 14 | class samplefromNeighbors: 15 | """ 16 | KNN sampling module 17 | """ 18 | def __init__(self, distances, other_trts): 19 | super().__init__() 20 | self.distances = distances 21 | self.other_trts = other_trts 22 | self.softmax() 23 | self.pca_50 = PCA(n_components=50, random_state = 42) 24 | 25 | def softmax(self): 26 | prob_array = np.exp(-self.distances) 27 | prob_sum = prob_array.sum() 28 | prob_array /= prob_sum 29 | self.prob_array = prob_array 30 | 31 | def samplingCTrt(self, list_trt, cell_type, list_c_trt, n_sample = 300): 32 | 33 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace = True, p = self.prob_array.squeeze()) 34 | trt_sample_array = np.array(list_trt)[needed_trts] 35 | trt_sample_pd = pd.Series(trt_sample_array).value_counts() 36 | 37 | idx_sample = None 38 | for t in range(len(list(trt_sample_pd.keys()))): 39 | trt_sample = list(trt_sample_pd.keys())[t] 40 | trt_sample_count = trt_sample_pd.values[t] 41 | ctrt_sample = cell_type + "_" + trt_sample 42 | 43 | idx_sample_type = [i for i in range(len(list_c_trt)) if list_c_trt[i] == ctrt_sample] 44 | idx_ctrt_sample = np.random.choice(idx_sample_type, trt_sample_count, replace = True) 45 | 46 | if idx_sample is None: 47 | idx_sample = idx_ctrt_sample 48 | else: 49 | idx_sample = np.append(idx_sample, idx_ctrt_sample) 50 | 51 | return idx_sample 52 | 53 | def samplingTrt(self, list_trt, list_data_trt, n_sample = 300): 54 | 55 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace = True, p = self.prob_array.squeeze()) 56 | trt_sample_array = np.array(list_trt)[needed_trts] 57 | trt_sample_pd = pd.Series(trt_sample_array).value_counts() 58 | 59 | idx_sample = None 60 | for t in range(len(list(trt_sample_pd.keys()))): 61 | trt_sample = list(trt_sample_pd.keys())[t] 62 | trt_sample_count = trt_sample_pd.values[t] 63 | ctrt_sample = trt_sample 64 | 65 | idx_sample_type = [i for i in range(len(list_data_trt)) if list_data_trt[i] == ctrt_sample] 66 | idx_ctrt_sample = np.random.choice(idx_sample_type, trt_sample_count, replace = True) 67 | 68 | if idx_sample is None: 69 | idx_sample = idx_ctrt_sample 70 | else: 71 | idx_sample = np.append(idx_sample, idx_ctrt_sample) 72 | 73 | return idx_sample 74 | 75 | def PlotUMAP(self, real_data, fake_data, path_file_save): 76 | 77 | all_data = np.concatenate([fake_data, real_data], axis = 0) 78 | pca_all = self.pca_50.fit(real_data).transform(all_data) 79 | pca_result_real = pca_all[fake_data.shape[0]:] 80 | 81 | cat_t = ["1-Real"] * real_data.shape[0] 82 | cat_g = ["2-KNN-Sampled"] * fake_data.shape[0] 83 | cat_rf_gt = np.append(cat_g, cat_t) 84 | 85 | trans = umap.UMAP(random_state=42, min_dist = 0.5, n_neighbors=30).fit(pca_result_real) 86 | 87 | X_embedded_pr = trans.transform(pca_all) 88 | df_tsne_pr = X_embedded_pr.copy() 89 | df_tsne_pr = pd.DataFrame(df_tsne_pr) 90 | df_tsne_pr['x-umap'] = X_embedded_pr[:,0] 91 | df_tsne_pr['y-umap'] = X_embedded_pr[:,1] 92 | df_tsne_pr['category'] = cat_rf_gt 93 | 94 | chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ 95 | + geom_point(size=0.5, alpha = 0.8) \ 96 | + ggtitle("UMAP dimensions") 97 | chart_pr.save(path_file_save, width=12, height=8, dpi=144) 98 | 99 | 100 | class samplefromNeighborsCINN: 101 | """ 102 | KNN sampling module within a constrained list 103 | """ 104 | def __init__(self, distances, other_trts): 105 | super().__init__() 106 | self.distances = distances 107 | self.other_trts = other_trts 108 | self.softmax() 109 | self.pca_50 = PCA(n_components=50, random_state=42) 110 | 111 | def softmax(self): 112 | prob_array = np.exp(-self.distances) 113 | prob_sum = prob_array.sum() 114 | prob_array /= prob_sum 115 | self.prob_array = prob_array 116 | 117 | def samplingTrt(self, data_onehot, n_sample=300): 118 | 119 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace=True, p=self.prob_array.squeeze()) 120 | onehot_data = data_onehot[needed_trts] 121 | 122 | return onehot_data 123 | 124 | def samplingTrtList(self, data_onehot, list_trt, list_data_trt, n_sample=300): 125 | list_trt_pd = pd.Series(list_trt) 126 | list_data_trt_pd = pd.Series(list_data_trt) 127 | 128 | indices = list_data_trt_pd.map(lambda x: np.where(list_trt_pd == x)[0][0]).tolist() 129 | onehot_data = data_onehot[indices] 130 | 131 | return onehot_data 132 | 133 | 134 | 135 | 136 | 137 | class samplefromNeighborsGenotype: 138 | """ 139 | KNN sampling module for genetic perturbations, especially for these with multiple target genes 140 | """ 141 | def __init__(self, distances, other_trts): 142 | super().__init__() 143 | self.distances = distances 144 | self.other_trts = other_trts 145 | self.softmax() 146 | self.pca_50 = PCA(n_components=50, random_state = 42) 147 | 148 | 149 | def softmax(self): 150 | prob_array = np.exp(-self.distances) 151 | prob_sum = prob_array.sum() 152 | prob_array /= prob_sum 153 | self.prob_array = prob_array 154 | 155 | def samplingCTrt(self, list_trt, cell_type, list_c_trt, n_sample = 300): 156 | 157 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace = True, p = self.prob_array.squeeze()) 158 | trt_sample_array = np.array(list_trt)[needed_trts] 159 | trt_sample_pd = pd.Series(trt_sample_array).value_counts() 160 | 161 | idx_sample = None 162 | for t in range(len(list(trt_sample_pd.keys()))): 163 | trt_sample = list(trt_sample_pd.keys())[t] 164 | trt_sample_count = trt_sample_pd.values[t] 165 | ctrt_sample = cell_type + "_" + trt_sample 166 | 167 | idx_sample_type = [i for i in range(len(list_c_trt)) if list_c_trt[i] == ctrt_sample] 168 | idx_ctrt_sample = np.random.choice(idx_sample_type, trt_sample_count, replace = True) 169 | 170 | if idx_sample is None: 171 | idx_sample = idx_ctrt_sample 172 | else: 173 | idx_sample = np.append(idx_sample, idx_ctrt_sample) 174 | 175 | return idx_sample 176 | 177 | def samplingTrt(self, list_trt, list_data_trt, n_sample = 300): 178 | 179 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace = True, p = self.prob_array.squeeze()) 180 | trt_sample_array = np.array(list_trt)[needed_trts] 181 | trt_sample_pd = pd.Series(trt_sample_array).value_counts() 182 | 183 | idx_sample = None 184 | for t in range(len(list(trt_sample_pd.keys()))): 185 | trt_sample = list(trt_sample_pd.keys())[t] 186 | trt_sample_count = trt_sample_pd.values[t] 187 | ctrt_sample = trt_sample 188 | trt_sample1, trt_sample2 = trt_sample.split('/') 189 | ctrt_sample_other = trt_sample2 + '/' + trt_sample1 190 | 191 | idx_sample_type = [i for i in range(len(list_data_trt)) if list_data_trt[i] in [ctrt_sample, ctrt_sample_other]] 192 | idx_ctrt_sample = np.random.choice(idx_sample_type, trt_sample_count, replace = True) 193 | 194 | if idx_sample is None: 195 | idx_sample = idx_ctrt_sample 196 | else: 197 | idx_sample = np.append(idx_sample, idx_ctrt_sample) 198 | 199 | return idx_sample 200 | 201 | def PlotUMAP(self, real_data, fake_data, path_file_save): 202 | 203 | all_data = np.concatenate([fake_data, real_data], axis = 0) 204 | pca_all = self.pca_50.fit(real_data).transform(all_data) 205 | pca_result_real = pca_all[fake_data.shape[0]:] 206 | 207 | cat_t = ["1-Real"] * real_data.shape[0] 208 | cat_g = ["2-KNN-Sampled"] * fake_data.shape[0] 209 | cat_rf_gt = np.append(cat_g, cat_t) 210 | 211 | trans = umap.UMAP(random_state=42, min_dist = 0.5, n_neighbors=30).fit(pca_result_real) 212 | 213 | X_embedded_pr = trans.transform(pca_all) 214 | df_tsne_pr = X_embedded_pr.copy() 215 | df_tsne_pr = pd.DataFrame(df_tsne_pr) 216 | df_tsne_pr['x-umap'] = X_embedded_pr[:,0] 217 | df_tsne_pr['y-umap'] = X_embedded_pr[:,1] 218 | df_tsne_pr['category'] = cat_rf_gt 219 | 220 | chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ 221 | + geom_point(size=0.5, alpha = 0.8) \ 222 | + ggtitle("UMAP dimensions") 223 | chart_pr.save(path_file_save, width=12, height=8, dpi=144) 224 | 225 | 226 | class samplefromNeighborsGenotypeCINN: 227 | """ 228 | KNN sampling module within a constrained list for genetic perturbations, 229 | especially for these with multiple target genes 230 | """ 231 | def __init__(self, distances, other_trts): 232 | super().__init__() 233 | self.distances = distances 234 | self.other_trts = other_trts 235 | self.softmax() 236 | self.pca_50 = PCA(n_components=50, random_state=42) 237 | 238 | def softmax(self): 239 | prob_array = np.exp(-self.distances) 240 | prob_sum = prob_array.sum() 241 | prob_array /= prob_sum 242 | self.prob_array = prob_array 243 | 244 | 245 | def samplingTrt(self, data_onehot, n_sample=300): 246 | 247 | needed_trts = np.random.choice(self.other_trts.squeeze(), n_sample, replace=True, p=self.prob_array.squeeze()) 248 | onehot_data = data_onehot[needed_trts] 249 | 250 | return onehot_data 251 | 252 | def samplingTrtList(self, data_onehot, perturbToOnehotLib, list_data_trt, n_sample=300): 253 | list_data_trt_pd = pd.Series(list_data_trt) 254 | 255 | indices = list_data_trt_pd.map(lambda x: perturbToOnehotLib[x]).tolist() 256 | onehot_data = data_onehot[indices] 257 | 258 | return onehot_data 259 | -------------------------------------------------------------------------------- /perturbnet/baselines/linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | 6 | 7 | def solve_y_axb(Y, A=None, B=None, A_ridge=0.1, B_ridge=0.1): 8 | if not isinstance(Y, np.ndarray): 9 | raise ValueError("Y must be a numpy array or matrix.") 10 | 11 | if A is not None and not isinstance(A, np.ndarray): 12 | raise ValueError("A must be None or a numpy array.") 13 | if B is not None and not isinstance(B, np.ndarray): 14 | raise ValueError("B must be None or a numpy array.") 15 | 16 | center = np.mean(Y, axis=1, keepdims=True) 17 | Y = Y - center 18 | 19 | if A is not None and B is not None: 20 | if Y.shape[0] != A.shape[0]: 21 | raise ValueError("Number of rows of Y must be equal to number of rows of A.") 22 | if Y.shape[1] != B.shape[1]: 23 | raise ValueError("Number of columns of Y must be equal to number of columns of B.") 24 | 25 | tmp = np.linalg.inv(A.T @ A + np.eye(A.shape[1]) * A_ridge) @ A.T @ Y @ B.T @ np.linalg.inv(B @ B.T + np.eye(B.shape[0]) * B_ridge) 26 | 27 | elif B is None: 28 | tmp = np.linalg.inv(A.T @ A + np.eye(A.shape[1]) * A_ridge) @ A.T @ Y 29 | 30 | elif A is None: 31 | tmp = Y @ B.T @ np.linalg.inv(B @ B.T + np.eye(B.shape[0]) * B_ridge) 32 | 33 | else: 34 | raise ValueError("Either A or B must be non-null") 35 | tmp = np.nan_to_num(tmp) 36 | 37 | return {"K": tmp, "center": center} -------------------------------------------------------------------------------- /perturbnet/chemicalvae/README.md: -------------------------------------------------------------------------------- 1 | # List of Files 2 | 3 | - `chemicalVAE.py` has the modules of training ChemicalVAE and fine-tuning it with the regularization term on the Laplacian matrix L; 4 | 5 | - `train_LINCSDrug.py` has the implementation of training ChemicalVAE with epoch defined from LINCS-Drug; 6 | 7 | - `generateKNNFIDMatrix.py` has the implementation of computing the Laplacian matrix of perturbations; 8 | 9 | - `fineTuneZLZ_LINCSDrug.py` has the implementation of fine-tuning ChemicalVAE with the Laplacian matrix; 10 | -------------------------------------------------------------------------------- /perturbnet/chemicalvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/chemicalvae/__init__.py -------------------------------------------------------------------------------- /perturbnet/cinn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/cinn/__init__.py -------------------------------------------------------------------------------- /perturbnet/cinn/flow_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import torch 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class SCVIZ_CheckNet2Net: 11 | """Class to use PerturbNet to predict cellular representations and 12 | count responses (scVI) from perturbation representations 13 | """ 14 | def __init__(self, model, device, scvi_model_decode): 15 | super().__init__() 16 | self.model = model 17 | self.device = device 18 | self.scvi_model_decode = scvi_model_decode 19 | 20 | @torch.no_grad() 21 | def sample_data(self, condition_data, library_data, batch_size = 50): 22 | #torch.manual_seed(2021) 23 | sampled_z = self.model.sample_conditional(torch.tensor(condition_data)\ 24 | .float().to(self.device)).squeeze(-1).squeeze(-1).cpu().detach().numpy() 25 | 26 | sampled_data = self.scvi_model_decode.posterior_predictive_sample_from_Z(sampled_z, library_data, batch_size = batch_size) 27 | 28 | return sampled_z, sampled_data 29 | 30 | @torch.no_grad() 31 | def recon_v(self, latent, condition): 32 | recon_val, _ = self.model.flow(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 33 | torch.tensor(condition).float().to(self.device)) 34 | return recon_val.squeeze(-1).squeeze(-1).cpu().detach().numpy() 35 | 36 | @torch.no_grad() 37 | def trans_data(self, latent, condition, condition_new, library_data, batch_size = 50): 38 | trans_z = self.model.generate_zprime(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 39 | torch.tensor(condition).float().to(self.device), 40 | torch.tensor(condition_new).float().to(self.device) 41 | ).squeeze(-1).squeeze(-1).cpu().detach().numpy() 42 | 43 | trans_data = self.scvi_model_decode.posterior_predictive_sample_from_Z(trans_z, library_data, batch_size = batch_size) 44 | 45 | return trans_z, trans_data 46 | 47 | @torch.no_grad() 48 | def recon_data(self, latent, condition, library_data, batch_size = 50): 49 | 50 | rec_z = self.model.generate_zrec(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 51 | torch.tensor(condition).float().to(self.device))\ 52 | .squeeze(-1).squeeze(-1).cpu().detach().numpy() 53 | 54 | rec_data = self.scvi_model_decode.posterior_predictive_sample_from_Z(rec_z, library_data, batch_size = batch_size) 55 | 56 | return rec_z, rec_data 57 | 58 | @torch.no_grad() 59 | def recon_data_with_y(self, latent, condition, library_data, y_data, batch_size = 50): 60 | 61 | rec_z = self.model.generate_zrec(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 62 | torch.tensor(condition).float().to(self.device))\ 63 | .squeeze(-1).squeeze(-1).cpu().detach().numpy() 64 | 65 | rec_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_y(rec_z, library_data, y_data, batch_size = batch_size) 66 | 67 | return rec_z, rec_data 68 | 69 | @torch.no_grad() 70 | def sample_data_with_y(self, condition_data, library_data, y_data, batch_size = 50): 71 | sampled_z = self.model.sample_conditional(torch.tensor(condition_data)\ 72 | .float().to(self.device)).squeeze(-1).squeeze(-1).cpu().detach().numpy() 73 | 74 | sampled_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_y(sampled_z, library_data, y_data, batch_size = batch_size) 75 | 76 | return sampled_z, sampled_data 77 | 78 | @torch.no_grad() 79 | def trans_data_with_y(self, latent, condition, condition_new, library_data, y_data, batch_size = 50): 80 | trans_z = self.model.generate_zprime(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 81 | torch.tensor(condition).float().to(self.device), 82 | torch.tensor(condition_new).float().to(self.device) 83 | ).squeeze(-1).squeeze(-1).cpu().detach().numpy() 84 | 85 | trans_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_y(trans_z, library_data, y_data, batch_size = batch_size) 86 | 87 | return trans_z, trans_data 88 | 89 | 90 | @torch.no_grad() 91 | def recon_data_with_batch(self, latent, condition, library_data, batch_data, batch_size = 50): 92 | 93 | rec_z = self.model.generate_zrec(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 94 | torch.tensor(condition).float().to(self.device))\ 95 | .squeeze(-1).squeeze(-1).cpu().detach().numpy() 96 | 97 | rec_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_batch(rec_z, library_data, batch_data, batch_size = batch_size) 98 | 99 | return rec_z, rec_data 100 | 101 | @torch.no_grad() 102 | def sample_data_with_batch(self, condition_data, library_data, batch_data, batch_size = 50): 103 | sampled_z = self.model.sample_conditional(torch.tensor(condition_data)\ 104 | .float().to(self.device)).squeeze(-1).squeeze(-1).cpu().detach().numpy() 105 | 106 | sampled_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_batch(sampled_z, library_data, batch_data, batch_size = batch_size) 107 | 108 | return sampled_z, sampled_data 109 | 110 | @torch.no_grad() 111 | def trans_data_with_batch(self, latent, condition, condition_new, library_data, batch_data, batch_size = 50): 112 | trans_z = self.model.generate_zprime(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 113 | torch.tensor(condition).float().to(self.device), 114 | torch.tensor(condition_new).float().to(self.device) 115 | ).squeeze(-1).squeeze(-1).cpu().detach().numpy() 116 | 117 | trans_data = self.scvi_model_decode.posterior_predictive_sample_from_Z_with_batch(trans_z, library_data, batch_data, batch_size = batch_size) 118 | 119 | return trans_z, trans_data 120 | 121 | 122 | 123 | class TFVAEZ_CheckNet2Net: 124 | """Class to use PerturbNet to predict cellular representations and 125 | normalized responses (VAE) from perturbation representations 126 | """ 127 | def __init__(self, model, device, tf_sess, tf_x_rec_data, tf_z_latent, is_training): 128 | super().__init__() 129 | self.model = model 130 | self.device = device 131 | self.tf_sess = tf_sess 132 | self.tf_x_rec_data = tf_x_rec_data 133 | self.tf_z_latent = tf_z_latent 134 | self.is_training = is_training 135 | 136 | @torch.no_grad() 137 | def sample_data(self, condition_data): 138 | 139 | sampled_z = self.model.sample_conditional(torch.tensor(condition_data)\ 140 | .float().to(self.device)).squeeze(-1).squeeze(-1).cpu().detach().numpy() 141 | 142 | feed_dict = {self.tf_z_latent: sampled_z, self.is_training: False} 143 | sampled_data = self.tf_sess.run(self.tf_x_rec_data, feed_dict = feed_dict) 144 | 145 | return sampled_z, sampled_data 146 | 147 | @torch.no_grad() 148 | def recon_v(self, latent, condition): 149 | recon_val, _ = self.model.flow(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 150 | torch.tensor(condition).float().to(self.device)) 151 | return recon_val.squeeze(-1).squeeze(-1).cpu().detach().numpy() 152 | 153 | @torch.no_grad() 154 | def trans_data(self, latent, condition, condition_new): 155 | trans_z = self.model.generate_zprime(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 156 | torch.tensor(condition).float().to(self.device), 157 | torch.tensor(condition_new).float().to(self.device) 158 | ).squeeze(-1).squeeze(-1).cpu().detach().numpy() 159 | 160 | feed_dict = {self.tf_z_latent: trans_z, self.is_training: False} 161 | trans_data = self.tf_sess.run(self.tf_x_rec_data, feed_dict = feed_dict) 162 | 163 | return trans_z, trans_data 164 | 165 | @torch.no_grad() 166 | def recon_data(self, latent, condition): 167 | 168 | rec_z = self.model.generate_zrec(torch.tensor(latent).float().to(self.device).unsqueeze(-1).unsqueeze(-1), 169 | torch.tensor(condition).float().to(self.device))\ 170 | .squeeze(-1).squeeze(-1).cpu().detach().numpy() 171 | 172 | feed_dict = {self.tf_z_latent: rec_z, self.is_training: False} 173 | rec_data = self.tf_sess.run(self.tf_x_rec_data, feed_dict = feed_dict) 174 | 175 | return rec_z, rec_data 176 | 177 | 178 | -------------------------------------------------------------------------------- /perturbnet/data_vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/data_vae/__init__.py -------------------------------------------------------------------------------- /perturbnet/genotypevae/README.md: -------------------------------------------------------------------------------- 1 | # List of Files 2 | 3 | - `genotypeVAE.py` has the modules of training GenotypeVAE and fine-tuning it with the regularization term on the Laplacian matrix L; 4 | 5 | - `train_LINCSGene.py` has the implementation of training GenotypeVAE with epoch defined from LINCS-Gene; 6 | 7 | - `train_ZINC.py` has the implementation of training GenotypeVAE with epoch defined from ZINC database 8 | 9 | - `generateKNNFIDMatrix.py` has the implementation of computing the Laplacian matrix of perturbations; 10 | 11 | - `fineTuneZLZ_LINCSGene.py` has the implementation of fine-tuning GenotypeVAE with the Laplacian matrix; 12 | 13 | 14 | -------------------------------------------------------------------------------- /perturbnet/genotypevae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/genotypevae/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "biggan_128": "https://heibox.uni-heidelberg.de/f/56ed256209fd40968864/?dl=1", 7 | "biggan_256": "https://heibox.uni-heidelberg.de/f/437b501944874bcc92a4/?dl=1", 8 | "dequant_vae": "https://heibox.uni-heidelberg.de/f/e7c8959b50a64f40826e/?dl=1", 9 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", 10 | "coco_captioner": "https://heibox.uni-heidelberg.de/f/b03aae864a0f42f1a2c3/?dl=1", 11 | "coco_word_map": "https://heibox.uni-heidelberg.de/f/1518aa8461d94e0cb3eb/?dl=1" 12 | } 13 | 14 | CKPT_MAP = { 15 | "biggan_128": "biggan-128.pth", 16 | "biggan_256": "biggan-256.pth", 17 | "dequant_vae": "dequantvae-20000.ckpt", 18 | "vgg_lpips": "autoencoders/lpips/vgg.pth", 19 | "coco_captioner": "captioning_model_pt16.ckpt", 20 | } 21 | 22 | MD5_MAP = { 23 | "biggan_128": "a2148cf64807444113fac5eede060d28", 24 | "biggan_256": "e23db3caa34ac4c4ae922a75258dcb8d", 25 | "dequant_vae": "5c2a6fe765142cbdd9f10f15d65a68b6", 26 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a", 27 | "coco_captioner": "db185e0f6791e60d27c00de0f40c376c", 28 | } 29 | 30 | 31 | def download(url, local_path, chunk_size=1024): 32 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 33 | with requests.get(url, stream=True) as r: 34 | total_size = int(r.headers.get("content-length", 0)) 35 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 36 | with open(local_path, "wb") as f: 37 | for data in r.iter_content(chunk_size=chunk_size): 38 | if data: 39 | f.write(data) 40 | pbar.update(chunk_size) 41 | 42 | 43 | def md5_hash(path): 44 | with open(path, "rb") as f: 45 | content = f.read() 46 | return hashlib.md5(content).hexdigest() 47 | 48 | 49 | def get_ckpt_path(name, root, check=False): 50 | assert name in URL_MAP 51 | path = os.path.join(root, CKPT_MAP[name]) 52 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 53 | print("Downloading {} from {} to {}".format(name, URL_MAP[name], path)) 54 | download(URL_MAP[name], path) 55 | md5 = md5_hash(path) 56 | assert md5 == MD5_MAP[name], md5 57 | return path 58 | -------------------------------------------------------------------------------- /perturbnet/net2net/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/data/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/data/base.py: -------------------------------------------------------------------------------- 1 | import os, bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /perturbnet/net2net/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CocoBase(Dataset): 11 | """needed for (image, caption, segmentation) pairs""" 12 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 13 | crop_size=None, force_no_crop=False): 14 | self.split = self.get_split() 15 | self.size = size 16 | if crop_size is None: 17 | self.crop_size = size 18 | else: 19 | self.crop_size = crop_size 20 | 21 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 22 | self.stuffthing = use_stuffthing # include thing in segmentation 23 | if self.onehot and not self.stuffthing: 24 | raise NotImplemented("One hot mode is only supported for the " 25 | "stuffthings version because labels are stored " 26 | "a bit different.") 27 | 28 | data_json = datajson 29 | with open(data_json) as json_file: 30 | self.json_data = json.load(json_file) 31 | self.img_id_to_captions = dict() 32 | self.img_id_to_filepath = dict() 33 | self.img_id_to_segmentation_filepath = dict() 34 | 35 | assert data_json.split("/")[-1] in ["captions_train2017.json", 36 | "captions_val2017.json"] 37 | 38 | if self.stuffthing: 39 | self.segmentation_prefix = ( 40 | "data/cocostuffthings/val2017" if 41 | data_json.endswith("captions_val2017.json") else 42 | "data/cocostuffthings/train2017") 43 | else: 44 | self.segmentation_prefix = ( 45 | "data/coco/annotations/stuff_val2017_pixelmaps" if 46 | data_json.endswith("captions_val2017.json") else 47 | "data/coco/annotations/stuff_train2017_pixelmaps") 48 | 49 | imagedirs = self.json_data["images"] 50 | self.labels = {"image_ids": list()} 51 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 52 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 53 | self.img_id_to_captions[imgdir["id"]] = list() 54 | pngfilename = imgdir["file_name"].replace("jpg", "png") 55 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 56 | self.segmentation_prefix, pngfilename) 57 | self.labels["image_ids"].append(imgdir["id"]) 58 | 59 | capdirs = self.json_data["annotations"] 60 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 61 | # there are in average 5 captions per image 62 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 63 | 64 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 65 | if self.split=="validation": 66 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 67 | else: 68 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 69 | self.preprocessor = albumentations.Compose( 70 | [self.rescaler, self.cropper], 71 | additional_targets={"segmentation": "image"}) 72 | if force_no_crop: 73 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 74 | self.preprocessor = albumentations.Compose( 75 | [self.rescaler], 76 | additional_targets={"segmentation": "image"}) 77 | 78 | def __len__(self): 79 | return len(self.labels["image_ids"]) 80 | 81 | def preprocess_image(self, image_path, segmentation_path): 82 | image = Image.open(image_path) 83 | if not image.mode == "RGB": 84 | image = image.convert("RGB") 85 | image = np.array(image).astype(np.uint8) 86 | 87 | segmentation = Image.open(segmentation_path) 88 | if not self.onehot and not segmentation.mode == "RGB": 89 | segmentation = segmentation.convert("RGB") 90 | segmentation = np.array(segmentation).astype(np.uint8) 91 | if self.onehot: 92 | assert self.stuffthing 93 | # stored in caffe format: unlabeled==255. stuff and thing from 94 | # 0-181. to be compatible with the labels in 95 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 96 | # we shift stuffthing one to the right and put unlabeled in zero 97 | # as long as segmentation is uint8 shifting to right handles the 98 | # latter too 99 | assert segmentation.dtype == np.uint8 100 | segmentation = segmentation + 1 101 | 102 | processed = self.preprocessor(image=image, segmentation=segmentation) 103 | image, segmentation = processed["image"], processed["segmentation"] 104 | image = (image / 127.5 - 1.0).astype(np.float32) 105 | 106 | if self.onehot: 107 | assert segmentation.dtype == np.uint8 108 | # make it one hot 109 | n_labels = 183 110 | flatseg = np.ravel(segmentation) 111 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 112 | onehot[np.arange(flatseg.size), flatseg] = True 113 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 114 | segmentation = onehot 115 | else: 116 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 117 | return image, segmentation 118 | 119 | def __getitem__(self, i): 120 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 121 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 122 | image, segmentation = self.preprocess_image(img_path, seg_path) 123 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 124 | # randomly draw one of all available captions per image 125 | caption = captions[np.random.randint(0, len(captions))] 126 | example = {"image": image, 127 | "caption": [str(caption[0])], 128 | "segmentation": segmentation, 129 | "img_path": img_path, 130 | "seg_path": seg_path 131 | } 132 | return example 133 | 134 | 135 | class CocoImagesAndCaptionsTrain(CocoBase): 136 | """returns a pair of (image, caption)""" 137 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 138 | super().__init__(size=size, 139 | dataroot="data/coco/train2017", 140 | datajson="data/coco/annotations/captions_train2017.json", 141 | onehot_segmentation=onehot_segmentation, 142 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 143 | 144 | def get_split(self): 145 | return "train" 146 | 147 | 148 | class CocoImagesAndCaptionsValidation(CocoBase): 149 | """returns a pair of (image, caption)""" 150 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 151 | super().__init__(size=size, 152 | dataroot="data/coco/val2017", 153 | datajson="data/coco/annotations/captions_val2017.json", 154 | onehot_segmentation=onehot_segmentation, 155 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 156 | 157 | def get_split(self): 158 | return "validation" 159 | 160 | -------------------------------------------------------------------------------- /perturbnet/net2net/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import urllib 4 | import tarfile, zipfile 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | 9 | def unpack(path): 10 | if path.endswith("tar.gz"): 11 | with tarfile.open(path, "r:gz") as tar: 12 | tar.extractall(path=os.path.split(path)[0]) 13 | elif path.endswith("tar"): 14 | with tarfile.open(path, "r:") as tar: 15 | tar.extractall(path=os.path.split(path)[0]) 16 | elif path.endswith("zip"): 17 | with zipfile.ZipFile(path, "r") as f: 18 | f.extractall(path=os.path.split(path)[0]) 19 | else: 20 | raise NotImplementedError( 21 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 22 | ) 23 | 24 | 25 | def reporthook(bar): 26 | """tqdm progress bar for downloads.""" 27 | 28 | def hook(b=1, bsize=1, tsize=None): 29 | if tsize is not None: 30 | bar.total = tsize 31 | bar.update(b * bsize - bar.n) 32 | 33 | return hook 34 | 35 | 36 | def get_root(name): 37 | base = "data/" 38 | root = os.path.join(base, name) 39 | os.makedirs(root, exist_ok=True) 40 | return root 41 | 42 | 43 | def is_prepared(root): 44 | return Path(root).joinpath(".ready").exists() 45 | 46 | 47 | def mark_prepared(root): 48 | Path(root).joinpath(".ready").touch() 49 | 50 | 51 | def prompt_download(file_, source, target_dir, content_dir=None): 52 | targetpath = os.path.join(target_dir, file_) 53 | while not os.path.exists(targetpath): 54 | if content_dir is not None and os.path.exists( 55 | os.path.join(target_dir, content_dir) 56 | ): 57 | break 58 | print( 59 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 60 | ) 61 | if content_dir is not None: 62 | print( 63 | "Or place its content into '{}'.".format( 64 | os.path.join(target_dir, content_dir) 65 | ) 66 | ) 67 | input("Press Enter when done...") 68 | return targetpath 69 | 70 | 71 | def download_url(file_, url, target_dir): 72 | targetpath = os.path.join(target_dir, file_) 73 | os.makedirs(target_dir, exist_ok=True) 74 | with tqdm( 75 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 76 | ) as bar: 77 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 78 | return targetpath 79 | 80 | 81 | def download_urls(urls, target_dir): 82 | paths = dict() 83 | for fname, url in urls.items(): 84 | outpath = download_url(fname, url, target_dir) 85 | paths[fname] = outpath 86 | return paths 87 | 88 | 89 | def quadratic_crop(x, bbox, alpha=1.0): 90 | """bbox is xmin, ymin, xmax, ymax""" 91 | im_h, im_w = x.shape[:2] 92 | bbox = np.array(bbox, dtype=np.float32) 93 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 94 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 95 | w = bbox[2] - bbox[0] 96 | h = bbox[3] - bbox[1] 97 | l = int(alpha * max(w, h)) 98 | l = max(l, 2) 99 | 100 | required_padding = -1 * min( 101 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 102 | ) 103 | required_padding = int(np.ceil(required_padding)) 104 | if required_padding > 0: 105 | padding = [ 106 | [required_padding, required_padding], 107 | [required_padding, required_padding], 108 | ] 109 | padding += [[0, 0]] * (len(x.shape) - 2) 110 | x = np.pad(x, padding, "reflect") 111 | center = center[0] + required_padding, center[1] + required_padding 112 | xmin = int(center[0] - l / 2) 113 | ymin = int(center[1] - l / 2) 114 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 115 | -------------------------------------------------------------------------------- /perturbnet/net2net/data/zcodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class PRNGMixin(object): 8 | """Adds a prng property which is a numpy RandomState which gets 9 | reinitialized whenever the pid changes to avoid synchronized sampling 10 | behavior when used in conjunction with multiprocessing.""" 11 | 12 | @property 13 | def prng(self): 14 | currentpid = os.getpid() 15 | if getattr(self, "_initpid", None) != currentpid: 16 | self._initpid = currentpid 17 | self._prng = np.random.RandomState() 18 | return self._prng 19 | 20 | 21 | class TrainSamples(Dataset, PRNGMixin): 22 | def __init__(self, n_samples, z_shape, n_classes, truncation=0): 23 | self.n_samples = n_samples 24 | self.z_shape = z_shape 25 | self.n_classes = n_classes 26 | self.truncation_threshold = truncation 27 | if self.truncation_threshold > 0: 28 | print("Applying truncation at level {}".format(self.truncation_threshold)) 29 | 30 | def __len__(self): 31 | return self.n_samples 32 | 33 | def __getitem__(self, i): 34 | z = self.prng.randn(*self.z_shape) 35 | if self.truncation_threshold > 0: 36 | for k, zi in enumerate(z): 37 | while abs(zi) > self.truncation_threshold: 38 | zi = self.prng.randn(1) 39 | z[k] = zi 40 | cls = self.prng.randint(self.n_classes) 41 | return {"z": z.astype(np.float32), "class": cls} 42 | 43 | 44 | class TestSamples(Dataset): 45 | def __init__(self, n_samples, z_shape, n_classes, truncation=0): 46 | self.prng = np.random.RandomState(1) 47 | self.n_samples = n_samples 48 | self.z_shape = z_shape 49 | self.n_classes = n_classes 50 | self.truncation_threshold = truncation 51 | if self.truncation_threshold > 0: 52 | print("Applying truncation at level {}".format(self.truncation_threshold)) 53 | self.zs = self.prng.randn(self.n_samples, *self.z_shape) 54 | if self.truncation_threshold > 0: 55 | print("Applying truncation at level {}".format(self.truncation_threshold)) 56 | ix = 0 57 | for z in tqdm(self.zs, desc="Truncation:"): 58 | for k, zi in enumerate(z): 59 | while abs(zi) > self.truncation_threshold: 60 | zi = self.prng.randn(1) 61 | z[k] = zi 62 | self.zs[ix] = z 63 | ix += 1 64 | print("Created truncated test data.") 65 | self.clss = self.prng.randint(self.n_classes, size=(self.n_samples,)) 66 | 67 | def __len__(self): 68 | return self.n_samples 69 | 70 | def __getitem__(self, i): 71 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]} 72 | 73 | 74 | class RestrictedTrainSamples(Dataset, PRNGMixin): 75 | def __init__(self, n_samples, z_shape, truncation=0): 76 | index_path = "data/coco_imagenet_overlap_idx.txt" 77 | self.n_samples = n_samples 78 | self.z_shape = z_shape 79 | self.classes = np.loadtxt(index_path).astype(int) 80 | self.truncation_threshold = truncation 81 | if self.truncation_threshold > 0: 82 | print("Applying truncation at level {}".format(self.truncation_threshold)) 83 | 84 | def __len__(self): 85 | return self.n_samples 86 | 87 | def __getitem__(self, i): 88 | z = self.prng.randn(*self.z_shape) 89 | if self.truncation_threshold > 0: 90 | for k, zi in enumerate(z): 91 | while abs(zi) > self.truncation_threshold: 92 | zi = self.prng.randn(1) 93 | z[k] = zi 94 | cls = self.prng.choice(self.classes) 95 | return {"z": z.astype(np.float32), "class": cls} 96 | 97 | 98 | class RestrictedTestSamples(Dataset): 99 | def __init__(self, n_samples, z_shape, truncation=0): 100 | index_path = "data/coco_imagenet_overlap_idx.txt" 101 | 102 | self.prng = np.random.RandomState(1) 103 | self.n_samples = n_samples 104 | self.z_shape = z_shape 105 | 106 | self.classes = np.loadtxt(index_path).astype(int) 107 | self.clss = self.prng.choice(self.classes, size=(self.n_samples,), replace=True) 108 | self.truncation_threshold = truncation 109 | self.zs = self.prng.randn(self.n_samples, *self.z_shape) 110 | if self.truncation_threshold > 0: 111 | print("Applying truncation at level {}".format(self.truncation_threshold)) 112 | ix = 0 113 | for z in tqdm(self.zs, desc="Truncation:"): 114 | for k, zi in enumerate(z): 115 | while abs(zi) > self.truncation_threshold: 116 | zi = self.prng.randn(1) 117 | z[k] = zi 118 | self.zs[ix] = z 119 | ix += 1 120 | print("Created truncated test data.") 121 | 122 | def __len__(self): 123 | return self.n_samples 124 | 125 | def __getitem__(self, i): 126 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]} 127 | 128 | 129 | -------------------------------------------------------------------------------- /perturbnet/net2net/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/models/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from perturbnet.net2net.modules.distributions.distributions import DiagonalGaussianDistribution 5 | from translation import instantiate_from_config 6 | 7 | 8 | class BigAE(pl.LightningModule): 9 | def __init__(self, 10 | encoder_config, 11 | decoder_config, 12 | loss_config, 13 | ckpt_path=None, 14 | ignore_keys=[] 15 | ): 16 | super().__init__() 17 | self.encoder = instantiate_from_config(encoder_config) 18 | self.decoder = instantiate_from_config(decoder_config) 19 | self.loss = instantiate_from_config(loss_config) 20 | 21 | if ckpt_path is not None: 22 | print("Loading model from {}".format(ckpt_path)) 23 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 24 | 25 | def init_from_ckpt(self, path, ignore_keys=list()): 26 | try: 27 | sd = torch.load(path, map_location="cpu")["state_dict"] 28 | except KeyError: 29 | sd = torch.load(path, map_location="cpu") 30 | 31 | keys = list(sd.keys()) 32 | for k in keys: 33 | for ik in ignore_keys: 34 | if k.startswith(ik): 35 | print("Deleting key {} from state_dict.".format(k)) 36 | del sd[k] 37 | missing, unexpected = self.load_state_dict(sd, strict=False) 38 | if len(missing) > 0: 39 | print(f"Missing keys in state dict: {missing}") 40 | if len(unexpected) > 0: 41 | print(f"Unexpected keys in state dict: {unexpected}") 42 | 43 | def encode(self, x, return_mode=False): 44 | moments = self.encoder(x) 45 | posterior = DiagonalGaussianDistribution(moments, deterministic=False) 46 | if return_mode: 47 | return posterior.mode() 48 | return posterior.sample() 49 | 50 | def decode(self, z): 51 | if len(z.shape) == 4: 52 | z = z.squeeze(-1).squeeze(-1) 53 | return self.decoder(z) 54 | 55 | def forward(self, x): 56 | moments = self.encoder(x) 57 | posterior = DiagonalGaussianDistribution(moments) 58 | h = posterior.sample() 59 | reconstructions = self.decoder(h.squeeze(-1).squeeze(-1)) 60 | return reconstructions, posterior 61 | 62 | def get_last_layer(self): 63 | return getattr(self.decoder.decoder.colorize.module, 'weight_bar') 64 | 65 | def log_images(self, batch, split=""): 66 | log = dict() 67 | inputs = batch["image"].permute(0, 3, 1, 2) 68 | inputs = inputs.to(self.device) 69 | reconstructions, posterior = self(inputs) 70 | log["inputs"] = inputs 71 | log["reconstructions"] = reconstructions 72 | return log 73 | 74 | def training_step(self, batch, batch_idx, optimizer_idx): 75 | inputs = batch["image"].permute(0, 3, 1, 2) 76 | reconstructions, posterior = self(inputs) 77 | 78 | if optimizer_idx == 0: 79 | # train encoder+decoder+logvar 80 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 81 | last_layer=self.get_last_layer(), split="train") 82 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss) 83 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 84 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 85 | return output 86 | 87 | if optimizer_idx == 1: 88 | # train the discriminator 89 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 90 | last_layer=self.get_last_layer(), split="train") 91 | output = pl.TrainResult(minimize=discloss) 92 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 93 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 94 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO 95 | return output 96 | 97 | def validation_step(self, batch, batch_idx): 98 | inputs = batch["image"].permute(0, 3, 1, 2) 99 | reconstructions, posterior = self(inputs) 100 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 101 | last_layer=self.get_last_layer(), split="val") 102 | 103 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 104 | last_layer=self.get_last_layer(), split="val") 105 | output = pl.EvalResult(checkpoint_on=aeloss) 106 | output.log_dict(log_dict_ae) 107 | output.log_dict(log_dict_disc) 108 | return output 109 | 110 | def configure_optimizers(self): 111 | lr = self.learning_rate 112 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+list(self.decoder.parameters())+[self.loss.logvar], 113 | lr=lr, betas=(0.5, 0.9)) 114 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) 115 | return [opt_ae, opt_disc], [] 116 | 117 | def on_epoch_end(self): 118 | pass 119 | 120 | 121 | class BasicAE(pl.LightningModule): 122 | def __init__(self, ae_config, loss_config, ckpt_path=None, ignore_keys=[]): 123 | super().__init__() 124 | self.autoencoder = instantiate_from_config(ae_config) 125 | self.loss = instantiate_from_config(loss_config) 126 | if ckpt_path is not None: 127 | print("Loading model from {}".format(ckpt_path)) 128 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 129 | 130 | def init_from_ckpt(self, path, ignore_keys=list()): 131 | try: 132 | sd = torch.load(path, map_location="cpu")["state_dict"] 133 | except KeyError: 134 | sd = torch.load(path, map_location="cpu") 135 | 136 | keys = list(sd.keys()) 137 | for k in keys: 138 | for ik in ignore_keys: 139 | if k.startswith(ik): 140 | print("Deleting key {} from state_dict.".format(k)) 141 | del sd[k] 142 | self.load_state_dict(sd, strict=False) 143 | 144 | def forward(self, x): 145 | posterior = self.autoencoder.encode(x) 146 | h = posterior.sample() 147 | reconstructions = self.autoencoder.decode(h) 148 | return reconstructions, posterior 149 | 150 | def encode(self, x): 151 | posterior = self.autoencoder.encode(x) 152 | h = posterior.sample() 153 | return h 154 | 155 | def get_last_layer(self): 156 | return self.autoencoder.get_last_layer() 157 | 158 | def log_images(self, batch, split=""): 159 | log = dict() 160 | inputs = batch["image"].permute(0, 3, 1, 2) 161 | inputs = inputs.to(self.device) 162 | reconstructions, posterior = self(inputs) 163 | log["inputs"] = inputs 164 | log["reconstructions"] = reconstructions 165 | return log 166 | 167 | def training_step(self, batch, batch_idx, optimizer_idx): 168 | inputs = batch["image"].permute(0, 3, 1, 2) 169 | reconstructions, posterior = self(inputs) 170 | 171 | if optimizer_idx == 0: 172 | # train encoder+decoder+logvar 173 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 174 | last_layer=self.get_last_layer(), split="train") 175 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss) 176 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 177 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 178 | return output 179 | 180 | if optimizer_idx == 1: 181 | # train the discriminator 182 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 183 | last_layer=self.get_last_layer(), split="train") 184 | output = pl.TrainResult(minimize=discloss) 185 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 186 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 187 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO 188 | return output 189 | 190 | def validation_step(self, batch, batch_idx): 191 | inputs = batch["image"].permute(0, 3, 1, 2) 192 | reconstructions, posterior = self(inputs) 193 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 194 | last_layer=self.get_last_layer(), split="val") 195 | 196 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 197 | last_layer=self.get_last_layer(), split="val") 198 | output = pl.EvalResult(checkpoint_on=aeloss) 199 | output.log_dict(log_dict_ae) 200 | output.log_dict(log_dict_disc) 201 | return output 202 | 203 | def configure_optimizers(self): 204 | lr = self.learning_rate 205 | opt_ae = torch.optim.Adam(list(self.autoencoder.parameters())+[self.loss.logvar], 206 | lr=lr, betas=(0.5, 0.9)) 207 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) 208 | return [opt_ae, opt_disc], [] 209 | -------------------------------------------------------------------------------- /perturbnet/net2net/models/flows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/models/flows/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/models/flows/flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from translation import instantiate_from_config 6 | from perturbnet.net2net.modules.flow.loss import NLL 7 | from perturbnet.net2net.ckpt_util import get_ckpt_path 8 | from perturbnet.net2net.modules.util import log_txt_as_img 9 | 10 | 11 | def disabled_train(self, mode=True): 12 | """Overwrite model.train with this function to make sure train/eval mode 13 | does not change anymore.""" 14 | return self 15 | 16 | 17 | class Flow(pl.LightningModule): 18 | def __init__(self, flow_config): 19 | super().__init__() 20 | self.flow = instantiate_from_config(config=flow_config) 21 | self.loss = NLL() 22 | 23 | def forward(self, x): 24 | zz, logdet = self.flow(x) 25 | return zz, logdet 26 | 27 | def sample_like(self, query): 28 | z = self.flow.sample(query.shape[0], device=query.device).float() 29 | return z 30 | 31 | def shared_step(self, batch, batch_idx): 32 | x, labels = batch 33 | x = x.float() 34 | zz, logdet = self(x) 35 | loss, log_dict = self.loss(zz, logdet) 36 | return loss, log_dict 37 | 38 | def training_step(self, batch, batch_idx): 39 | loss, log_dict = self.shared_step(batch, batch_idx) 40 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss) 41 | output.log_dict(log_dict, prog_bar=False, on_epoch=True) 42 | return output 43 | 44 | def validation_step(self, batch, batch_idx): 45 | loss, log_dict = self.shared_step(batch, batch_idx) 46 | output = pl.EvalResult(checkpoint_on=loss) 47 | output.log_dict(log_dict, prog_bar=False) 48 | 49 | x, _ = batch 50 | x = x.float() 51 | sample = self.sample_like(x) 52 | output.sample_like = sample 53 | output.input = x.clone() 54 | 55 | return output 56 | 57 | def configure_optimizers(self): 58 | opt = torch.optim.Adam((self.flow.parameters()),lr=self.learning_rate, betas=(0.5, 0.9)) 59 | return [opt], [] 60 | 61 | 62 | class Net2NetFlow(pl.LightningModule): 63 | def __init__(self, 64 | flow_config, 65 | first_stage_config, 66 | cond_stage_config, 67 | ckpt_path=None, 68 | ignore_keys=[], 69 | first_stage_key="image", 70 | cond_stage_key="image", 71 | interpolate_cond_size=-1 72 | ): 73 | super().__init__() 74 | self.init_first_stage_from_ckpt(first_stage_config) 75 | self.init_cond_stage_from_ckpt(cond_stage_config) 76 | self.flow = instantiate_from_config(config=flow_config) 77 | self.loss = NLL() 78 | self.first_stage_key = first_stage_key 79 | self.cond_stage_key = cond_stage_key 80 | self.interpolate_cond_size = interpolate_cond_size 81 | if ckpt_path is not None: 82 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 83 | 84 | def init_from_ckpt(self, path, ignore_keys=list()): 85 | sd = torch.load(path, map_location="cpu")["state_dict"] 86 | for k in sd.keys(): 87 | for ik in ignore_keys: 88 | if k.startswith(ik): 89 | self.print("Deleting key {} from state_dict.".format(k)) 90 | del sd[k] 91 | self.load_state_dict(sd, strict=False) 92 | print(f"Restored from {path}") 93 | 94 | def init_first_stage_from_ckpt(self, config): 95 | model = instantiate_from_config(config) 96 | model = model.eval() 97 | model.train = disabled_train 98 | self.first_stage_model = model 99 | 100 | def init_cond_stage_from_ckpt(self, config): 101 | model = instantiate_from_config(config) 102 | model = model.eval() 103 | model.train = disabled_train 104 | self.cond_stage_model = model 105 | 106 | def forward(self, x, c): 107 | c = self.encode_to_c(c) 108 | q = self.encode_to_z(x) 109 | zz, logdet = self.flow(q, c) 110 | return zz, logdet 111 | 112 | @torch.no_grad() 113 | def sample_conditional(self, c): 114 | z = self.flow.sample(c) 115 | return z 116 | 117 | @torch.no_grad() 118 | def encode_to_z(self, x): 119 | z = self.first_stage_model.encode(x).detach() 120 | return z 121 | 122 | @torch.no_grad() 123 | def encode_to_c(self, c): 124 | c = self.cond_stage_model.encode(c).detach() 125 | return c 126 | 127 | @torch.no_grad() 128 | def decode_to_img(self, z): 129 | x = self.first_stage_model.decode(z.detach()) 130 | return x 131 | 132 | @torch.no_grad() 133 | def log_images(self, batch, split=""): 134 | log = dict() 135 | x = self.get_input(self.first_stage_key, batch).to(self.device) 136 | xc = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 137 | if self.cond_stage_key not in ["text", "caption"]: 138 | xc = xc.to(self.device) 139 | 140 | z = self.encode_to_z(x) 141 | c = self.encode_to_c(xc) 142 | 143 | zz, _ = self.flow(z, c) 144 | zrec = self.flow.reverse(zz, c) 145 | xrec = self.decode_to_img(zrec) 146 | z_sample = self.sample_conditional(c) 147 | xsample = self.decode_to_img(z_sample) 148 | 149 | cshift = torch.cat((c[1:],c[:1]),dim=0) 150 | zshift = self.flow.reverse(zz, cshift) 151 | xshift = self.decode_to_img(zshift) 152 | 153 | log["inputs"] = x 154 | if self.cond_stage_key not in ["text", "caption", "class"]: 155 | log["conditioning"] = xc 156 | else: 157 | _,_,h,w = x.shape 158 | log["conditioning"] = log_txt_as_img((w,h), xc) 159 | 160 | log["reconstructions"] = xrec 161 | log["shift"] = xshift 162 | log["samples"] = xsample 163 | return log 164 | 165 | def get_input(self, key, batch, is_conditioning = False): 166 | x = batch[key] 167 | if key in ["caption", "text"]: 168 | x = list(x[0]) 169 | elif key in ["class"]: 170 | pass 171 | else: 172 | if len(x.shape) == 3: 173 | x = x[..., None] 174 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 175 | if is_conditioning: 176 | if self.interpolate_cond_size > -1: 177 | x = F.interpolate(x, size=(self.interpolate_cond_size, self.interpolate_cond_size)) 178 | return x 179 | 180 | def shared_step(self, batch, batch_idx, split="train"): 181 | x = self.get_input(self.first_stage_key, batch) 182 | c = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 183 | zz, logdet = self(x, c) 184 | loss, log_dict = self.loss(zz, logdet, split=split) 185 | return loss, log_dict 186 | 187 | def training_step(self, batch, batch_idx): 188 | loss, log_dict = self.shared_step(batch, batch_idx, split="train") 189 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss) 190 | output.log_dict(log_dict, prog_bar=False, on_epoch=True, logger=True, on_step=True) 191 | return output 192 | 193 | def validation_step(self, batch, batch_idx): 194 | loss, log_dict = self.shared_step(batch, batch_idx, split="val") 195 | output = pl.EvalResult(checkpoint_on=loss) 196 | output.log_dict(log_dict, prog_bar=False, logger=True) 197 | return output 198 | 199 | def configure_optimizers(self): 200 | opt = torch.optim.Adam((self.flow.parameters()), 201 | lr=self.learning_rate, 202 | betas=(0.5, 0.9), 203 | amsgrad=True) 204 | return [opt], [] 205 | 206 | 207 | class Net2BigGANFlow(Net2NetFlow): 208 | def __init__(self, 209 | flow_config, 210 | gan_config, 211 | cond_stage_config, 212 | make_cond_config, 213 | ckpt_path=None, 214 | ignore_keys=[], 215 | cond_stage_key="caption" 216 | ): 217 | super().__init__(flow_config=flow_config, 218 | first_stage_config=gan_config, cond_stage_config=cond_stage_config, 219 | ckpt_path=ckpt_path, ignore_keys=ignore_keys, cond_stage_key=cond_stage_key 220 | ) 221 | 222 | self.init_to_c_model(make_cond_config) 223 | self.init_preprocessing() 224 | 225 | @torch.no_grad() 226 | def get_input(self, batch, move_to_device=False): 227 | zin = batch["z"] 228 | cin = batch["class"] 229 | if move_to_device: 230 | zin, cin = zin.to(self.device), cin.to(self.device) 231 | # dequantize the discrete class code 232 | cin = self.first_stage_model.embed_labels(cin, labels_are_one_hot=False) 233 | split_sizes = [zin.shape[1], cin.shape[1]] 234 | xin = self.first_stage_model.generate_from_embedding(zin, cin) 235 | cin = self.dequantizer(cin) 236 | xc = self.to_c_model(xin) 237 | zflow = torch.cat([zin, cin.detach()], dim=1)[:, :, None, None] # this will be flowed 238 | return {"zcode": zflow, 239 | "xgen": xin, 240 | "xcon": xc, 241 | "split_sizes": split_sizes 242 | } 243 | 244 | def init_to_c_model(self, config): 245 | model = instantiate_from_config(config) 246 | model = model.eval() 247 | model.train = disabled_train 248 | self.to_c_model = model 249 | 250 | def init_preprocessing(self): 251 | dqcfg = {"target": "net2net.modules.autoencoder.basic.BasicFullyConnectedVAE"} 252 | self.dequantizer = instantiate_from_config(dqcfg) 253 | ckpt = get_ckpt_path("dequant_vae", "net2net/modules/autoencoder/dequant_vae") 254 | self.dequantizer.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 255 | self.dequantizer.eval() 256 | self.dequantizer.train = disabled_train 257 | 258 | def shared_step(self, batch, batch_idx, split="train"): 259 | data = self.get_input(batch) 260 | z, c = data["zcode"], data["xcon"] 261 | zz, logdet = self(z, c) 262 | loss, log_dict = self.loss(zz, logdet, split=split) 263 | return loss, log_dict 264 | 265 | def forward(self, z, c): 266 | c = self.encode_to_c(c) 267 | zz, logdet = self.flow(z, c) 268 | return zz, logdet 269 | 270 | @torch.no_grad() 271 | def log_images(self, batch, split=""): 272 | log = dict() 273 | data = self.get_input(batch, move_to_device=True) 274 | z, xc, x = data["zcode"], data["xcon"], data["xgen"] 275 | c = self.encode_to_c(xc) 276 | zz, _ = self.flow(z, c) 277 | 278 | z_sample = self.sample_conditional(c) 279 | zdec, cdec = torch.split(z_sample, data["split_sizes"], dim=1) 280 | xsample = self.first_stage_model.generate_from_embedding(zdec.squeeze(-1).squeeze(-1), 281 | cdec.squeeze(-1).squeeze(-1)) 282 | 283 | cshift = torch.cat((c[1:],c[:1]),dim=0) 284 | zshift = self.flow.reverse(zz, cshift) 285 | zshift, cshift = torch.split(zshift, data["split_sizes"], dim=1) 286 | xshift = self.first_stage_model.generate_from_embedding(zshift.squeeze(-1).squeeze(-1), 287 | cshift.squeeze(-1).squeeze(-1)) 288 | 289 | log["inputs"] = x 290 | if self.cond_stage_key not in ["text", "caption", "class"]: 291 | log["conditioning"] = xc 292 | else: 293 | _,_,h,w = x.shape 294 | log["conditioning"] = log_txt_as_img((w,h), xc) 295 | 296 | log["shift"] = xshift 297 | log["samples"] = xsample 298 | return log 299 | 300 | @torch.no_grad() 301 | def sample_conditional(self, c): 302 | z = self.flow.sample(c) 303 | return z 304 | -------------------------------------------------------------------------------- /perturbnet/net2net/models/flows/scviflow.py: -------------------------------------------------------------------------------- 1 | """"flow file for scvi latet space""" 2 | import torch 3 | import torch.nn.functional as F 4 | import pytorch_lightning as pl 5 | 6 | from translation import instantiate_from_config 7 | from perturbnet.net2net.modules.flow.loss import NLL 8 | from perturbnet.net2net.ckpt_util import get_ckpt_path 9 | from perturbnet.net2net.modules.util import log_txt_as_img 10 | 11 | 12 | class Net2NetFlow(pl.LightningModule): 13 | def __init__(self, 14 | flow_config, 15 | first_stage_config, 16 | cond_stage_config, 17 | ckpt_path=None, 18 | ignore_keys=[], 19 | first_stage_key="image", 20 | cond_stage_key="image", 21 | interpolate_cond_size=-1 22 | ): 23 | super().__init__() 24 | self.init_first_stage_from_ckpt(first_stage_config) 25 | self.init_cond_stage_from_ckpt(cond_stage_config) 26 | self.flow = instantiate_from_config(config = flow_config) 27 | self.loss = NLL() 28 | self.first_stage_key = first_stage_key 29 | self.cond_stage_key = cond_stage_key 30 | self.interpolate_cond_size = interpolate_cond_size 31 | if ckpt_path is not None: 32 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 33 | 34 | def init_from_ckpt(self, path, ignore_keys=list()): 35 | sd = torch.load(path, map_location="cpu")["state_dict"] 36 | for k in sd.keys(): 37 | for ik in ignore_keys: 38 | if k.startswith(ik): 39 | self.print("Deleting key {} from state_dict.".format(k)) 40 | del sd[k] 41 | self.load_state_dict(sd, strict=False) 42 | print(f"Restored from {path}") 43 | 44 | def init_first_stage_from_ckpt(self, config): 45 | model = instantiate_from_config(config) 46 | model = model.eval() 47 | model.train = disabled_train 48 | self.first_stage_model = model 49 | 50 | def init_cond_stage_from_ckpt(self, config): 51 | model = instantiate_from_config(config) 52 | model = model.eval() 53 | model.train = disabled_train 54 | self.cond_stage_model = model 55 | 56 | def forward(self, x, c): 57 | c = self.encode_to_c(c) 58 | q = self.encode_to_z(x) 59 | zz, logdet = self.flow(q, c) 60 | return zz, logdet 61 | 62 | @torch.no_grad() 63 | def sample_conditional(self, c): 64 | z = self.flow.sample(c) 65 | return z 66 | 67 | @torch.no_grad() 68 | def encode_to_z(self, x): 69 | z = self.first_stage_model.encode(x).detach() 70 | return z 71 | 72 | @torch.no_grad() 73 | def encode_to_c(self, c): 74 | c = self.cond_stage_model.encode(c).detach() 75 | return c 76 | 77 | @torch.no_grad() 78 | def decode_to_img(self, z): 79 | x = self.first_stage_model.decode(z.detach()) 80 | return x 81 | 82 | @torch.no_grad() 83 | def log_images(self, batch, split=""): 84 | log = dict() 85 | x = self.get_input(self.first_stage_key, batch).to(self.device) 86 | xc = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 87 | if self.cond_stage_key not in ["text", "caption"]: 88 | xc = xc.to(self.device) 89 | 90 | z = self.encode_to_z(x) 91 | c = self.encode_to_c(xc) 92 | 93 | zz, _ = self.flow(z, c) 94 | zrec = self.flow.reverse(zz, c) 95 | xrec = self.decode_to_img(zrec) 96 | z_sample = self.sample_conditional(c) 97 | xsample = self.decode_to_img(z_sample) 98 | 99 | cshift = torch.cat((c[1:],c[:1]),dim=0) 100 | zshift = self.flow.reverse(zz, cshift) 101 | xshift = self.decode_to_img(zshift) 102 | 103 | log["inputs"] = x 104 | if self.cond_stage_key not in ["text", "caption", "class"]: 105 | log["conditioning"] = xc 106 | else: 107 | _,_,h,w = x.shape 108 | log["conditioning"] = log_txt_as_img((w,h), xc) 109 | 110 | log["reconstructions"] = xrec 111 | log["shift"] = xshift 112 | log["samples"] = xsample 113 | return log 114 | 115 | def get_input(self, key, batch, is_conditioning=False): 116 | x = batch[key] 117 | if key in ["caption", "text"]: 118 | x = list(x[0]) 119 | elif key in ["class"]: 120 | pass 121 | else: 122 | if len(x.shape) == 3: 123 | x = x[..., None] 124 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 125 | if is_conditioning: 126 | if self.interpolate_cond_size > -1: 127 | x = F.interpolate(x, size=(self.interpolate_cond_size, self.interpolate_cond_size)) 128 | return x 129 | 130 | def shared_step(self, batch, batch_idx, split="train"): 131 | x = self.get_input(self.first_stage_key, batch) 132 | c = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 133 | zz, logdet = self(x, c) 134 | loss, log_dict = self.loss(zz, logdet, split=split) 135 | return loss, log_dict 136 | 137 | def training_step(self, batch, batch_idx): 138 | loss, log_dict = self.shared_step(batch, batch_idx, split="train") 139 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss) 140 | output.log_dict(log_dict, prog_bar=False, on_epoch=True, logger=True, on_step=True) 141 | return output 142 | 143 | def validation_step(self, batch, batch_idx): 144 | loss, log_dict = self.shared_step(batch, batch_idx, split="val") 145 | output = pl.EvalResult(checkpoint_on=loss) 146 | output.log_dict(log_dict, prog_bar=False, logger=True) 147 | return output 148 | 149 | def configure_optimizers(self): 150 | opt = torch.optim.Adam((self.flow.parameters()), 151 | lr=self.learning_rate, 152 | betas=(0.5, 0.9), 153 | amsgrad=True) 154 | return [opt], [] 155 | -------------------------------------------------------------------------------- /perturbnet/net2net/models/flows/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | #from sklearn.neighbors import KernelDensity 6 | 7 | 8 | def kde2D(x, y, bandwidth, xbins=250j, ybins=250j, **kwargs): 9 | """Build 2D kernel density estimate (KDE).""" 10 | 11 | # create grid of sample locations (default: 100x100) 12 | xx, yy = np.mgrid[x.min():x.max():xbins, 13 | y.min():y.max():ybins] 14 | 15 | xy_sample = np.vstack([yy.ravel(), xx.ravel()]).T 16 | xy_train = np.vstack([y, x]).T 17 | 18 | kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs) 19 | kde_skl.fit(xy_train) 20 | 21 | # score_samples() returns the log-likelihood of the samples 22 | z = np.exp(kde_skl.score_samples(xy_sample)) 23 | return xx, yy, np.reshape(z, xx.shape) 24 | 25 | 26 | def plot2d(x, savepath=None): 27 | """make a scatter plot of x and return an Image of it""" 28 | x = x.cpu().numpy().squeeze() 29 | fig = plt.figure(dpi=300) 30 | xx, yy, zz = kde2D(x[:,0], x[:,1], 0.1) 31 | plt.pcolormesh(xx, yy, zz) 32 | plt.scatter(x[:,0], x[:, 1], s=0.1, c='mistyrose') 33 | if savepath is not None: 34 | plt.savefig(savepath, dpi=300) 35 | return fig 36 | 37 | 38 | def reshape_to_grid(x, num_samples=16, iw=28, ih=28, nc=1): 39 | x = x[:num_samples] 40 | x = x.detach().cpu() 41 | x = torch.reshape(x, (x.shape[0], nc, iw, ih)) 42 | xgrid = torchvision.utils.make_grid(x) 43 | return xgrid -------------------------------------------------------------------------------- /perturbnet/net2net/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/autoencoder/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/autoencoder/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/autoencoder/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/__pycache__/basic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/autoencoder/__pycache__/basic.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/__pycache__/basic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/autoencoder/__pycache__/basic.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from perturbnet.net2net.modules.gan.biggan import load_variable_latsize_generator 5 | 6 | class ClassUp(nn.Module): 7 | def __init__(self, dim, depth, hidden_dim=256, use_sigmoid=False, out_dim=None): 8 | super().__init__() 9 | layers = [] 10 | layers.append(nn.Linear(dim, hidden_dim)) 11 | layers.append(nn.LeakyReLU()) 12 | for d in range(depth): 13 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 14 | layers.append(nn.LeakyReLU()) 15 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim)) 16 | if use_sigmoid: 17 | layers.append(nn.Sigmoid()) 18 | self.main = nn.Sequential(*layers) 19 | 20 | def forward(self, x): 21 | x = self.main(x.squeeze(-1).squeeze(-1)) 22 | x = torch.nn.functional.softmax(x, dim=1) 23 | return x 24 | 25 | 26 | class BigGANDecoderWrapper(nn.Module): 27 | """Wraps a BigGAN into our autoencoding framework""" 28 | def __init__(self, z_dim, in_size=128, use_actnorm_in_dec=False, extra_z_dims=list()): 29 | super().__init__() 30 | self.z_dim = z_dim 31 | class_embedding_dim = 1000 32 | self.extra_z_dims = extra_z_dims 33 | self.map_to_class_embedding = ClassUp(z_dim, depth=2, hidden_dim=2*class_embedding_dim, 34 | use_sigmoid=False, out_dim=class_embedding_dim) 35 | self.decoder = load_variable_latsize_generator(in_size, z_dim, 36 | use_actnorm=use_actnorm_in_dec, 37 | n_class=class_embedding_dim, 38 | extra_z_dims=self.extra_z_dims) 39 | 40 | def forward(self, x, labels=None): 41 | emb = self.map_to_class_embedding(x[:,:self.z_dim,...]) 42 | x = self.decoder(x, emb) 43 | return x -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import models 5 | 6 | from perturbnet.net2net.modules.autoencoder.basic import ActNorm, DenseEncoderLayer 7 | 8 | 9 | class ResnetEncoder(nn.Module): 10 | def __init__(self, z_dim, in_size, in_channels=3, 11 | pretrained=False, type="resnet50", 12 | double_z=True, pre_process=True, 13 | ): 14 | super().__init__() 15 | __possible_resnets = { 16 | 'resnet18': models.resnet18, 17 | 'resnet34': models.resnet34, 18 | 'resnet50': models.resnet50, 19 | 'resnet101': models.resnet101 20 | } 21 | self.use_preprocess = pre_process 22 | self.in_channels = in_channels 23 | norm_layer = ActNorm 24 | self.z_dim = z_dim 25 | self.model = __possible_resnets[type](pretrained=pretrained, norm_layer=norm_layer) 26 | 27 | self.image_transform = torchvision.transforms.Compose( 28 | [torchvision.transforms.Lambda(self.normscale)] 29 | ) 30 | 31 | size_pre_fc = self.get_spatial_size(in_size) 32 | assert size_pre_fc[2]==size_pre_fc[3], 'Output spatial size is not quadratic' 33 | spatial_size = size_pre_fc[2] 34 | num_channels_pre_fc = size_pre_fc[1] 35 | # replace last fc 36 | self.model.fc = DenseEncoderLayer(0, 37 | spatial_size=spatial_size, 38 | out_size=2*z_dim if double_z else z_dim, 39 | in_channels=num_channels_pre_fc) 40 | if self.in_channels != 3: 41 | self.model.in_ch_match = nn.Conv2d(self.in_channels, 3, 3, 1) 42 | 43 | def forward(self, x): 44 | if self.use_preprocess: 45 | x = self.pre_process(x) 46 | if self.in_channels != 3: 47 | assert not self.use_preprocess 48 | x = self.model.in_ch_match(x) 49 | features = self.features(x) 50 | encoding = self.model.fc(features) 51 | return encoding 52 | 53 | def rescale(self, x): 54 | return 0.5 * (x + 1) 55 | 56 | def normscale(self, image): 57 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 58 | return torch.stack([normalize(self.rescale(x)) for x in image]) 59 | 60 | def features(self, x): 61 | if self.use_preprocess: 62 | x = self.pre_process(x) 63 | x = self.model.conv1(x) 64 | x = self.model.bn1(x) 65 | x = self.model.relu(x) 66 | x = self.model.maxpool(x) 67 | x = self.model.layer1(x) 68 | x = self.model.layer2(x) 69 | x = self.model.layer3(x) 70 | x = self.model.layer4(x) 71 | x = self.model.avgpool(x) 72 | return x 73 | 74 | def post_features(self, x): 75 | x = self.model.fc(x) 76 | return x 77 | 78 | def pre_process(self, x): 79 | x = self.image_transform(x) 80 | return x 81 | 82 | def get_spatial_size(self, ipt_size): 83 | x = torch.randn(1, 3, ipt_size, ipt_size) 84 | return self.features(x).size() 85 | 86 | @property 87 | def mean(self): 88 | return [0.485, 0.456, 0.406] 89 | 90 | @property 91 | def std(self): 92 | return [0.229, 0.224, 0.225] 93 | 94 | @property 95 | def input_size(self): 96 | return [3, 224, 224] 97 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from perturbnet.net2net.modules.autoencoder.lpips import LPIPS # LPIPS loss 6 | from perturbnet.net2net.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | def adopt_weight(weight, global_step, threshold=0, value=0.): 10 | if global_step < threshold: 11 | weight = value 12 | return weight 13 | 14 | 15 | def hinge_d_loss(logits_real, logits_fake): 16 | loss_real = torch.mean(F.relu(1. - logits_real)) 17 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | 22 | def vanilla_d_loss(logits_real, logits_fake): 23 | d_loss = 0.5 * ( 24 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 25 | torch.mean(torch.nn.functional.softplus(logits_fake))) 26 | return d_loss 27 | 28 | 29 | class LPIPSWithDiscriminator(nn.Module): 30 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 31 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 32 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 33 | evo_disc=False, disc_loss="hinge"): 34 | 35 | super().__init__() 36 | assert disc_loss in ["hinge", "vanilla"] 37 | self.kl_weight = kl_weight 38 | self.pixel_weight = pixelloss_weight 39 | self.perceptual_loss = LPIPS().eval() 40 | self.perceptual_weight = perceptual_weight 41 | # output log variance 42 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 43 | 44 | if evo_disc: 45 | self.discriminator = NLayerDiscriminatorEvoNorm(input_nc=disc_in_channels, 46 | n_layers=disc_num_layers 47 | ).apply(weights_init) 48 | else: 49 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 50 | n_layers=disc_num_layers, 51 | use_actnorm=use_actnorm 52 | ).apply(weights_init) 53 | self.discriminator_iter_start = disc_start 54 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 55 | self.disc_factor = disc_factor 56 | self.discriminator_weight = disc_weight 57 | self.disc_conditional = disc_conditional 58 | 59 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 60 | if last_layer is not None: 61 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 62 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 63 | else: 64 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 65 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 66 | 67 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 68 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 69 | d_weight = d_weight * self.discriminator_weight 70 | return d_weight 71 | 72 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 73 | global_step, last_layer=None, cond=None, split="train", 74 | side_outputs=None, weights=None): 75 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 76 | if self.perceptual_weight > 0: 77 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 78 | rec_loss = rec_loss + self.perceptual_weight * p_loss 79 | 80 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 81 | weighted_nll_loss = nll_loss 82 | if weights is not None: 83 | weighted_nll_loss = weights*nll_loss 84 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 85 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 86 | kl_loss = posteriors.kl() 87 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | if self.disc_factor > 0.0: 101 | try: 102 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 103 | except RuntimeError: 104 | assert not self.training 105 | d_weight = torch.tensor(0.0) 106 | else: 107 | d_weight = torch.tensor(0.0) 108 | 109 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 110 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 111 | 112 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 113 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 114 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 115 | "{}/d_weight".format(split): d_weight.detach(), 116 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 117 | "{}/g_loss".format(split): g_loss.detach().mean(), 118 | } 119 | return loss, log 120 | 121 | if optimizer_idx == 1: 122 | # second pass for discriminator update 123 | if cond is None: 124 | logits_real = self.discriminator(inputs.contiguous().detach()) 125 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 126 | else: 127 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 128 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 129 | 130 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 131 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 132 | 133 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 134 | "{}/logits_real".format(split): logits_real.detach().mean(), 135 | "{}/logits_fake".format(split): logits_fake.detach().mean() 136 | } 137 | return d_loss, log 138 | 139 | class DummyLoss: 140 | pass -------------------------------------------------------------------------------- /perturbnet/net2net/modules/autoencoder/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from perturbnet.net2net.ckpt_util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "net2net/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name is not "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/captions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/captions/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/captions/model.py: -------------------------------------------------------------------------------- 1 | """Code is based on https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning""" 2 | 3 | import os, sys 4 | import json 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torchvision 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | 14 | from perturbnet.net2net.ckpt_util import get_ckpt_path 15 | 16 | #import warnings 17 | #warnings.filterwarnings("ignore") 18 | 19 | from net2net.modules.captions.models import Encoder, DecoderWithAttention 20 | 21 | 22 | rescale = lambda x: 0.5*(x+1) 23 | 24 | 25 | def imresize(img, size): 26 | return np.array(Image.fromarray(img).resize(size)) 27 | 28 | 29 | class Img2Text(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | model_path = get_ckpt_path("coco_captioner", "net2net/modules/captions") 33 | word_map_path = "data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json" 34 | 35 | # Load word map (word2ix) 36 | with open(word_map_path, 'r') as j: 37 | word_map = json.load(j) 38 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 39 | self.word_map = word_map 40 | self.rev_word_map = rev_word_map 41 | 42 | checkpoint = torch.load(model_path) 43 | 44 | self.encoder = Encoder() 45 | self.decoder = DecoderWithAttention(embed_dim=512, decoder_dim=512, attention_dim=512, vocab_size=9490) 46 | missing, unexpected = self.load_state_dict(checkpoint, strict=False) 47 | if len(missing) > 0: 48 | print(f"Missing keys in state-dict: {missing}") 49 | if len(unexpected) > 0: 50 | print(f"Unexpected keys in state-dict: {unexpected}") 51 | self.encoder.eval() 52 | self.decoder.eval() 53 | 54 | resize = transforms.Lambda(lambda image: F.interpolate(image, size=(256, 256), mode="bilinear")) 55 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 56 | norm = torchvision.transforms.Lambda(lambda image: torch.stack([normalize(rescale(x)) for x in image])) 57 | self.img_transform = transforms.Compose([resize, norm]) 58 | self.device = "cuda" 59 | 60 | def _pre_process(self, x): 61 | x = self.img_transform(x) 62 | return x 63 | 64 | @property 65 | def mean(self): 66 | return [0.485, 0.456, 0.406] 67 | 68 | @property 69 | def std(self): 70 | return [0.229, 0.224, 0.225] 71 | 72 | def forward(self, x): 73 | captions = list() 74 | for subx in x: 75 | subx = subx.unsqueeze(0) 76 | captions.append(self.make_single_caption(subx)) 77 | return captions 78 | 79 | def make_single_caption(self, x): 80 | seq = self.caption_image_beam_search(x)[0][0] 81 | words = [self.rev_word_map[ind] for ind in seq] 82 | words = words[:50] 83 | #if len(words) > 50: 84 | # return np.array(['']) 85 | text = '' 86 | for word in words: 87 | text += word + ' ' 88 | return text 89 | 90 | def caption_image_beam_search(self, image, beam_size=3): 91 | """ 92 | Reads a batch of images and captions each of it with beam search. 93 | :param image: batch of pytorch images 94 | :param beam_size: number of sequences to consider at each decode-step 95 | :return: caption, weights for visualization 96 | """ 97 | 98 | k = beam_size 99 | vocab_size = len(self.word_map) 100 | 101 | # Encode 102 | # image is a batch of images 103 | encoder_out_ = self.encoder(image) # (b, enc_image_size, enc_image_size, encoder_dim) 104 | enc_image_size = encoder_out_.size(1) 105 | encoder_dim = encoder_out_.size(3) 106 | batch_size = encoder_out_.size(0) 107 | 108 | # Flatten encoding 109 | encoder_out_ = encoder_out_.view(batch_size, -1, encoder_dim) # (1, num_pixels, encoder_dim) 110 | num_pixels = encoder_out_.size(1) 111 | 112 | sequences = list() 113 | alphas_ = list() 114 | # We'll treat the problem as having a batch size of k per example 115 | for single_example in encoder_out_: 116 | single_example = single_example[None, ...] 117 | encoder_out = single_example.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 118 | 119 | # Tensor to store top k previous words at each step; now they're just 120 | k_prev_words = torch.LongTensor([[self.word_map['']]] * k).to(self.device) # (k, 1) 121 | 122 | # Tensor to store top k sequences; now they're just 123 | seqs = k_prev_words # (k, 1) 124 | 125 | # Tensor to store top k sequences' scores; now they're just 0 126 | top_k_scores = torch.zeros(k, 1).to(self.device) # (k, 1) 127 | 128 | # Tensor to store top k sequences' alphas; now they're just 1s 129 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(self.device) # (k, 1, enc_image_size, enc_image_size) 130 | 131 | # Lists to store completed sequences, their alphas and scores 132 | complete_seqs = list() 133 | complete_seqs_alpha = list() 134 | complete_seqs_scores = list() 135 | 136 | # Start decoding 137 | step = 1 138 | h, c = self.decoder.init_hidden_state(encoder_out) 139 | 140 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 141 | while True: 142 | embeddings = self.decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 143 | awe, alpha = self.decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 144 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 145 | gate = self.decoder.sigmoid(self.decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 146 | awe = gate * awe 147 | h, c = self.decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 148 | scores = self.decoder.fc(h) # (s, vocab_size) 149 | scores = F.log_softmax(scores, dim=1) 150 | # Add 151 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 152 | 153 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 154 | if step == 1: 155 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 156 | else: 157 | # Unroll and find top scores, and their unrolled indices 158 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 159 | 160 | # Convert unrolled indices to actual indices of scores 161 | prev_word_inds = top_k_words // vocab_size # (s) 162 | next_word_inds = top_k_words % vocab_size # (s) 163 | 164 | # Add new words to sequences, alphas 165 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 166 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 167 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 168 | 169 | # Which sequences are incomplete (didn't reach )? 170 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 171 | next_word != self.word_map['']] 172 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 173 | 174 | # Set aside complete sequences 175 | if len(complete_inds) > 0: 176 | complete_seqs.extend(seqs[complete_inds].tolist()) 177 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 178 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 179 | k -= len(complete_inds) # reduce beam length accordingly 180 | 181 | # Proceed with incomplete sequences 182 | if k == 0: 183 | break 184 | seqs = seqs[incomplete_inds] 185 | seqs_alpha = seqs_alpha[incomplete_inds] 186 | h = h[prev_word_inds[incomplete_inds]] 187 | c = c[prev_word_inds[incomplete_inds]] 188 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 189 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 190 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 191 | 192 | # Break if things have been going on too long 193 | if step > 50: 194 | break 195 | step += 1 196 | 197 | try: 198 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 199 | seq = complete_seqs[i] 200 | alphas = complete_seqs_alpha[i] 201 | except ValueError: 202 | print("Catching an empty sequence.") 203 | try: 204 | len_ = len(sequences[-1]) 205 | seq = [0]*len_ 206 | alphas = None 207 | except: 208 | seq = [0]*9 209 | alphas = None 210 | 211 | sequences.append(seq) 212 | alphas_.append(alphas) 213 | 214 | return sequences, alphas_ 215 | 216 | def visualize_text(self, root, images, sequences, n_row=5, img_name='examples'): 217 | """ 218 | plot the text corresponding to the given images in a matplotlib figure. 219 | images are a batch of pytorch images 220 | """ 221 | 222 | n_img = images.size(0) 223 | n_col = max(n_img // n_row + 1, 2) 224 | 225 | fig, ax = plt.subplots(n_row, n_col) 226 | 227 | i = 0 228 | j = 0 229 | for image, seq in zip(images, sequences): 230 | if i == n_row: 231 | i = 0 232 | j += 1 233 | image = image.cpu().numpy().transpose(1, 2, 0) 234 | image = 255*(0.5*(image+1)) 235 | image = Image.fromarray(image.astype('uint8')) 236 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 237 | words = [self.rev_word_map[ind] for ind in seq] 238 | if len(words) > 50: 239 | return 240 | text = '' 241 | for word in words: 242 | text += word + ' ' 243 | 244 | ax[i, j].text(0, 1, '%s' % (text), color='black', backgroundcolor='white', fontsize=12) 245 | ax[i, j].imshow(image) 246 | ax[i, j].axis('off') 247 | 248 | plt.savefig(os.path.join(root, img_name + '.png')) 249 | 250 | 251 | if __name__ == '__main__': 252 | model = Img2Text() 253 | print("done.") 254 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/captions/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class Encoder(nn.Module): 9 | """ 10 | Encoder. 11 | """ 12 | 13 | def __init__(self, encoded_image_size=14): 14 | super(Encoder, self).__init__() 15 | self.enc_image_size = encoded_image_size 16 | 17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 18 | 19 | # Remove linear and pool layers (since we're not doing classification) 20 | modules = list(resnet.children())[:-2] 21 | self.resnet = nn.Sequential(*modules) 22 | 23 | # Resize image to fixed size to allow input images of variable size 24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 25 | 26 | self.fine_tune() 27 | 28 | def forward(self, images): 29 | """ 30 | Forward propagation. 31 | 32 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 33 | :return: encoded images 34 | """ 35 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32) 36 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size) 37 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048) 38 | return out 39 | 40 | def fine_tune(self, fine_tune=True): 41 | """ 42 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 43 | 44 | :param fine_tune: Allow? 45 | """ 46 | for p in self.resnet.parameters(): 47 | p.requires_grad = False 48 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 49 | for c in list(self.resnet.children())[5:]: 50 | for p in c.parameters(): 51 | p.requires_grad = fine_tune 52 | 53 | 54 | class Attention(nn.Module): 55 | """ 56 | Attention Network. 57 | """ 58 | 59 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 60 | """ 61 | :param encoder_dim: feature size of encoded images 62 | :param decoder_dim: size of decoder's RNN 63 | :param attention_dim: size of the attention network 64 | """ 65 | super(Attention, self).__init__() 66 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 67 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 68 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 69 | self.relu = nn.ReLU() 70 | #self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 71 | 72 | def forward(self, encoder_out, decoder_hidden): 73 | """ 74 | Forward propagation. 75 | 76 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 77 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 78 | :return: attention weighted encoding, weights 79 | """ 80 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 81 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 82 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 83 | #alpha = self.softmax(att) # (batch_size, num_pixels) 84 | alpha = torch.nn.functional.softmax(att, dim=1) 85 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 86 | 87 | return attention_weighted_encoding, alpha 88 | 89 | 90 | class DecoderWithAttention(nn.Module): 91 | """ 92 | Decoder. 93 | """ 94 | 95 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): 96 | """ 97 | :param attention_dim: size of attention network 98 | :param embed_dim: embedding size 99 | :param decoder_dim: size of decoder's RNN 100 | :param vocab_size: size of vocabulary 101 | :param encoder_dim: feature size of encoded images 102 | :param dropout: dropout 103 | """ 104 | super(DecoderWithAttention, self).__init__() 105 | 106 | self.encoder_dim = encoder_dim 107 | self.attention_dim = attention_dim 108 | self.embed_dim = embed_dim 109 | self.decoder_dim = decoder_dim 110 | self.vocab_size = vocab_size 111 | self.dropout = dropout 112 | 113 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 114 | 115 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 116 | self.dropout = nn.Dropout(p=self.dropout) 117 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 118 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 119 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 120 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 121 | self.sigmoid = nn.Sigmoid() 122 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 123 | self.init_weights() # initialize some layers with the uniform distribution 124 | 125 | def init_weights(self): 126 | """ 127 | Initializes some parameters with values from the uniform distribution, for easier convergence. 128 | """ 129 | self.embedding.weight.data.uniform_(-0.1, 0.1) 130 | self.fc.bias.data.fill_(0) 131 | self.fc.weight.data.uniform_(-0.1, 0.1) 132 | 133 | def load_pretrained_embeddings(self, embeddings): 134 | """ 135 | Loads embedding layer with pre-trained embeddings. 136 | 137 | :param embeddings: pre-trained embeddings 138 | """ 139 | self.embedding.weight = nn.Parameter(embeddings) 140 | 141 | def fine_tune_embeddings(self, fine_tune=True): 142 | """ 143 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 144 | 145 | :param fine_tune: Allow? 146 | """ 147 | for p in self.embedding.parameters(): 148 | p.requires_grad = fine_tune 149 | 150 | def init_hidden_state(self, encoder_out): 151 | """ 152 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 153 | 154 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 155 | :return: hidden state, cell state 156 | """ 157 | mean_encoder_out = encoder_out.mean(dim=1) 158 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 159 | c = self.init_c(mean_encoder_out) 160 | return h, c 161 | 162 | def forward(self, encoder_out, encoded_captions, caption_lengths): 163 | """ 164 | Forward propagation. 165 | 166 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 167 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 168 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 169 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 170 | """ 171 | 172 | batch_size = encoder_out.size(0) 173 | encoder_dim = encoder_out.size(-1) 174 | vocab_size = self.vocab_size 175 | 176 | # Flatten image 177 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 178 | num_pixels = encoder_out.size(1) 179 | 180 | # Sort input data by decreasing lengths; why? apparent below 181 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 182 | encoder_out = encoder_out[sort_ind] 183 | encoded_captions = encoded_captions[sort_ind] 184 | 185 | # Embedding 186 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 187 | 188 | # Initialize LSTM state 189 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 190 | 191 | # We won't decode at the position, since we've finished generating as soon as we generate 192 | # So, decoding lengths are actual lengths - 1 193 | decode_lengths = (caption_lengths - 1).tolist() 194 | 195 | # Create tensors to hold word predicion scores and alphas 196 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 197 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 198 | 199 | # At each time-step, decode by 200 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 201 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 202 | for t in range(max(decode_lengths)): 203 | batch_size_t = sum([l > t for l in decode_lengths]) 204 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 205 | h[:batch_size_t]) 206 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 207 | attention_weighted_encoding = gate * attention_weighted_encoding 208 | h, c = self.decode_step( 209 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 210 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 211 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 212 | predictions[:batch_size_t, t, :] = preds 213 | alphas[:batch_size_t, t, :] = alpha 214 | 215 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind 216 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/discriminator/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from perturbnet.net2net.modules.autoencoder.basic import ActNorm 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class NLayerDiscriminator(nn.Module): 17 | """Defines a PatchGAN discriminator as in Pix2Pix 18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 19 | """ 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm2d 31 | else: 32 | norm_layer = ActNorm 33 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 34 | use_bias = norm_layer.func != nn.BatchNorm2d 35 | else: 36 | use_bias = norm_layer != nn.BatchNorm2d 37 | 38 | kw = 4 39 | padw = 1 40 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 41 | nf_mult = 1 42 | nf_mult_prev = 1 43 | for n in range(1, n_layers): # gradually increase the number of filters 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2 ** n_layers, 8) 54 | sequence += [ 55 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 56 | norm_layer(ndf * nf_mult), 57 | nn.LeakyReLU(0.2, True) 58 | ] 59 | 60 | sequence += [ 61 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 62 | self.main = nn.Sequential(*sequence) 63 | 64 | def forward(self, input): 65 | """Standard forward.""" 66 | return self.main(input) 67 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/distributions/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/distributions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/distributions/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=[1, 2, 3]) 60 | 61 | def mode(self): 62 | return self.mean 63 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/facenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/facenet/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/facenet/inception_resnet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import requests 5 | from requests.adapters import HTTPAdapter 6 | import os 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | 11 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 12 | super().__init__() 13 | self.conv = nn.Conv2d( 14 | in_planes, out_planes, 15 | kernel_size=kernel_size, stride=stride, 16 | padding=padding, bias=False 17 | ) # verify bias false 18 | self.bn = nn.BatchNorm2d( 19 | out_planes, 20 | eps=0.001, # value found in tensorflow 21 | momentum=0.1, # default pytorch value 22 | affine=True 23 | ) 24 | self.relu = nn.ReLU(inplace=False) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | x = self.relu(x) 30 | return x 31 | 32 | 33 | class Block35(nn.Module): 34 | 35 | def __init__(self, scale=1.0): 36 | super().__init__() 37 | 38 | self.scale = scale 39 | 40 | self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1) 41 | 42 | self.branch1 = nn.Sequential( 43 | BasicConv2d(256, 32, kernel_size=1, stride=1), 44 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 45 | ) 46 | 47 | self.branch2 = nn.Sequential( 48 | BasicConv2d(256, 32, kernel_size=1, stride=1), 49 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), 50 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 51 | ) 52 | 53 | self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1) 54 | self.relu = nn.ReLU(inplace=False) 55 | 56 | def forward(self, x): 57 | x0 = self.branch0(x) 58 | x1 = self.branch1(x) 59 | x2 = self.branch2(x) 60 | out = torch.cat((x0, x1, x2), 1) 61 | out = self.conv2d(out) 62 | out = out * self.scale + x 63 | out = self.relu(out) 64 | return out 65 | 66 | 67 | class Block17(nn.Module): 68 | 69 | def __init__(self, scale=1.0): 70 | super().__init__() 71 | 72 | self.scale = scale 73 | 74 | self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1) 75 | 76 | self.branch1 = nn.Sequential( 77 | BasicConv2d(896, 128, kernel_size=1, stride=1), 78 | BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)), 79 | BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)) 80 | ) 81 | 82 | self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1) 83 | self.relu = nn.ReLU(inplace=False) 84 | 85 | def forward(self, x): 86 | x0 = self.branch0(x) 87 | x1 = self.branch1(x) 88 | out = torch.cat((x0, x1), 1) 89 | out = self.conv2d(out) 90 | out = out * self.scale + x 91 | out = self.relu(out) 92 | return out 93 | 94 | 95 | class Block8(nn.Module): 96 | 97 | def __init__(self, scale=1.0, noReLU=False): 98 | super().__init__() 99 | 100 | self.scale = scale 101 | self.noReLU = noReLU 102 | 103 | self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1) 104 | 105 | self.branch1 = nn.Sequential( 106 | BasicConv2d(1792, 192, kernel_size=1, stride=1), 107 | BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)), 108 | BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)) 109 | ) 110 | 111 | self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1) 112 | if not self.noReLU: 113 | self.relu = nn.ReLU(inplace=False) 114 | 115 | def forward(self, x): 116 | x0 = self.branch0(x) 117 | x1 = self.branch1(x) 118 | out = torch.cat((x0, x1), 1) 119 | out = self.conv2d(out) 120 | out = out * self.scale + x 121 | if not self.noReLU: 122 | out = self.relu(out) 123 | return out 124 | 125 | 126 | class Mixed_6a(nn.Module): 127 | 128 | def __init__(self): 129 | super().__init__() 130 | 131 | self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2) 132 | 133 | self.branch1 = nn.Sequential( 134 | BasicConv2d(256, 192, kernel_size=1, stride=1), 135 | BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1), 136 | BasicConv2d(192, 256, kernel_size=3, stride=2) 137 | ) 138 | 139 | self.branch2 = nn.MaxPool2d(3, stride=2) 140 | 141 | def forward(self, x): 142 | x0 = self.branch0(x) 143 | x1 = self.branch1(x) 144 | x2 = self.branch2(x) 145 | out = torch.cat((x0, x1, x2), 1) 146 | return out 147 | 148 | 149 | class Mixed_7a(nn.Module): 150 | 151 | def __init__(self): 152 | super().__init__() 153 | 154 | self.branch0 = nn.Sequential( 155 | BasicConv2d(896, 256, kernel_size=1, stride=1), 156 | BasicConv2d(256, 384, kernel_size=3, stride=2) 157 | ) 158 | 159 | self.branch1 = nn.Sequential( 160 | BasicConv2d(896, 256, kernel_size=1, stride=1), 161 | BasicConv2d(256, 256, kernel_size=3, stride=2) 162 | ) 163 | 164 | self.branch2 = nn.Sequential( 165 | BasicConv2d(896, 256, kernel_size=1, stride=1), 166 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 167 | BasicConv2d(256, 256, kernel_size=3, stride=2) 168 | ) 169 | 170 | self.branch3 = nn.MaxPool2d(3, stride=2) 171 | 172 | def forward(self, x): 173 | x0 = self.branch0(x) 174 | x1 = self.branch1(x) 175 | x2 = self.branch2(x) 176 | x3 = self.branch3(x) 177 | out = torch.cat((x0, x1, x2, x3), 1) 178 | return out 179 | 180 | 181 | class InceptionResnetV1(nn.Module): 182 | """Inception Resnet V1 model with optional loading of pretrained weights. 183 | 184 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 185 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 186 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 187 | redownloading. 188 | 189 | Keyword Arguments: 190 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 191 | (default: {None}) 192 | classify {bool} -- Whether the model should output classification probabilities or feature 193 | embeddings. (default: {False}) 194 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 195 | equal to that used for the pretrained model, the final linear layer will be randomly 196 | initialized. (default: {None}) 197 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 198 | """ 199 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): 200 | super().__init__() 201 | 202 | # Set simple attributes 203 | self.pretrained = pretrained 204 | self.classify = classify 205 | self.num_classes = num_classes 206 | 207 | if pretrained == 'vggface2': 208 | tmp_classes = 8631 209 | elif pretrained == 'casia-webface': 210 | tmp_classes = 10575 211 | elif pretrained is None and self.num_classes is None: 212 | raise Exception('At least one of "pretrained" or "num_classes" must be specified') 213 | else: 214 | tmp_classes = self.num_classes 215 | 216 | 217 | # Define layers 218 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 219 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 220 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 221 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 222 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 223 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 224 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 225 | self.repeat_1 = nn.Sequential( 226 | Block35(scale=0.17), 227 | Block35(scale=0.17), 228 | Block35(scale=0.17), 229 | Block35(scale=0.17), 230 | Block35(scale=0.17), 231 | ) 232 | self.mixed_6a = Mixed_6a() 233 | self.repeat_2 = nn.Sequential( 234 | Block17(scale=0.10), 235 | Block17(scale=0.10), 236 | Block17(scale=0.10), 237 | Block17(scale=0.10), 238 | Block17(scale=0.10), 239 | Block17(scale=0.10), 240 | Block17(scale=0.10), 241 | Block17(scale=0.10), 242 | Block17(scale=0.10), 243 | Block17(scale=0.10), 244 | ) 245 | self.mixed_7a = Mixed_7a() 246 | self.repeat_3 = nn.Sequential( 247 | Block8(scale=0.20), 248 | Block8(scale=0.20), 249 | Block8(scale=0.20), 250 | Block8(scale=0.20), 251 | Block8(scale=0.20), 252 | ) 253 | self.block8 = Block8(noReLU=True) 254 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 255 | self.dropout = nn.Dropout(dropout_prob) 256 | self.last_linear = nn.Linear(1792, 512, bias=False) 257 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 258 | self.logits = nn.Linear(512, tmp_classes) 259 | 260 | if pretrained is not None: 261 | load_weights(self, pretrained) 262 | 263 | if self.num_classes is not None: 264 | self.logits = nn.Linear(512, self.num_classes) 265 | 266 | self.device = torch.device('cpu') 267 | if device is not None: 268 | self.device = device 269 | self.to(device) 270 | 271 | def forward(self, x): 272 | """Calculate embeddings or probabilities given a batch of input image tensors. 273 | 274 | Arguments: 275 | x {torch.tensor} -- Batch of image tensors representing faces. 276 | 277 | Returns: 278 | torch.tensor -- Batch of embeddings or softmax probabilities. 279 | """ 280 | x = self.conv2d_1a(x) 281 | x = self.conv2d_2a(x) 282 | x = self.conv2d_2b(x) 283 | x = self.maxpool_3a(x) 284 | x = self.conv2d_3b(x) 285 | x = self.conv2d_4a(x) 286 | x = self.conv2d_4b(x) 287 | x = self.repeat_1(x) 288 | x = self.mixed_6a(x) 289 | x = self.repeat_2(x) 290 | x = self.mixed_7a(x) 291 | x = self.repeat_3(x) 292 | x = self.block8(x) 293 | x = self.avgpool_1a(x) 294 | x = self.dropout(x) 295 | x = self.last_linear(x.view(x.shape[0], -1)) 296 | x = self.last_bn(x) 297 | x = F.normalize(x, p=2, dim=1) 298 | if self.classify: 299 | x = self.logits(x) 300 | return x 301 | 302 | 303 | def load_weights(mdl, name): 304 | """Download pretrained state_dict and load into model. 305 | 306 | Arguments: 307 | mdl {torch.nn.Module} -- Pytorch model. 308 | name {str} -- Name of dataset that was used to generate pretrained state_dict. 309 | 310 | Raises: 311 | ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'. 312 | """ 313 | if name == 'vggface2': 314 | features_path = 'https://drive.google.com/uc?export=download&id=1cWLH_hPns8kSfMz9kKl9PsG5aNV2VSMn' 315 | logits_path = 'https://drive.google.com/uc?export=download&id=1mAie3nzZeno9UIzFXvmVZrDG3kwML46X' 316 | elif name == 'casia-webface': 317 | features_path = 'https://drive.google.com/uc?export=download&id=1LSHHee_IQj5W3vjBcRyVaALv4py1XaGy' 318 | logits_path = 'https://drive.google.com/uc?export=download&id=1QrhPgn1bGlDxAil2uc07ctunCQoDnCzT' 319 | else: 320 | raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"') 321 | 322 | model_dir = os.path.join(get_torch_home(), 'checkpoints') 323 | os.makedirs(model_dir, exist_ok=True) 324 | 325 | state_dict = {} 326 | for i, path in enumerate([features_path, logits_path]): 327 | cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:])) 328 | if not os.path.exists(cached_file): 329 | print('Downloading parameters ({}/2)'.format(i+1)) 330 | s = requests.Session() 331 | s.mount('https://', HTTPAdapter(max_retries=10)) 332 | r = s.get(path, allow_redirects=True) 333 | with open(cached_file, 'wb') as f: 334 | f.write(r.content) 335 | state_dict.update(torch.load(cached_file)) 336 | 337 | mdl.load_state_dict(state_dict) 338 | 339 | 340 | def get_torch_home(): 341 | torch_home = os.path.expanduser( 342 | os.getenv( 343 | 'TORCH_HOME', 344 | os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch') 345 | ) 346 | ) 347 | return torch_home 348 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/facenet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from perturbnet.net2net.modules.facenet.inception_resnet_v1 import InceptionResnetV1 6 | 7 | 8 | """FaceNet adopted from https://github.com/timesler/facenet-pytorch""" 9 | 10 | 11 | class FaceNet(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | # InceptionResnetV1 has a bottleneck of size 512 15 | self.net = InceptionResnetV1(pretrained='vggface2').eval() 16 | 17 | def _pre_process(self, x): 18 | # TODO: neccessary for InceptionResnetV1? 19 | # seems like mtcnn (multi-task cnn) preprocessing is neccessary, but not 100% sure 20 | return x 21 | 22 | def forward(self, x, return_logits=False): 23 | # output are logits of size 8631 or embeddings of size 512 24 | x = self._pre_process(x) 25 | emb = self.net(x) 26 | if return_logits: 27 | return self.net.logits(emb) 28 | return emb 29 | 30 | def encode(self, x): 31 | return self(x) 32 | 33 | def return_features(self, x): 34 | """ returned features have the following dimensions: 35 | 36 | torch.Size([11, 3, 128, 128]), x 49152 37 | torch.Size([11, 192, 28, 28]), x 150528 38 | torch.Size([11, 896, 6, 6]), x 32256 39 | torch.Size([11, 1792, 1, 1]), x 1792 40 | torch.Size([11, 512]) x 512 41 | logits (8xxx) x 8xxx 42 | """ 43 | 44 | x = self._pre_process(x) 45 | features = [x] # this 46 | x = self.net.conv2d_1a(x) 47 | x = self.net.conv2d_2a(x) 48 | x = self.net.conv2d_2b(x) 49 | x = self.net.maxpool_3a(x) 50 | x = self.net.conv2d_3b(x) 51 | x = self.net.conv2d_4a(x) 52 | features.append(x) # this 53 | x = self.net.conv2d_4b(x) 54 | x = self.net.repeat_1(x) 55 | x = self.net.mixed_6a(x) 56 | features.append(x) # this 57 | x = self.net.repeat_2(x) 58 | x = self.net.mixed_7a(x) 59 | x = self.net.repeat_3(x) 60 | x = self.net.block8(x) 61 | x = self.net.avgpool_1a(x) 62 | features.append(x) # this 63 | x = self.net.dropout(x) 64 | x = self.net.last_linear(x.view(x.shape[0], -1)) 65 | x = self.net.last_bn(x) 66 | emb = F.normalize(x, p=2, dim=1) # the final embeddings 67 | features.append(emb[..., None, None]) # need extra dimensions for flow later 68 | features.append(self.net.logits(emb).unsqueeze(-1).unsqueeze(-1)) 69 | return features # has 6 elements as of now 70 | 71 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/blocks.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/flatflow.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/flatflow.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/flatflow.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/flatflow.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/flow/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class NormalizingFlow(nn.Module): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__() 7 | 8 | def forward(self, *args, **kwargs): 9 | # return transformed, logdet 10 | raise NotImplementedError 11 | 12 | def reverse(self, *args, **kwargs): 13 | # return transformed_reverse 14 | raise NotImplementedError 15 | 16 | def sample(self, *args, **kwargs): 17 | # return sample 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/flatflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from perturbnet.net2net.modules.autoencoder.basic import ActNorm 6 | from perturbnet.net2net.modules.flow.blocks import UnconditionalFlatDoubleCouplingFlowBlock, PureAffineDoubleCouplingFlowBlock, \ 7 | ConditionalFlatDoubleCouplingFlowBlock 8 | from perturbnet.net2net.modules.flow.base import NormalizingFlow 9 | from perturbnet.net2net.modules.autoencoder.basic import FeatureLayer, DenseEncoderLayer 10 | from perturbnet.net2net.modules.flow.blocks import BasicFullyConnectedNet 11 | 12 | 13 | class UnconditionalFlatCouplingFlow(NormalizingFlow): 14 | """Flat, multiple blocks of ActNorm, DoubleAffineCoupling, Shuffle""" 15 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth): 16 | super().__init__() 17 | self.in_channels = in_channels 18 | self.mid_channels = hidden_dim 19 | self.num_blocks = hidden_depth 20 | self.n_flows = n_flows 21 | self.sub_layers = nn.ModuleList() 22 | 23 | for flow in range(self.n_flows): 24 | self.sub_layers.append(UnconditionalFlatDoubleCouplingFlowBlock( 25 | self.in_channels, self.mid_channels, 26 | self.num_blocks) 27 | ) 28 | 29 | def forward(self, x, reverse=False): 30 | if len(x.shape) == 2: 31 | x = x[:,:,None,None] 32 | self.last_outs = [] 33 | self.last_logdets = [] 34 | if not reverse: 35 | logdet = 0.0 36 | for i in range(self.n_flows): 37 | x, logdet_ = self.sub_layers[i](x) 38 | logdet = logdet + logdet_ 39 | self.last_outs.append(x) 40 | self.last_logdets.append(logdet) 41 | return x, logdet 42 | else: 43 | for i in reversed(range(self.n_flows)): 44 | x = self.sub_layers[i](x, reverse=True) 45 | return x 46 | 47 | def reverse(self, out): 48 | if len(out.shape) == 2: 49 | out = out[:,:,None,None] 50 | return self(out, reverse=True) 51 | 52 | def sample(self, num_samples, device="cpu"): 53 | zz = torch.randn(num_samples, self.in_channels, 1, 1).to(device) 54 | return self.reverse(zz) 55 | 56 | def get_last_layer(self): 57 | return getattr(self.sub_layers[-1].coupling.t[-1].main[-1], 'weight') 58 | 59 | 60 | class PureAffineFlatCouplingFlow(UnconditionalFlatCouplingFlow): 61 | """Flat, multiple blocks of DoubleAffineCoupling""" 62 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth): 63 | super().__init__(in_channels, n_flows, hidden_dim, hidden_depth) 64 | del self.sub_layers 65 | self.sub_layers = nn.ModuleList() 66 | for flow in range(self.n_flows): 67 | self.sub_layers.append(PureAffineDoubleCouplingFlowBlock( 68 | self.in_channels, self.mid_channels, 69 | self.num_blocks) 70 | ) 71 | 72 | 73 | class DenseEmbedder(nn.Module): 74 | """Supposed to map small-scale features (e.g. labels) to some given latent dim""" 75 | def __init__(self, in_dim, up_dim, depth=4, given_dims=None): 76 | super().__init__() 77 | self.net = nn.ModuleList() 78 | if given_dims is not None: 79 | assert given_dims[0] == in_dim 80 | assert given_dims[-1] == up_dim 81 | dims = given_dims 82 | else: 83 | dims = np.linspace(in_dim, up_dim, depth).astype(int) 84 | for l in range(len(dims)-2): 85 | self.net.append(nn.Conv2d(dims[l], dims[l + 1], 1)) 86 | self.net.append(ActNorm(dims[l + 1])) 87 | self.net.append(nn.LeakyReLU(0.2)) 88 | 89 | self.net.append(nn.Conv2d(dims[-2], dims[-1], 1)) 90 | 91 | def forward(self, x): 92 | for layer in self.net: 93 | x = layer(x) 94 | return x.squeeze(-1).squeeze(-1) 95 | 96 | 97 | class Embedder(nn.Module): 98 | """Embeds a 4-dim tensor onto dense latent code, much like the classic encoder.""" 99 | def __init__(self, in_spatial_size, in_channels, emb_dim, n_down=4): 100 | super().__init__() 101 | self.feature_layers = nn.ModuleList() 102 | norm = 'an' # hard coded yes 103 | bottleneck_size = in_spatial_size // 2**n_down 104 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm)) 105 | for scale in range(1, n_down): 106 | self.feature_layers.append(FeatureLayer(scale, norm=norm)) 107 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, emb_dim) 108 | if n_down == 1: 109 | # add some extra parameters to make model a little more powerful ? 110 | print(" Warning: Embedder for ConditionalTransformer has only one down-sampling step. You might want to " 111 | "increase its capacity.") 112 | 113 | def forward(self, input): 114 | h = input 115 | for layer in self.feature_layers: 116 | h = layer(h) 117 | h = self.dense_encode(h) 118 | return h.squeeze(-1).squeeze(-1) 119 | 120 | 121 | class ConditionalFlatCouplingFlow(nn.Module): 122 | """Flat version. Feeds an embedding into the flow in every block""" 123 | def __init__(self, in_channels, conditioning_dim, embedding_dim, hidden_dim, hidden_depth, 124 | n_flows, conditioning_option="none", activation='lrelu', 125 | conditioning_hidden_dim=256, conditioning_depth=2, conditioner_use_bn=False, 126 | conditioner_use_an=False): 127 | super().__init__() 128 | self.in_channels = in_channels 129 | self.cond_channels = embedding_dim 130 | self.mid_channels = hidden_dim 131 | self.num_blocks = hidden_depth 132 | self.n_flows = n_flows 133 | self.conditioning_option = conditioning_option 134 | # TODO: also for spatial inputs... 135 | if conditioner_use_bn: 136 | assert not conditioner_use_an, 'Can not use ActNorm and BatchNorm simultaneously in Embedder.' 137 | print("Note: Conditioning network uses batch-normalization. " 138 | "Make sure to train with a sufficiently large batch size") 139 | 140 | self.embedder = BasicFullyConnectedNet(dim=conditioning_dim, 141 | depth=conditioning_depth, 142 | out_dim=embedding_dim, 143 | hidden_dim=conditioning_hidden_dim, 144 | use_bn=conditioner_use_bn, 145 | use_an=conditioner_use_an) 146 | 147 | self.sub_layers = nn.ModuleList() 148 | if self.conditioning_option.lower() != "none": 149 | self.conditioning_layers = nn.ModuleList() 150 | for flow in range(self.n_flows): 151 | self.sub_layers.append(ConditionalFlatDoubleCouplingFlowBlock( 152 | self.in_channels, self.cond_channels, self.mid_channels, 153 | self.num_blocks, activation = activation) 154 | ) 155 | if self.conditioning_option.lower() != "none": 156 | self.conditioning_layers.append(nn.Conv2d(self.cond_channels, self.cond_channels, 1)) 157 | 158 | def forward(self, x, cond, reverse=False): 159 | hconds = list() 160 | if len(cond.shape) == 4: 161 | if cond.shape[2] == 1: 162 | assert cond.shape[3] == 1 163 | cond = cond.squeeze(-1).squeeze(-1) 164 | else: 165 | raise ValueError("Spatial conditionings not yet supported. TODO") 166 | embedding = self.embedder(cond.float()) 167 | hcond = embedding[:, :, None, None] 168 | for i in range(self.n_flows): 169 | if self.conditioning_option.lower() == "parallel": 170 | hcond = self.conditioning_layers[i](embedding) 171 | elif self.conditioning_option.lower() == "sequential": 172 | hcond = self.conditioning_layers[i](hcond) 173 | hconds.append(hcond) 174 | if not reverse: 175 | logdet = 0.0 176 | for i in range(self.n_flows): 177 | x, logdet_ = self.sub_layers[i](x, hconds[i]) 178 | logdet = logdet + logdet_ 179 | return x, logdet 180 | else: 181 | for i in reversed(range(self.n_flows)): 182 | x = self.sub_layers[i](x, hconds[i], reverse=True) 183 | return x 184 | 185 | def reverse(self, out, xcond): 186 | return self(out, xcond, reverse=True) 187 | 188 | def sample(self, xc): 189 | zz = torch.randn(xc.shape[0], self.in_channels, 1, 1).to(xc) 190 | return self.reverse(zz, xc) 191 | 192 | 193 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/flow/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def nll(sample): 6 | if len(sample.shape) == 2: 7 | sample = sample[:,:,None,None] 8 | return 0.5*torch.sum(torch.pow(sample, 2), dim=[1,2,3]) 9 | 10 | 11 | class NLL(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, sample, logdet, split = "train"): 16 | nll_loss = torch.mean(nll(sample)) 17 | assert len(logdet.shape) == 1 18 | nlogdet_loss = -torch.mean(logdet) 19 | loss = nll_loss + nlogdet_loss 20 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample))) 21 | log = {f"{split}/total_loss": loss, f"{split}/reference_nll_loss": reference_nll_loss, 22 | f"{split}/nlogdet_loss": nlogdet_loss, f"{split}/nll_loss": nll_loss, 23 | } 24 | return loss, log -------------------------------------------------------------------------------- /perturbnet/net2net/modules/gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/gan/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/gan/bigbigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import tensorflow.compat.v1 as tf 5 | tf.disable_v2_behavior() 6 | 7 | import tensorflow_hub as hub 8 | 9 | 10 | class BigBiGAN(object): 11 | def __init__(self, 12 | module_path='https://tfhub.dev/deepmind/bigbigan-resnet50/1', 13 | allow_growth=True): 14 | """Initialize a BigBiGAN from the given TF Hub module.""" 15 | self._module = hub.Module(module_path) 16 | 17 | # encode graph 18 | self.enc_ph = self.make_encoder_ph() 19 | self.z_sample = self.encode_graph(self.enc_ph) 20 | self.z_mean = self.encode_graph(self.enc_ph, return_all_features=True)['z_mean'] 21 | 22 | # decode graph 23 | self.gen_ph = self.make_generator_ph() 24 | self.gen_samples = self.generate_graph(self.gen_ph, upsample=True) 25 | 26 | # session 27 | init = tf.global_variables_initializer() 28 | gpu_options = tf.GPUOptions(allow_growth=allow_growth) 29 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 30 | self.sess.run(init) 31 | 32 | def generate_graph(self, z, upsample=False): 33 | """Run a batch of latents z through the generator to generate images. 34 | 35 | Args: 36 | z: A batch of 120D Gaussian latents, shape [N, 120]. 37 | 38 | Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range 39 | [-1, 1]. 40 | """ 41 | outputs = self._module(z, signature='generate', as_dict=True) 42 | return outputs['upsampled' if upsample else 'default'] 43 | 44 | def make_generator_ph(self): 45 | """Creates a tf.placeholder with the dtype & shape of generator inputs.""" 46 | info = self._module.get_input_info_dict('generate')['z'] 47 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) 48 | 49 | def encode_graph(self, x, return_all_features=False): 50 | """Run a batch of images x through the encoder. 51 | 52 | Args: 53 | x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range 54 | [-1, 1]. 55 | return_all_features: If True, return all features computed by the encoder. 56 | Otherwise (default) just return a sample z_hat. 57 | 58 | Returns: the sample z_hat of shape [N, 120] (or a dict of all features if 59 | return_all_features). 60 | """ 61 | outputs = self._module(x, signature='encode', as_dict=True) 62 | return outputs if return_all_features else outputs['z_sample'] 63 | 64 | def make_encoder_ph(self): 65 | """Creates a tf.placeholder with the dtype & shape of encoder inputs.""" 66 | info = self._module.get_input_info_dict('encode')['x'] 67 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) 68 | 69 | @torch.no_grad() 70 | def encode(self, x_torch): 71 | x_np = x_torch.detach().permute(0,2,3,1).cpu().numpy() 72 | feed_dict = {self.enc_ph: x_np} 73 | z = self.sess.run(self.z_sample, feed_dict=feed_dict) 74 | z_torch = torch.tensor(z).to(device=x_torch.device) 75 | return z_torch.unsqueeze(-1).unsqueeze(-1) 76 | 77 | @torch.no_grad() 78 | def decode(self, z_torch): 79 | z_np = z_torch.detach().squeeze(-1).squeeze(-1).cpu().numpy() 80 | feed_dict = {self.gen_ph: z_np} 81 | x = self.sess.run(self.gen_samples, feed_dict=feed_dict) 82 | x = x.transpose(0,3,1,2) 83 | x_torch = torch.tensor(x).to(device=z_torch.device) 84 | return x_torch 85 | 86 | def eval(self): 87 | # interface requirement 88 | return self 89 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/labels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/labels/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/labels/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Labelator(nn.Module): 8 | def __init__(self, num_classes, as_one_hot=True): 9 | super().__init__() 10 | self.num_classes = num_classes 11 | self.as_one_hot = as_one_hot 12 | 13 | def encode(self, x): 14 | if self.as_one_hot: 15 | x = self.make_one_hot(x) 16 | return x 17 | 18 | def other_label(self, given_label): 19 | # if only two classes are present, inverts them 20 | others = [] 21 | for l in given_label: 22 | other = int(np.random.choice(np.arange(self.num_classes))) 23 | while other == l: 24 | other = int(np.random.choice(np.arange(self.num_classes))) 25 | others.append(other) 26 | return torch.LongTensor(others) 27 | 28 | def make_one_hot(self, label): 29 | one_hot = F.one_hot(label, num_classes=self.num_classes) 30 | return one_hot 31 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/mlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/mlp/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/mlp/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class NeRFMLP(nn.Module): 8 | """basic MLP from the NERF paper""" 9 | def __init__(self, in_dim, out_dim, width=256): 10 | super().__init__() 11 | self.D = 8 12 | self.W = width # hidden dim 13 | self.skips = [4] 14 | self.n_in = in_dim 15 | self.out_dim = out_dim 16 | 17 | self.layers = nn.ModuleList() 18 | self.layers.append(nn.Linear(self.n_in, self.W)) 19 | for i in range(1, self.D): 20 | if i-1 in self.skips: 21 | nin = self.W + self.n_in 22 | else: 23 | nin = self.W 24 | self.layers.append(nn.Linear(nin, self.W)) 25 | self.out_layer = nn.Linear(self.W, self.out_dim) 26 | 27 | def forward(self, z): 28 | h = z 29 | for i in range(self.D): 30 | h = self.layers[i](h) 31 | h = F.relu(h) 32 | if i in self.skips: 33 | h = torch.cat([h,z], dim=1) 34 | h = self.out_layer(h) 35 | return h 36 | 37 | 38 | class SineLayer(nn.Module): 39 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0. 40 | 41 | # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 42 | # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 43 | # hyperparameter. 44 | 45 | # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 46 | # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5) 47 | 48 | def __init__(self, in_features, out_features, bias=True, 49 | is_first=False, omega_0=30): 50 | super().__init__() 51 | self.omega_0 = omega_0 52 | self.is_first = is_first 53 | self.in_features = in_features 54 | self.linear = nn.Linear(in_features, out_features, bias=bias) 55 | self.init_weights() 56 | 57 | def init_weights(self): 58 | with torch.no_grad(): 59 | if self.is_first: 60 | self.linear.weight.uniform_(-1 / self.in_features, 61 | 1 / self.in_features) 62 | else: 63 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 64 | np.sqrt(6 / self.in_features) / self.omega_0) 65 | 66 | def forward(self, input): 67 | return torch.sin(self.omega_0 * self.linear(input)) 68 | 69 | def forward_with_intermediate(self, input): 70 | # For visualization of activation distributions 71 | intermediate = self.omega_0 * self.linear(input) 72 | return torch.sin(intermediate), intermediate 73 | 74 | 75 | class Siren(nn.Module): 76 | def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 77 | first_omega_0=30, hidden_omega_0=30.): 78 | super().__init__() 79 | 80 | self.net = [] 81 | self.net.append(SineLayer(in_features, hidden_features, 82 | is_first=True, omega_0=first_omega_0)) 83 | 84 | for i in range(hidden_layers): 85 | self.net.append(SineLayer(hidden_features, hidden_features, 86 | is_first=False, omega_0=hidden_omega_0)) 87 | 88 | if outermost_linear: 89 | final_linear = nn.Linear(hidden_features, out_features) 90 | 91 | with torch.no_grad(): 92 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 93 | np.sqrt(6 / hidden_features) / hidden_omega_0) 94 | 95 | self.net.append(final_linear) 96 | else: 97 | self.net.append(SineLayer(hidden_features, out_features, 98 | is_first=False, omega_0=hidden_omega_0)) 99 | 100 | self.net = nn.Sequential(*self.net) 101 | 102 | def forward(self, coords): 103 | #coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input 104 | output = self.net(coords) 105 | #return output, coords 106 | return output 107 | 108 | def forward_with_activations(self, coords, retain_grad=False): 109 | '''Returns not only model output, but also intermediate activations. 110 | Only used for visualizing activations later!''' 111 | activations = OrderedDict() 112 | activation_count = 0 113 | x = coords.clone().detach().requires_grad_(True) 114 | activations['input'] = x 115 | for i, layer in enumerate(self.net): 116 | if isinstance(layer, SineLayer): 117 | x, intermed = layer.forward_with_intermediate(x) 118 | 119 | if retain_grad: 120 | x.retain_grad() 121 | intermed.retain_grad() 122 | 123 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed 124 | activation_count += 1 125 | else: 126 | x = layer(x) 127 | 128 | if retain_grad: 129 | x.retain_grad() 130 | 131 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x 132 | activation_count += 1 133 | return activations 134 | 135 | 136 | # And finally, differential operators that allow us to leverage autograd to 137 | # compute gradients, the laplacian, etc. 138 | 139 | def laplace(y, x): 140 | grad = gradient(y, x) 141 | return divergence(grad, x) 142 | 143 | 144 | def divergence(y, x): 145 | div = 0. 146 | for i in range(y.shape[-1]): 147 | div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1] 148 | return div 149 | 150 | 151 | def gradient(y, x, grad_outputs=None): 152 | if grad_outputs is None: 153 | grad_outputs = torch.ones_like(y) 154 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] 155 | return grad 156 | 157 | 158 | if __name__ == "__main__": 159 | siren = Siren(2, 64, 2, 2) 160 | x = torch.randn(11, 2) 161 | x.requires_grad = True 162 | y = siren(x).mean() 163 | grad1 = torch.autograd.grad(y, x)[0] 164 | print(grad1.shape) 165 | print("done.") -------------------------------------------------------------------------------- /perturbnet/net2net/modules/sbert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/net2net/modules/sbert/__init__.py -------------------------------------------------------------------------------- /perturbnet/net2net/modules/sbert/model.py: -------------------------------------------------------------------------------- 1 | # check out https://github.com/UKPLab/sentence-transformers, 2 | # list of pretrained models @ https://www.sbert.net/docs/pretrained_models.html 3 | 4 | from sentence_transformers import SentenceTransformer 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | 9 | class SentenceEmbedder(nn.Module): 10 | def __init__(self, version='bert-large-nli-stsb-mean-tokens'): 11 | super().__init__() 12 | np.set_printoptions(threshold=100) 13 | # Load Sentence model (based on BERT) from URL 14 | self.model = SentenceTransformer(version, device="cuda") 15 | self.model.eval() 16 | 17 | def forward(self, sentences): 18 | """sentences are expect to be a list of strings, e.g. 19 | sentences = ['This framework generates embeddings for each input sentence', 20 | 'Sentences are passed as a list of string.', 21 | 'The quick brown fox jumps over the lazy dog.' 22 | ] 23 | """ 24 | sentence_embeddings = self.model.encode(sentences, batch_size=len(sentences), show_progress_bar=False, 25 | convert_to_tensor=True) 26 | return sentence_embeddings.cuda() 27 | 28 | def encode(self, sentences): 29 | embeddings = self(sentences) 30 | return embeddings[:,:,None,None] 31 | 32 | 33 | if __name__ == '__main__': 34 | model = SentenceEmbedder(version='distilroberta-base-paraphrase-v1') 35 | sentences = ['This framework generates embeddings for each input sentence', 36 | 'Sentences are passed as a list of string.', 37 | 'The quick brown fox jumps over the lazy dog.' 38 | ] 39 | emb = model.encode(sentences) 40 | print(emb.shape) 41 | print("done.") 42 | -------------------------------------------------------------------------------- /perturbnet/net2net/modules/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from torchvision import models 7 | import os 8 | from PIL import Image, ImageDraw 9 | 10 | 11 | def log_txt_as_img(wh, xc): 12 | b = len(xc) 13 | txts = list() 14 | for bi in range(b): 15 | txt = Image.new("RGB", wh, color="white") 16 | draw = ImageDraw.Draw(txt) 17 | nc = int(40 * (wh[0]/256)) 18 | lines = "\n".join(xc[bi][start:start+nc] for start in range(0, len(xc[bi]), nc)) 19 | draw.text((0,0), lines, fill="black") 20 | txt = np.array(txt).transpose(2,0,1)/127.5-1.0 21 | txts.append(txt) 22 | txts = np.stack(txts) 23 | txts = torch.tensor(txts) 24 | return txts 25 | 26 | 27 | class Downscale(nn.Module): 28 | def __init__(self, mode="bilinear", size=32): 29 | super().__init__() 30 | self.mode = mode 31 | self.out_size = size 32 | 33 | def forward(self, x): 34 | x = F.interpolate(x, mode=self.mode, size=self.out_size) 35 | return x 36 | 37 | 38 | class DownscaleUpscale(nn.Module): 39 | def __init__(self, mode_down="bilinear", mode_up="bilinear", size=32): 40 | super().__init__() 41 | self.mode_down = mode_down 42 | self.mode_up = mode_up 43 | self.out_size = size 44 | 45 | def forward(self, x): 46 | assert len(x.shape) == 4 47 | z = F.interpolate(x, mode=self.mode_down, size=self.out_size) 48 | x = F.interpolate(z, mode=self.mode_up, size=x.shape[2]) 49 | return x 50 | 51 | 52 | class TpsGridGen(nn.Module): 53 | def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True): 54 | super(TpsGridGen, self).__init__() 55 | self.out_h, self.out_w = out_h, out_w 56 | self.reg_factor = reg_factor 57 | self.use_cuda = use_cuda 58 | 59 | # create grid in numpy 60 | self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32) 61 | # sampling grid with dim-0 coords (Y) 62 | self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w),np.linspace(-1,1,out_h)) 63 | # grid_X,grid_Y: size [1,H,W,1,1] 64 | self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3) 65 | self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3) 66 | if use_cuda: 67 | self.grid_X = self.grid_X.cuda() 68 | self.grid_Y = self.grid_Y.cuda() 69 | 70 | # initialize regular grid for control points P_i 71 | if use_regular_grid: 72 | axis_coords = np.linspace(-1,1,grid_size) 73 | self.N = grid_size*grid_size 74 | P_Y,P_X = np.meshgrid(axis_coords,axis_coords) 75 | P_X = np.reshape(P_X,(-1,1)) # size (N,1) 76 | P_Y = np.reshape(P_Y,(-1,1)) # size (N,1) 77 | P_X = torch.FloatTensor(P_X) 78 | P_Y = torch.FloatTensor(P_Y) 79 | self.P_X_base = P_X.clone() 80 | self.P_Y_base = P_Y.clone() 81 | self.Li = self.compute_L_inverse(P_X,P_Y).unsqueeze(0) 82 | self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 83 | self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 84 | if use_cuda: 85 | self.P_X = self.P_X.cuda() 86 | self.P_Y = self.P_Y.cuda() 87 | self.P_X_base = self.P_X_base.cuda() 88 | self.P_Y_base = self.P_Y_base.cuda() 89 | 90 | 91 | def forward(self, theta): 92 | warped_grid = self.apply_transformation(theta,torch.cat((self.grid_X,self.grid_Y),3)) 93 | 94 | return warped_grid 95 | 96 | def compute_L_inverse(self,X,Y): 97 | N = X.size()[0] # num of points (along dim 0) 98 | # construct matrix K 99 | Xmat = X.expand(N,N) 100 | Ymat = Y.expand(N,N) 101 | P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2) 102 | P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation 103 | K = torch.mul(P_dist_squared,torch.log(P_dist_squared)) 104 | # construct matrix L 105 | O = torch.FloatTensor(N,1).fill_(1) 106 | Z = torch.FloatTensor(3,3).fill_(0) 107 | P = torch.cat((O,X,Y),1) 108 | L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0) 109 | Li = torch.inverse(L) 110 | if self.use_cuda: 111 | Li = Li.cuda() 112 | return Li 113 | 114 | def apply_transformation(self,theta,points): 115 | if theta.dim()==2: 116 | theta = theta.unsqueeze(2).unsqueeze(3) 117 | # points should be in the [B,H,W,2] format, 118 | # where points[:,:,:,0] are the X coords 119 | # and points[:,:,:,1] are the Y coords 120 | 121 | # input are the corresponding control points P_i 122 | batch_size = theta.size()[0] 123 | # split theta into point coordinates 124 | Q_X=theta[:,:self.N,:,:].squeeze(3) 125 | Q_Y=theta[:,self.N:,:,:].squeeze(3) 126 | Q_X = Q_X + self.P_X_base.expand_as(Q_X) 127 | Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y) 128 | 129 | # get spatial dimensions of points 130 | points_b = points.size()[0] 131 | points_h = points.size()[1] 132 | points_w = points.size()[2] 133 | 134 | # repeat pre-defined control points along spatial dimensions of points to be transformed 135 | P_X = self.P_X.expand((1,points_h,points_w,1,self.N)) 136 | P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N)) 137 | 138 | # compute weigths for non-linear part 139 | W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X) 140 | W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y) 141 | # reshape 142 | # W_X,W,Y: size [B,H,W,1,N] 143 | W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 144 | W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 145 | # compute weights for affine part 146 | A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X) 147 | A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y) 148 | # reshape 149 | # A_X,A,Y: size [B,H,W,1,3] 150 | A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 151 | A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 152 | 153 | # compute distance P_i - (grid_X,grid_Y) 154 | # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch 155 | points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N)) 156 | points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N)) 157 | 158 | if points_b==1: 159 | delta_X = points_X_for_summation-P_X 160 | delta_Y = points_Y_for_summation-P_Y 161 | else: 162 | # use expanded P_X,P_Y in batch dimension 163 | delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation) 164 | delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation) 165 | 166 | dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2) 167 | # U: size [1,H,W,1,N] 168 | dist_squared[dist_squared==0]=1 # avoid NaN in log computation 169 | U = torch.mul(dist_squared,torch.log(dist_squared)) 170 | 171 | # expand grid in batch dimension if necessary 172 | points_X_batch = points[:,:,:,0].unsqueeze(3) 173 | points_Y_batch = points[:,:,:,1].unsqueeze(3) 174 | if points_b==1: 175 | points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:]) 176 | points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:]) 177 | 178 | points_X_prime = A_X[:,:,:,:,0]+ \ 179 | torch.mul(A_X[:,:,:,:,1],points_X_batch) + \ 180 | torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \ 181 | torch.sum(torch.mul(W_X,U.expand_as(W_X)),4) 182 | 183 | points_Y_prime = A_Y[:,:,:,:,0]+ \ 184 | torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \ 185 | torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \ 186 | torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4) 187 | 188 | return torch.cat((points_X_prime,points_Y_prime),3) 189 | 190 | 191 | def random_tps(*args, grid_size=4, reg_factor=0, strength_factor=1.0): 192 | """Random TPS. Device and size determined from first argument, all 193 | remaining arguments transformed with the same parameters.""" 194 | x = args[0] 195 | is_np = type(x) == np.ndarray 196 | no_batch = len(x.shape) == 3 197 | if no_batch: 198 | args = [x[None,...] for x in args] 199 | if is_np: 200 | args = [torch.tensor(x.copy()).permute(0,3,1,2) for x in args] 201 | x = args[0] 202 | use_cuda = x.is_cuda 203 | b,c,h,w = x.shape 204 | grid_size = 4 205 | tps = TpsGridGen(out_h=h, 206 | out_w=w, 207 | use_regular_grid=True, 208 | grid_size=grid_size, 209 | reg_factor=reg_factor, 210 | use_cuda=use_cuda) 211 | # theta = b,2*N - first N = X, second N = Y, 212 | control = torch.cat([tps.P_X_base, tps.P_Y_base], dim=0) 213 | control = control[None,:,0][b*[0],...] 214 | theta = (torch.rand(b,2*grid_size*grid_size)*2-1.0) / grid_size * strength_factor 215 | final = control+theta.to(control) 216 | final = torch.clamp(final, -1.0, 1.0) 217 | final = final - control 218 | final[control==-1.0] = 0.0 219 | final[control==1.0] = 0.0 220 | grid = tps(final) 221 | 222 | is_uint8 = [x.dtype == torch.uint8 for x in args] 223 | args = [args[i].to(grid)+0.5 if is_uint8[i] else args[i] for i in range(len(args))] 224 | out = [torch.nn.functional.grid_sample(x, grid, align_corners=True) for x in args] 225 | out = [out[i].to(torch.uint8) if is_uint8[i] else out[i] for i in range(len(out))] 226 | if is_np: 227 | out = [x.permute(0,2,3,1).numpy() for x in out] 228 | if no_batch: 229 | out = [x[0,...] for x in out] 230 | if len(out) == 1: 231 | out = out[0] 232 | return out 233 | 234 | 235 | def count_params(model): 236 | total_params = sum(p.numel() for p in model.parameters()) 237 | return total_params 238 | -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/pytorch_scvi/__init__.py -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/pytorch_scvi/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/pytorch_scvi/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/__pycache__/scvi_generate_z.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/pytorch_scvi/__pycache__/scvi_generate_z.cpython-37.pyc -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/__pycache__/scvi_generate_z.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/welch-lab/PerturbNet/96f38d8e2629cc4416c0f4c8e2051b16ec2a4816/perturbnet/pytorch_scvi/__pycache__/scvi_generate_z.cpython-39.pyc -------------------------------------------------------------------------------- /perturbnet/pytorch_scvi/scvi_generate_z.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import sys 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from perturbnet.pytorch_scvi.distributions import * 8 | 9 | class ConcatDataset(torch.utils.data.Dataset): 10 | """ 11 | data structure with sample indices of two datasets 12 | """ 13 | def __init__(self, *datasets): 14 | self.datasets = datasets 15 | 16 | def __getitem__(self, i): 17 | return tuple(d[i] for d in self.datasets) 18 | 19 | def __len__(self): 20 | return min(len(d) for d in self.datasets) 21 | 22 | class scvi_predictive_z: 23 | """ 24 | class to generate the gene expression data from latent variables of scVI 25 | """ 26 | def __init__(self, model): 27 | super().__init__() 28 | self.model = model 29 | 30 | def one_hot(self, index, n_cat): 31 | onehot = torch.zeros(index.size(0), n_cat, device = index.device) 32 | onehot.scatter_(1, index.type(torch.long), 1) 33 | return onehot.type(torch.float32) 34 | 35 | 36 | def decoder_inference(self, z, library, batch_index = None, y = None, n_samples = 1, 37 | transform_batch = None): 38 | """ 39 | a function employed on the scVI.model object, currently only allow n_samples == 1 40 | """ 41 | if transform_batch is not None: 42 | dec_batch_index = transform_batch * torch.ones_like(batch_index) 43 | else: 44 | dec_batch_index = batch_index 45 | 46 | px_scale, px_r, px_rate, px_dropout = self.model.model.decoder(self.model.model.dispersion, z, library, dec_batch_index, y) 47 | if self.model.model.dispersion == "gene-label": 48 | px_r = F.linear( 49 | self.one_hot(y, self.model.model.n_labels), self.model.model.px_r 50 | ) # px_r gets transposed - last dimension is nb genes 51 | elif self.model.model.dispersion == "gene-batch": 52 | px_r = F.linear(self.one_hot(dec_batch_index, self.model.model.n_batch), self.model.model.px_r) 53 | elif self.model.model.dispersion == "gene": 54 | px_r = self.model.model.px_r 55 | px_r = torch.exp(px_r) 56 | 57 | return dict( 58 | px_scale = px_scale, 59 | px_r = px_r, 60 | px_rate = px_rate, 61 | px_dropout = px_dropout) 62 | 63 | 64 | @torch.no_grad() 65 | def posterior_predictive_sample_from_Z( 66 | self, 67 | z_sample, 68 | l_sample, 69 | n_samples: int = 1, 70 | batch_size = None 71 | ): 72 | 73 | if self.model.model.gene_likelihood not in ["zinb", "nb", "poisson"]: 74 | raise ValueError("Invalid gene_likelihood.") 75 | 76 | if batch_size is None: 77 | batch_size = 32 78 | 79 | data_loader = torch.utils.data.DataLoader( 80 | ConcatDataset(z_sample, l_sample), 81 | batch_size = batch_size, 82 | shuffle = False) 83 | 84 | x_new = [] 85 | for batch_idx, (batch_z, batch_l) in enumerate(data_loader): 86 | 87 | labels = None # currently only support unsupervised learning 88 | 89 | outputs = self.decoder_inference( 90 | batch_z, batch_l, batch_index = batch_idx, y = labels, n_samples = n_samples 91 | ) 92 | px_r = outputs["px_r"] 93 | px_rate = outputs["px_rate"] 94 | px_dropout = outputs["px_dropout"] 95 | 96 | if self.model.model.gene_likelihood == "poisson": 97 | l_train = px_rate 98 | l_train = torch.clamp(l_train, max=1e8) 99 | dist = torch.distributions.Poisson( 100 | l_train 101 | ) # Shape : (n_samples, n_cells_batch, n_genes) 102 | elif self.model.model.gene_likelihood == "nb": 103 | dist = NegativeBinomial(mu = px_rate, theta = px_r) 104 | elif self.model.model.gene_likelihood == "zinb": 105 | dist = ZeroInflatedNegativeBinomial( 106 | mu = px_rate, theta = px_r, zi_logits = px_dropout 107 | ) 108 | 109 | else: 110 | raise ValueError( 111 | "{} reconstruction error not handled right now".format( 112 | self.model.model.gene_likelihood 113 | ) 114 | ) 115 | 116 | if n_samples > 1: 117 | exprs = dist.sample().permute( 118 | [1, 2, 0] 119 | ) # Shape : (n_cells_batch, n_genes, n_samples) 120 | else: 121 | exprs = dist.sample() 122 | 123 | x_new.append(exprs.cpu()) 124 | x_new = torch.cat(x_new) # Shape (n_cells, n_genes, n_samples) 125 | 126 | return x_new.numpy() 127 | 128 | @torch.no_grad() 129 | def posterior_predictive_sample_from_Z_with_y( 130 | self, 131 | z_sample, 132 | l_sample, 133 | y_sample, 134 | n_samples: int = 1, 135 | batch_size = None 136 | ): 137 | 138 | if self.model.model.gene_likelihood not in ["zinb", "nb", "poisson"]: 139 | raise ValueError("Invalid gene_likelihood.") 140 | 141 | if batch_size is None: 142 | batch_size = 32 143 | 144 | data_loader = torch.utils.data.DataLoader( 145 | ConcatDataset(z_sample, l_sample, y_sample), 146 | batch_size = batch_size, 147 | shuffle = False) 148 | 149 | x_new = [] 150 | for batch_idx, (batch_z, batch_l, batch_y) in enumerate(data_loader): 151 | 152 | outputs = self.decoder_inference( 153 | batch_z, batch_l, batch_index = batch_idx, y = batch_y, n_samples = n_samples 154 | ) 155 | px_r = outputs["px_r"] 156 | px_rate = outputs["px_rate"] 157 | px_dropout = outputs["px_dropout"] 158 | 159 | if self.model.model.gene_likelihood == "poisson": 160 | l_train = px_rate 161 | l_train = torch.clamp(l_train, max=1e8) 162 | dist = torch.distributions.Poisson( 163 | l_train 164 | ) # Shape : (n_samples, n_cells_batch, n_genes) 165 | elif self.model.model.gene_likelihood == "nb": 166 | dist = NegativeBinomial(mu = px_rate, theta = px_r) 167 | elif self.model.model.gene_likelihood == "zinb": 168 | dist = ZeroInflatedNegativeBinomial( 169 | mu = px_rate, theta = px_r, zi_logits = px_dropout 170 | ) 171 | 172 | else: 173 | raise ValueError( 174 | "{} reconstruction error not handled right now".format( 175 | self.model.model.gene_likelihood 176 | ) 177 | ) 178 | 179 | if n_samples > 1: 180 | exprs = dist.sample().permute( 181 | [1, 2, 0] 182 | ) # Shape : (n_cells_batch, n_genes, n_samples) 183 | else: 184 | exprs = dist.sample() 185 | 186 | x_new.append(exprs.cpu()) 187 | x_new = torch.cat(x_new) # Shape (n_cells, n_genes, n_samples) 188 | 189 | return x_new.numpy() 190 | 191 | @torch.no_grad() 192 | def posterior_predictive_sample_from_Z_with_batch( 193 | self, 194 | z_sample, 195 | l_sample, 196 | batch_sample, 197 | n_samples: int = 1, 198 | batch_size = None 199 | ): 200 | 201 | if self.model.model.gene_likelihood not in ["zinb", "nb", "poisson"]: 202 | raise ValueError("Invalid gene_likelihood.") 203 | 204 | if batch_size is None: 205 | batch_size = 32 206 | 207 | data_loader = torch.utils.data.DataLoader( 208 | ConcatDataset(z_sample, l_sample, batch_sample), 209 | batch_size = batch_size, 210 | shuffle = False) 211 | 212 | x_new = [] 213 | for batch_idx, (batch_z, batch_l, batch_batch) in enumerate(data_loader): 214 | 215 | labels = None # currently only support unsupervised learning 216 | 217 | outputs = self.decoder_inference( 218 | batch_z, batch_l, batch_index = batch_batch.view(batch_batch.shape[0], -1), y = labels, n_samples = n_samples 219 | ) 220 | px_r = outputs["px_r"] 221 | px_rate = outputs["px_rate"] 222 | px_dropout = outputs["px_dropout"] 223 | 224 | if self.model.model.gene_likelihood == "poisson": 225 | l_train = px_rate 226 | l_train = torch.clamp(l_train, max=1e8) 227 | dist = torch.distributions.Poisson( 228 | l_train 229 | ) # Shape : (n_samples, n_cells_batch, n_genes) 230 | elif self.model.model.gene_likelihood == "nb": 231 | dist = NegativeBinomial(mu = px_rate, theta = px_r) 232 | elif self.model.model.gene_likelihood == "zinb": 233 | dist = ZeroInflatedNegativeBinomial( 234 | mu = px_rate, theta = px_r, zi_logits = px_dropout 235 | ) 236 | 237 | else: 238 | raise ValueError( 239 | "{} reconstruction error not handled right now".format( 240 | self.model.model.gene_likelihood 241 | ) 242 | ) 243 | 244 | if n_samples > 1: 245 | exprs = dist.sample().permute( 246 | [1, 2, 0] 247 | ) # Shape : (n_cells_batch, n_genes, n_samples) 248 | else: 249 | exprs = dist.sample() 250 | 251 | x_new.append(exprs.cpu()) 252 | x_new = torch.cat(x_new) # Shape (n_cells, n_genes, n_samples) 253 | 254 | return x_new.numpy() 255 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.7" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | fail_on_warning: false 11 | 12 | python: 13 | install: 14 | - method: pip 15 | path: . 16 | - requirements: docs/requirements.txt 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scvi-tools==0.7.1 2 | tensorflow==1.15.0 3 | torch==1.13.1 4 | umap-learn==0.4.6 5 | numpy==1.19.5 6 | torchvision==0.2.0 7 | numba==0.49.1 8 | llvmlite==0.32.1 9 | captum==0.7.0 10 | rdkit 11 | pickle5 12 | scikit-learn 13 | scipy 14 | scanpy 15 | tqdm 16 | seaborn 17 | fair-esm 18 | requests 19 | plotnine 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # read the contents of README file 4 | from os import path 5 | from io import open # for Python 2 and 3 compatibility 6 | 7 | this_directory = path.abspath(path.dirname(__file__)) 8 | 9 | 10 | # read the contents of README.md 11 | def readme(): 12 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 13 | return f.read() 14 | 15 | 16 | # read the contents of requirements.txt 17 | with open(path.join(this_directory, 'requirements.txt'), 18 | encoding='utf-8') as f: 19 | requirements = f.read().splitlines() 20 | 21 | VERSION = "0.0.3" 22 | 23 | setup( name='PerturbNet', 24 | version=VERSION, 25 | license='GPL-3.0', 26 | description='PerturbNet', 27 | long_description=readme(), 28 | long_description_content_type='text/markdown', 29 | url='https://github.com/welch-lab/PerturbNet', 30 | author='Hengshi Yu, Weizhou Qian, Yuxuan Song, Joshua Welch', 31 | packages=find_packages(exclude=['test']), 32 | zip_safe=False, 33 | include_package_data=True, 34 | install_requires=requirements, 35 | python_requires=">=3.7,<3.8" 36 | #setup_requires=['setuptools>=38.6.0'] 37 | ) 38 | --------------------------------------------------------------------------------