├── .dockerignore ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ └── manuscript.md └── workflows │ └── build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── README.md ├── doc ├── benchmarks.md ├── getting_started.md ├── images │ ├── model_arch.png │ └── stats_apex_vs_not_apex.png ├── training_examples.md └── training_notes.md ├── download_data.py ├── examples ├── 3D_benchmark_tests │ ├── run_3D.sh │ ├── run_3D_big.sh │ ├── run_3D_small.sh │ ├── run_apex_tests.py │ ├── run_benchmark_tests.py │ ├── run_docker.sh │ └── run_docker_apex.sh ├── latent_variation.ipynb ├── loading_models.ipynb ├── paper_plos_2020 │ ├── 1_model_compare_2D.ipynb │ ├── 2_model_compare_3D.ipynb │ ├── 3_ELBO_demo 3D_single_channel.ipynb │ ├── 4_ELBO_ablation.ipynb │ ├── 5_drug_embedding.ipynb │ ├── 6_3D_cell_mage_printing.ipynb │ └── 7_latent_space_visualization.ipynb ├── plot_error.ipynb └── training_scripts │ ├── ae2D.sh │ ├── ae3D.sh │ ├── bvae2D.sh │ ├── bvae3D.sh │ ├── cbvae2D.sh │ ├── cbvae3D.sh │ ├── cbvaegan2D_target.sh │ └── cbvaegan2D_target2.sh ├── integrated_cell ├── __init__.py ├── arc_walk.py ├── bin │ ├── __init__.py │ ├── train_model.py │ └── train_multi.py ├── corr_stats.py ├── data_providers │ ├── DataProvider.py │ ├── DataProvider3D_label_free.py │ ├── DataProviderABC.py │ ├── GraphDataProvider.py │ ├── MaskedChannelsDataProvider.py │ ├── MultiScaleDataProvider.py │ ├── PatchDataProvider.py │ ├── ProjectedDataProvider.py │ ├── RefDataProvider.py │ ├── RescaledIntensityDataProvider.py │ ├── RescaledIntensityRefDataProvider.py │ ├── RescaledIntensityTargetDataProvider.py │ ├── TargetDataProvider.py │ └── __init__.py ├── external │ ├── __init__.py │ └── pytorch_fid │ │ ├── LICENSE │ │ ├── README.md │ │ ├── fid_score.py │ │ └── inception.py ├── imgToProjection.py ├── layers.py ├── losses.py ├── metrics │ ├── __init__.py │ ├── embeddings.py │ ├── embeddings_reference.py │ ├── embeddings_target.py │ ├── features2D.py │ └── precision_recall.py ├── model_utils.py ├── models │ ├── __init__.py │ ├── ae.py │ ├── base_model.py │ ├── bvae.py │ ├── bvae_cond.py │ ├── bvae_cond_apex.py │ ├── cbvaae.py │ ├── cbvae.py │ ├── cbvae2.py │ ├── cbvae2_gan2.py │ ├── cbvae2_gan_apex.py │ ├── cbvae2_gan_model_parallel.py │ ├── cbvae2_gan_ref.py │ ├── cbvae2_progressive.py │ ├── cbvae2_ref.py │ ├── cbvae2_target.py │ ├── cbvae_apex.py │ ├── cbvae_gan.py │ ├── cbvae_multi.py │ ├── cbvaegan_target.py │ └── cbvaegan_target2.py ├── networks │ ├── __init__.py │ ├── ae2D_residual.py │ ├── ae3D_residual.py │ ├── common.py │ ├── cvaegan2D_residual.py │ ├── cvaegan3D_residual.py │ ├── old │ │ ├── aaegan.py │ │ ├── aaegan3D.py │ │ ├── aaegan3D_resid.py │ │ ├── aaegan3Dv2.py │ │ ├── aaegan3Dv3.py │ │ ├── aaegan3Dv4-elu.py │ │ ├── aaegan3Dv4-relu.py │ │ ├── aaegan3Dv4.py │ │ ├── aaegan3Dv5.py │ │ ├── aaegan_256v2.py │ │ ├── aaegan_256v3.py │ │ ├── aaegan_compact.py │ │ ├── aaegan_compact_lrelu.py │ │ ├── aaegan_compact_lrelu2.py │ │ ├── aaegan_v2.py │ │ ├── aaegan_v3.py │ │ ├── waaegan_256v2.py │ │ ├── waaegan_v2.py │ │ └── waaegan_v3_improved.py │ ├── proto │ │ └── __init__.py │ ├── ref_target_autoencoder.py │ ├── vaegan2D_cgan_target.py │ ├── vaegan3D_cgan_target2.py │ └── vaegan3D_cgan_tsm.py ├── param_search.py ├── simplelogger.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── resources │ │ ├── img2D.png │ │ ├── img3D.tiff │ │ └── test_data.csv │ ├── test_dataprovider.py │ ├── test_dummy.py │ └── test_kld.py └── utils │ ├── __init__.py │ ├── build_control_images.py │ ├── features.py │ ├── image_transforms.py │ ├── imgToProjection.py │ ├── plots.py │ ├── reference │ └── __init__.py │ ├── spectral_norm.py │ ├── target │ ├── __init__.py │ ├── plots.py │ └── utils.py │ └── utils.py ├── requirements.txt ├── scripts ├── aaegan_short.sh ├── bvae.sh ├── bvae_short.sh ├── bvaegan_short.sh ├── multi_pred.sh ├── multi_pred_short.sh └── test_build_control_images.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | examples/ 3 | .git/ 4 | .* 5 | *.egg-info 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | exclude = 4 | .git, 5 | __pycache__, 6 | docs/source/conf.py, 7 | build 8 | ignore = 9 | E501, # max line length 10 | W503, # breaking before binary operators 11 | E302 # multi-line empty space / comments bug 12 | builtins = 13 | _, 14 | NAME, 15 | MAINTAINER, 16 | MAINTAINER_EMAIL, 17 | DESCRIPTION, 18 | LONG_DESCRIPTION, 19 | URL, 20 | DOWNLOAD_URL, 21 | LICENSE, 22 | CLASSIFIERS, 23 | AUTHOR, 24 | AUTHOR_EMAIL, 25 | PLATFORMS, 26 | VERSION, 27 | PACKAGES, 28 | PACKAGE_DATA, 29 | REQUIRES 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/manuscript.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Manuscript 3 | about: Issue to address in manuscript 4 | title: '' 5 | labels: manuscript 6 | assignees: donovanr 7 | 8 | --- 9 | 10 | ## Issue summary 11 | 12 | 13 | ## Details 14 | 15 | 16 | ## TODO 17 | - [ ] code 18 | - [ ] manuscript text 19 | - [ ] figure 20 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | - name: Build Docker Image 11 | run: | 12 | docker build -t aics/pytorch_integrated_cell -f Dockerfile . 13 | - name: Run tests 14 | run: | 15 | docker run aics/pytorch_integrated_cell pytest integrated_cell/tests/ 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | ._.DS_Store 3 | **.pyc 4 | **.pyt 5 | .ipynb_checkpoints 6 | .vscode 7 | .bash_history 8 | train_modules/__pycache__/* 9 | *.egg-info/ 10 | examples/3D_benchmark_tests/*/* 11 | results/ 12 | data/ 13 | integrated_cell/networks/old 14 | integrated_cell/networks/proto 15 | integrated_cell/models/proto 16 | examples/paper_plos_2020/images 17 | examples/paper_plos_2020/demo_imgs 18 | examples/paper_plos_2020/*.tiff 19 | examples/training_scripts/*/ 20 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.6 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v1.3.0 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Allen Institute Contribution Agreement 2 | 3 | This document describes the terms under which you may make "Contributions" — which may include without limitation, software additions, revisions, bug fixes, configuration changes, documentation, or any other materials — to any of the projects owned or managed by the Allen Institute. If you have questions about these terms, please contact us at terms@alleninstitute.org. 4 | 5 | You certify that: 6 | 7 | * Your Contributions are either: 8 | 1. Created in whole or in part by you and you have the right to submit them under the designated license (described below); or 9 | 2. Based upon previous work that, to the best of your knowledge, is covered under an appropriate open source license and you have the right under that license to submit that work with modifications, whether created in whole or in part by you, under the designated license; or 10 | 3. Provided directly to you by some other person who certified (1) or (2) and you have not modified them. 11 | 12 | * You are granting your Contributions to the Allen Institute under the terms of the [2-Clause BSD license](https://opensource.org/licenses/BSD-2-Clause) (the “designated license”). 13 | 14 | * You understand and agree that the Allen Institute projects and your Contributions are public and that a record of the Contributions (including all metadata and personal information you submit with them) is maintained indefinitely and may be redistributed consistent with the Allen Institute’s mission and the 2-Clause BSD license. 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # From pytorch compiled from source 2 | FROM nvcr.io/nvidia/pytorch:19.09-py3 3 | 4 | COPY ./ /root/projects/pytorch_integrated_cell 5 | 6 | WORKDIR /root/projects/pytorch_integrated_cell 7 | 8 | RUN pip install -e . 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pytorch 3D Integrated Cell 2 | =============================== 3 | 4 | ![Model Architecture](doc/images/model_arch.png?raw=true "Model Architecture") 5 | 6 | Building a 3D Integrated Cell: https://www.biorxiv.org/content/early/2017/12/21/238378 7 | 8 | ### For the 2D manuscript and software: 9 | 10 | **Generative Modeling with Conditional Autoencoders: Building an Integrated Cell** 11 | Manuscript: https://arxiv.org/abs/1705.00092 12 | GitHub: https://github.com/AllenCellModeling/torch_integrated_cell 13 | 14 | ## Todo Items 15 | - GIT 16 | - [x] Remove old big files from git 17 | 18 | - Jupyter notebooks 19 | - [x] Check-in current state to git 20 | - [x] Make sure notebooks all run and can produce figures 21 | - [x] Annotate notebooks (notebook purpose) 22 | - [x] Clear outputs 23 | - [x] Check-in final state to git 24 | 25 | - Data 26 | - [x] Make sure current Quilt data works 27 | - [ ] Check-in manuscript data to Quilt 28 | 29 | - Code 30 | - [x] Check-in current state to git 31 | - [x] Clear unused code 32 | - [ ] Clean up and annotate main functions 33 | - [ ] Check-in final state to git 34 | 35 | - Demos/Docs 36 | - [x] Installation instructions 37 | - [x] Getting Started doc 38 | - [ ] Demos for different training methods 39 | - [ ] Update doc figures 40 | 41 | ## Support 42 | 43 | This code is in active development and is used within our organization. We are currently not supporting this code for external use and are simply releasing the code to the community AS IS. The community is welcome to submit issues, but you should not expect an active response. 44 | 45 | 46 | ## System requirements 47 | 48 | We recommend installation on Linux and an NVIDIA graphics card with 10+ GB of RAM (e.g., NVIDIA Titan X Pascal) with the latest drivers installed. 49 | 50 | ## Installation 51 | 52 | Installing on linux is recommended. 53 | 54 | - Install Python 3.6+/Docker/etc if necessary. 55 | - All commands listed below assume the bash shell. 56 | 57 | ### **Installation method (A): In Existing Workspace** 58 | (Optional) Make a fresh conda repo. (This will mess up some libraries if inside a some of Nvidia's Docker images) 59 | ```shell 60 | conda create --name pytorch_integrated_cell python=3.7 61 | conda activate pytorch_integrated_cell 62 | ``` 63 | Clone and install the repo 64 | ```shell 65 | git clone https://github.com/AllenCellModeling/pytorch_integrated_cell 66 | cd pytorch_integrated_cell 67 | pip install -e . 68 | ``` 69 | 70 | If you want to do some development, install the pre-commit hooks: 71 | ```shell 72 | pip install pre-commit 73 | pre-commit install 74 | ``` 75 | 76 | (Optional) Clone and install Nvidia Apex for half-precision computation 77 | Please follow the instructions on the Nvidia Apex github page: 78 | https://github.com/NVIDIA/apex 79 | 80 | ### **Installation method (B): Docker** 81 | We build on Nvidia Docker images. In our tests this runs very slightly slower than (A) although your mileage may vary. This comes with Nvidia Apex. 82 | ```shell 83 | git clone https://github.com/AllenCellModeling/pytorch_integrated_cell 84 | cd pytorch_integrated_cell 85 | docker build -t aics/pytorch_integrated_cell -f Dockerfile . 86 | ``` 87 | 88 | ## Data 89 | Data can be downloaded via Quilt T3. The following script will dump the complete 2D and 3D dataset into `./data/`. This may take a long time depending on your connection. 90 | ```shell 91 | python download_data.py 92 | ``` 93 | The dataset is about 250gb. 94 | 95 | ## Training Models 96 | Models are trained by via command line argument. A typical training call looks something like: 97 | ```shell 98 | ic_train_model \ 99 | --gpu_ids 0 \ 100 | --model_type ae \ 101 | --save_parent ./ \ 102 | --lr_enc 2E-4 --lr_dec 2E-4 \ 103 | --data_save_path ./data.pyt \ 104 | --imdir ./data/ \ 105 | --crit_recon integrated_cell.losses.BatchMSELoss \ 106 | --kwargs_crit_recon '{}' \ 107 | --network_name vaegan2D_cgan \ 108 | --kwargs_enc '{"n_classes": 24, "ch_ref": [0, 2], "ch_target": [1], "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \ 109 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 110 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "ch_ref": [0, 2], "ch_target": [1], "n_latent_dim": 512, "n_ref": 512, "output_padding": [1,1], "activation_last": "softplus"}' \ 111 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 112 | --kwargs_model '{"kld_reduction": "mean_batch", "objective": "H", "beta": 1E-2}' \ 113 | --train_module cbvae2 \ 114 | --dataProvider DataProvider \ 115 | --kwargs_dp '{"crop_to": [160,96], "return2D": 1, "check_files": 0, "make_controls": 0, "csv_name": "controls/data_plus_controls.csv", "normalize_intensity": "avg_intensity"}' \ 116 | --saveStateIter 1 --saveProgressIter 1 \ 117 | --channels_pt1 0 1 2 \ 118 | --batch_size 64 \ 119 | --nepochs 300 \ 120 | ``` 121 | 122 | This automatically creates a timestamped directory in the current directory `./`. 123 | 124 | For details on how to modify training options, please see [the training documentation](doc/training.md) 125 | 126 | ## Loading Modes 127 | Models are loaded via python API. A typical loading call looks something like: 128 | ```python 129 | from integrated_cell import utils 130 | 131 | model_dir = '/my_parent_directory/model_type/date_time/' 132 | parent_dir = '/my_parent_directory/' 133 | 134 | networks, data_provider, args = utils.load_network_from_dir(model_dir, parent_dir) 135 | 136 | target_enc = networks['enc'] 137 | target_dec = networks['dec'] 138 | 139 | ``` 140 | 141 | `networks` is a dictionary of the model subcomponents. 142 | `data_provider` is the an object that contains train, validate, and test data. 143 | `args` is a dictionary of the list of aguments passed to the model 144 | 145 | For details on how to modify training options, please see [the loading documentation](doc/loading.md) 146 | 147 | 148 | 149 | ## Project website 150 | Example outputs of this model can be viewed at http://www.allencell.org. 151 | 152 | ## Examples ## 153 | Examples of how to run the code can be found in the [3D benchmarks section](doc/benchmarks.md). 154 | 155 | ## Citation 156 | If you find this code useful in your research, please consider citing the following paper: 157 | 158 | @article {Johnson238378, 159 | author = {Johnson, Gregory R. and Donovan-Maiye, Rory M. and Maleckar, Mary M.}, 160 | title = {Building a 3D Integrated Cell}, 161 | year = {2017}, 162 | doi = {10.1101/238378}, 163 | publisher = {Cold Spring Harbor Laboratory}, 164 | URL = {https://www.biorxiv.org/content/early/2017/12/21/238378}, 165 | eprint = {https://www.biorxiv.org/content/early/2017/12/21/238378.full.pdf}, 166 | journal = {bioRxiv} 167 | } 168 | 169 | ## Contact 170 | Gregory Johnson 171 | E-mail: gregj@alleninstitute.org 172 | 173 | ## License 174 | This program is free software: you can redistribute it and/or modify 175 | it under the terms of the GNU General Public License as published by 176 | the Free Software Foundation, either version 3 of the License, or 177 | (at your option) any later version. 178 | 179 | This program is distributed in the hope that it will be useful, 180 | but WITHOUT ANY WARRANTY; without even the implied warranty of 181 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 182 | GNU General Public License for more details. 183 | 184 | You should have received a copy of the GNU General Public License 185 | along with this program. If not, see . 186 | -------------------------------------------------------------------------------- /doc/benchmarks.md: -------------------------------------------------------------------------------- 1 | Running 3D Benchmarks 2 | =============================== 3 | ## About 4 | In the directory 5 | `examples/3D_benchmark_tests` there are some scripts to run benchmarks on your system. You can configure which type of model you want to run, which gpus you want to run on, etc. 6 | 7 | ## How to Run 8 | 9 | ### System requirements 10 | 11 | Benchmarks are set to run on a box with 8 GPUs. Instructions to reconfigure your code are at the bottom of this page. 12 | 13 | ### Installation 14 | Install the Integrated Cell code via _both_ method (A) (with Nvidia Apex) and method (B). Download the data too. 15 | 16 | ### Running scripts 17 | From the root directory of this repo, cd into the 3D_benchmark_tests directory, and run `run_benchmark_tests.py` 18 | 19 | ```shell 20 | cd examples/3D_benchmark_tests/ 21 | python run_benchmark_tests.py 22 | ``` 23 | 24 | This will run the model configuration in `run_3D.sh` on different GPUs, with and without Apex, and with and without Docker. Plots showing the iteration time vs batch size will appear in this directory, e.g. 25 | ![apex](images/stats_apex_vs_not_apex.png?raw=true "apex vs not apex") 26 | 27 | This figure shows the relationship between the largest models we can run with and without Nvidia Apex. 28 | 29 | ### Changing the configuration 30 | The primary configuration section is in this block in the `run_benchmark_tests.py` file: 31 | 32 | ```python 33 | experiment_dict = {} 34 | experiment_dict["function_call"] = ["bash run_docker.sh", "bash run_3D.sh"] 35 | experiment_dict["trainer_type"] = ["cbvae_apex", "cbvae"] 36 | experiment_dict["gpu_id"] = [ 37 | [2], 38 | [2, 3], 39 | [3, 4], 40 | [0, 1, 2, 3], 41 | [2, 3, 4, 5], 42 | [0, 1, 2, 3, 4, 5, 6, 7], 43 | ] 44 | experiment_dict["batch_size"] = [8, 16, 32, 64, 128, 256] 45 | ``` 46 | 47 | If you want to change the number or subsets of GPUs to try, change the `"gpu_id"` list, batch size with the `"batch_size"` list, etc. -------------------------------------------------------------------------------- /doc/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | The Integrated Cell is a collection of tools for training deep generative models of cell organization. 4 | This document serves as a basic guide for how the code is organized and how to train and interface with trained networks. 5 | 6 | **This is code is NOT for production. It is for research. It still needs a lot of work.** 7 | 8 | ## Prerequesites 9 | This assumes you have a medium-level handle of Python as well as the PyTorch framework. 10 | 11 | You will need a CUDA-capable GPU as well. If it is installed with up-to-date drivers, insta 12 | 13 | ## Code and Model Organization 14 | The major components to train a model are as follows: 15 | ### 1) A **data provider** 16 | (`integrated_cell/data_providers`) 17 | 18 | These are objects that provide data via a `get_sample` method. The general means by which to interface with them is via this command: 19 | 20 | ```x, classes, ref = data_provider.get_sample()``` 21 | 22 | where `x` is a (`batch_size` by `channel_dimension` by `spatial_dimensions`) batch of images, `classes` is a list of integer values corresponding to the class label of `x` (usually corresponding to the so-called "structure channel"), and `ref` is reserved for some reference information, this is sometimes the "reference channels", somtimes it is empty. 23 | 24 | The "base" data provider is in `integrated_cell/data_providers/DataProvider.py`. There are also child data providers that do different types of augmentation, etc. 25 | 26 | ### 2) A **network** to optimize 27 | (`integrated_cell/networks`) 28 | 29 | These are the networks that are being optimized. Each network type (e.g. `integrated_cell/networks/vaegan3D_cgan_target2.py`) generally contain an encoder and decoder object (`Enc` and `Dec`) and depending on the type, `Enc` may produce multiple outputs (in the case of a Variational Autoencoder). 30 | Other objects may be present in the file as well; a decoder discriminator (`DecD`) or encoder discriminator (`EncD`) may be present for advarsarial models. 31 | 32 | The constructors of these sub-networks may have parameters that define the shape or attributes of the network. These are typical Pytorch `torch.nn.Module` objects. 33 | 34 | ### 3) A **loss** function 35 | (`integrated_cell/losses.py` or [loss functions from Pytorch](https://pytorch.org/docs/stable/nn.html#loss-functions)) 36 | 37 | These are custom losses used for model training. 38 | There may be multiple loss objects for a **network** that may be passed into a **model**. 39 | Typically we use pixel-wise mean squared error, binary cross entropy, or L1-loss for images. 40 | 41 | ### 2) A **model**-type trainer 42 | (`integrated_cell/models`) 43 | 44 | These are objects that train specific types of models. 45 | They intake all of the above components and perform backprop steps until some number of iterations or epochs have been completed. They also control saving model state and doing checks on validation data, etc. 46 | 47 | ## Model training 48 | Kicking off model training is usually performed via the command-line with `ic_train_model`. 49 | This command allows you to pass in ALL of the parameters necessary to define a training session. 50 | 51 | A command looks like: 52 | ```shell 53 | ic_train_model \ 54 | --gpu_ids 0 \ 55 | --model_type ae \ 56 | --save_parent ./ \ 57 | --lr_enc 2E-4 --lr_dec 2E-4 \ 58 | --data_save_path ./data_target.pyt \ 59 | --crit_recon integrated_cell.losses.BatchMSELoss \ 60 | --kwargs_crit_recon '{}' \ 61 | --network_name vaegan3D_cgan_target2 \ 62 | --kwargs_enc '{"n_classes": 24, "n_latent_dim": 512}' \ 63 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 64 | --kwargs_dec '{"n_classes": 24, "n_latent_dim": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \ 65 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 66 | --train_module cbvae2_target \ 67 | --kwargs_model '{"beta_min": 1E-6, "beta_start": -1, "beta_step": 3E-5, "objective": "A", "kld_reduction": "batch"}' \ 68 | --imdir /raid/shared/ipp/scp_19_04_10/ \ 69 | --dataProvider TargetDataProvider \ 70 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "make_controls": 0, "csv_name": "controls/data_plus_controls.csv"}' \ 71 | --saveStateIter 1 --saveProgressIter 1 \ 72 | --channels_pt1 0 1 2 \ 73 | --batch_size 32 \ 74 | --nepochs 300 \ 75 | ``` 76 | 77 | You can see that there are a lot of settions that may be configured. 78 | This command is usually run from a `.sh` file, so one can pass in arguments via command line. 79 | 80 | It should be noted that the parameter `--save_parent` is the directory in which the model save directory is constructed. 81 | Running the above command will create a date-timestamped directory for that particular run. 82 | If you would like to specify the explicit save directory, use the `--save_dir` parameter instead. 83 | 84 | ## Viewing model status 85 | Viewing model status is usually performed with the jupyter notebook `examples/plot_error.ipynb`. 86 | It is meant to be interactive. 87 | There is a variable `dirs_to_search` that allows one to specify the specifc model directories view. 88 | 89 | For each model directory, the current checkpoint results are displayed. These generally include a print-off of images for train and validation, as well as loss curves on the train set and latent space embeddings. 90 | 91 | ## Loading and using a trained model 92 | 93 | To programmatically access a model, the function `integrated_cell/utils/utils.py::load_network_from_dir` may be used. Typical usage is as follows: 94 | 95 | ```python 96 | from integrated_cell import utils 97 | 98 | model_dir = "my/model/save/dir/" 99 | networks, data_provider, args = utils.load_network_from_dir(model_dir) 100 | 101 | encoder = networks['Enc'] 102 | decoder = networks['Dec'] 103 | ``` 104 | 105 | If one moves the model save directory, or one uses a relative path in the `ic_train_model` call for the `--data_save_path`, relative components may be overwritten with the parameter `parent_dir`. 106 | Also if you would like to load a specific checkpoint, one can use the parameter `suffix`. An example is as follows: 107 | 108 | ```python 109 | model_dir = "/my/model/save/dir/" 110 | parent_dir = "/my/model" 111 | suffix = "_93300" # this is the iteration at which the model was saved 112 | 113 | networks, data_provider, args = utils.load_network_from_dir(model_dir, parent_dir=parent_dir, suffix=suffix) 114 | ``` 115 | 116 | ## More: 117 | [Training different types of models](./training_examples.md) -------------------------------------------------------------------------------- /doc/images/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/doc/images/model_arch.png -------------------------------------------------------------------------------- /doc/images/stats_apex_vs_not_apex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/doc/images/stats_apex_vs_not_apex.png -------------------------------------------------------------------------------- /doc/training_examples.md: -------------------------------------------------------------------------------- 1 | # Training Examples 2 | 3 | The integrated cell code base supports training of many types of models. 4 | We use a simple taxonomy to identify how to load different network components: 5 | 6 | Autoencoders (`ae`) contain an encoder and a decoder. Examples of this are: 7 | - vanilla autoencoders 8 | - [variational autoencoders](https://arxiv.org/abs/1312.6114) 9 | - [conditional variational autoencoders](https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models) 10 | 11 | Advarsarial Autoencoders (`aae`) contain an advarsary on the encoder in addition to the decoder. 12 | - [advarsarial autoencoders](https://arxiv.org/abs/1511.05644) 13 | 14 | Autoencoder GANs (`aegan`) contain a GAN attached to the decoder. 15 | - [Autoencoder GANs](https://arxiv.org/abs/1512.09300) 16 | 17 | Advarsarial Autoencoder GANs (`aaegan`) contain a advarsary attached to the encoder and a different advarsary attached to the decoder. 18 | 19 | Examples of how to configure training code for different models is as follows: 20 | 21 | 22 | -------------------------------------------------------------------------------- /doc/training_notes.md: -------------------------------------------------------------------------------- 1 | # Training Tips and Thoughts 2 | 3 | ### Compute 4 | 5 | Generally need a machine with a GPU 6 | `dgx-aics-dcp-001.corp.alleninstitute.org` is a good place to start 7 | 8 | ### General 9 | - For short training sessions (for debugging) utilize the cli command `--ndat`. This overwrites the number of data in the `data_provider` object for shorter epoch times. When used, it's generally set to 2x batch size. 10 | 11 | - Launch jobs in a screen session in case you lose connection to the host machine 12 | 13 | - From `integrated_cell.models.base_model.Model`, an "epoch" is defined as `np.ceil(len(data_provider) / data_provider.batch_size)` training-step iterations. `data_provider.get_sample()` controls how samples are returned (usually sampling with replacement). 14 | 15 | - To specify a specific directory to save in, use the CLI command `--save_dir`. To not have to specify a new dir every run, use the CLI command `--save_parent`, and the code will dump everything in a date-timestamped directory. 16 | 17 | ### Recommended Usage 18 | - `integrated_cell` installed in a fresh Conda env. 19 | - Run training from a screen session that is running on the host machine. 20 | - Interrogate training models with `examples/plot_error.ipynb`. 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | import quilt3 2 | 3 | pkg = quilt3.Package.browse( 4 | "aics/pipeline_integrated_single_cell", 5 | registry="s3://allencell", 6 | top_hash="7fd488f05ec41968607c7263cb13b3e70812972a24e832ef6f72195bdd35f1b2", 7 | ) 8 | 9 | pkg.fetch("./data/") 10 | -------------------------------------------------------------------------------- /examples/3D_benchmark_tests/run_3D.sh: -------------------------------------------------------------------------------- 1 | gpu_ids=$1 2 | save_dir=$2 3 | trainer_type=$3 4 | batch_size=$4 5 | ndat=$5 6 | 7 | image_dir=../../data/ 8 | csv_name=metadata.csv 9 | 10 | 11 | ic_train_model \ 12 | --gpu_ids $gpu_ids \ 13 | --model_type ae \ 14 | --save_dir $save_dir \ 15 | --lr_enc 2E-4 --lr_dec 2E-4 \ 16 | --data_save_path $save_dir/data.pyt \ 17 | --crit_recon integrated_cell.losses.BatchMSELoss \ 18 | --kwargs_crit_recon '{}' \ 19 | --network_name vaegan3D_cgan_p \ 20 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \ 21 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 22 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \ 23 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 24 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1E9, "save_progress_iter": 1E9}' \ 25 | --train_module $trainer_type \ 26 | --imdir $image_dir \ 27 | --dataProvider DataProvider \ 28 | --kwargs_dp '{"crop_to": [160, 96, 64], "check_files": 0, "csv_name": "'$csv_name'"}' \ 29 | --saveStateIter 1 --saveProgressIter 1 \ 30 | --channels_pt1 0 1 2 \ 31 | --batch_size $batch_size \ 32 | --nepochs 1 \ 33 | --ndat $ndat 34 | -------------------------------------------------------------------------------- /examples/3D_benchmark_tests/run_3D_big.sh: -------------------------------------------------------------------------------- 1 | gpu_ids=$1 2 | save_dir=$2 3 | opt_level=$3 4 | batch_size=16 5 | 6 | # image_dir=../../data/ 7 | # csv_name=metadata.csv 8 | 9 | image_dir=/raid/shared/ipp/scp_19_04_10/ 10 | csv_name=controls/data_plus_controls.csv 11 | 12 | ic_train_model \ 13 | --gpu_ids $gpu_ids \ 14 | --model_type ae \ 15 | --save_dir $save_dir \ 16 | --lr_enc 2E-4 --lr_dec 2E-4 \ 17 | --data_save_path $save_dir/data.pyt \ 18 | --crit_recon torch.nn.MSELoss \ 19 | --kwargs_crit_recon '{"reduction": "sum"}' \ 20 | --network_name vaegan3D_cgan_p \ 21 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \ 22 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 23 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \ 24 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 25 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1, "save_progress_iter": 1, "opt_level": "'$opt_level'"}' \ 26 | --train_module cbvae_apex \ 27 | --imdir $image_dir \ 28 | --dataProvider DataProvider \ 29 | --kwargs_dp '{"crop_to": [160, 96, 64], "check_files": 0, "csv_name": "'$csv_name'"}' \ 30 | --saveStateIter 1 --saveProgressIter 1 \ 31 | --channels_pt1 0 1 2 \ 32 | --batch_size $batch_size \ 33 | --nepochs 3 \ 34 | -------------------------------------------------------------------------------- /examples/3D_benchmark_tests/run_3D_small.sh: -------------------------------------------------------------------------------- 1 | gpu_ids=$1 2 | save_dir=$2 3 | opt_level=$3 4 | batch_size=16 5 | 6 | # image_dir=../../data/ 7 | # csv_name=metadata.csv 8 | 9 | image_dir=/raid/shared/ipp/scp_19_04_10/ 10 | csv_name=controls/data_plus_controls.csv 11 | 12 | ic_train_model \ 13 | --gpu_ids $gpu_ids \ 14 | --model_type ae \ 15 | --save_dir $save_dir \ 16 | --lr_enc 2E-4 --lr_dec 2E-4 \ 17 | --data_save_path $save_dir/data.pyt \ 18 | --crit_recon torch.nn.MSELoss \ 19 | --kwargs_crit_recon '{"reduction": "sum"}' \ 20 | --network_name vaegan3D_cgan_p \ 21 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "conv_channels_list": [32, 64, 128]}' \ 22 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 23 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus", "conv_channels_list": [128, 64, 32]}' \ 24 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 25 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1, "save_progress_iter": 1, "opt_level": "'$opt_level'"}' \ 26 | --train_module cbvae_apex \ 27 | --imdir $image_dir \ 28 | --dataProvider DataProvider \ 29 | --kwargs_dp '{"rescale_to": 0.25, "crop_to": [40, 24, 16], "check_files": 0, "csv_name": "'$csv_name'"}' \ 30 | --saveStateIter 1 --saveProgressIter 1 \ 31 | --channels_pt1 0 1 2 \ 32 | --batch_size $batch_size \ 33 | --nepochs 5 \ 34 | -------------------------------------------------------------------------------- /examples/3D_benchmark_tests/run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker run --runtime=nvidia \ 4 | -v $PWD/../../:$PWD/../../ \ 5 | -v $PWD/../../:/root/projects/pytorch_integrated_cell \ 6 | -v /raid/shared:/raid/shared \ 7 | aics/pytorch_integrated_cell \ 8 | /bin/bash -c " cd $PWD; bash run_3D.sh '$1' $2 $3 $4 $5" 9 | -------------------------------------------------------------------------------- /examples/3D_benchmark_tests/run_docker_apex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker run --detach --runtime=nvidia \ 4 | -v $PWD/../../:$PWD/../../ \ 5 | -v $PWD/../../:/root/projects/pytorch_integrated_cell \ 6 | -v /raid/shared:/raid/shared \ 7 | aics/pytorch_integrated_cell \ 8 | /bin/bash -c " cd $PWD; bash $1 '$2' $3 '$4'" 9 | -------------------------------------------------------------------------------- /examples/loading_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Loading Models Demo\n", 8 | "\n", 9 | "This is a demo for how to load and use serialized models from a directory. \n", 10 | "\n", 11 | "Run a model in the `./training_scripts` directory, then use this demo. Here we assume you ran the `ae2D.sh` " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "from integrated_cell import utils\n", 22 | "\n", 23 | "#If you need to, use this to change the GPU with which to load a model\n", 24 | "gpu_ids = [0]\n", 25 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(ID) for ID in gpu_ids])\n", 26 | "\n", 27 | "#Specify the model directory to use\n", 28 | "model_dir = \"./training_scripts/ae2D/\"\n", 29 | "\n", 30 | "#Load the model. These variables shouldn't change as a function of model type\n", 31 | "networks, data_provider, args = utils.load_network_from_dir(model_dir)\n", 32 | "\n", 33 | "encoder = networks['enc'].cuda()\n", 34 | "decoder = networks['dec'].cuda()\n", 35 | "\n", 36 | "encoder.train(False)\n", 37 | "decoder.train(False)\n", 38 | "\n", 39 | "#If you're not clear on how to use the network, then the best thing to do is to training model and check out how it is used:\n", 40 | "f\"integrated_cell/models/{args['train_module']}.py\"" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from integrated_cell.utils.plots import imshow\n", 50 | "import integrated_cell.utils as utils\n", 51 | "import torch \n", 52 | "import matplotlib\n", 53 | "%matplotlib inline\n", 54 | "\n", 55 | "x, class_labels, ref = data_provider.get_sample('test', [0])\n", 56 | "\n", 57 | "x = x.cuda()\n", 58 | "class_labels = class_labels.cuda().long()\n", 59 | "class_labels_onehot = utils.index_to_onehot(class_labels, data_provider.get_n_classes())\n", 60 | "ref = ref.cuda()\n", 61 | "\n", 62 | "# Typical autoencoder\n", 63 | "if args['train_module'] == 'ae':\n", 64 | " \n", 65 | " # z is the latent representation of the image x\n", 66 | " z = encoder(x)\n", 67 | " xHat = decoder(z)\n", 68 | "\n", 69 | "# Beta variational autoencoder\n", 70 | "elif args['train_module'] == 'bvae':\n", 71 | "\n", 72 | " # z is the latent representation of the image x\n", 73 | " z_mu, z_sigma = encoder(x)\n", 74 | " \n", 75 | " # here we can either the \"average\" by passing z_mu into the decoder or\n", 76 | " # we can resample from N(z_mu, z_sigma)\n", 77 | " z_sampled = utils.reparameterize(z_mu, z_sigma)\n", 78 | " xHat = decoder(z_sampled)\n", 79 | " \n", 80 | "# Conditional beta variational autoencoders, sometimes with advarsarial loss\n", 81 | "elif args['train_module'] in ['cbvae2_target', 'cbvaegan_target2', 'cbvaegan_target']:\n", 82 | " # These are the conditional models: We provide the class labels and reference structures to both the encoder and decoder\n", 83 | " # and the model (hopefully) doesn't \n", 84 | " \n", 85 | " # z is the latent representation of the image x\n", 86 | " z_mu, z_sigma = encoder(x, ref, class_labels_onehot)\n", 87 | " \n", 88 | " # here we can either the \"average\" by passing z_mu into the decoder or\n", 89 | " # we can resample from N(z_mu, z_sigma)\n", 90 | " z_sampled = utils.reparameterize(z_mu, z_sigma)\n", 91 | " \n", 92 | " xHat = decoder(z_sampled, ref, class_labels_onehot)\n", 93 | " \n", 94 | "im_out = torch.cat([x, xHat], axis = 3)\n", 95 | "imshow(im_out)\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "\n" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python [conda env:pytorch_integrated_cell]", 111 | "language": "python", 112 | "name": "conda-env-pytorch_integrated_cell-py" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.6.8" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 4 129 | } 130 | -------------------------------------------------------------------------------- /examples/plot_error.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# This is the main notebook for interrogating the results of models as they train" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "## import integrated_cell.SimpleLogger as SimpleLogger\n", 17 | "import torch\n", 18 | "from IPython.core.display import display\n", 19 | "import pickle\n", 20 | "import os\n", 21 | "import PIL.Image\n", 22 | "import numpy as np\n", 23 | "import scipy.misc as misc\n", 24 | "import glob\n", 25 | "import json\n", 26 | "from natsort import natsorted\n", 27 | "\n", 28 | "%matplotlib inline\n", 29 | "\n", 30 | "class RenamingUnpickler(pickle.Unpickler):\n", 31 | " def find_class(self, module, name):\n", 32 | " if module == 'integrated_cell.SimpleLogger':\n", 33 | " module = 'integrated_cell.simplelogger'\n", 34 | " return super().find_class(module, name)\n", 35 | "\n", 36 | "\n", 37 | "import warnings\n", 38 | "warnings.filterwarnings(\"ignore\")\n", 39 | "\n", 40 | "dirs_to_search = [\n", 41 | " \n", 42 | "# '/allen/aics/modeling/gregj/results/integrated_cell/test_cbvae_3D_avg_inten/*',\n", 43 | "# '/allen/aics/modeling/gregj/results/integrated_cell/test_cbvae_beta_ref/*',\n", 44 | " './training_scripts/*/'\n", 45 | " \n", 46 | " ]\n", 47 | "\n", 48 | "model_dirs = list()\n", 49 | "\n", 50 | "for my_dir in dirs_to_search:\n", 51 | " model_dirs += glob.glob(my_dir)\n", 52 | "\n", 53 | "model_dirs = natsorted(model_dirs)[::-1] \n", 54 | " \n", 55 | "\n", 56 | "def show_dir(model_dir): \n", 57 | " logger_file = '{0}/logger_tmp.pkl'.format(model_dir)\n", 58 | " \n", 59 | " if not os.path.exists(logger_file):\n", 60 | " print('Could not find logger at ' + logger_file)\n", 61 | " return\n", 62 | "\n", 63 | " print(model_dir)\n", 64 | " \n", 65 | " try:\n", 66 | " logger = RenamingUnpickler(open( logger_file, \"rb\" )).load()\n", 67 | "# logger = pickle.load( open( logger_file, \"rb\" ) )\n", 68 | " epoch = max(logger.log['epoch'])\n", 69 | " except:\n", 70 | " pass\n", 71 | " \n", 72 | " try:\n", 73 | " opt = pickle.load( open( '{0}/opt.pkl'.format(model_dir), \"rb\" ))\n", 74 | " except:\n", 75 | " opt = json.load(open( '{0}/../args.json'.format(model_dir), \"rb\" ))\n", 76 | "\n", 77 | " print(opt)\n", 78 | " print('Epoch: ' + str(epoch))\n", 79 | "\n", 80 | " im_progress_path = '{0}/progress_{1}.png'.format(model_dir, int(epoch))\n", 81 | " im_progress = misc.imread(im_progress_path)\n", 82 | "\n", 83 | " im_history_path = '{0}/history.png'.format(model_dir)\n", 84 | " im_history = misc.imread(im_history_path) \n", 85 | "\n", 86 | " im_history_short_path = '{0}/history_short.png'.format(model_dir)\n", 87 | " im_history_short = misc.imread(im_history_short_path) \n", 88 | "\n", 89 | " im_embedding_path = '{0}/embedding.png'.format(model_dir)\n", 90 | " im_embedding = misc.imread(im_embedding_path) \n", 91 | "\n", 92 | " display(PIL.Image.fromarray(im_progress))\n", 93 | " display(PIL.Image.fromarray(np.concatenate((im_history, im_history_short, im_embedding), 1)))\n", 94 | " print(' ')\n", 95 | " \n", 96 | "for model_dir in model_dirs:\n", 97 | " \n", 98 | " sub_dirs = glob.glob(os.path.join(model_dir, 'ref_model/'))\n", 99 | " \n", 100 | " for sub_dir in sub_dirs:\n", 101 | " try:\n", 102 | " show_dir(sub_dir)\n", 103 | " except:\n", 104 | " print('could not load ' + model_dir)\n", 105 | " " 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python [conda env:pytorch_integrated_cell]", 112 | "language": "python", 113 | "name": "conda-env-pytorch_integrated_cell-py" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.6.8" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 4 130 | } 131 | -------------------------------------------------------------------------------- /examples/training_scripts/ae2D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/ae2D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/ae2D/data_ae.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name ae2D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch": 3}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch": 3}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{}' \ 15 | --train_module ae \ 16 | --imdir $PWD/../../data/ \ 17 | --dataProvider DataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/ae3D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/ae3D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/ae3D/data_ae.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name ae3D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch": 3}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch": 3}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{}' \ 15 | --train_module ae \ 16 | --imdir $PWD/../../data/ \ 17 | --dataProvider DataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/bvae2D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/bvae2D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/bvae2D/data_ae.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name cvaegan2D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{"beta": 1}' \ 15 | --train_module bvae \ 16 | --imdir $PWD/../../data/ \ 17 | --dataProvider DataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/bvae3D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/bvae3D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/bvae3D/data_ae.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name cvaegan3D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{"beta": 1}' \ 15 | --train_module bvae \ 16 | --imdir $PWD/../../data/ \ 17 | --dataProvider DataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/cbvae2D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/cbvae2D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/cbvae2D/data.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name cvaegan2D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{"beta": 1}' \ 15 | --train_module cbvae2_target \ 16 | --imdir ../../data/ \ 17 | --dataProvider TargetDataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/cbvae3D.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type ae \ 4 | --save_dir $PWD/cbvae3D \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 \ 6 | --data_save_path $PWD/cbvae3D/data.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --network_name cvaegan3D_residual \ 10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \ 12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \ 14 | --kwargs_model '{"beta": 1}' \ 15 | --train_module cbvae2_target \ 16 | --imdir $PWD/../../data/ \ 17 | --dataProvider TargetDataProvider \ 18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \ 19 | --saveStateIter 1 --saveProgressIter 1 \ 20 | --channels_pt1 0 1 2 \ 21 | --batch_size 32 \ 22 | --nepochs 300 \ 23 | -------------------------------------------------------------------------------- /examples/training_scripts/cbvaegan2D_target.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type aegan \ 4 | --save_dir $PWD/cbvaegan2D_target \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 --lr_decD 5E-4 \ 6 | --data_save_path $PWD/cbvaegan2D_target/data.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --crit_decD torch.nn.BCEWithLogitsLoss \ 10 | --kwargs_crit_decD '{"reduction": "sum"}' \ 11 | --network_name cvaegan2D_residual \ 12 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 13 | --kwargs_enc_optim '{"betas": [0.5, 0.9]}' \ 14 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 15 | --kwargs_dec_optim '{"betas": [0.5, 0.9]}' \ 16 | --kwargs_decD '{"n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 17 | --kwargs_decD_optim '{"betas": [0.5, 0.9]}' \ 18 | --kwargs_model '{"beta": 1, "lambda_decD_loss": 1}' \ 19 | --train_module cbvaegan_target \ 20 | --imdir $PWD/../../data/ \ 21 | --dataProvider TargetDataProvider \ 22 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \ 23 | --saveStateIter 1 --saveProgressIter 1 \ 24 | --channels_pt1 0 1 2 \ 25 | --batch_size 64 \ 26 | --nepochs 300 \ 27 | -------------------------------------------------------------------------------- /examples/training_scripts/cbvaegan2D_target2.sh: -------------------------------------------------------------------------------- 1 | ic_train_model \ 2 | --gpu_ids $1 \ 3 | --model_type aegan \ 4 | --save_dir $PWD/cbvaegan2D_target2 \ 5 | --lr_enc 2E-4 --lr_dec 2E-4 --lr_decD 5E-4 \ 6 | --data_save_path $PWD/cbvaegan2D_target2/data.pyt \ 7 | --crit_recon integrated_cell.losses.BatchMSELoss \ 8 | --kwargs_crit_recon '{}' \ 9 | --crit_decD torch.nn.CrossEntropyLoss \ 10 | --kwargs_crit_decD '{"reduction": "sum"}' \ 11 | --network_name cvaegan2D_residual \ 12 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 13 | --kwargs_enc_optim '{"betas": [0.5, 0.9]}' \ 14 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \ 15 | --kwargs_dec_optim '{"betas": [0.5, 0.9]}' \ 16 | --kwargs_decD '{"n_ch_target": 1, "n_ch_ref": 2, "n_classes": 0, "n_classes_out": 25}' \ 17 | --kwargs_decD_optim '{"betas": [0.5, 0.9]}' \ 18 | --kwargs_model '{"beta": 1, "lambda_decD_loss": 1}' \ 19 | --train_module cbvaegan_target2 \ 20 | --imdir $PWD/../../data/ \ 21 | --dataProvider TargetDataProvider \ 22 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \ 23 | --saveStateIter 1 --saveProgressIter 1 \ 24 | --channels_pt1 0 1 2 \ 25 | --batch_size 64 \ 26 | --nepochs 300 \ 27 | -------------------------------------------------------------------------------- /integrated_cell/__init__.py: -------------------------------------------------------------------------------- 1 | from .imgToProjection import imgtoprojection # noqa 2 | from .simplelogger import SimpleLogger # noqa 3 | -------------------------------------------------------------------------------- /integrated_cell/arc_walk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def cart_from_pol(r,theta): 4 | 5 | x = np.zeros(len(theta)+1) 6 | sin = np.sin(theta) 7 | cos = np.cos(theta) 8 | 9 | x[0] = r * cos[0] 10 | 11 | for i in range(1,len(theta)): 12 | x[i] = r * cos[i] * np.prod(sin[:i]) 13 | 14 | x[-1] = r * np.prod(sin) 15 | 16 | return x 17 | 18 | 19 | def pol_from_cart(x): 20 | 21 | x_sq=x**2 22 | r = np.sqrt(np.sum(x_sq)) 23 | 24 | if len(x) > 1: 25 | theta = np.zeros(len(x)-1) 26 | 27 | for i,_ in enumerate(theta): 28 | theta[i] = np.arccos( x[i]/np.sqrt(np.sum(x_sq[i:])) ) 29 | 30 | if x[-1] < 0: 31 | theta[-1] *= -1.0 32 | 33 | return (r,theta) 34 | 35 | 36 | def shortest_angular_path(theta_start, theta_end, N_points): 37 | 38 | theta_end %= (2*np.pi) 39 | theta_start %= (2*np.pi) 40 | 41 | # print('theta_end = ', theta_end) 42 | # print('theta_start = ', theta_start) 43 | 44 | swap = theta_end < theta_start 45 | if swap: 46 | theta_end, theta_start = theta_start, theta_end 47 | 48 | theta = theta_end - theta_start 49 | # print('theta = ', theta) 50 | 51 | if theta <= np.pi: 52 | path = np.linspace(0, theta, N_points) 53 | else: 54 | path = np.linspace(theta, 2*np.pi, N_points) 55 | path %= (2*np.pi) 56 | path = np.flipud(path) 57 | 58 | path += theta_start 59 | path %= (2*np.pi) 60 | 61 | if swap: 62 | path = np.flipud(path) 63 | 64 | return(path) 65 | 66 | 67 | def linspace_sph_pol(x, y, N_points): 68 | 69 | assert len(x) == len(y) 70 | 71 | r1,theta1 = pol_from_cart(x) 72 | r2,theta2 = pol_from_cart(y) 73 | 74 | r_path = np.linspace(r1, r2, N_points) 75 | 76 | theta_path = np.zeros([N_points, len(theta1)]) 77 | for i in range(len(theta1)-1): 78 | theta_path[:,i] = np.linspace(theta1[i], theta2[i], N_points) 79 | 80 | theta_path[:,-1] = shortest_angular_path(theta1[-1],theta2[-1],N_points) 81 | 82 | cart_path = np.zeros([N_points, len(x)]) 83 | for i in range(N_points): 84 | x_i = cart_from_pol(r_path[i],theta_path[i,:]) 85 | cart_path[i,:] = x_i 86 | 87 | return cart_path 88 | -------------------------------------------------------------------------------- /integrated_cell/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/bin/__init__.py -------------------------------------------------------------------------------- /integrated_cell/corr_stats.py: -------------------------------------------------------------------------------- 1 | # copy and pasted from https://github.com/pytorch/pytorch/issues/1254 2 | 3 | import torch 4 | 5 | import pdb 6 | 7 | def pearsonr(x, y): 8 | """ 9 | Mimics `scipy.stats.pearsonr` 10 | 11 | Arguments 12 | --------- 13 | x : 1D torch.Tensor 14 | y : 1D torch.Tensor 15 | 16 | Returns 17 | ------- 18 | r_val : float 19 | pearsonr correlation coefficient between x and y 20 | 21 | Scipy docs ref: 22 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html 23 | 24 | Scipy code ref: 25 | https://github.com/scipy/scipy/blob/v0.19.0/scipy/stats/stats.py#L2975-L3033 26 | Example: 27 | >>> x = np.random.randn(100) 28 | >>> y = np.random.randn(100) 29 | >>> sp_corr = scipy.stats.pearsonr(x, y)[0] 30 | >>> th_corr = pearsonr(torch.from_numpy(x), torch.from_numpy(y)) 31 | >>> np.allclose(sp_corr, th_corr) 32 | """ 33 | mean_x = torch.mean(x) 34 | mean_y = torch.mean(y) 35 | xm = x.sub(mean_x.data[0]) 36 | ym = y.sub(mean_y.data[0]) 37 | r_num = xm.dot(ym) 38 | r_den = torch.norm(xm, 2) * torch.norm(ym, 2) 39 | r_val = r_num / r_den 40 | return r_val 41 | 42 | def corrcoef(x): 43 | """ 44 | Mimics `np.corrcoef` 45 | 46 | Arguments 47 | --------- 48 | x : 2D torch.Tensor 49 | 50 | Returns 51 | ------- 52 | c : torch.Tensor 53 | if x.size() = (5, 100), then return val will be of size (5,5) 54 | 55 | Numpy docs ref: 56 | https://docs.scipy.org/doc/numpy/reference/generated/numpy.corrcoef.html 57 | Numpy code ref: 58 | https://github.com/numpy/numpy/blob/v1.12.0/numpy/lib/function_base.py#L2933-L3013 59 | 60 | Example: 61 | >>> x = np.random.randn(5,120) 62 | # result is a (5,5) matrix of correlations between rows 63 | >>> np_corr = np.corrcoef(x) 64 | >>> th_corr = corrcoef(torch.from_numpy(x)) 65 | >>> np.allclose(np_corr, th_corr.numpy()) 66 | # [out]: True 67 | """ 68 | # calculate covariance matrix of rows 69 | mean_x = torch.mean(x, 1).unsqueeze(1) 70 | xm = x.sub(mean_x.expand_as(x)) 71 | c = xm.mm(xm.t()) 72 | c = c / (x.size(1) - 1) 73 | 74 | # normalize covariance matrix 75 | d = torch.diag(c) 76 | stddev = torch.pow(d, 0.5) 77 | c = c.div(stddev.expand_as(c)) 78 | c = c.div(stddev.expand_as(c).t()) 79 | 80 | # clamp between -1 and 1 81 | # probably not necessary but numpy does it 82 | c = torch.clamp(c, -1.0, 1.0) 83 | 84 | return c -------------------------------------------------------------------------------- /integrated_cell/data_providers/DataProviderABC.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class DataProviderABC(ABC): 4 | 5 | @abstractmethod 6 | def get_n_dat(self, train_or_test): 7 | pass 8 | 9 | @abstractmethod 10 | def get_n_classes(self): 11 | pass 12 | 13 | @abstractmethod 14 | def get_image_paths(self, inds, train_or_test): 15 | pass 16 | 17 | @abstractmethod 18 | def get_images(self, inds, train_or_test): 19 | pass 20 | 21 | @abstractmethod 22 | def get_classes(self, inds, train_or_test): 23 | pass 24 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/GraphDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from integrated_cell.data_providers.DataProvider import ( 3 | DataProvider as ParentDataProvider, 4 | ) # ugh im sorry 5 | 6 | from .. import utils 7 | 8 | 9 | class DataProvider(ParentDataProvider): 10 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 11 | def __init__(self, masked_channels=[], **kwargs): 12 | 13 | super().__init__(**kwargs) 14 | 15 | self.masked_channels = masked_channels 16 | 17 | def get_sample(self, train_or_test="train", inds=None): 18 | # returns 19 | # x is b by c by y by x 20 | # x_class is b by c by #classes 21 | # graph is b by c by c - a random dag over channels 22 | 23 | x, classes, _ = super().get_sample(train_or_test=train_or_test, inds=inds) 24 | 25 | n_b = x.shape[0] 26 | n_c = x.shape[1] 27 | 28 | classes = classes.type_as(x).long() 29 | classes = utils.index_to_onehot(classes, self.get_n_classes()) 30 | 31 | classes_mem = torch.ones(n_b) * self.get_n_classes() - 2 32 | classes_mem = utils.index_to_onehot(classes_mem, self.get_n_classes()) 33 | 34 | classes_dna = torch.ones(n_b) * self.get_n_classes() - 1 35 | classes_dna = utils.index_to_onehot(classes_dna, self.get_n_classes()) 36 | 37 | classes = [torch.unsqueeze(c, 1) for c in [classes_mem, classes, classes_dna]] 38 | classes = torch.cat(classes, 1) 39 | 40 | graph = torch.zeros(x.shape[0], x.shape[1], x.shape[1]) 41 | 42 | graphs = list() 43 | for i in range(n_b): 44 | graph = ( 45 | torch.ones([n_c, n_c]).triu(1) * torch.zeros([n_c, n_c]).bernoulli_() 46 | ) 47 | rand_inds = torch.randperm(n_c) 48 | 49 | graph = graph[:, rand_inds][rand_inds, :] 50 | graphs.append(graph) 51 | 52 | graph = torch.stack(graphs, 0).long() 53 | 54 | return x, classes, graph 55 | 56 | def get_n_classes(self): 57 | return super().get_n_classes() + 2 58 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/MaskedChannelsDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from integrated_cell.data_providers.DataProvider import ( 3 | DataProvider as ParentDataProvider, 4 | ) # ugh im sorry 5 | 6 | 7 | class DataProvider(ParentDataProvider): 8 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 9 | def __init__(self, masked_channels=[], **kwargs): 10 | 11 | super().__init__(**kwargs) 12 | 13 | self.masked_channels = masked_channels 14 | 15 | def get_sample(self, train_or_test="train", inds=None): 16 | 17 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 18 | 19 | # build a mask for channels 20 | n_channels = x.shape[1] 21 | channel_mask = torch.zeros(n_channels).byte() 22 | channel_mask[self.masked_channels] = 1 23 | 24 | x[:, channel_mask] = 0 25 | 26 | return x, classes, ref 27 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/MultiScaleDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from integrated_cell.data_providers.DataProvider import ( 3 | DataProvider as ParentDataProvider, 4 | ) # ugh im sorry 5 | 6 | 7 | class DataProvider(ParentDataProvider): 8 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 9 | def __init__(self, n_x_sub=4, **kwargs): 10 | 11 | super().__init__(**kwargs) 12 | 13 | self.n_x_sub = n_x_sub 14 | 15 | def get_sample(self, train_or_test="train", inds=None, patched=True): 16 | # returns 17 | # x is b by c by y by x 18 | # x_class is b by c by #classes 19 | # graph is b by c by c - a random dag over channels 20 | 21 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 22 | 23 | # this will be faster 24 | x = x.cuda() 25 | 26 | scales = 1 / (2 ** torch.arange(0, self.n_x_sub).float()) 27 | 28 | x = [x] + [ 29 | torch.nn.functional.interpolate(x, scale_factor=scale.item()) 30 | for scale in scales[1:] 31 | ] 32 | 33 | return x, classes, ref 34 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/PatchDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from integrated_cell.data_providers.DataProvider import ( 3 | DataProvider as ParentDataProvider, 4 | ) # ugh im sorry 5 | 6 | 7 | def get_patch_slice(shape, patch_size): 8 | patch_size = torch.tensor(patch_size) 9 | 10 | shape_dims = torch.tensor(shape) - patch_size 11 | 12 | starts = torch.cat( 13 | [torch.randint(i, [1]) if i > 0 else torch.tensor([0]) for i in shape_dims], 0 14 | ) 15 | ends = starts + patch_size 16 | 17 | slices = tuple(slice(s, e) for s, e in zip(starts, ends)) 18 | 19 | # x = x[slices] 20 | 21 | return slices 22 | 23 | 24 | class DataProvider(ParentDataProvider): 25 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 26 | def __init__( 27 | self, 28 | patch_size=[3, 64, 64, 32], 29 | default_return_mesh=False, 30 | default_return_patch=True, 31 | **kwargs 32 | ): 33 | 34 | super().__init__(**kwargs) 35 | 36 | self.patch_size = patch_size 37 | self.default_return_mesh = default_return_mesh 38 | self.default_return_patch = default_return_patch 39 | 40 | def get_sample( 41 | self, train_or_test="train", inds=None, patched=None, return_mesh=None 42 | ): 43 | # returns 44 | # x is b by c by y by x 45 | # x_class is b by c by #classes 46 | # graph is b by c by c - a random dag over channels 47 | 48 | if return_mesh is None: 49 | return_mesh = self.default_return_mesh 50 | 51 | if patched is None: 52 | patched = self.default_return_patch 53 | 54 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 55 | 56 | shape = x[0].shape 57 | 58 | if return_mesh: 59 | 60 | meshes = [torch.arange(0, i) - (i // 2) for i in list(shape[1:])] 61 | mesh = torch.meshgrid(meshes) 62 | mesh = torch.stack([m.float() for m in mesh], 0) 63 | 64 | if not patched: 65 | ref = torch.stack([mesh for i in range(x.shape[0])], 0) 66 | 67 | if patched: 68 | 69 | slices = [ 70 | get_patch_slice(shape, self.patch_size) for i in range(x.shape[0]) 71 | ] 72 | 73 | x = torch.stack( 74 | [x_sub[slices_sub] for x_sub, slices_sub in zip(x, slices)], 0 75 | ) 76 | 77 | if return_mesh: 78 | ref = torch.stack( 79 | [ 80 | mesh[[slice(0, mesh.shape[0])] + list(slices[0][1:])] 81 | for slices_sub in slices 82 | ], 83 | 0, 84 | ) 85 | 86 | return x, classes, ref 87 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/ProjectedDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from integrated_cell.data_providers.DataProvider import ( 4 | DataProvider as ParentDataProvider, 5 | ) # ugh im sorry 6 | 7 | 8 | class DataProvider(ParentDataProvider): 9 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 10 | def __init__(self, channel_intensity_values=None, slice_or_proj="slice", **kwargs): 11 | 12 | super().__init__(**kwargs) 13 | 14 | self.channel_intensity_values = channel_intensity_values 15 | self.slice_or_proj = slice_or_proj 16 | 17 | def get_sample(self, train_or_test="train", inds=None): 18 | # returns 19 | # x is b by c by y by x 20 | # x_class is b by c by #classes 21 | # graph is b by c by c - a random dag over channels 22 | 23 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 24 | 25 | if self.channel_intensity_values: 26 | for i in range(x.shape[0]): 27 | for j in range(x.shape[1]): 28 | if torch.sum(x[i, j]) > 0: 29 | x[i, j] = x[i, j] * ( 30 | self.channel_intensity_values[j] / torch.sum(x[i, j]) 31 | ) 32 | 33 | if self.slice_or_proj == "slice": 34 | center_of_image = torch.tensor(x.shape[2:]) / 2 35 | 36 | zx = x[:, :, center_of_image[0], :, :] 37 | zy = x[:, :, :, center_of_image[1], :] 38 | xy = x[:, :, :, :, center_of_image[2]] 39 | 40 | x = [zx, zy, xy] 41 | 42 | elif self.slice_or_proj == "proj": 43 | zx = torch.max(x, 2)[0] 44 | zy = torch.max(x, 3)[0] 45 | xy = torch.max(x, 4)[0] 46 | 47 | x = [zx, zy, xy] 48 | 49 | return x, classes, ref 50 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/RefDataProvider.py: -------------------------------------------------------------------------------- 1 | from integrated_cell.data_providers.DataProvider import ( 2 | DataProvider as ParentDataProvider, 3 | ) # ugh im sorry 4 | 5 | 6 | class DataProvider(ParentDataProvider): 7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 8 | def __init__(self, **kwargs): 9 | 10 | super().__init__(**kwargs) 11 | 12 | def get_sample(self, train_or_test="train", inds=None): 13 | # returns 14 | # x is b by c by y by x 15 | # x_class is b by c by #classes 16 | # graph is b by c by c - a random dag over channels 17 | 18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 19 | 20 | x = x[:, [0, 2]] 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/RescaledIntensityDataProvider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from integrated_cell.data_providers.DataProvider import ( 4 | DataProvider as ParentDataProvider, 5 | ) # ugh im sorry 6 | 7 | 8 | class DataProvider(ParentDataProvider): 9 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 10 | def __init__( 11 | self, channel_intensity_values=[2596.9521, 2596.9521, 2596.9521], **kwargs 12 | ): 13 | 14 | super().__init__(**kwargs) 15 | 16 | self.channel_intensity_values = channel_intensity_values 17 | 18 | def get_sample(self, train_or_test="train", inds=None): 19 | # returns 20 | # x is b by c by y by x 21 | # x_class is b by c by #classes 22 | # graph is b by c by c - a random dag over channels 23 | 24 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 25 | 26 | for i in range(x.shape[0]): 27 | for j in range(x.shape[1]): 28 | if torch.sum(x[i, j]) > 0: 29 | x[i, j] = x[i, j] * ( 30 | self.channel_intensity_values[j] / torch.sum(x[i, j]) 31 | ) 32 | 33 | return x, classes, ref 34 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/RescaledIntensityRefDataProvider.py: -------------------------------------------------------------------------------- 1 | from integrated_cell.data_providers.RescaledIntensityDataProvider import ( 2 | DataProvider as ParentDataProvider, 3 | ) # ugh im sorry 4 | 5 | 6 | class DataProvider(ParentDataProvider): 7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 8 | def __init__(self, **kwargs): 9 | 10 | super().__init__(**kwargs) 11 | 12 | def get_sample(self, train_or_test="train", inds=None): 13 | # returns 14 | # x is b by c by y by x 15 | # x_class is b by c by #classes 16 | # graph is b by c by c - a random dag over channels 17 | 18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 19 | 20 | x = x[:, [0, 2]] 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/RescaledIntensityTargetDataProvider.py: -------------------------------------------------------------------------------- 1 | from integrated_cell.data_providers.RescaledIntensityDataProvider import ( 2 | DataProvider as ParentDataProvider, 3 | ) # ugh im sorry 4 | 5 | 6 | class DataProvider(ParentDataProvider): 7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 8 | def __init__(self, target_inds=[1], ref_inds=[0, 2], **kwargs): 9 | 10 | super().__init__(**kwargs) 11 | 12 | def get_sample(self, train_or_test="train", inds=None): 13 | # returns 14 | # x is b by c by y by x 15 | # x_class is b by c by #classes 16 | # graph is b by c by c - a random dag over channels 17 | 18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 19 | 20 | ref = x[:, [0, 2]] 21 | x = x[:, [1]] 22 | 23 | return x, classes, ref 24 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/TargetDataProvider.py: -------------------------------------------------------------------------------- 1 | from integrated_cell.data_providers.DataProvider import ( 2 | DataProvider as ParentDataProvider, 3 | ) # ugh im sorry 4 | 5 | 6 | class DataProvider(ParentDataProvider): 7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels' 8 | def __init__(self, **kwargs): 9 | 10 | super().__init__(**kwargs) 11 | 12 | def get_sample(self, train_or_test="train", inds=None): 13 | # returns 14 | # x is b by c by y by x 15 | # x_class is b by c by #classes 16 | # graph is b by c by c - a random dag over channels 17 | 18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds) 19 | 20 | ref = x[:, [0, 2]] 21 | x = x[:, [1]] 22 | 23 | return x, classes, ref 24 | -------------------------------------------------------------------------------- /integrated_cell/data_providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/data_providers/__init__.py -------------------------------------------------------------------------------- /integrated_cell/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/external/__init__.py -------------------------------------------------------------------------------- /integrated_cell/external/pytorch_fid/README.md: -------------------------------------------------------------------------------- 1 | Original Repository: [pytorch-fid](https://github.com/mseitzer/pytorch-fid) 2 | 3 | # Fréchet Inception Distance (FID score) in PyTorch 4 | 5 | This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch. 6 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow. 7 | 8 | FID is a measure of similarity between two datasets of images. 9 | It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. 10 | FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. 11 | 12 | Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337). 13 | 14 | **Note that the official implementation gives slightly different scores.** If you report FID scores in your paper, and you want them to be exactly comparable to FID scores reported in other papers, you should use [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR). 15 | You can still use this version if you want a quick FID estimate without installing Tensorflow. 16 | 17 | **Update:** The weights and the model are now exactly the same as in the official Tensorflow implementation, and I verified them to give the same results (around `1e-8` mean absolute error) on single inputs on my platform. However, due to differences in the image interpolation implementation and library backends, FID results might still differ slightly from the original implementation. A test I ran (details are to come) resulted in `.08` absolute error and `0.0009` relative error. 18 | 19 | ## Usage 20 | 21 | Requirements: 22 | - python3 23 | - pytorch 24 | - torchvision 25 | - numpy 26 | - scipy 27 | 28 | To compute the FID score between two datasets, where images of each dataset are contained in an individual folder: 29 | ``` 30 | ./fid_score.py path/to/dataset1 path/to/dataset2 31 | ``` 32 | 33 | To run the evaluation on GPU, use the flag `--gpu N`, where `N` is the index of the GPU to use. 34 | 35 | ### Using different layers for feature maps 36 | 37 | In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer. 38 | As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance. 39 | 40 | This might be useful if the datasets you want to compare have less than the otherwise required 2048 images. 41 | Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality. 42 | The resulting scores might also no longer correlate with visual quality. 43 | 44 | You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features. 45 | The choices are: 46 | - 64: first max pooling features 47 | - 192: second max pooling featurs 48 | - 768: pre-aux classifier features 49 | - 2048: final average pooling features (this is the default) 50 | 51 | ## License 52 | 53 | This implementation is licensed under the Apache License 2.0. 54 | 55 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 56 | 57 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 58 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 59 | -------------------------------------------------------------------------------- /integrated_cell/imgToProjection.py: -------------------------------------------------------------------------------- 1 | # Author: Evan Wiederspan 2 | 3 | import numpy as np 4 | import matplotlib as mpl 5 | 6 | mpl.use("Agg") # noqa 7 | import matplotlib.pyplot as pplot 8 | 9 | 10 | def matproj(im, dim, method="max", slice_index=0): 11 | if method == "max": 12 | im = np.max(im, dim) 13 | elif method == "mean": 14 | im = np.mean(im, dim) 15 | elif method == "sum": 16 | im = np.sum(im, dim) 17 | elif method == "slice": 18 | im = im[slice_index] 19 | else: 20 | raise ValueError("Invalid projection method") 21 | return im 22 | 23 | 24 | def imgtoprojection( 25 | im1, 26 | proj_all=False, 27 | proj_method="max", 28 | colors=lambda i: [1, 1, 1], 29 | global_adjust=False, 30 | local_adjust=False, 31 | ): 32 | """ 33 | Outputs projections of a 4d CZYX numpy array into a CYX numpy array, allowing for color masks for each input channel 34 | as well as adjustment options 35 | :param im1: Either a 4d numpy array or a list of 3D or 2D numpy arrays. The input that will be projected 36 | :param proj_all: boolean. True outputs XY, YZ, and XZ projections in a grid, False just outputs XY. False by default 37 | :param proj_method: string. Method by which to do projections. 'Max' by default 38 | :param colors: Can be either a string which corresponds to a cmap function in matplotlib, a function that 39 | takes in the channel index and returns a list of numbers, or a list of lists containing the color multipliers. 40 | :param global_adjust: boolean. If true, scales each color channel to set its max to be 255 41 | after combining all channels. False by default 42 | :param local_adjust: boolean. If true, performs contrast adjustment on each channel individually. False by default 43 | :return: a CYX numpy array containing the requested projections 44 | """ 45 | 46 | # turn list of 2d or 3d arrays into single 4d array if needed 47 | try: 48 | if isinstance(im1, (type([]), type(()))): 49 | # if only YX, add a single Z dimen 50 | if im1[0].ndim == 2: 51 | im1 = [np.expand_dims(c, axis=0) for c in im1] 52 | elif im1[0].ndim != 3: 53 | raise ValueError("im1 must be a list of 2d or 3d arrays") 54 | # combine list into 4d array 55 | im = np.stack(im1) 56 | else: 57 | if im1.ndim != 4: 58 | raise ValueError("Invalid dimensions for im1") 59 | im = im1 60 | 61 | except (AttributeError, IndexError): 62 | # its not a list of np arrays 63 | raise ValueError( 64 | "im1 must be either a 4d numpy array or a list of numpy arrays" 65 | ) 66 | 67 | # color processing code 68 | if isinstance(colors, str): 69 | # pass it in to matplotlib 70 | try: 71 | colors = pplot.get_cmap(colors)(np.linspace(0, 1, im.shape[0])) 72 | except ValueError: 73 | # thrown when string is not valid function 74 | raise ValueError("Invalid cmap string") 75 | elif callable(colors): 76 | # if its a function 77 | try: 78 | colors = [colors(i) for i in range(im.shape[0])] 79 | except: # noqa 80 | raise ValueError("Invalid color function") 81 | 82 | # else, were assuming it's a list 83 | # scale colors down to 0-1 range if they're bigger than 1 84 | if any(v > 1 for v in np.array(colors).flatten()): 85 | colors = [[v / 255 for v in c] for c in colors] 86 | 87 | # create final image 88 | if not proj_all: 89 | img_final = np.zeros((3, im.shape[2], im.shape[3])) 90 | else: 91 | # y + z, x + z 92 | img_final = np.zeros((3, im.shape[2] + im.shape[1], im.shape[3] + im.shape[1])) 93 | img_piece = np.zeros(img_final.shape) 94 | # loop through all channels 95 | for i, img_c in enumerate(im): 96 | try: 97 | proj_z = matproj(img_c, 0, proj_method, img_c.shape[0] // 2) 98 | if proj_all: 99 | proj_y, proj_x = ( 100 | matproj(img_c, axis, proj_method, img_c.shape[axis] // 2) 101 | for axis in range(1, 3) 102 | ) 103 | # flipping to get them facing the right way 104 | # proj_x = np.fliplr(np.transpose(proj_x, (1, 0))) 105 | # proj_y = np.flipud(proj_y) 106 | proj_x = np.transpose(proj_x, (1, 0)) 107 | proj_y = np.flipud(proj_y) 108 | 109 | _, sy, sz = proj_z.shape[1], proj_z.shape[0], proj_y.shape[0] # noqa 110 | img_piece[:, :sy, :sz] = proj_x 111 | img_piece[:, :sy, sz:] = proj_z 112 | img_piece[:, sy:, sz:] = proj_y 113 | else: 114 | img_piece[:] = proj_z 115 | except ValueError: 116 | raise ValueError("Invalid projection function") 117 | 118 | for c in range(3): 119 | img_piece[c] *= colors[i][c] 120 | 121 | # local contrast adjustment, minus the min, divide the max 122 | if local_adjust: 123 | img_piece -= np.min(img_piece) 124 | img_piece /= np.max(img_piece) 125 | img_final += img_piece 126 | 127 | # color range adjustment, ensure that max value is 255 128 | if global_adjust: 129 | # scale color channels independently 130 | img_final /= np.max(img_final) 131 | 132 | return img_final 133 | -------------------------------------------------------------------------------- /integrated_cell/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChannelSoftmax(nn.Module): 6 | # normalizes each channel to sum to one 7 | def __init__(self): 8 | super(ChannelSoftmax, self).__init__() 9 | 10 | def forward(self, x): 11 | 12 | # https://stackoverflow.com/questions/44081007/logsoftmax-stability 13 | # or even better 14 | # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ 15 | 16 | n_dims = len(x.shape) - 2 17 | 18 | b = torch.max(x.view(x.shape[0], x.shape[1], -1), dim=2)[0] 19 | 20 | for i in range(n_dims): 21 | b = torch.unsqueeze(b, -1) 22 | 23 | x_exp = torch.exp(x - b) 24 | 25 | normalize = x_exp.clone() 26 | for i in range(n_dims): 27 | normalize = torch.sum(normalize, -1) 28 | 29 | for i in range(n_dims): 30 | normalize = torch.unsqueeze(normalize, -1) 31 | 32 | softmaxed = x_exp / normalize 33 | 34 | return softmaxed 35 | -------------------------------------------------------------------------------- /integrated_cell/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/metrics/__init__.py -------------------------------------------------------------------------------- /integrated_cell/metrics/embeddings.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | 4 | from ..utils import utils, reparameterize 5 | 6 | from ..losses import KLDLoss 7 | 8 | 9 | def get_latent_embeddings( 10 | enc, 11 | dec, 12 | dp, 13 | recon_loss, 14 | modes=["test"], 15 | batch_size=256, 16 | n_recon_samples=10, 17 | sampler=None, 18 | beta=1, 19 | channels_ref=[0, 2], 20 | channels_target=[1], 21 | ): 22 | 23 | if sampler is None: 24 | 25 | def sampler(mode, inds): 26 | return dp.get_sample(mode, inds) 27 | 28 | enc.eval() 29 | dec.eval() 30 | 31 | embedding = dict() 32 | 33 | for mode in modes: 34 | ndat = dp.get_n_dat(mode) 35 | 36 | x, classes, ref = sampler(mode, [0]) 37 | 38 | x = x.cuda() 39 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda() 40 | 41 | with torch.no_grad(): 42 | zAll = enc(x, classes_onehot) 43 | 44 | ndims = torch.prod(torch.tensor(zAll[0][0].shape[1:])) 45 | 46 | embeddings_ref_mu = torch.zeros(ndat, ndims) 47 | embeddings_ref_sigma = torch.zeros(ndat, ndims) 48 | 49 | embeddings_ref_recons = torch.zeros(ndat, n_recon_samples) 50 | embeddings_ref_kld = torch.zeros(ndat) 51 | 52 | embeddings_target_mu = torch.zeros(ndat, ndims) 53 | embeddings_target_sigma = torch.zeros(ndat, ndims) 54 | 55 | embeddings_target_recons = torch.zeros(ndat, n_recon_samples) 56 | embeddings_target_kld = torch.zeros(ndat) 57 | 58 | embeddings_classes = torch.zeros(ndat).long() 59 | 60 | inds = list(range(0, ndat)) 61 | data_iter = [ 62 | torch.LongTensor(inds[i : i + batch_size]) # noqa 63 | for i in range(0, len(inds), batch_size) # noqa 64 | ] 65 | 66 | for i in tqdm(range(0, len(data_iter))): 67 | batch_size = len(data_iter[i]) 68 | 69 | x, classes, ref = sampler(mode, data_iter[i]) 70 | 71 | x = x.cuda() 72 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda() 73 | 74 | with torch.no_grad(): 75 | zAll = enc(x, classes_onehot) 76 | 77 | embeddings_ref_mu.index_copy_( 78 | 0, data_iter[i], zAll[0][0].cpu().view([batch_size, -1]) 79 | ) 80 | embeddings_ref_sigma.index_copy_( 81 | 0, data_iter[i], zAll[0][1].cpu().view([batch_size, -1]) 82 | ) 83 | 84 | embeddings_target_mu.index_copy_( 85 | 0, data_iter[i], zAll[1][0].cpu().view([batch_size, -1]) 86 | ) 87 | embeddings_target_sigma.index_copy_( 88 | 0, data_iter[i], zAll[1][1].cpu().view([batch_size, -1]) 89 | ) 90 | 91 | embeddings_classes.index_copy_(0, data_iter[i], classes) 92 | 93 | recons_ref, recons_target = get_recons( 94 | x, 95 | classes_onehot, 96 | dec, 97 | zAll, 98 | n_recon_samples, 99 | recon_loss=recon_loss, 100 | channels_ref=channels_ref, 101 | channels_target=channels_target, 102 | ) 103 | 104 | embeddings_ref_kld.index_copy_( 105 | 0, 106 | torch.LongTensor(data_iter[i]), 107 | get_klds(zAll[0][0], zAll[0][1]).data[:].cpu(), 108 | ) 109 | embeddings_ref_recons.index_copy_( 110 | 0, torch.LongTensor(data_iter[i]), recons_ref 111 | ) 112 | 113 | embeddings_target_kld.index_copy_( 114 | 0, torch.LongTensor(data_iter[i]), get_klds(zAll[1][0], zAll[1][1]) 115 | ) 116 | embeddings_target_recons.index_copy_( 117 | 0, torch.LongTensor(data_iter[i]), recons_target 118 | ) 119 | 120 | embedding[mode] = {} 121 | embedding[mode]["ref"] = {} 122 | embedding[mode]["ref"]["mu"] = embeddings_ref_mu 123 | embedding[mode]["ref"]["sigma"] = embeddings_ref_sigma 124 | 125 | embedding[mode]["ref"]["kld"] = embeddings_ref_kld 126 | embedding[mode]["ref"]["recon"] = embeddings_ref_recons 127 | embedding[mode]["ref"]["elbo"] = -( 128 | torch.mean(embeddings_ref_recons, 1) + beta * embeddings_ref_kld 129 | ) 130 | 131 | embedding[mode]["target"] = {} 132 | embedding[mode]["target"]["mu"] = embeddings_target_mu 133 | embedding[mode]["target"]["sigma"] = embeddings_target_sigma 134 | embedding[mode]["target"]["class"] = embeddings_classes 135 | 136 | embedding[mode]["target"]["kld"] = embeddings_target_kld 137 | embedding[mode]["target"]["recon"] = embeddings_target_recons 138 | 139 | embedding[mode]["target"]["elbo"] = -( 140 | torch.mean(embeddings_target_recons, 1) + beta * embeddings_target_kld 141 | ) 142 | 143 | for mode in embedding: 144 | for component in embedding[mode]: 145 | for thing in embedding[mode][component]: 146 | embedding[mode][component][thing] = ( 147 | embedding[mode][component][thing].cpu().detach() 148 | ) 149 | 150 | return embedding 151 | 152 | 153 | def get_klds(mus, sigmas): 154 | 155 | kld_loss = KLDLoss(reduction="sum") 156 | 157 | klds = torch.zeros(mus.shape[0]) 158 | 159 | for i, (mu, sigma) in enumerate(zip(mus, sigmas)): 160 | 161 | kld, _, _ = kld_loss(mu.unsqueeze(0), sigma.unsqueeze(0)) 162 | 163 | klds[i] = kld[0] 164 | 165 | return klds 166 | 167 | 168 | def get_recons( 169 | x, 170 | classes_onehot, 171 | dec, 172 | zAll, 173 | n_recon_samples, 174 | recon_loss, 175 | channels_ref=[0, 2], 176 | channels_target=[1], 177 | ): 178 | 179 | recons_ref = torch.zeros(x.shape[0], n_recon_samples) 180 | recons_target = torch.zeros(x.shape[0], n_recon_samples) 181 | 182 | for i in range(n_recon_samples): 183 | zOut = [reparameterize(z[0], z[1]) for z in zAll] 184 | 185 | with torch.no_grad(): 186 | xHat = dec([classes_onehot] + zOut) 187 | 188 | recons_ref[:, i] = torch.stack( 189 | [ 190 | recon_loss(xHat[[ind], [channels_ref]], x[[ind], [channels_ref]]) 191 | for ind in range(len(x)) 192 | ] 193 | ) 194 | 195 | recons_target[:, i] = torch.stack( 196 | [ 197 | recon_loss(xHat[[ind], [channels_target]], x[[ind], [channels_target]]) 198 | for ind in range(len(x)) 199 | ] 200 | ) 201 | 202 | return recons_ref, recons_target 203 | -------------------------------------------------------------------------------- /integrated_cell/metrics/embeddings_reference.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | 4 | from ..utils import reparameterize 5 | 6 | from ..losses import KLDLoss 7 | 8 | 9 | def get_latent_embeddings( 10 | enc, 11 | dec, 12 | dp, 13 | recon_loss, 14 | modes=["test"], 15 | batch_size=256, 16 | n_recon_samples=10, 17 | sampler=None, 18 | beta=1, 19 | channels_ref=[0, 2], 20 | channels_target=[1], 21 | ): 22 | 23 | kld_loss = KLDLoss(reduction="sum") 24 | 25 | if sampler is None: 26 | 27 | def sampler(mode, inds): 28 | return dp.get_sample(mode, inds) 29 | 30 | enc.eval() 31 | dec.eval() 32 | 33 | embedding = dict() 34 | 35 | for mode in modes: 36 | ndat = dp.get_n_dat(mode) 37 | 38 | x = sampler(mode, [0]) 39 | x = x.cuda() 40 | 41 | with torch.no_grad(): 42 | zAll = enc(x) 43 | 44 | ndims = torch.prod(torch.tensor(zAll[0].shape[1:])) 45 | 46 | embeddings_ref_mu = torch.zeros(ndat, ndims) 47 | embeddings_ref_sigma = torch.zeros(ndat, ndims) 48 | 49 | embeddings_ref_recons = torch.zeros(ndat, n_recon_samples) 50 | embeddings_ref_kld = torch.zeros(ndat) 51 | 52 | inds = list(range(0, ndat)) 53 | data_iter = [ 54 | torch.LongTensor(inds[i : i + batch_size]) # noqa 55 | for i in range(0, len(inds), batch_size) # noqa 56 | ] 57 | 58 | for i in tqdm(range(0, len(data_iter))): 59 | batch_size = len(data_iter[i]) 60 | 61 | x = sampler(mode, data_iter[i]) 62 | x = x.cuda() 63 | 64 | with torch.no_grad(): 65 | zAll = enc(x) 66 | 67 | recons_ref = get_recons( 68 | x, dec, zAll, n_recon_samples, recon_loss=recon_loss 69 | ) 70 | 71 | klds = torch.stack( 72 | [kld_loss(mu, sigma) for mu, sigma in zip(zAll[0], zAll[1])] 73 | ) 74 | 75 | embeddings_ref_mu.index_copy_( 76 | 0, torch.LongTensor(data_iter[i]), zAll[0].cpu() 77 | ) 78 | 79 | embeddings_ref_sigma.index_copy_( 80 | 0, torch.LongTensor(data_iter[i]), zAll[1].cpu() 81 | ) 82 | 83 | embeddings_ref_kld.index_copy_( 84 | 0, torch.LongTensor(data_iter[i]), klds.cpu() 85 | ) 86 | 87 | embeddings_ref_recons.index_copy_( 88 | 0, torch.LongTensor(data_iter[i]), recons_ref 89 | ) 90 | 91 | embedding[mode] = {} 92 | embedding[mode]["ref"] = {} 93 | embedding[mode]["ref"]["mu"] = embeddings_ref_mu 94 | embedding[mode]["ref"]["sigma"] = embeddings_ref_sigma 95 | 96 | embedding[mode]["ref"]["kld"] = embeddings_ref_kld 97 | embedding[mode]["ref"]["recon"] = embeddings_ref_recons 98 | 99 | embedding[mode]["ref"]["elbo"] = -( 100 | torch.mean(embeddings_ref_recons, 1) + beta * embeddings_ref_kld 101 | ) 102 | 103 | for mode in embedding: 104 | for component in embedding[mode]: 105 | for thing in embedding[mode][component]: 106 | embedding[mode][component][thing] = ( 107 | embedding[mode][component][thing].cpu().detach() 108 | ) 109 | 110 | return embedding 111 | 112 | 113 | def get_recons(x, dec, zAll, n_recon_samples, recon_loss): 114 | 115 | recons = torch.zeros(x.shape[0], n_recon_samples) 116 | 117 | for i in range(n_recon_samples): 118 | zOut = reparameterize(zAll[0], zAll[1]) 119 | 120 | with torch.no_grad(): 121 | xHat = dec(zOut) 122 | 123 | recons[:, i] = torch.stack( 124 | [recon_loss(xHat[[ind]], x[[ind]]) for ind in range(len(x))] 125 | ) 126 | 127 | return recons 128 | -------------------------------------------------------------------------------- /integrated_cell/metrics/embeddings_target.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | 4 | from ..utils import utils, reparameterize 5 | 6 | from ..losses import KLDLoss 7 | 8 | 9 | def get_latent_embeddings( 10 | enc, 11 | dec, 12 | dp, 13 | recon_loss, 14 | modes=["test"], 15 | batch_size=256, 16 | n_recon_samples=10, 17 | sampler=None, 18 | beta=1, 19 | channels_ref=[0, 2], 20 | channels_target=[1], 21 | ): 22 | 23 | kld_loss = KLDLoss(reduction="sum") 24 | 25 | if sampler is None: 26 | 27 | def sampler(mode, inds): 28 | return dp.get_sample(mode, inds) 29 | 30 | enc.eval() 31 | dec.eval() 32 | 33 | embedding = dict() 34 | 35 | for mode in modes: 36 | ndat = dp.get_n_dat(mode) 37 | 38 | x, classes, ref = sampler(mode, [0]) 39 | 40 | x = x.cuda() 41 | ref = ref.cuda() 42 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda() 43 | 44 | with torch.no_grad(): 45 | zAll = enc(x, ref, classes_onehot) 46 | 47 | ndims = torch.prod(torch.tensor(zAll[0].shape[1:])) 48 | 49 | embeddings_target_mu = torch.zeros(ndat, ndims) 50 | embeddings_target_sigma = torch.zeros(ndat, ndims) 51 | 52 | embeddings_target_recons = torch.zeros(ndat, n_recon_samples) 53 | embeddings_target_kld = torch.zeros(ndat) 54 | 55 | embeddings_classes = torch.zeros(ndat).long() 56 | 57 | inds = list(range(0, ndat)) 58 | data_iter = [ 59 | torch.LongTensor(inds[i : i + batch_size]) # noqa 60 | for i in range(0, len(inds), batch_size) # noqa 61 | ] 62 | 63 | for i in tqdm(range(0, len(data_iter))): 64 | batch_size = len(data_iter[i]) 65 | 66 | x, classes, ref = sampler(mode, data_iter[i]) 67 | 68 | x = x.cuda() 69 | ref = ref.cuda() 70 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda() 71 | 72 | with torch.no_grad(): 73 | zAll = enc(x, ref, classes_onehot) 74 | 75 | recons_target = get_recons( 76 | x, 77 | ref, 78 | classes_onehot, 79 | dec, 80 | zAll, 81 | n_recon_samples, 82 | recon_loss=recon_loss, 83 | ) 84 | 85 | klds = torch.stack( 86 | [kld_loss(mu, sigma) for mu, sigma in zip(zAll[0], zAll[1])] 87 | ) 88 | 89 | embeddings_target_mu.index_copy_( 90 | 0, torch.LongTensor(data_iter[i]), zAll[0].cpu() 91 | ) 92 | 93 | embeddings_target_sigma.index_copy_( 94 | 0, torch.LongTensor(data_iter[i]), zAll[1].cpu() 95 | ) 96 | 97 | embeddings_classes.index_copy_(0, data_iter[i], classes) 98 | 99 | embeddings_target_kld.index_copy_( 100 | 0, torch.LongTensor(data_iter[i]), klds.cpu() 101 | ) 102 | 103 | embeddings_target_recons.index_copy_( 104 | 0, torch.LongTensor(data_iter[i]), recons_target 105 | ) 106 | 107 | embedding[mode] = {} 108 | embedding[mode]["target"] = {} 109 | embedding[mode]["target"]["mu"] = embeddings_target_mu 110 | embedding[mode]["target"]["sigma"] = embeddings_target_sigma 111 | embedding[mode]["target"]["class"] = embeddings_classes 112 | 113 | embedding[mode]["target"]["kld"] = embeddings_target_kld 114 | embedding[mode]["target"]["recon"] = embeddings_target_recons 115 | 116 | embedding[mode]["target"]["elbo"] = -( 117 | torch.mean(embeddings_target_recons, 1) + beta * embeddings_target_kld 118 | ) 119 | 120 | for mode in embedding: 121 | for component in embedding[mode]: 122 | for thing in embedding[mode][component]: 123 | embedding[mode][component][thing] = ( 124 | embedding[mode][component][thing].cpu().detach() 125 | ) 126 | 127 | return embedding 128 | 129 | 130 | def get_recons(x, ref, classes_onehot, dec, zAll, n_recon_samples, recon_loss): 131 | 132 | recons = torch.zeros(x.shape[0], n_recon_samples) 133 | 134 | for i in range(n_recon_samples): 135 | zOut = reparameterize(zAll[0], zAll[1]) 136 | 137 | with torch.no_grad(): 138 | xHat = dec(zOut, ref, classes_onehot) 139 | 140 | recons[:, i] = torch.stack( 141 | [recon_loss(xHat[[ind]], x[[ind]]) for ind in range(len(x))] 142 | ) 143 | 144 | return recons 145 | -------------------------------------------------------------------------------- /integrated_cell/metrics/features2D.py: -------------------------------------------------------------------------------- 1 | import skimage.measure as measure 2 | import skimage.filters as filters 3 | import numpy as np 4 | 5 | 6 | def props2summary(props): 7 | 8 | summary_props = [ 9 | "area", 10 | "bbox_area", 11 | "convex_area", 12 | "centroid", 13 | "equivalent_diameter", 14 | "euler_number", 15 | "extent", 16 | "filled_area", 17 | "major_axis_length", 18 | "max_intensity", 19 | "mean_intensity", 20 | "min_intensity", 21 | "minor_axis_length", 22 | "moments", 23 | "moments_central", 24 | "moments_hu", 25 | "moments_normalized", 26 | "orientation", 27 | "perimeter", 28 | "solidity", 29 | "weighted_centroid", 30 | "weighted_moments_central", 31 | "weighted_moments_hu", 32 | "weighted_moments_normalized", 33 | ] 34 | 35 | prop_new = {} 36 | 37 | for k in summary_props: 38 | prop_list = list() 39 | 40 | for prop in props: 41 | p = np.array(prop[k]) 42 | prop_list.append(p) 43 | 44 | if len(prop_list) > 1: 45 | prop_stack = np.stack(prop_list, -1) 46 | 47 | prop_mean = np.mean(prop_stack, -1) 48 | prop_std = np.std(prop_stack, -1) 49 | 50 | prop_total = np.sum(prop_stack, -1) 51 | 52 | else: 53 | prop_stack = prop_list[0] 54 | 55 | prop_total = prop_stack 56 | prop_mean = prop_stack 57 | prop_std = np.zeros(prop_stack.shape) 58 | 59 | prop_new[k + "_total"] = prop_total 60 | prop_new[k + "_mean"] = prop_mean 61 | prop_new[k + "_std"] = prop_std 62 | 63 | prop_new["num_objs"] = len(props) 64 | 65 | return prop_new 66 | 67 | 68 | def find_main_obj(im_bw): 69 | im_label = measure.label(im_bw > 0) 70 | 71 | ulabels = np.unique(im_label[im_label > 0]) 72 | 73 | label_counts = np.zeros(len(ulabels)) 74 | 75 | for i, label in enumerate(ulabels): 76 | label_counts[i] = np.sum(im_label == label) 77 | 78 | return im_label == ulabels[np.argmax(label_counts)] 79 | 80 | 81 | def ch_feats(im_bw, bg_thresh=1e-2): 82 | 83 | im_bg_sub = im_bw > bg_thresh 84 | 85 | props = {} 86 | 87 | im_bg_sub = find_main_obj(im_bg_sub) 88 | 89 | main_props = measure.regionprops(im_bg_sub.astype("uint8"), im_bw) 90 | main_props = props2summary(main_props) 91 | 92 | for k in main_props: 93 | props["main_" + k] = main_props[k] 94 | 95 | thresh = filters.threshold_otsu(im_bw[im_bg_sub]) 96 | im_obj = measure.label(im_bw > thresh) 97 | 98 | thresh_props = measure.regionprops(im_obj, im_bw) 99 | thresh_props = props2summary(thresh_props) 100 | 101 | for k in thresh_props: 102 | props["obj_" + k] = thresh_props[k] 103 | 104 | return props 105 | 106 | 107 | def im2feats(im, ch_names, bg_thresh=1e-2): 108 | 109 | feats = {} 110 | for i, ch_name in enumerate(ch_names): 111 | im_ch = im[i] 112 | 113 | feats[ch_name] = ch_feats(im_ch, bg_thresh) 114 | 115 | return feats 116 | -------------------------------------------------------------------------------- /integrated_cell/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..external.pytorch_fid.inception import InceptionV3 3 | from ..external.pytorch_fid import fid_score 4 | import os 5 | import pathlib 6 | 7 | """ 8 | Implements Improved Precision and Recall Metric for Assessing Generative Models 9 | https://arxiv.org/abs/1904.06991 10 | """ 11 | 12 | 13 | def my_cdist(x1, x2): 14 | x1_norm = x1.pow(2).sum(dim=-1, keepdim=True) 15 | x2_norm = x2.pow(2).sum(dim=-1, keepdim=True) 16 | res = torch.addmm( 17 | x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2 18 | ).add_(x1_norm) 19 | res = res.clamp_min_(1e-30).sqrt_() 20 | return res 21 | 22 | 23 | def manifold_estimate(phi_a, phi_b, k=[3], batch_size=None): 24 | """ 25 | phi_a and phi_b are N_a x D and N_b x D sets of features, 26 | k is the nearest neighbor setting for knn manifold estimation, 27 | batch_size is the number of pairwise distance computations done at once (None = do them all at once) 28 | """ 29 | if batch_size is None: 30 | d_aa = my_cdist(phi_a, phi_a) 31 | r_phi = {kappa: torch.kthvalue(d_aa, kappa + 1, dim=1)[0] for kappa in k} 32 | d_ba = my_cdist(phi_b, phi_a) 33 | b_in_a = { 34 | kappa: torch.any(d_ba <= rp, dim=1).float() for kappa, rp in r_phi.items() 35 | } 36 | else: 37 | r_phi = {kappa: torch.empty_like(phi_a[:, 0]) for kappa in k} 38 | for batch_inds in torch.split(torch.arange(len(phi_a)), batch_size): 39 | d_aa = my_cdist(phi_a[batch_inds], phi_a) 40 | for kappa in k: 41 | r_phi[kappa][batch_inds] = torch.kthvalue(d_aa, kappa + 1, dim=1)[0] 42 | b_in_a = {kappa: torch.empty_like(phi_b[:, 0]) for kappa in k} 43 | for batch_inds in torch.split(torch.arange(len(phi_b)), batch_size): 44 | d_ba = my_cdist(phi_b[batch_inds], phi_a) 45 | for kappa in k: 46 | b_in_a[kappa][batch_inds] = torch.any( 47 | d_ba <= r_phi[kappa], dim=1 48 | ).float() 49 | return {key: value.mean().item() for key, value in b_in_a.items()} 50 | 51 | 52 | def precision_recall(phi_r, phi_g, k=[3], batch_size=None): 53 | """ 54 | phi_r and phi_s are N_a x D and N_b x D sets of features 55 | from real and generated images respectively, 56 | k is a list of integers for the nearest neighbor setting for knn manifold estimation. 57 | returns a dict {k_1:{'precision':float, 'recall':float}, k_2:{'precision':float, 'recall':float}, ...} 58 | """ 59 | d = { 60 | "precision": manifold_estimate(phi_r, phi_g, k=k, batch_size=batch_size), 61 | "recall": manifold_estimate(phi_g, phi_r, k=k, batch_size=batch_size), 62 | } 63 | 64 | result = {} 65 | for k1, subdict in d.items(): 66 | for k2, v in subdict.items(): 67 | result.setdefault(k2, {})[k1] = v 68 | 69 | return result 70 | 71 | 72 | def calculate_inception_pr_given_paths( 73 | paths, k=[3], batch_size=1, cuda=True, dims=2048, verbose=False 74 | ): 75 | # the is constructed to imitate `calculate_fid_given_paths` from integrated_cell.external.pytorch_fid.fidscore.calculate_fid_given_paths 76 | 77 | for p in paths: 78 | if not os.path.exists(p): 79 | raise RuntimeError("Invalid path: %s" % p) 80 | 81 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 82 | 83 | model = InceptionV3([block_idx]) 84 | if cuda: 85 | model.cuda() 86 | 87 | path = pathlib.Path(paths[0]) 88 | files = list(path.glob("*.jpg")) + list(path.glob("*.png")) 89 | act_real = fid_score.get_activations(files, model, batch_size, dims, cuda, verbose) 90 | 91 | path = pathlib.Path(paths[1]) 92 | files = list(path.glob("*.jpg")) + list(path.glob("*.png")) 93 | act_gen = fid_score.get_activations(files, model, batch_size, dims, cuda, verbose) 94 | 95 | act_real = torch.tensor(act_real) 96 | act_gen = torch.tensor(act_gen) 97 | 98 | return precision_recall(act_real, act_gen, k, batch_size=None) 99 | 100 | 101 | def calculate_inception_pr_given_paired_paths( 102 | paths_real, paths_gen, k=[3], batch_size=1, cuda=True, dims=2048, verbose=False 103 | ): 104 | # the is constructed to imitate `calculate_fid_given_paths` from integrated_cell.external.pytorch_fid.fidscore.calculate_fid_given_paths 105 | 106 | for p in paths_real + paths_gen: 107 | if not os.path.exists(p): 108 | raise RuntimeError("Invalid path: %s" % p) 109 | 110 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 111 | 112 | model = InceptionV3([block_idx]) 113 | if cuda: 114 | model.cuda() 115 | 116 | prs = list() 117 | for i, [path_real, path_gen] in enumerate(zip(paths_real, paths_gen)): 118 | path = pathlib.Path(path_real) 119 | files = list(path.glob("*.jpg")) + list(path.glob("*.png")) 120 | act_real = fid_score.get_activations( 121 | files, model, batch_size, dims, cuda, verbose 122 | ) 123 | 124 | path = pathlib.Path(path_gen) 125 | files = list(path.glob("*.jpg")) + list(path.glob("*.png")) 126 | act_gen = fid_score.get_activations( 127 | files, model, batch_size, dims, cuda, verbose 128 | ) 129 | 130 | act_real = torch.tensor(act_real) 131 | act_gen = torch.tensor(act_gen) 132 | 133 | prs.append(precision_recall(act_real, act_gen, k, batch_size=None)) 134 | 135 | return prs 136 | -------------------------------------------------------------------------------- /integrated_cell/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | # from .utils import plots 5 | import matplotlib as mpl 6 | 7 | mpl.use("Agg") # noqa 8 | 9 | import warnings 10 | 11 | 12 | def init_opts(opt, opt_default): 13 | vars_default = vars(opt_default) 14 | for var in vars_default: 15 | if not hasattr(opt, var): 16 | setattr(opt, var, getattr(opt_default, var)) 17 | return opt 18 | 19 | 20 | def set_gpu_recursive(var, gpu_id): 21 | for key in var: 22 | if isinstance(var[key], dict): 23 | var[key] = set_gpu_recursive(var[key], gpu_id) 24 | else: 25 | try: 26 | if gpu_id != -1: 27 | var[key] = var[key].cuda(gpu_id) 28 | else: 29 | var[key] = var[key].cpu() 30 | except AttributeError: 31 | pass 32 | return var 33 | 34 | 35 | def sampleUniform(batsize, nlatentdim): 36 | return torch.Tensor(batsize, nlatentdim).uniform_(-1, 1) 37 | 38 | 39 | def sampleGaussian(batsize, nlatentdim): 40 | return torch.Tensor(batsize, nlatentdim).normal_() 41 | 42 | 43 | def tensor2img(img): 44 | warnings.warn( 45 | "integrated_cell.model_utils.tensor2img is depricated. Please use integrated_cell.utils.plots.tensor2im instead." 46 | ) 47 | # return plots.tensor2im(img) 48 | 49 | 50 | def load_embeddings(embeddings_path, enc=None, dp=None): 51 | 52 | if os.path.exists(embeddings_path): 53 | 54 | embeddings = torch.load(embeddings_path) 55 | else: 56 | embeddings = get_latent_embeddings(enc, dp) 57 | torch.save(embeddings, embeddings_path) 58 | 59 | return embeddings 60 | 61 | 62 | def get_latent_embeddings(enc, dp): 63 | enc.eval() 64 | gpu_id = enc.gpu_ids[0] 65 | 66 | modes = ("test", "train") 67 | 68 | embedding = dict() 69 | 70 | for mode in modes: 71 | ndat = dp.get_n_dat(mode) 72 | embeddings = torch.zeros(ndat, enc.n_latent_dim) 73 | 74 | inds = list(range(0, ndat)) 75 | data_iter = [ 76 | inds[i : i + dp.batch_size] # noqa 77 | for i in range(0, len(inds), dp.batch_size) 78 | ] 79 | 80 | for i in range(0, len(data_iter)): 81 | print(str(i) + "/" + str(len(data_iter))) 82 | x = dp.get_images(data_iter[i], mode).cuda(gpu_id) 83 | 84 | with torch.no_grad(): 85 | zAll = enc(x) 86 | 87 | embeddings.index_copy_( 88 | 0, torch.LongTensor(data_iter[i]), zAll[-1].data[:].cpu() 89 | ) 90 | 91 | embedding[mode] = embeddings 92 | 93 | return embedding 94 | 95 | 96 | def load_state(model, optimizer, path, gpu_id): 97 | # device = torch.device('cpu') 98 | 99 | checkpoint = torch.load(path) 100 | 101 | model.load_state_dict(checkpoint["model"]) 102 | optimizer.load_state_dict(checkpoint["optimizer"]) 103 | 104 | # model.cuda(gpu_id) 105 | 106 | # optimizer.state = set_gpu_recursive(optimizer.state, gpu_id) 107 | 108 | 109 | def save_state(model, optimizer, path, gpu_id): 110 | 111 | # model = model.cpu() 112 | # optimizer.state = set_gpu_recursive(optimizer.state, -1) 113 | 114 | checkpoint = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} 115 | torch.save(checkpoint, path) 116 | 117 | # model = model.cuda(gpu_id) 118 | # optimizer.state = set_gpu_recursive(optimizer.state, gpu_id) 119 | -------------------------------------------------------------------------------- /integrated_cell/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/models/__init__.py -------------------------------------------------------------------------------- /integrated_cell/models/ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from . import base_model 4 | from .. import SimpleLogger 5 | 6 | import scipy 7 | 8 | from ..utils.plots import tensor2im 9 | from integrated_cell.utils import plots as plots 10 | from .. import model_utils 11 | 12 | import os 13 | import pickle 14 | 15 | import shutil 16 | 17 | 18 | class Model(base_model.Model): 19 | def __init__( 20 | self, enc, dec, opt_enc, opt_dec, crit_recon, n_display_imgs=10, **kwargs 21 | ): 22 | # Train an autoencoder 23 | 24 | super(Model, self).__init__(**kwargs) 25 | 26 | self.enc = enc 27 | self.dec = dec 28 | self.opt_enc = opt_enc 29 | self.opt_dec = opt_dec 30 | 31 | self.crit_recon = crit_recon 32 | 33 | self.n_display_imgs = n_display_imgs 34 | 35 | logger_path = "{}/logger.pkl".format(self.save_dir) 36 | 37 | if os.path.exists(logger_path): 38 | self.logger = pickle.load(open(logger_path, "rb")) 39 | else: 40 | columns = ("epoch", "iter", "reconLoss", "time") 41 | print_str = "[%d][%d] reconLoss: %.6f time: %.2f" 42 | self.logger = SimpleLogger(columns, print_str) 43 | 44 | def iteration(self): 45 | 46 | torch.cuda.empty_cache() 47 | 48 | enc, dec = self.enc, self.dec 49 | opt_enc, opt_dec = self.opt_enc, self.opt_dec 50 | crit_recon = self.crit_recon 51 | 52 | enc.train(True) 53 | dec.train(True) 54 | 55 | for p in enc.parameters(): 56 | p.requires_grad = True 57 | 58 | for p in dec.parameters(): 59 | p.requires_grad = True 60 | 61 | # Ignore class labels and reference information 62 | x, _, _ = self.data_provider.get_sample() 63 | x = x.cuda() 64 | 65 | opt_enc.zero_grad() 66 | opt_dec.zero_grad() 67 | 68 | ##################### 69 | # train autoencoder 70 | ##################### 71 | 72 | # Forward passes 73 | z = enc(x) 74 | 75 | xHat = dec(z) 76 | 77 | # Update the image reconstruction 78 | recon_loss = crit_recon(xHat, x) 79 | 80 | recon_loss.backward() 81 | 82 | opt_enc.step() 83 | opt_dec.step() 84 | 85 | zLatent = z.data.cpu() 86 | recon_loss = recon_loss.item() 87 | 88 | errors = [recon_loss] 89 | 90 | return errors, zLatent 91 | 92 | def save_progress(self): 93 | gpu_id = self.gpu_ids[0] 94 | epoch = self.get_current_epoch() 95 | 96 | data_provider = self.data_provider 97 | enc = self.enc 98 | dec = self.dec 99 | 100 | enc.train(False) 101 | dec.train(False) 102 | 103 | ############### 104 | # TRAINING DATA 105 | ############### 106 | img_inds = np.arange(self.n_display_imgs) 107 | 108 | x, _, _ = data_provider.get_sample("train", img_inds) 109 | x = x.cuda(gpu_id) 110 | 111 | with torch.no_grad(): 112 | z = enc(x) 113 | xHat = dec(z) 114 | 115 | imgX = tensor2im(x.data.cpu()) 116 | imgXHat = tensor2im(xHat.data.cpu()) 117 | imgTrainOut = np.concatenate((imgX, imgXHat), 0) 118 | 119 | ############### 120 | # TESTING DATA 121 | ############### 122 | x, _, _ = data_provider.get_sample("validate", img_inds) 123 | x = x.cuda(gpu_id) 124 | 125 | with torch.no_grad(): 126 | z = enc(x) 127 | xHat = dec(z) 128 | 129 | imgX = tensor2im(x.data.cpu()) 130 | imgXHat = tensor2im(xHat.data.cpu()) 131 | 132 | imgTestOut = np.concatenate((imgX, imgXHat), 0) 133 | 134 | imgOut = np.concatenate((imgTrainOut, imgTestOut)) 135 | 136 | scipy.misc.imsave( 137 | "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)), imgOut 138 | ) 139 | 140 | embeddings_train = np.concatenate(self.zAll, 0) 141 | 142 | pickle.dump( 143 | embeddings_train, open("{0}/embedding.pth".format(self.save_dir), "wb") 144 | ) 145 | pickle.dump( 146 | embeddings_train, 147 | open( 148 | "{0}/embedding_{1}.pth".format(self.save_dir, self.get_current_iter()), 149 | "wb", 150 | ), 151 | ) 152 | 153 | pickle.dump(self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb")) 154 | 155 | # History 156 | plots.history(self.logger, "{0}/history.png".format(self.save_dir)) 157 | 158 | # Short History 159 | plots.short_history(self.logger, "{0}/history_short.png".format(self.save_dir)) 160 | 161 | # Embedding figure 162 | plots.embeddings(embeddings_train, "{0}/embedding.png".format(self.save_dir)) 163 | 164 | enc.train(True) 165 | dec.train(True) 166 | 167 | def save(self, save_dir): 168 | # for saving and loading see: 169 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718 170 | 171 | gpu_id = self.gpu_ids[0] 172 | 173 | n_iters = self.get_current_iter() 174 | 175 | img_embeddings = np.concatenate(self.zAll, 0) 176 | pickle.dump(img_embeddings, open("{0}/embedding.pth".format(save_dir), "wb")) 177 | pickle.dump( 178 | img_embeddings, 179 | open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb"), 180 | ) 181 | 182 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir) 183 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters) 184 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir) 185 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters) 186 | 187 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id) 188 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final) 189 | 190 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id) 191 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final) 192 | 193 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb")) 194 | pickle.dump( 195 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb") 196 | ) 197 | -------------------------------------------------------------------------------- /integrated_cell/models/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import time 4 | 5 | # This is the base class for trainers 6 | class Model(object): 7 | def __init__( 8 | self, 9 | data_provider, 10 | n_epochs, 11 | gpu_ids, 12 | save_dir, 13 | save_state_iter=1, 14 | save_progress_iter=1, 15 | ): 16 | 17 | self.data_provider = data_provider 18 | self.n_epochs = n_epochs 19 | 20 | self.gpu_ids = gpu_ids 21 | 22 | self.save_dir = save_dir 23 | 24 | self.save_state_iter = save_state_iter 25 | self.save_progress_iter = save_progress_iter 26 | 27 | self.iters_per_epoch = np.ceil(len(data_provider) / data_provider.batch_size) 28 | 29 | self.zAll = list() 30 | 31 | def iteration(self): 32 | raise NotImplementedError 33 | 34 | def get_current_iter(self): 35 | return len(self.logger) 36 | 37 | def get_current_epoch(self, iteration=-1): 38 | 39 | if iteration == -1: 40 | iteration = self.get_current_iter() 41 | 42 | return np.floor(iteration / self.iters_per_epoch) 43 | 44 | def load(self): 45 | # This is where we load the model 46 | raise NotImplementedError 47 | 48 | def save(self, save_dir): 49 | # This is where we save the model 50 | raise NotImplementedError 51 | 52 | def maybe_save(self): 53 | 54 | epoch = self.get_current_epoch(self.get_current_iter() - 1) 55 | epoch_next = self.get_current_epoch(self.get_current_iter()) 56 | 57 | saved = False 58 | if epoch != epoch_next: 59 | # save the logger every epoch 60 | pickle.dump( 61 | self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb") 62 | ) 63 | 64 | if (epoch_next % self.save_progress_iter) == 0: 65 | print("saving progress") 66 | self.save_progress() 67 | 68 | if (epoch_next % self.save_state_iter) == 0: 69 | print("saving state") 70 | self.save(self.save_dir) 71 | 72 | saved = True 73 | 74 | return saved 75 | 76 | def save_progress(self): 77 | raise NotImplementedError 78 | 79 | def total_iters(self): 80 | return int(np.ceil(self.iters_per_epoch) * self.n_epochs) 81 | 82 | def train(self): 83 | start_iter = self.get_current_iter() 84 | 85 | for this_iter in range(int(start_iter), self.total_iters()): 86 | 87 | start = time.time() 88 | 89 | errors, zLatent = self.iteration() 90 | 91 | stop = time.time() 92 | deltaT = stop - start 93 | 94 | self.logger.add( 95 | [self.get_current_epoch(), self.get_current_iter()] + errors + [deltaT] 96 | ) 97 | self.zAll.append(zLatent.data.cpu().detach().numpy()) 98 | 99 | if self.maybe_save(): 100 | self.zAll = list() 101 | -------------------------------------------------------------------------------- /integrated_cell/models/cbvaae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import pickle 5 | import shutil 6 | 7 | from .. import utils 8 | from . import cbvae2 9 | from .. import SimpleLogger 10 | from .. import model_utils 11 | 12 | 13 | # conditional beta variational advarsarial auto encoder 14 | # 15 | # This is because sometimes the conditioning results in different latent space embeddings 16 | # and we try to fix that with the advarsary 17 | 18 | 19 | class Model(cbvae2.Model): 20 | def __init__( 21 | self, encD, opt_encD, crit_encD, lambda_encD_loss=1, n_encD_steps=1, **kwargs 22 | ): 23 | 24 | super(Model, self).__init__(**kwargs) 25 | 26 | self.encD = encD 27 | self.opt_encD = opt_encD 28 | 29 | self.crit_encD = crit_encD 30 | 31 | self.lambda_encD_loss = lambda_encD_loss 32 | self.n_encD_steps = n_encD_steps 33 | 34 | logger_path = "{}/logger.pkl".format(self.save_dir) 35 | if os.path.exists(logger_path): 36 | self.logger = pickle.load(open(logger_path, "rb")) 37 | else: 38 | columns = ("epoch", "iter", "reconLoss") 39 | print_str = "[%d][%d] reconLoss: %.6f" 40 | 41 | columns += ( 42 | "kldLossRef", 43 | "kldLossStruct", 44 | "minimaxencDLoss", 45 | "encDLoss", 46 | "time", 47 | ) 48 | print_str += ( 49 | " kld ref: %.6f kld struct: %.6f mmEncD: %.6f encD: %.6f time: %.2f" 50 | ) 51 | self.logger = SimpleLogger(columns, print_str) 52 | 53 | def iteration(self): 54 | gpu_id = self.gpu_ids[0] 55 | 56 | enc, dec, encD = self.enc, self.dec, self.encD 57 | opt_enc, opt_dec, opt_encD = self.opt_enc, self.opt_dec, self.opt_encD 58 | crit_recon, crit_encD = self.crit_recon, self.crit_encD 59 | 60 | # do this just incase anything upstream changes these values 61 | enc.train(True) 62 | dec.train(True) 63 | encD.train(True) 64 | 65 | # update the discriminator 66 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x))) 67 | 68 | for p in encD.parameters(): 69 | p.requires_grad = True 70 | 71 | for p in enc.parameters(): 72 | p.requires_grad = False 73 | 74 | for p in dec.parameters(): 75 | p.requires_grad = False 76 | 77 | encD_loss = 0 78 | for step_num in range(self.n_encD_steps): 79 | x, classes, ref = self.data_provider.get_sample() 80 | x = x.cuda(gpu_id) 81 | 82 | classes = classes.type_as(x).long() 83 | classes_onehot = utils.index_to_onehot( 84 | classes, self.data_provider.get_n_classes() 85 | ) 86 | 87 | y_xReal = classes 88 | y_xFake = ( 89 | torch.zeros(classes.shape) 90 | .fill_(self.data_provider.get_n_classes()) 91 | .type_as(x) 92 | .long() 93 | ) 94 | 95 | y_zReal = y_xFake 96 | y_zFake = y_xReal 97 | 98 | with torch.no_grad(): 99 | zAll = enc(x, classes_onehot) 100 | 101 | zFake = self.reparameterize(zAll[-1][0], zAll[-1][1]) 102 | zReal = torch.zeros(zFake.shape).type_as(zFake).normal_() 103 | 104 | opt_enc.zero_grad() 105 | opt_dec.zero_grad() 106 | opt_encD.zero_grad() 107 | 108 | ############### 109 | # train encD 110 | ############### 111 | 112 | # train with real 113 | yHat_zReal = encD(zReal) 114 | errEncD_real = crit_encD(yHat_zReal, y_zReal) 115 | 116 | # train with fake 117 | yHat_zFake = encD(zFake) 118 | errEncD_fake = crit_encD(yHat_zFake, y_zFake) 119 | 120 | encD_loss_tmp = (errEncD_real + errEncD_fake) / 2 121 | encD_loss_tmp.backward() 122 | 123 | opt_encD.step() 124 | 125 | encD_loss += encD_loss_tmp.item() 126 | 127 | for p in enc.parameters(): 128 | p.requires_grad = True 129 | 130 | for p in dec.parameters(): 131 | p.requires_grad = True 132 | 133 | for p in encD.parameters(): 134 | p.requires_grad = False 135 | 136 | opt_enc.zero_grad() 137 | opt_dec.zero_grad() 138 | opt_encD.zero_grad() 139 | 140 | ##################### 141 | # train autoencoder 142 | ##################### 143 | 144 | # Forward passes 145 | z_ref, z_struct = enc(x, classes_onehot) 146 | 147 | kld_ref = self.kld_loss(z_ref[0], z_ref[1]) 148 | kld_struct = self.kld_loss(z_struct[0], z_struct[1]) 149 | 150 | kld_loss = kld_ref + kld_struct 151 | 152 | zAll = [z_ref, z_struct] 153 | for i in range(len(zAll)): 154 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1]) 155 | 156 | xHat = dec([classes_onehot] + zAll) 157 | 158 | recon_loss = crit_recon(xHat, x) 159 | 160 | if self.lambda_encD_loss > 0: 161 | yHat_zFake = encD(zAll[-1]) 162 | minimax_encD_loss = crit_encD(yHat_zFake, y_xReal) 163 | else: 164 | minimax_encD_loss = 0 165 | 166 | beta_vae_loss = ( 167 | self.vae_loss(recon_loss, kld_loss) 168 | + self.lambda_encD_loss * minimax_encD_loss 169 | ) 170 | 171 | beta_vae_loss.backward() 172 | opt_enc.step() 173 | opt_dec.step() 174 | 175 | # Log a bunch of stuff 176 | recon_loss = recon_loss.item() 177 | minimax_encD_loss = minimax_encD_loss.item() 178 | 179 | kld_loss_ref = kld_ref.item() 180 | kld_loss_struct = kld_struct.item() 181 | 182 | zLatent = z_struct[0].data.cpu() 183 | 184 | errors = [ 185 | recon_loss, 186 | kld_loss_ref, 187 | kld_loss_struct, 188 | minimax_encD_loss, 189 | encD_loss, 190 | ] 191 | 192 | return errors, zLatent 193 | 194 | def save(self, save_dir): 195 | # for saving and loading see: 196 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718 197 | 198 | gpu_id = self.gpu_ids[0] 199 | 200 | n_iters = self.get_current_iter() 201 | 202 | embeddings = np.concatenate(self.zAll, 0) 203 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb")) 204 | pickle.dump( 205 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb") 206 | ) 207 | 208 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir) 209 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters) 210 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir) 211 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters) 212 | 213 | encD_save_path_tmp = "{0}/encD.pth".format(save_dir) 214 | encD_save_path_final = "{0}/encD_{1}.pth".format(save_dir, n_iters) 215 | 216 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id) 217 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final) 218 | 219 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id) 220 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final) 221 | 222 | model_utils.save_state(self.encD, self.opt_encD, encD_save_path_tmp, gpu_id) 223 | shutil.copyfile(encD_save_path_tmp, encD_save_path_final) 224 | 225 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb")) 226 | pickle.dump( 227 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb") 228 | ) 229 | -------------------------------------------------------------------------------- /integrated_cell/models/cbvae2_gan_model_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import shutil 5 | 6 | from .. import utils 7 | from . import cbvae2_gan 8 | from .. import model_utils 9 | 10 | 11 | class Model(cbvae2_gan.Model): 12 | def __init__(self, **kwargs): 13 | 14 | super(Model, self).__init__(**kwargs) 15 | 16 | def iteration(self): 17 | 18 | torch.cuda.empty_cache() 19 | 20 | enc, dec, decD = self.enc, self.dec, self.decD 21 | opt_enc, opt_dec, opt_decD = self.opt_enc, self.opt_dec, self.opt_decD 22 | crit_recon, crit_decD = (self.crit_recon, self.crit_decD) 23 | 24 | # do this just incase anything upstream changes these values 25 | enc.train(True) 26 | dec.train(True) 27 | decD.train(True) 28 | 29 | gpu_enc = self.enc.gpu_ids[0] 30 | gpu_dec = self.dec.gpu_ids[0] 31 | gpu_decD = self.decD.gpu_ids[0] 32 | 33 | # update the discriminator 34 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x))) 35 | 36 | for p in decD.parameters(): 37 | p.requires_grad = True 38 | 39 | for p in enc.parameters(): 40 | p.requires_grad = False 41 | 42 | for p in dec.parameters(): 43 | p.requires_grad = False 44 | 45 | decDLoss = 0 46 | for step_num in range(self.n_decD_steps): 47 | x, classes, ref = self.data_provider.get_sample() 48 | x = x.cuda(gpu_enc) 49 | 50 | classes = classes.type_as(x).long() 51 | classes_onehot = utils.index_to_onehot( 52 | classes, self.data_provider.get_n_classes() 53 | ) 54 | 55 | y_xReal = classes.clone().cuda(gpu_decD) 56 | y_xFake = ( 57 | torch.zeros(classes.shape) 58 | .fill_(self.data_provider.get_n_classes()) 59 | .type_as(x) 60 | .long() 61 | .cuda(gpu_decD) 62 | ) 63 | with torch.no_grad(): 64 | zAll = enc(x, classes_onehot) 65 | 66 | for i in range(len(zAll)): 67 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1]).cuda(gpu_dec) 68 | zAll[i].detach_() 69 | 70 | with torch.no_grad(): 71 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll) 72 | 73 | opt_enc.zero_grad() 74 | opt_dec.zero_grad() 75 | opt_decD.zero_grad() 76 | 77 | ############## 78 | # Train decD 79 | ############## 80 | 81 | # train with real 82 | yHat_xReal = decD(x.cuda(gpu_decD)) 83 | errDecD_real = crit_decD(yHat_xReal, y_xReal) 84 | 85 | # train with fake, reconstructed 86 | yHat_xFake = decD(xHat.cuda(gpu_decD)) 87 | errDecD_fake = crit_decD(yHat_xFake, y_xFake) 88 | 89 | # train with fake, sampled and decoded 90 | for z in zAll: 91 | z.normal_() 92 | 93 | with torch.no_grad(): 94 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll) 95 | 96 | yHat_xFake2 = decD(xHat.cuda(gpu_decD)) 97 | errDecD_fake2 = crit_decD(yHat_xFake2, y_xFake) 98 | 99 | decDLoss_tmp = (errDecD_real + (errDecD_fake + errDecD_fake2) / 2) / 2 100 | decDLoss_tmp.backward() 101 | 102 | opt_decD.step() 103 | 104 | decDLoss += decDLoss_tmp.item() 105 | 106 | errDecD_real = None 107 | errDecD_fake = None 108 | errDecD_fake2 = None 109 | 110 | for p in enc.parameters(): 111 | p.requires_grad = True 112 | 113 | for p in dec.parameters(): 114 | p.requires_grad = True 115 | 116 | for p in decD.parameters(): 117 | p.requires_grad = False 118 | 119 | opt_enc.zero_grad() 120 | opt_dec.zero_grad() 121 | opt_decD.zero_grad() 122 | 123 | ##################### 124 | # train autoencoder 125 | ##################### 126 | torch.cuda.empty_cache() 127 | 128 | # Forward passes 129 | z_ref, z_struct = enc(x, classes_onehot) 130 | 131 | kld_ref = self.kld_loss(z_ref[0], z_ref[1]) 132 | kld_struct = self.kld_loss(z_struct[0], z_struct[1]) 133 | 134 | kld_loss = kld_ref + kld_struct 135 | 136 | kld_loss_ref = kld_ref.item() 137 | kld_loss_struct = kld_struct.item() 138 | 139 | zLatent = z_struct[0].data.cpu() 140 | 141 | zAll = [z_ref, z_struct] 142 | for i in range(len(zAll)): 143 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1]).cuda(gpu_dec) 144 | 145 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll) 146 | 147 | # resample from the structure space and make sure that the reference channel stays the same 148 | # shuffle_inds = torch.randperm(x.shape[0]) 149 | # zAll[-1].normal_() 150 | # xHat2 = dec([classes_onehot[shuffle_inds]] + zAll) 151 | 152 | # Update the image reconstruction 153 | recon_loss = crit_recon(xHat.cuda(gpu_enc), x) 154 | 155 | beta_vae_loss = self.vae_loss(recon_loss, kld_loss) 156 | 157 | beta_vae_loss.backward(retain_graph=True) 158 | 159 | recon_loss = recon_loss.item() 160 | 161 | opt_enc.step() 162 | 163 | if self.lambda_decD_loss > 0: 164 | for p in enc.parameters(): 165 | p.requires_grad = False 166 | 167 | # update wrt decD(dec(enc(X))) 168 | yHat_xFake = decD(xHat.cuda(gpu_decD)) 169 | minimaxDecDLoss = crit_decD(yHat_xFake, y_xReal) 170 | 171 | shuffle_inds = torch.randperm(x.shape[0]) 172 | xHat = dec([classes_onehot[shuffle_inds].cuda(gpu_dec)] + zAll) 173 | 174 | yHat_xFake2 = decD(xHat.cuda(gpu_decD)) 175 | minimaxDecDLoss2 = crit_decD(yHat_xFake2, y_xReal[shuffle_inds]) 176 | 177 | minimaxDecLoss = (minimaxDecDLoss + minimaxDecDLoss2) / 2 178 | minimaxDecLoss.mul(self.lambda_decD_loss).backward() 179 | minimaxDecLoss = minimaxDecLoss.item() 180 | else: 181 | minimaxDecLoss = 0 182 | 183 | opt_dec.step() 184 | 185 | errors = [recon_loss, kld_loss_ref, kld_loss_struct, minimaxDecLoss, decDLoss] 186 | 187 | return errors, zLatent 188 | 189 | def save(self, save_dir): 190 | # for saving and loading see: 191 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718 192 | 193 | gpu_id = self.gpu_ids[0] 194 | 195 | n_iters = self.get_current_iter() 196 | 197 | embeddings = np.concatenate(self.zAll, 0) 198 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb")) 199 | pickle.dump( 200 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb") 201 | ) 202 | 203 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir) 204 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters) 205 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir) 206 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters) 207 | 208 | decD_save_path_tmp = "{0}/decD.pth".format(save_dir) 209 | decD_save_path_final = "{0}/decD_{1}.pth".format(save_dir, n_iters) 210 | 211 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id) 212 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final) 213 | 214 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id) 215 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final) 216 | 217 | model_utils.save_state(self.decD, self.opt_decD, decD_save_path_tmp, gpu_id) 218 | shutil.copyfile(decD_save_path_tmp, decD_save_path_final) 219 | 220 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb")) 221 | pickle.dump( 222 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb") 223 | ) 224 | -------------------------------------------------------------------------------- /integrated_cell/models/cbvae_apex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .. import utils 3 | from . import cbvae 4 | from .bvae import reparameterize, kl_divergence 5 | 6 | from apex import amp 7 | 8 | 9 | class Model(cbvae.Model): 10 | def __init__(self, opt_level="O1", **kwargs): 11 | 12 | super(Model, self).__init__(**kwargs) 13 | 14 | is_data_parallel = False 15 | # do some hack if set up for DataParallel 16 | if ( 17 | str(self.enc.__class__) 18 | == "" 19 | ): 20 | is_data_parallel = True 21 | device_ids_enc = self.enc.device_ids 22 | device_ids_dec = self.dec.device_ids 23 | 24 | self.enc = self.enc.module 25 | self.dec = self.dec.module 26 | 27 | [self.enc, self.dec], [self.opt_enc, self.opt_dec] = amp.initialize( 28 | [self.enc, self.dec], [self.opt_enc, self.opt_dec], opt_level=opt_level 29 | ) 30 | 31 | if is_data_parallel: 32 | self.enc = torch.nn.DataParallel(self.enc, device_ids_enc) 33 | self.dec = torch.nn.DataParallel(self.dec, device_ids_dec) 34 | 35 | def iteration(self): 36 | 37 | torch.cuda.empty_cache() 38 | 39 | gpu_id = self.gpu_ids[0] 40 | 41 | enc, dec = self.enc, self.dec 42 | opt_enc, opt_dec = self.opt_enc, self.opt_dec 43 | crit_recon = self.crit_recon 44 | 45 | # do this just incase anything upstream changes these values 46 | enc.train(True) 47 | dec.train(True) 48 | 49 | x, classes, ref = self.data_provider.get_sample() 50 | x = x.cuda(gpu_id) 51 | 52 | classes = classes.type_as(x).long() 53 | classes_onehot = utils.index_to_onehot( 54 | classes, self.data_provider.get_n_classes() 55 | ) 56 | 57 | for p in enc.parameters(): 58 | p.requires_grad = True 59 | 60 | for p in dec.parameters(): 61 | p.requires_grad = True 62 | 63 | opt_enc.zero_grad() 64 | opt_dec.zero_grad() 65 | 66 | ##################### 67 | # train autoencoder 68 | ##################### 69 | 70 | # Forward passes 71 | z_ref, z_struct = enc(x, classes_onehot) 72 | 73 | kld_ref, _, _ = kl_divergence(z_ref[0], z_ref[1]) 74 | kld_struct, _, _ = kl_divergence(z_struct[0], z_struct[1]) 75 | 76 | kld = kld_ref + kld_struct 77 | 78 | kld_loss_ref = kld_ref.item() 79 | kld_loss_struct = kld_struct.item() 80 | 81 | zLatent = z_struct[0].data.cpu() 82 | 83 | zAll = [z_ref, z_struct] 84 | for i in range(len(zAll)): 85 | zAll[i] = reparameterize(zAll[i][0], zAll[i][1]) 86 | 87 | xHat = dec([classes_onehot] + zAll) 88 | 89 | # Update the image reconstruction 90 | recon_loss = crit_recon(xHat, x) 91 | 92 | if self.objective == "H": 93 | beta_vae_loss = recon_loss + self.beta * kld 94 | elif self.objective == "H_eps": 95 | beta_vae_loss = ( 96 | recon_loss 97 | + torch.abs((self.beta * kld_ref) - x.shape[0] * 0.1) 98 | + torch.abs((self.beta * kld_struct) - x.shape[0] * 0.1) 99 | ) 100 | elif self.objective == "B": 101 | C = torch.clamp( 102 | torch.Tensor( 103 | [self.c_max / self.c_iters_max * len(self.logger)] 104 | ).type_as(x), 105 | 0, 106 | self.c_max, 107 | ) 108 | beta_vae_loss = recon_loss + self.gamma * (kld - C).abs() 109 | 110 | elif self.objective == "B_eps": 111 | C = torch.clamp( 112 | torch.Tensor( 113 | [self.c_max / self.c_iters_max * len(self.logger)] 114 | ).type_as(x), 115 | 0, 116 | self.c_max, 117 | ) 118 | beta_vae_loss = recon_loss + self.gamma * (kld - C).abs() 119 | 120 | elif self.objective == "A": 121 | # warmup mode 122 | beta_mult = self.beta_start + self.beta_step * self.get_current_iter() 123 | if beta_mult > self.beta_max: 124 | beta_mult = self.beta_max 125 | 126 | if beta_mult < self.beta_min: 127 | beta_mult = self.beta_min 128 | 129 | beta_vae_loss = recon_loss + beta_mult * kld 130 | 131 | with amp.scale_loss(beta_vae_loss, [opt_enc, opt_dec]) as scaled_loss: 132 | scaled_loss.backward() 133 | 134 | recon_loss = recon_loss.item() 135 | 136 | opt_enc.step() 137 | opt_dec.step() 138 | 139 | errors = [recon_loss, kld_loss_ref, kld_loss_struct] 140 | 141 | return errors, zLatent 142 | -------------------------------------------------------------------------------- /integrated_cell/models/cbvaegan_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import pickle 5 | import shutil 6 | 7 | from .. import utils 8 | from . import cbvae2_target 9 | from .. import SimpleLogger 10 | from .. import model_utils 11 | 12 | 13 | class Model(cbvae2_target.Model): 14 | def __init__( 15 | self, decD, opt_decD, crit_decD, lambda_decD_loss=1, n_decD_steps=1, **kwargs 16 | ): 17 | 18 | super(Model, self).__init__(**kwargs) 19 | 20 | self.decD = decD 21 | self.opt_decD = opt_decD 22 | 23 | self.crit_decD = crit_decD 24 | 25 | self.lambda_decD_loss = lambda_decD_loss 26 | self.n_decD_steps = n_decD_steps 27 | 28 | logger_path = "{}/logger.pkl".format(self.save_dir) 29 | if os.path.exists(logger_path): 30 | self.logger = pickle.load(open(logger_path, "rb")) 31 | else: 32 | columns = ("epoch", "iter", "reconLoss") 33 | print_str = "[%d][%d] reconLoss: %.6f" 34 | 35 | columns += ("kldLoss", "minimaxDecDLoss", "decDLoss", "time") 36 | print_str += " kld: %.6f mmDecD: %.6f decD: %.6f time: %.2f" 37 | self.logger = SimpleLogger(columns, print_str) 38 | 39 | def iteration(self): 40 | 41 | torch.cuda.empty_cache() 42 | 43 | gpu_id = self.gpu_ids[0] 44 | 45 | enc, dec, decD = self.enc, self.dec, self.decD 46 | opt_enc, opt_dec, opt_decD = self.opt_enc, self.opt_dec, self.opt_decD 47 | crit_recon, crit_decD = (self.crit_recon, self.crit_decD) 48 | 49 | # do this just incase anything upstream changes these values 50 | enc.train(True) 51 | dec.train(True) 52 | decD.train(True) 53 | 54 | # update the discriminator 55 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x))) 56 | 57 | for p in decD.parameters(): 58 | p.requires_grad = True 59 | 60 | for p in enc.parameters(): 61 | p.requires_grad = False 62 | 63 | for p in dec.parameters(): 64 | p.requires_grad = False 65 | 66 | decDLoss = 0 67 | for step_num in range(self.n_decD_steps): 68 | x, classes, ref = self.data_provider.get_sample() 69 | x = x.cuda(gpu_id) 70 | 71 | classes = classes.type_as(x).long() 72 | classes_onehot = utils.index_to_onehot( 73 | classes, self.data_provider.get_n_classes() 74 | ) 75 | 76 | ref = ref.cuda(gpu_id) 77 | 78 | y_xReal = torch.ones(classes.shape[0], 1).type_as(x) 79 | y_xFake = torch.zeros(classes.shape[0], 1).type_as(x) 80 | 81 | with torch.no_grad(): 82 | mu, logsigma = enc(x, ref, classes_onehot) 83 | z = self.reparameterize(mu, logsigma) 84 | xHat = dec(z, ref, classes_onehot) 85 | 86 | opt_enc.zero_grad() 87 | opt_dec.zero_grad() 88 | opt_decD.zero_grad() 89 | 90 | ############## 91 | # Train decD 92 | ############## 93 | 94 | # train with real 95 | yHat_xReal = decD(x, ref, classes_onehot) 96 | errDecD_real = crit_decD(yHat_xReal, y_xReal) 97 | 98 | # train with fake, reconstructed 99 | yHat_xFake = decD(xHat, ref, classes_onehot) 100 | errDecD_fake = crit_decD(yHat_xFake, y_xFake) 101 | 102 | # train with fake, sampled and decoded 103 | z.normal_() 104 | shuffle_inds = torch.randperm(x.shape[0]) 105 | 106 | with torch.no_grad(): 107 | xHat = dec(z, ref, classes_onehot[shuffle_inds]) 108 | 109 | yHat_xFake2 = decD(x, ref, classes_onehot[shuffle_inds]) 110 | errDecD_fake2 = crit_decD(yHat_xFake2, y_xFake) 111 | 112 | decDLoss_tmp = (errDecD_real + (errDecD_fake + errDecD_fake2) / 2) / 2 113 | decDLoss_tmp.backward() 114 | 115 | opt_decD.step() 116 | 117 | decDLoss += decDLoss_tmp.item() 118 | 119 | errDecD_real = None 120 | errDecD_fake = None 121 | errDecD_fake2 = None 122 | 123 | for p in enc.parameters(): 124 | p.requires_grad = True 125 | 126 | for p in dec.parameters(): 127 | p.requires_grad = True 128 | 129 | for p in decD.parameters(): 130 | p.requires_grad = False 131 | 132 | opt_enc.zero_grad() 133 | opt_dec.zero_grad() 134 | opt_decD.zero_grad() 135 | 136 | ##################### 137 | # train autoencoder 138 | ##################### 139 | torch.cuda.empty_cache() 140 | 141 | # Forward passes 142 | mu, logsigma = enc(x, ref, classes_onehot) 143 | 144 | z = self.reparameterize(mu, logsigma) 145 | 146 | kld_loss = self.kld_loss(mu, logsigma) 147 | 148 | xHat = dec(z, ref, classes_onehot) 149 | 150 | recon_loss = crit_recon(xHat, x) 151 | beta_vae_loss = self.vae_loss(recon_loss, kld_loss) 152 | beta_vae_loss.backward(retain_graph=True) 153 | 154 | recon_loss = recon_loss.item() 155 | kld_loss = kld_loss.item() 156 | zLatent = mu.data.cpu() 157 | 158 | opt_enc.step() 159 | 160 | if self.lambda_decD_loss > 0: 161 | for p in enc.parameters(): 162 | p.requires_grad = False 163 | 164 | # update wrt reconstructed images 165 | yHat_xFake = decD(xHat, ref, classes_onehot) 166 | minimaxDecDLoss = crit_decD(yHat_xFake, y_xReal) 167 | 168 | shuffle_inds = torch.randperm(x.shape[0]) 169 | 170 | # update wrt generated imges 171 | z = torch.Tensor(z.shape).normal_().type_as(x) 172 | xHat = dec(z, ref, classes_onehot[shuffle_inds]) 173 | 174 | yHat_xFake2 = decD(xHat, ref, classes_onehot[shuffle_inds]) 175 | minimaxDecDLoss2 = crit_decD(yHat_xFake2, y_xReal) 176 | 177 | minimaxDecDLoss = (minimaxDecDLoss + minimaxDecDLoss2) / 2 178 | minimaxDecDLoss.mul(self.lambda_decD_loss).backward() 179 | minimaxDecDLoss = minimaxDecDLoss.item() 180 | else: 181 | minimaxDecDLoss = 0 182 | 183 | opt_dec.step() 184 | 185 | errors = [recon_loss, kld_loss, minimaxDecDLoss, decDLoss] 186 | 187 | return errors, zLatent 188 | 189 | def save(self, save_dir): 190 | # for saving and loading see: 191 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718 192 | 193 | gpu_id = self.gpu_ids[0] 194 | 195 | n_iters = self.get_current_iter() 196 | 197 | embeddings = np.concatenate(self.zAll, 0) 198 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb")) 199 | pickle.dump( 200 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb") 201 | ) 202 | 203 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir) 204 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters) 205 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir) 206 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters) 207 | 208 | decD_save_path_tmp = "{0}/decD.pth".format(save_dir) 209 | decD_save_path_final = "{0}/decD_{1}.pth".format(save_dir, n_iters) 210 | 211 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id) 212 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final) 213 | 214 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id) 215 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final) 216 | 217 | model_utils.save_state(self.decD, self.opt_decD, decD_save_path_tmp, gpu_id) 218 | shutil.copyfile(decD_save_path_tmp, decD_save_path_final) 219 | 220 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb")) 221 | pickle.dump( 222 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb") 223 | ) 224 | -------------------------------------------------------------------------------- /integrated_cell/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/networks/__init__.py -------------------------------------------------------------------------------- /integrated_cell/networks/ae2D_residual.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import numpy as np 3 | 4 | from ..utils import spectral_norm 5 | from ..utils import get_activation 6 | 7 | # An autoencoder for 2D images using modified residual layers 8 | # This is effectively a 2D simplified version of vaegan3D_cgan_target2.py 9 | 10 | 11 | class PadLayer(nn.Module): 12 | def __init__(self, pad_dims): 13 | super(PadLayer, self).__init__() 14 | 15 | self.pad_dims = pad_dims 16 | 17 | def forward(self, x): 18 | if np.sum(self.pad_dims) == 0: 19 | return x 20 | else: 21 | return nn.functional.pad( 22 | x, 23 | [0, self.pad_dims[2], 0, self.pad_dims[1], 0, self.pad_dims[0]], 24 | "constant", 25 | 0, 26 | ) 27 | 28 | 29 | class DownLayerResidual(nn.Module): 30 | def __init__(self, ch_in, ch_out, activation="relu", activation_last=None): 31 | super(DownLayerResidual, self).__init__() 32 | 33 | if activation_last is None: 34 | activation_last = activation 35 | 36 | self.bypass = nn.Sequential( 37 | nn.AvgPool2d(2, stride=2, padding=0), 38 | spectral_norm(nn.Conv2d(ch_in, ch_out, 1, 1, padding=0, bias=True)), 39 | ) 40 | 41 | self.resid = nn.Sequential( 42 | spectral_norm(nn.Conv2d(ch_in, ch_in, 4, 2, padding=1, bias=True)), 43 | nn.BatchNorm2d(ch_in), 44 | get_activation(activation), 45 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, padding=1, bias=True)), 46 | nn.BatchNorm2d(ch_out), 47 | ) 48 | 49 | self.activation = get_activation(activation_last) 50 | 51 | def forward(self, x, x_proj=None, x_class=None): 52 | 53 | x = self.bypass(x) + self.resid(x) 54 | 55 | x = self.activation(x) 56 | 57 | return x 58 | 59 | 60 | class UpLayerResidual(nn.Module): 61 | def __init__( 62 | self, ch_in, ch_out, activation="relu", output_padding=0, activation_last=None 63 | ): 64 | super(UpLayerResidual, self).__init__() 65 | 66 | if activation_last is None: 67 | activation_last = activation 68 | 69 | self.bypass = nn.Sequential( 70 | spectral_norm(nn.Conv2d(ch_in, ch_out, 1, 1, padding=0, bias=True)), 71 | nn.Upsample(scale_factor=2), 72 | PadLayer(output_padding), 73 | ) 74 | 75 | self.resid = nn.Sequential( 76 | spectral_norm( 77 | nn.ConvTranspose2d( 78 | ch_in, 79 | ch_in, 80 | 4, 81 | 2, 82 | padding=1, 83 | output_padding=output_padding, 84 | bias=True, 85 | ) 86 | ), 87 | nn.BatchNorm2d(ch_in), 88 | get_activation(activation), 89 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, padding=1, bias=True)), 90 | nn.BatchNorm2d(ch_out), 91 | ) 92 | 93 | self.activation = get_activation(activation_last) 94 | 95 | def forward(self, x): 96 | x = self.bypass(x) + self.resid(x) 97 | 98 | x = self.activation(x) 99 | 100 | return x 101 | 102 | 103 | class Enc(nn.Module): 104 | def __init__( 105 | self, 106 | n_latent_dim, 107 | gpu_ids, 108 | n_ch=3, 109 | conv_channels_list=[32, 64, 128, 256, 512], 110 | imsize_compressed=[5, 3], 111 | ): 112 | super(Enc, self).__init__() 113 | 114 | self.gpu_ids = gpu_ids 115 | self.n_latent_dim = n_latent_dim 116 | 117 | self.target_path = nn.ModuleList( 118 | [DownLayerResidual(n_ch, conv_channels_list[0])] 119 | ) 120 | 121 | for ch_in, ch_out in zip(conv_channels_list[0:-1], conv_channels_list[1:]): 122 | self.target_path.append(DownLayerResidual(ch_in, ch_out)) 123 | 124 | ch_in = ch_out 125 | 126 | self.latent_out = spectral_norm( 127 | nn.Linear( 128 | ch_in * int(np.prod(imsize_compressed)), self.n_latent_dim, bias=True 129 | ) 130 | ) 131 | 132 | def forward(self, x_target): 133 | 134 | for target_path in self.target_path: 135 | x_target = target_path(x_target) 136 | 137 | x_target = x_target.view(x_target.size()[0], -1) 138 | 139 | z = self.latent_out(x_target) 140 | 141 | return z 142 | 143 | 144 | class Dec(nn.Module): 145 | def __init__( 146 | self, 147 | n_latent_dim, 148 | gpu_ids, 149 | padding_latent=[0, 0], 150 | imsize_compressed=[5, 3], 151 | n_ch=3, 152 | conv_channels_list=[512, 256, 128, 64, 32], 153 | activation_last="sigmoid", 154 | ): 155 | 156 | super(Dec, self).__init__() 157 | 158 | self.gpu_ids = gpu_ids 159 | self.padding_latent = padding_latent 160 | self.imsize_compressed = imsize_compressed 161 | 162 | self.ch_first = conv_channels_list[0] 163 | 164 | self.n_latent_dim = n_latent_dim 165 | 166 | self.n_channels = n_ch 167 | 168 | self.target_fc = spectral_norm( 169 | nn.Linear( 170 | self.n_latent_dim, 171 | conv_channels_list[0] * int(np.prod(self.imsize_compressed)), 172 | bias=True, 173 | ) 174 | ) 175 | 176 | self.target_bn_relu = nn.Sequential( 177 | nn.BatchNorm2d(conv_channels_list[0]), nn.ReLU(inplace=True) 178 | ) 179 | 180 | self.target_path = nn.ModuleList([]) 181 | 182 | l_sizes = conv_channels_list 183 | for i in range(len(l_sizes) - 1): 184 | if i == 0: 185 | padding = padding_latent 186 | else: 187 | padding = 0 188 | 189 | self.target_path.append( 190 | UpLayerResidual(l_sizes[i], l_sizes[i + 1], output_padding=padding) 191 | ) 192 | 193 | self.target_path.append( 194 | UpLayerResidual(l_sizes[i + 1], n_ch, activation_last=activation_last) 195 | ) 196 | 197 | def forward(self, z_target): 198 | # gpu_ids = None 199 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 200 | 201 | x_target = self.target_fc(z_target).view( 202 | z_target.size()[0], 203 | self.ch_first, 204 | self.imsize_compressed[0], 205 | self.imsize_compressed[1], 206 | ) 207 | x_target = self.target_bn_relu(x_target) 208 | 209 | for target_path in self.target_path: 210 | 211 | x_target = target_path(x_target) 212 | 213 | return x_target 214 | -------------------------------------------------------------------------------- /integrated_cell/networks/ae3D_residual.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import numpy as np 3 | 4 | from ..utils import spectral_norm 5 | from ..utils import get_activation 6 | 7 | # An autoencoder for 3D images using modified residual layers 8 | # This is effectively a simplified version of vaegan3D_cgan_target2.py 9 | 10 | 11 | class PadLayer(nn.Module): 12 | def __init__(self, pad_dims): 13 | super(PadLayer, self).__init__() 14 | 15 | self.pad_dims = pad_dims 16 | 17 | def forward(self, x): 18 | if np.sum(self.pad_dims) == 0: 19 | return x 20 | else: 21 | return nn.functional.pad( 22 | x, 23 | [0, self.pad_dims[2], 0, self.pad_dims[1], 0, self.pad_dims[0]], 24 | "constant", 25 | 0, 26 | ) 27 | 28 | 29 | class DownLayerResidual(nn.Module): 30 | def __init__(self, ch_in, ch_out, activation="relu", activation_last=None): 31 | super(DownLayerResidual, self).__init__() 32 | 33 | if activation_last is None: 34 | activation_last = activation 35 | 36 | self.bypass = nn.Sequential( 37 | nn.AvgPool3d(2, stride=2, padding=0), 38 | spectral_norm(nn.Conv3d(ch_in, ch_out, 1, 1, padding=0, bias=True)), 39 | ) 40 | 41 | self.resid = nn.Sequential( 42 | spectral_norm(nn.Conv3d(ch_in, ch_in, 4, 2, padding=1, bias=True)), 43 | nn.BatchNorm3d(ch_in), 44 | get_activation(activation), 45 | spectral_norm(nn.Conv3d(ch_in, ch_out, 3, 1, padding=1, bias=True)), 46 | nn.BatchNorm3d(ch_out), 47 | ) 48 | 49 | self.activation = get_activation(activation_last) 50 | 51 | def forward(self, x, x_proj=None, x_class=None): 52 | 53 | x = self.bypass(x) + self.resid(x) 54 | 55 | x = self.activation(x) 56 | 57 | return x 58 | 59 | 60 | class UpLayerResidual(nn.Module): 61 | def __init__( 62 | self, ch_in, ch_out, activation="relu", output_padding=0, activation_last=None 63 | ): 64 | super(UpLayerResidual, self).__init__() 65 | 66 | if activation_last is None: 67 | activation_last = activation 68 | 69 | self.bypass = nn.Sequential( 70 | spectral_norm(nn.Conv3d(ch_in, ch_out, 1, 1, padding=0, bias=True)), 71 | nn.Upsample(scale_factor=2), 72 | PadLayer(output_padding), 73 | ) 74 | 75 | self.resid = nn.Sequential( 76 | spectral_norm( 77 | nn.ConvTranspose3d( 78 | ch_in, 79 | ch_in, 80 | 4, 81 | 2, 82 | padding=1, 83 | output_padding=output_padding, 84 | bias=True, 85 | ) 86 | ), 87 | nn.BatchNorm3d(ch_in), 88 | get_activation(activation), 89 | spectral_norm(nn.Conv3d(ch_in, ch_out, 3, 1, padding=1, bias=True)), 90 | nn.BatchNorm3d(ch_out), 91 | ) 92 | 93 | self.activation = get_activation(activation_last) 94 | 95 | def forward(self, x): 96 | x = self.bypass(x) + self.resid(x) 97 | 98 | x = self.activation(x) 99 | 100 | return x 101 | 102 | 103 | class Enc(nn.Module): 104 | def __init__( 105 | self, 106 | n_latent_dim, 107 | gpu_ids, 108 | n_ch=3, 109 | conv_channels_list=[32, 64, 128, 256, 512], 110 | imsize_compressed=[5, 3, 2], 111 | ): 112 | super(Enc, self).__init__() 113 | 114 | self.gpu_ids = gpu_ids 115 | self.n_latent_dim = n_latent_dim 116 | 117 | self.target_path = nn.ModuleList( 118 | [DownLayerResidual(n_ch, conv_channels_list[0])] 119 | ) 120 | 121 | for ch_in, ch_out in zip(conv_channels_list[0:-1], conv_channels_list[1:]): 122 | self.target_path.append(DownLayerResidual(ch_in, ch_out)) 123 | 124 | ch_in = ch_out 125 | 126 | self.latent_out = spectral_norm( 127 | nn.Linear( 128 | ch_in * int(np.prod(imsize_compressed)), self.n_latent_dim, bias=True 129 | ) 130 | ) 131 | 132 | def forward(self, x_target): 133 | 134 | for target_path in self.target_path: 135 | x_target = target_path(x_target) 136 | 137 | x_target = x_target.view(x_target.size()[0], -1) 138 | 139 | z = self.latent_out(x_target) 140 | 141 | return z 142 | 143 | 144 | class Dec(nn.Module): 145 | def __init__( 146 | self, 147 | n_latent_dim, 148 | gpu_ids, 149 | padding_latent=[0, 0, 0], 150 | imsize_compressed=[5, 3, 2], 151 | n_ch=3, 152 | conv_channels_list=[512, 256, 128, 64, 32], 153 | activation_last="sigmoid", 154 | ): 155 | 156 | super(Dec, self).__init__() 157 | 158 | self.gpu_ids = gpu_ids 159 | self.padding_latent = padding_latent 160 | self.imsize_compressed = imsize_compressed 161 | 162 | self.ch_first = conv_channels_list[0] 163 | 164 | self.n_latent_dim = n_latent_dim 165 | 166 | self.n_channels = n_ch 167 | 168 | self.target_fc = spectral_norm( 169 | nn.Linear( 170 | self.n_latent_dim, 171 | conv_channels_list[0] * int(np.prod(self.imsize_compressed)), 172 | bias=True, 173 | ) 174 | ) 175 | 176 | self.target_bn_relu = nn.Sequential( 177 | nn.BatchNorm3d(conv_channels_list[0]), nn.ReLU(inplace=True) 178 | ) 179 | 180 | self.target_path = nn.ModuleList([]) 181 | 182 | l_sizes = conv_channels_list 183 | for i in range(len(l_sizes) - 1): 184 | if i == 0: 185 | padding = padding_latent 186 | else: 187 | padding = 0 188 | 189 | self.target_path.append( 190 | UpLayerResidual(l_sizes[i], l_sizes[i + 1], output_padding=padding) 191 | ) 192 | 193 | self.target_path.append( 194 | UpLayerResidual(l_sizes[i + 1], n_ch, activation_last=activation_last) 195 | ) 196 | 197 | def forward(self, z_target): 198 | # gpu_ids = None 199 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 200 | 201 | x_target = self.target_fc(z_target).view( 202 | z_target.size()[0], 203 | self.ch_first, 204 | self.imsize_compressed[0], 205 | self.imsize_compressed[1], 206 | self.imsize_compressed[2], 207 | ) 208 | x_target = self.target_bn_relu(x_target) 209 | 210 | for target_path in self.target_path: 211 | 212 | x_target = target_path(x_target) 213 | 214 | return x_target 215 | -------------------------------------------------------------------------------- /integrated_cell/networks/old/aaegan_compact.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import pdb 4 | from model_utils import init_opts 5 | 6 | ksize = 4 7 | dstep = 2 8 | 9 | class Enc(nn.Module): 10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 11 | super(Enc, self).__init__() 12 | 13 | self.gpu_ids = gpu_ids 14 | self.fcsize = insize/64 15 | 16 | self.nLatentDim = nLatentDim 17 | self.nClasses = nClasses 18 | self.nRef = nRef 19 | 20 | self.main = nn.Sequential( 21 | nn.Conv2d(nch, 64, ksize, dstep, 1), 22 | nn.BatchNorm2d(64), 23 | 24 | nn.PReLU(), 25 | nn.Conv2d(64, 128, ksize, dstep, 1), 26 | nn.BatchNorm2d(128), 27 | 28 | nn.PReLU(), 29 | nn.Conv2d(128, 256, ksize, dstep, 1), 30 | nn.BatchNorm2d(256), 31 | 32 | nn.PReLU(), 33 | nn.Conv2d(256, 512, ksize, dstep, 1), 34 | nn.BatchNorm2d(512), 35 | 36 | nn.PReLU(), 37 | nn.Conv2d(512, 1024, ksize, dstep, 1), 38 | nn.BatchNorm2d(1024), 39 | 40 | nn.PReLU(), 41 | nn.Conv2d(1024, 1024, ksize, dstep, 1), 42 | nn.BatchNorm2d(1024), 43 | 44 | nn.PReLU() 45 | ) 46 | 47 | # discriminator output 48 | self.discrim = nn.Sequential( 49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1), 50 | # nn.BatchNorm1d(self.nClasses+1), 51 | ) 52 | 53 | if self.nClasses > 0: 54 | self.discrim.add_module('end', nn.LogSoftmax()) 55 | else: 56 | self.discrim.add_module('end', nn.Sigmoid()) 57 | 58 | if self.nClasses > 0: 59 | self.classOut = nn.Sequential( 60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses), 61 | nn.BatchNorm1d(self.nClasses), 62 | nn.LogSoftmax() 63 | ) 64 | 65 | if self.nRef > 0: 66 | self.refOut = nn.Sequential( 67 | nn.Linear(1024*int(self.fcsize**2), self.nRef), 68 | nn.BatchNorm1d(self.nRef) 69 | ) 70 | 71 | if self.nLatentDim > 0: 72 | self.latentOut = nn.Sequential( 73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim), 74 | nn.BatchNorm1d(self.nLatentDim) 75 | ) 76 | 77 | def forward(self, x): 78 | # gpu_ids = None 79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 80 | gpu_ids = self.gpu_ids 81 | 82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2)) 84 | 85 | xOut = list() 86 | 87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids) 88 | # xOut.append(xDiscrim) 89 | 90 | if self.nClasses > 0: 91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids) 92 | xOut.append(xClasses) 93 | 94 | if self.nRef > 0: 95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids) 96 | xOut.append(xRef) 97 | 98 | if self.nLatentDim > 0: 99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids) 100 | xOut.append(xLatent) 101 | 102 | return xOut, xDiscrim 103 | 104 | class Dec(nn.Module): 105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 106 | super(Dec, self).__init__() 107 | 108 | self.gpu_ids = gpu_ids 109 | self.fcsize = int(insize/64) 110 | 111 | self.nLatentDim = nLatentDim 112 | self.nClasses = nClasses 113 | self.nRef = nRef 114 | 115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2)) 116 | 117 | self.main = nn.Sequential( 118 | # nn.BatchNorm2d(1024), 119 | 120 | nn.PReLU(), 121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1), 122 | nn.BatchNorm2d(1024), 123 | 124 | nn.PReLU(), 125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1), 126 | nn.BatchNorm2d(512), 127 | 128 | nn.PReLU(), 129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1), 130 | nn.BatchNorm2d(256), 131 | 132 | nn.PReLU(), 133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1), 134 | nn.BatchNorm2d(128), 135 | 136 | nn.PReLU(), 137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1), 138 | nn.BatchNorm2d(64), 139 | 140 | nn.PReLU(), 141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1), 142 | # nn.BatchNorm2d(nch), 143 | nn.Sigmoid() 144 | ) 145 | 146 | def forward(self, xIn): 147 | # gpu_ids = None 148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 149 | gpu_ids = self.gpu_ids 150 | 151 | x = torch.cat(xIn, 1) 152 | 153 | x = self.fc(x) 154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize) 155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 156 | 157 | return x 158 | 159 | class EncD(nn.Module): 160 | def __init__(self, nlatentdim, gpu_ids, opt=None): 161 | super(EncD, self).__init__() 162 | 163 | nfc = 1024 164 | 165 | self.gpu_ids = gpu_ids 166 | 167 | self.main = nn.Sequential( 168 | nn.Linear(nlatentdim, nfc), 169 | nn.BatchNorm1d(nfc), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | 172 | nn.Linear(nfc, nfc), 173 | nn.BatchNorm1d(nfc), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | 176 | nn.Linear(nfc, 512), 177 | nn.BatchNorm1d(512), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | 180 | nn.Linear(512, 1), 181 | nn.Sigmoid() 182 | ) 183 | 184 | def forward(self, x): 185 | # gpu_ids = None 186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 187 | gpu_ids = self.gpu_ids 188 | 189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 190 | 191 | return x 192 | 193 | class DecD(nn.Module): 194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None): 195 | super(DecD, self).__init__() 196 | 197 | self.linear = nn.Linear(1,1) 198 | 199 | def forward(self, x): 200 | return x 201 | 202 | 203 | -------------------------------------------------------------------------------- /integrated_cell/networks/old/aaegan_compact_lrelu.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import pdb 4 | from model_utils import init_opts 5 | 6 | ksize = 4 7 | dstep = 2 8 | 9 | class Enc(nn.Module): 10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 11 | super(Enc, self).__init__() 12 | 13 | self.gpu_ids = gpu_ids 14 | self.fcsize = insize/64 15 | 16 | self.nLatentDim = nLatentDim 17 | self.nClasses = nClasses 18 | self.nRef = nRef 19 | 20 | self.main = nn.Sequential( 21 | nn.Conv2d(nch, 64, ksize, dstep, 1), 22 | nn.BatchNorm2d(64), 23 | 24 | nn.LeakyReLU(0.2, inplace=True), 25 | nn.Conv2d(64, 128, ksize, dstep, 1), 26 | nn.BatchNorm2d(128), 27 | 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Conv2d(128, 256, ksize, dstep, 1), 30 | nn.BatchNorm2d(256), 31 | 32 | nn.LeakyReLU(0.2, inplace=True), 33 | nn.Conv2d(256, 512, ksize, dstep, 1), 34 | nn.BatchNorm2d(512), 35 | 36 | nn.LeakyReLU(0.2, inplace=True), 37 | nn.Conv2d(512, 1024, ksize, dstep, 1), 38 | nn.BatchNorm2d(1024), 39 | 40 | nn.LeakyReLU(0.2, inplace=True), 41 | nn.Conv2d(1024, 1024, ksize, dstep, 1), 42 | nn.BatchNorm2d(1024), 43 | 44 | nn.LeakyReLU(0.2, inplace=True) 45 | ) 46 | 47 | # discriminator output 48 | self.discrim = nn.Sequential( 49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1), 50 | # nn.BatchNorm1d(self.nClasses+1), 51 | ) 52 | 53 | if self.nClasses > 0: 54 | self.discrim.add_module('end', nn.LogSoftmax()) 55 | else: 56 | self.discrim.add_module('end', nn.Sigmoid()) 57 | 58 | if self.nClasses > 0: 59 | self.classOut = nn.Sequential( 60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses), 61 | nn.BatchNorm1d(self.nClasses), 62 | nn.LogSoftmax() 63 | ) 64 | 65 | if self.nRef > 0: 66 | self.refOut = nn.Sequential( 67 | nn.Linear(1024*int(self.fcsize**2), self.nRef), 68 | nn.BatchNorm1d(self.nRef) 69 | ) 70 | 71 | if self.nLatentDim > 0: 72 | self.latentOut = nn.Sequential( 73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim), 74 | nn.BatchNorm1d(self.nLatentDim) 75 | ) 76 | 77 | def forward(self, x): 78 | # gpu_ids = None 79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 80 | gpu_ids = self.gpu_ids 81 | 82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2)) 84 | 85 | xOut = list() 86 | 87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids) 88 | # xOut.append(xDiscrim) 89 | 90 | if self.nClasses > 0: 91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids) 92 | xOut.append(xClasses) 93 | 94 | if self.nRef > 0: 95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids) 96 | xOut.append(xRef) 97 | 98 | if self.nLatentDim > 0: 99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids) 100 | xOut.append(xLatent) 101 | 102 | return xOut, xDiscrim 103 | 104 | class Dec(nn.Module): 105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 106 | super(Dec, self).__init__() 107 | 108 | self.gpu_ids = gpu_ids 109 | self.fcsize = int(insize/64) 110 | 111 | self.nLatentDim = nLatentDim 112 | self.nClasses = nClasses 113 | self.nRef = nRef 114 | 115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2)) 116 | 117 | self.main = nn.Sequential( 118 | # nn.BatchNorm2d(1024), 119 | 120 | nn.LeakyReLU(0.2, inplace=True), 121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1), 122 | nn.BatchNorm2d(1024), 123 | 124 | nn.LeakyReLU(0.2, inplace=True), 125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1), 126 | nn.BatchNorm2d(512), 127 | 128 | nn.LeakyReLU(0.2, inplace=True), 129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1), 130 | nn.BatchNorm2d(256), 131 | 132 | nn.LeakyReLU(0.2, inplace=True), 133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1), 134 | nn.BatchNorm2d(128), 135 | 136 | nn.LeakyReLU(0.2, inplace=True), 137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1), 138 | nn.BatchNorm2d(64), 139 | 140 | nn.LeakyReLU(0.2, inplace=True), 141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1), 142 | # nn.BatchNorm2d(nch), 143 | nn.Sigmoid() 144 | ) 145 | 146 | def forward(self, xIn): 147 | # gpu_ids = None 148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 149 | gpu_ids = self.gpu_ids 150 | 151 | x = torch.cat(xIn, 1) 152 | 153 | x = self.fc(x) 154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize) 155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 156 | 157 | return x 158 | 159 | class EncD(nn.Module): 160 | def __init__(self, nlatentdim, gpu_ids, opt=None): 161 | super(EncD, self).__init__() 162 | 163 | nfc = 1024 164 | 165 | self.gpu_ids = gpu_ids 166 | 167 | self.main = nn.Sequential( 168 | nn.Linear(nlatentdim, nfc), 169 | nn.BatchNorm1d(nfc), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | 172 | nn.Linear(nfc, nfc), 173 | nn.BatchNorm1d(nfc), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | 176 | nn.Linear(nfc, 512), 177 | nn.BatchNorm1d(512), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | 180 | nn.Linear(512, 1), 181 | nn.Sigmoid() 182 | ) 183 | 184 | def forward(self, x): 185 | # gpu_ids = None 186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 187 | gpu_ids = self.gpu_ids 188 | 189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 190 | 191 | return x 192 | 193 | class DecD(nn.Module): 194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None): 195 | super(DecD, self).__init__() 196 | 197 | self.linear = nn.Linear(1,1) 198 | 199 | def forward(self, x): 200 | return x 201 | 202 | 203 | -------------------------------------------------------------------------------- /integrated_cell/networks/old/aaegan_compact_lrelu2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import pdb 4 | from model_utils import init_opts 5 | 6 | ksize = 4 7 | dstep = 2 8 | 9 | class Enc(nn.Module): 10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 11 | super(Enc, self).__init__() 12 | 13 | self.gpu_ids = gpu_ids 14 | self.fcsize = insize/64 15 | 16 | self.nLatentDim = nLatentDim 17 | self.nClasses = nClasses 18 | self.nRef = nRef 19 | 20 | self.main = nn.Sequential( 21 | nn.Conv2d(nch, 64, ksize, dstep, 1), 22 | nn.BatchNorm2d(64), 23 | 24 | nn.LeakyReLU(0.2, inplace=True), 25 | nn.Conv2d(64, 128, ksize, dstep, 1), 26 | nn.BatchNorm2d(128), 27 | 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Conv2d(128, 256, ksize, dstep, 1), 30 | nn.BatchNorm2d(256), 31 | 32 | nn.LeakyReLU(0.2, inplace=True), 33 | nn.Conv2d(256, 512, ksize, dstep, 1), 34 | nn.BatchNorm2d(512), 35 | 36 | nn.LeakyReLU(0.2, inplace=True), 37 | nn.Conv2d(512, 1024, ksize, dstep, 1), 38 | nn.BatchNorm2d(1024), 39 | 40 | nn.LeakyReLU(0.2, inplace=True), 41 | nn.Conv2d(1024, 1024, ksize, dstep, 1), 42 | nn.BatchNorm2d(1024), 43 | 44 | nn.LeakyReLU(0.2, inplace=True) 45 | ) 46 | 47 | # discriminator output 48 | self.discrim = nn.Sequential( 49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1), 50 | # nn.BatchNorm1d(self.nClasses+1), 51 | ) 52 | 53 | if self.nClasses > 0: 54 | self.discrim.add_module('end', nn.LogSoftmax()) 55 | else: 56 | self.discrim.add_module('end', nn.Sigmoid()) 57 | 58 | if self.nClasses > 0: 59 | self.classOut = nn.Sequential( 60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses), 61 | nn.BatchNorm1d(self.nClasses), 62 | nn.LogSoftmax() 63 | ) 64 | 65 | if self.nRef > 0: 66 | self.refOut = nn.Sequential( 67 | nn.Linear(1024*int(self.fcsize**2), self.nRef), 68 | nn.BatchNorm1d(self.nRef) 69 | ) 70 | 71 | if self.nLatentDim > 0: 72 | self.latentOut = nn.Sequential( 73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim), 74 | nn.BatchNorm1d(self.nLatentDim) 75 | ) 76 | 77 | def forward(self, x): 78 | # gpu_ids = None 79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 80 | gpu_ids = self.gpu_ids 81 | 82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2)) 84 | 85 | xOut = list() 86 | 87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids) 88 | # xOut.append(xDiscrim) 89 | 90 | if self.nClasses > 0: 91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids) 92 | xOut.append(xClasses) 93 | 94 | if self.nRef > 0: 95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids) 96 | xOut.append(xRef) 97 | 98 | if self.nLatentDim > 0: 99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids) 100 | xOut.append(xLatent) 101 | 102 | return xOut, xDiscrim 103 | 104 | class Dec(nn.Module): 105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None): 106 | super(Dec, self).__init__() 107 | 108 | self.gpu_ids = gpu_ids 109 | self.fcsize = int(insize/64) 110 | 111 | self.nLatentDim = nLatentDim 112 | self.nClasses = nClasses 113 | self.nRef = nRef 114 | 115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2)) 116 | 117 | self.main = nn.Sequential( 118 | # nn.BatchNorm2d(1024), 119 | 120 | nn.ReLU(inplace=True), 121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1), 122 | nn.BatchNorm2d(1024), 123 | 124 | nn.ReLU(inplace=True), 125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1), 126 | nn.BatchNorm2d(512), 127 | 128 | nn.ReLU(inplace=True), 129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1), 130 | nn.BatchNorm2d(256), 131 | 132 | nn.ReLU(inplace=True), 133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1), 134 | nn.BatchNorm2d(128), 135 | 136 | nn.ReLU(inplace=True), 137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1), 138 | nn.BatchNorm2d(64), 139 | 140 | nn.ReLU(inplace=True), 141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1), 142 | # nn.BatchNorm2d(nch), 143 | nn.Sigmoid() 144 | ) 145 | 146 | def forward(self, xIn): 147 | # gpu_ids = None 148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 149 | gpu_ids = self.gpu_ids 150 | 151 | x = torch.cat(xIn, 1) 152 | 153 | x = self.fc(x) 154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize) 155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 156 | 157 | return x 158 | 159 | class EncD(nn.Module): 160 | def __init__(self, nlatentdim, gpu_ids, opt=None): 161 | super(EncD, self).__init__() 162 | 163 | nfc = 1024 164 | 165 | self.gpu_ids = gpu_ids 166 | 167 | self.main = nn.Sequential( 168 | nn.Linear(nlatentdim, nfc), 169 | nn.BatchNorm1d(nfc), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | 172 | nn.Linear(nfc, nfc), 173 | nn.BatchNorm1d(nfc), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | 176 | nn.Linear(nfc, 512), 177 | nn.BatchNorm1d(512), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | 180 | nn.Linear(512, 1), 181 | nn.Sigmoid() 182 | ) 183 | 184 | def forward(self, x): 185 | # gpu_ids = None 186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1: 187 | gpu_ids = self.gpu_ids 188 | 189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids) 190 | 191 | return x 192 | 193 | class DecD(nn.Module): 194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None): 195 | super(DecD, self).__init__() 196 | 197 | self.linear = nn.Linear(1,1) 198 | 199 | def forward(self, x): 200 | return x 201 | 202 | 203 | -------------------------------------------------------------------------------- /integrated_cell/networks/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/networks/proto/__init__.py -------------------------------------------------------------------------------- /integrated_cell/networks/ref_target_autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import utils 4 | 5 | 6 | class Autoencoder(nn.Module): 7 | def __init__(self, ref_enc, ref_dec, target_enc, target_dec): 8 | super(Autoencoder, self).__init__() 9 | 10 | self.ref_enc = ref_enc 11 | self.ref_dec = ref_dec 12 | self.ref_n_latent_dim = ref_enc.n_latent_dim 13 | 14 | self.target_enc = target_enc 15 | self.target_dec = target_dec 16 | self.target_n_latent_dim = target_enc.n_latent_dim 17 | 18 | def encode(self, target, ref, labels): 19 | batch_size = labels.shape[0] 20 | 21 | if ref is None: 22 | # generate ref image 23 | self.z_ref = ( 24 | torch.zeros([batch_size, self.ref_n_latent_dim]) 25 | .type_as(labels) 26 | .normal_() 27 | ) 28 | 29 | else: 30 | self.z_ref = self.ref_enc(ref) 31 | 32 | if target is None: 33 | 34 | self.z_target = ( 35 | torch.zeros([batch_size, self.target_n_latent_dim]) 36 | .type_as(labels) 37 | .normal_() 38 | ) 39 | 40 | else: 41 | self.z_target = self.target_enc(target, ref, labels) 42 | 43 | return self.z_ref, self.z_target 44 | 45 | def forward(self, target, ref, labels): 46 | 47 | batch_size = labels.shape[0] 48 | 49 | if ref is None: 50 | # generate ref image 51 | self.z_ref = ( 52 | torch.zeros([batch_size, self.ref_n_latent_dim]) 53 | .type_as(labels) 54 | .normal_() 55 | ) 56 | 57 | ref_hat = self.ref_dec(self.z_ref) 58 | ref = ref_hat 59 | else: 60 | 61 | self.z_ref = self.ref_enc(ref) 62 | self.z_ref_sampled = utils.reparameterize(self.z_ref[0], self.z_ref[1]) 63 | 64 | ref_hat = self.ref_dec(self.z_ref_sampled) 65 | 66 | if target is None: 67 | 68 | self.z_target = ( 69 | torch.zeros([batch_size, self.target_n_latent_dim]) 70 | .type_as(labels) 71 | .normal_() 72 | ) 73 | 74 | target_hat = self.target_dec(self.z_target, ref, labels) 75 | else: 76 | self.z_target = self.target_enc(target, ref, labels) 77 | self.z_target_sampled = utils.reparameterize( 78 | self.z_target[0], self.z_target[1] 79 | ) 80 | target_hat = self.target_dec(self.z_target_sampled, ref, labels) 81 | 82 | return target_hat, ref_hat 83 | -------------------------------------------------------------------------------- /integrated_cell/param_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import subprocess 4 | import time 5 | import tqdm 6 | import itertools 7 | import pandas as pd 8 | 9 | import pdb 10 | 11 | parent_dir = '/root/results/integrated_cell/param_search' 12 | if not os.path.exists(parent_dir): 13 | os.makedirs(parent_dir) 14 | 15 | scripts_dir = parent_dir + os.sep + 'scripts' 16 | if not os.path.exists(scripts_dir): 17 | os.makedirs(scripts_dir) 18 | 19 | 20 | job_list_path = parent_dir + 'job_list.csv' 21 | 22 | job_template = 'cd .. \n' \ 23 | 'python /root/projects/pytorch_integrated_cell/train_model.py \\\n' \ 24 | '\t--gpu_ids {0} \\\n' \ 25 | '\t--save_dir {1} \\\n' \ 26 | '\t--model_name aaegan3Dv6-exp-half \\\n' \ 27 | '\t--lrEnc {2} --lrDec {2} --lrEncD {3} --lrDecD {2} \\\n' \ 28 | '\t--encDRatio {4} --decDRatio {5} \\\n' \ 29 | '\t--noise {6} \\\n' \ 30 | '\t--nlatentdim 128 \\\n' \ 31 | '\t--batch_size 30 \\\n' \ 32 | '\t--nepochs 25 --nepochs_pt2 0 \\\n' \ 33 | '\t--train_module aaegan_trainv6 \\\n' \ 34 | '\t--imdir /root/data/ipp/ipp_17_10_25 \\\n' \ 35 | '\t--dataProvider DataProvider3Dh5-half \\\n' \ 36 | '\t--saveStateIter 1 --saveProgressIter 1 \\\n' \ 37 | '\t--channels_pt1 0 2 --channels_pt2 0 1 2 \\\n' 38 | 39 | def format_job(string, gpu_id, row): 40 | return string.format(gpu_id, row['job_dir'], row['lr'], row['lrDec'], row['encDRatio'], row['decDRatio'], row['noise']) 41 | 42 | 43 | # job_template = 'sleep {0!s}' 44 | # def format_job(string): 45 | # return string.format(np.random.choice(10, 1)[0]) 46 | job_path_template = scripts_dir + os.sep + 'param_job_{0!s}.sh' 47 | job_dir_template = parent_dir + os.sep + 'param_job_{0!s}' 48 | 49 | if os.path.exists(job_list_path): 50 | job_list = pd.read_csv(job_list_path) 51 | else: 52 | lr_range = [5E-4, 2E-4, 1E-4, 5E-5, 1E-5, 5E-6] 53 | lrDec_range = [1E-1, 5E-2, 1E-2, 5E-3, 1E-3, 1E-4, 5E-5, 1E-5] 54 | encDRatio_range = [1E-1, 5E-2, 1E-2, 5E-3, 1E-3, 5E-4, 1E-4, 5E-5, 1E-5, 5E-6, 1E-6] 55 | decDRatio_range = encDRatio_range 56 | noise_range = [0.1, 0.01, 0.001] 57 | 58 | param_opts = itertools.product(lr_range, lrDec_range, encDRatio_range, decDRatio_range, noise_range) 59 | 60 | opts_list = [list(i) for i in param_opts] 61 | 62 | job_scripts = [job_path_template.format(i) for i in range(0, len(opts_list))] 63 | job_dirs = [job_dir_template.format(i) for i in range(0, len(opts_list))] 64 | 65 | all_opts = [opt+[job_dir]+[job_scripts] for opt, job_dir, job_scripts in zip(opts_list, job_dirs, job_scripts)] 66 | 67 | job_list = pd.DataFrame(all_opts, columns = ['lr', 'lrDec', 'encDRatio', 'decDRatio', 'noise', 'job_dir', 'job_script']) 68 | 69 | job_list.to_csv(job_list_path) 70 | 71 | 72 | # randomly shuffle 73 | job_list = job_list.sample(frac=1) 74 | 75 | 76 | 77 | 78 | processes = range(0, 8) 79 | job_process_list = [-1] * len(processes) 80 | 81 | #for every job 82 | for index, row in job_list.iterrows(): 83 | 84 | #check to see if the job exists 85 | job_script = row['job_script'] 86 | if os.path.exists(job_script): 87 | continue 88 | 89 | #if it doesnt exist, wait for 90 | job_status = None 91 | 92 | #look for a process that isn't running 93 | while job_status is None: 94 | for job_process, process_id in zip(job_process_list, processes): 95 | 96 | #check process status 97 | if job_process == -1: 98 | job_status = True 99 | break 100 | else: 101 | job_status = job_process.poll() 102 | 103 | if job_status is not None: 104 | break 105 | 106 | if job_status is not None: 107 | break 108 | 109 | print('Waiting for next job') 110 | time.sleep(3) 111 | 112 | #start a job 113 | #get the job string 114 | job_str = format_job(job_template, process_id, row) 115 | 116 | #write the bash file to disk 117 | 118 | print('starting ' + job_script + ' on process ' + str(process_id)) 119 | with open(job_script, 'w') as text_file: 120 | text_file.write(job_str) 121 | 122 | pdb.set_trace() 123 | job_process_list[process_id] = subprocess.Popen('bash ' + job_script, stdout=subprocess.PIPE, shell=True) 124 | 125 | print(job_template) -------------------------------------------------------------------------------- /integrated_cell/simplelogger.py: -------------------------------------------------------------------------------- 1 | class SimpleLogger: 2 | def __init__(self, fields, print_format=""): 3 | if isinstance(print_format, str) and not print_format: 4 | printstr = "" 5 | for field in fields: 6 | printstr = printstr + field + ": %f " 7 | print_format = printstr 8 | 9 | self.print_format = print_format 10 | 11 | self.fields = fields 12 | 13 | self.log = dict() 14 | for field in fields: 15 | self.log[field] = [] 16 | 17 | def add(self, inputs): 18 | 19 | assert len(inputs) == len(self.fields) 20 | 21 | for i in range(0, len(self.fields)): 22 | self.log[self.fields[i]].append(inputs[i]) 23 | 24 | if isinstance(self.print_format, str): 25 | print(self.print_format % tuple(inputs)) 26 | 27 | def __len__(self): 28 | return len(self.log[self.fields[0]]) 29 | -------------------------------------------------------------------------------- /integrated_cell/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/__init__.py -------------------------------------------------------------------------------- /integrated_cell/tests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from pathlib import Path 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def data_dir() -> Path: 10 | return Path(__file__).parent / "resources" 11 | -------------------------------------------------------------------------------- /integrated_cell/tests/resources/img2D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/resources/img2D.png -------------------------------------------------------------------------------- /integrated_cell/tests/resources/img3D.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/resources/img3D.tiff -------------------------------------------------------------------------------- /integrated_cell/tests/resources/test_data.csv: -------------------------------------------------------------------------------- 1 | CellId,StructureId/Name,CellLineId/Name,save_reg_path,save_reg_path_flat,ch_dna,ch_memb,ch_seg_cell,ch_seg_nuc,ch_struct,ch_trans 2 | 0,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5 3 | 1,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5 4 | 2,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5 5 | 3,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5 6 | 4,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5 7 | 5,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5 8 | 6,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5 9 | 7,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5 10 | 8,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5 11 | 9,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5 12 | 10,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5 13 | 11,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5 14 | 12,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5 15 | 13,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5 16 | 14,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5 17 | 15,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5 18 | 16,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5 19 | 17,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5 20 | 18,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5 21 | 19,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5 -------------------------------------------------------------------------------- /integrated_cell/tests/test_dataprovider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ..data_providers import DataProvider 4 | 5 | import pandas as pd 6 | 7 | # TODO: Add croping, rescaling, normalization tests 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "batch_size, n_dat, hold_out, channelInds, return2D, rescale_to, crop_to, normalize_intensity", 12 | [ 13 | [1, -1, 0.1, [0, 1, 2], False, None, None, False], 14 | [2, -1, 0.1, [0, 1, 2], False, None, None, False], 15 | [2, -1, 0.1, [0, 1, 2], True, None, None, False], 16 | ], 17 | ) 18 | def test_dataprovider( 19 | data_dir, 20 | batch_size, 21 | n_dat, 22 | hold_out, 23 | channelInds, 24 | return2D, 25 | rescale_to, 26 | crop_to, 27 | normalize_intensity, 28 | ): 29 | csv_name = "test_data.csv" 30 | test_data_csv = "{}/{}".format(data_dir, csv_name) 31 | 32 | df = pd.read_csv(test_data_csv) 33 | 34 | dp = DataProvider.DataProvider( 35 | image_parent=data_dir, 36 | csv_name=csv_name, 37 | batch_size=batch_size, 38 | n_dat=n_dat, 39 | hold_out=hold_out, 40 | channelInds=channelInds, 41 | return2D=return2D, 42 | rescale_to=rescale_to, 43 | crop_to=crop_to, 44 | normalize_intensity=normalize_intensity, 45 | ) 46 | 47 | n_dat = 0 48 | for group in ["train", "test", "validate"]: 49 | n_dat += dp.get_n_dat(group) 50 | 51 | assert n_dat == len(df) 52 | 53 | x, classes, ref = dp.get_sample() 54 | 55 | assert x.shape[0] == batch_size 56 | assert x.shape[1] == len(channelInds) 57 | 58 | if return2D: 59 | assert len(x.shape) == 4 60 | else: 61 | assert len(x.shape) == 5 62 | -------------------------------------------------------------------------------- /integrated_cell/tests/test_dummy.py: -------------------------------------------------------------------------------- 1 | def test_dummy(): 2 | assert True 3 | -------------------------------------------------------------------------------- /integrated_cell/tests/test_kld.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_kld(): 5 | with pytest.mark.raises(NotImplementedError): 6 | raise NotImplementedError 7 | -------------------------------------------------------------------------------- /integrated_cell/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .spectral_norm import spectral_norm, remove_spectral_norm # noqa 2 | from .utils import * # noqa 3 | -------------------------------------------------------------------------------- /integrated_cell/utils/build_control_images.py: -------------------------------------------------------------------------------- 1 | import tifffile 2 | import pandas as pd 3 | import os 4 | from tqdm import tqdm 5 | import numpy as np 6 | import scipy 7 | from scipy.stats import norm 8 | 9 | from .utils import str2rand 10 | from .imgToProjection import imgtoprojection 11 | 12 | 13 | # I mostly copied over the parts from the IPP to clone the processing to make these control images. 14 | # This should be moved into the IPP later - GRJ 05/2019 15 | 16 | 17 | def build_control_images( 18 | save_dir, 19 | csv_path, 20 | image_parent, 21 | split_seed=615, 22 | verbose=True, 23 | target_col="StructureId/Name", 24 | ): 25 | # constructs "control" images that contain the following content as the "structure" channel: 26 | # DNA label 27 | # Membrane label 28 | # Random structure from a different cell 29 | # Blank (all zeros) 30 | # Poisson noise 31 | 32 | # make a dataframe out of the csv log file 33 | if verbose: 34 | print("reading csv manifest") 35 | csv_df = pd.read_csv(csv_path) 36 | 37 | image_col = "save_reg_path" 38 | 39 | csv_df["ch_memb"] = 3 40 | csv_df["ch_struct"] = 4 41 | csv_df["ch_dna"] = 2 42 | csv_df["ch_seg_cell"] = 1 43 | csv_df["ch_seg_nuc"] = 0 44 | csv_df["ch_trans"] = 5 45 | 46 | ch_names = ["nuc", "cell", "dna", "memb", "struct", "trans"] 47 | 48 | im_paths = list() 49 | 50 | for i, im_path in enumerate(csv_df[image_col]): 51 | splits = np.array(im_path.split("/")) 52 | lens = np.array([len(s) for s in splits]) 53 | splits = splits[lens > 0] 54 | im_paths += ["/".join(splits[-2::])] 55 | 56 | csv_df[image_col] = im_paths 57 | 58 | # csv_df = check_files(csv_df, image_col, verbose) 59 | 60 | # pick some cells at random 61 | image_classes = list(csv_df[target_col]) 62 | [label_names, labels] = np.unique(image_classes, return_inverse=True) 63 | 64 | n_classes = len(label_names) 65 | fract_sample_labels = 1 / n_classes 66 | 67 | rand = str2rand(csv_df["CellId"], split_seed) 68 | rand_struct = np.argsort(str2rand(csv_df["CellId"], split_seed + 1)) 69 | 70 | duplicate_inds = rand <= fract_sample_labels 71 | 72 | df_controls = csv_df[duplicate_inds].copy().reset_index() 73 | 74 | random_scale = norm.ppf(0.999) 75 | 76 | channels_to_make = ["DNA", "Memb", "Blank", "Noise", "Random"] 77 | 78 | if not os.path.exists(save_dir): 79 | os.makedirs(save_dir) 80 | 81 | im = load_image(csv_df.iloc[0], image_parent, image_col) 82 | n_ch = im.shape[0] 83 | 84 | for index in tqdm(range(len(df_controls))): 85 | row = df_controls.iloc[index] 86 | 87 | save_path = "{}/control_{}.tiff".format(save_dir, index) 88 | 89 | save_path_flats = [ 90 | "{}/control_{}_{}_flat.png".format(save_dir, index, ch_name) 91 | for ch_name in channels_to_make 92 | ] 93 | save_path_flat_projs = [ 94 | "{}/control_{}_{}_flat_proj.png".format(save_dir, index, ch_name) 95 | for ch_name in channels_to_make 96 | ] 97 | 98 | try: 99 | im = load_image(row, image_parent, image_col) 100 | except: # noqa 101 | print("Skipping image {}".format(row[image_col])) 102 | continue 103 | 104 | n_ch = im.shape[0] 105 | 106 | ch_dna = im[row["ch_dna"]] * (im[row["ch_seg_nuc"]] > 0) 107 | ch_memb = im[row["ch_memb"]] 108 | 109 | ch_blank = np.zeros(im[0].shape) 110 | 111 | ch_noise = (np.random.normal(0, 1, im[0].shape) / random_scale) * 255 112 | ch_noise[ch_noise < 0] = 0 113 | ch_noise[ch_noise > 255] = 255 114 | 115 | random_image_row = csv_df.iloc[rand_struct[index]] 116 | ch_random = load_image(random_image_row, image_parent, image_col)[ 117 | random_image_row["ch_struct"] 118 | ] 119 | 120 | im_out = np.concatenate( 121 | [im] 122 | + [ 123 | np.expand_dims(ch, 0) 124 | for ch in [ch_dna, ch_memb, ch_blank, ch_noise, ch_random] 125 | ], 126 | 0, 127 | ) 128 | 129 | im_tmp = im_out.astype("uint8") 130 | tifffile.imsave(save_path, im_tmp) 131 | 132 | im_crop = crop_cell_nuc(im_out, ch_names + channels_to_make) 133 | 134 | row_list = list() 135 | for i, channel in enumerate(channels_to_make): 136 | name = "Control - {}".format(channel) 137 | 138 | row = row.copy() 139 | 140 | row["StructureDisplayName"] = name 141 | row["StructureId"] = -i 142 | row["StructureShortName"] = name 143 | row["ProteinId/Name"] = name 144 | row["StructureId/Name"] = name 145 | 146 | row["save_reg_path"] = save_path 147 | row["save_reg_path_flat"] = save_path_flats[i] 148 | row["save_reg_path_flat_proj"] = save_path_flat_projs[i] 149 | 150 | row["ch_struct"] = n_ch + i 151 | 152 | row_list.append(row) 153 | 154 | im_tmp = im_crop[[2, 3, n_ch + i]] 155 | colors = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] 156 | 157 | im_flat = imgtoprojection( 158 | im_tmp / 255, 159 | proj_all=True, 160 | proj_method="max", 161 | local_adjust=False, 162 | global_adjust=True, 163 | colors=colors, 164 | ) 165 | scipy.misc.imsave( 166 | row["save_reg_path_flat_proj"], im_flat.transpose(1, 2, 0) 167 | ) 168 | 169 | im_flat = imgtoprojection( 170 | im_tmp / 255, 171 | proj_all=False, 172 | proj_method="max", 173 | local_adjust=False, 174 | global_adjust=True, 175 | colors=colors, 176 | ) 177 | scipy.misc.imsave(row["save_reg_path_flat"], im_flat.transpose(1, 2, 0)) 178 | 179 | csv_df = csv_df.append(row_list) 180 | 181 | csv_df.to_csv("{}/data_plus_controls.csv".format(save_dir)) 182 | 183 | 184 | def check_files(csv_df, image_col="save_reg_path", verbose=True): 185 | # TODO add checking to make sure number of keys in h5 file matches number of lines in csv file 186 | if verbose: 187 | print("Checking the existence of files") 188 | 189 | for index in tqdm(range(0, csv_df.shape[0]), desc="Checking files", ascii=True): 190 | is_good_row = True 191 | 192 | row = csv_df.loc[index] 193 | 194 | image_path = os.sep + row[image_col] 195 | 196 | try: 197 | load_image(row) 198 | except: # noqa 199 | print("Could not load from image. " + image_path) 200 | is_good_row = False 201 | 202 | csv_df.loc[index, "valid_row"] = is_good_row 203 | 204 | # only work with valid rows 205 | n_old_rows = len(csv_df) 206 | csv_df = csv_df.loc[csv_df["valid_row"] == True] # noqa 207 | csv_df = csv_df.drop("valid_row", 1) 208 | n_new_rows = len(csv_df) 209 | if verbose: 210 | print("{0}/{1} samples have all files present".format(n_new_rows, n_old_rows)) 211 | 212 | return csv_df 213 | 214 | 215 | def load_image(df_row, image_parent, image_col): 216 | 217 | im_path = image_parent + os.sep + df_row[image_col] 218 | 219 | im = tifffile.imread(im_path) 220 | 221 | return im 222 | 223 | 224 | def crop_cell_nuc(im, channel_names): 225 | 226 | nuc_ind = np.array(channel_names) == "nuc" 227 | cell_ind = np.array(channel_names) == "cell" 228 | 229 | dna_ind = np.array(channel_names) == "dna" 230 | # trans_ind = np.array(channel_names) == "trans" 231 | 232 | other_channel_inds = np.ones(len(channel_names)) 233 | other_channel_inds[nuc_ind | cell_ind | dna_ind] = 0 234 | 235 | im[dna_ind] = im[dna_ind] * im[nuc_ind] 236 | 237 | for i in np.where(other_channel_inds)[0]: 238 | im[i] = im[i] * im[cell_ind] 239 | 240 | return im 241 | -------------------------------------------------------------------------------- /integrated_cell/utils/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from skimage.measure import label 4 | from skimage.filters import gaussian, threshold_otsu 5 | from scipy.ndimage.morphology import binary_fill_holes 6 | from aicsfeature.extractor import dna, cell 7 | 8 | 9 | def im2feats( 10 | im_cell, 11 | im_nuc, 12 | im_structures, 13 | extra_features=["io_intensity", "bright_spots", "intensity", "skeleton", "texture"], 14 | ): 15 | # im_cell = Y x X x Z binary numpy array 16 | # im_nuc = Y x X x Z binary numpy array 17 | # im_structures = C x Y x X x Z numpy array 18 | 19 | # im_cell, im_nuc, im_structures, im_structures_seg, seg_cell, seg_nuc = im_process( 20 | # im_cell, im_nuc, im_structures 21 | # ) 22 | 23 | nuc_feats = dna.get_features(im_nuc, extra_features=extra_features) 24 | cell_feats = cell.get_features(im_cell, extra_features=extra_features) 25 | 26 | # feats_out = aicsfeature.kitchen_sink.kitchen_sink( 27 | # im_cell=im_cell, 28 | # im_nuc=im_nuc, 29 | # im_structures=im_structures, 30 | # seg_cell=seg_cell, 31 | # seg_nuc=seg_nuc, 32 | # extra_features=extra_features, 33 | # ) 34 | # feats_out_seg = aicsfeature.kitchen_sink.kitchen_sink( 35 | # im_cell=im_cell, 36 | # im_nuc=im_nuc, 37 | # im_structures=im_structures_seg, 38 | # seg_cell=seg_cell, 39 | # seg_nuc=seg_nuc, 40 | # extra_features=extra_features, 41 | # ) 42 | 43 | return [nuc_feats, cell_feats] 44 | 45 | 46 | def im_process(im_cell, im_nuc, im_structures): 47 | seg_nuc = binary_fill_holes(im_nuc) 48 | seg_cell = binary_fill_holes(im_cell) 49 | seg_cell[seg_nuc] = 1 50 | 51 | im_nuc = (im_nuc * (255)).astype("uint16") * seg_nuc 52 | im_cell = (im_cell * (255)).astype("uint16") * seg_cell 53 | 54 | im_structures = [(im_ch * (255)).astype("uint16") for im_ch in im_structures] 55 | 56 | im_structures_seg = list() 57 | 58 | for i, im_structure in enumerate(im_structures): 59 | im_blur = gaussian(im_structure, 1) 60 | 61 | im_pix = im_structure[im_cell > 0] 62 | if np.all(im_pix == 0): 63 | im_structures_seg.append(im_structure) 64 | continue 65 | 66 | im_structures_seg.append( 67 | im_structure * (im_blur > threshold_otsu(im_blur[im_cell > 0])) 68 | ) 69 | 70 | return im_cell, im_nuc, im_structures, im_structures_seg, seg_cell, seg_nuc 71 | 72 | 73 | def find_main_obj(im_seg): 74 | im_label = label(im_seg) 75 | 76 | obj_index = -1 77 | max_cell_obj_size = -1 78 | for i in range(1, np.max(im_label) + 1): 79 | obj_size = np.sum(im_label == i) 80 | if obj_size > max_cell_obj_size: 81 | max_cell_obj_size = obj_size 82 | obj_index = i 83 | 84 | main_obj = im_label == obj_index 85 | 86 | return main_obj 87 | -------------------------------------------------------------------------------- /integrated_cell/utils/image_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def crop_to(images, crop_to): 6 | crop = (np.array(images.shape[2:]) - np.array(crop_to)) / 2 7 | crop_pre = np.floor(crop).astype(int) 8 | crop_post = np.ceil(crop).astype(int) 9 | 10 | pad_pre = -crop_pre 11 | pad_pre[pad_pre < 0] = 0 12 | 13 | pad_post = -crop_post 14 | pad_post[pad_post < 0] = 0 15 | 16 | crop_pre[crop_pre < 0] = 0 17 | 18 | crop_post[crop_post < 0] = 0 19 | crop_post[crop_post == 0] = -np.array(images.shape[2:])[crop_post == 0] 20 | 21 | if len(crop_pre) == 2: 22 | images = images[ 23 | :, 24 | :, 25 | crop_pre[0] : -crop_post[0], # noqa 26 | crop_pre[1] : -crop_post[1], # noqa 27 | ] 28 | 29 | elif len(crop_pre) == 3: 30 | images = images[ 31 | :, 32 | :, 33 | crop_pre[0] : -crop_post[0], # noqa 34 | crop_pre[1] : -crop_post[1], # noqa 35 | crop_pre[2] : -crop_post[2], # noqa 36 | ] 37 | 38 | pad_pre = np.hstack([np.zeros(2), pad_pre]) 39 | pad_post = np.hstack([np.zeros(2), pad_post]) 40 | 41 | padding = np.vstack([pad_pre, pad_post]).transpose().astype("int") 42 | 43 | images = np.pad(images, padding, mode="constant", constant_values=0) 44 | 45 | images = torch.tensor(images) 46 | -------------------------------------------------------------------------------- /integrated_cell/utils/imgToProjection.py: -------------------------------------------------------------------------------- 1 | # Author: Evan Wiederspan 2 | 3 | import numpy as np 4 | import matplotlib 5 | 6 | matplotlib.use("agg") 7 | import matplotlib.pyplot as pplot # noqa 8 | 9 | 10 | def matproj(im, dim, method="max", slice_index=0): 11 | if method == "max": 12 | im = np.max(im, dim) 13 | elif method == "mean": 14 | im = np.mean(im, dim) 15 | elif method == "sum": 16 | im = np.sum(im, dim) 17 | elif method == "slice": 18 | im = im[slice_index] 19 | else: 20 | raise ValueError("Invalid projection method") 21 | return im 22 | 23 | 24 | def imgtoprojection( 25 | im1, 26 | proj_all=False, 27 | proj_method="max", 28 | colors=lambda i: [1, 1, 1], 29 | global_adjust=False, 30 | local_adjust=False, 31 | ): 32 | """ 33 | Outputs projections of a 4d CZYX numpy array into a CYX numpy array, allowing for color masks for each input channel 34 | as well as adjustment options 35 | :param im1: Either a 4d numpy array or a list of 3D or 2D numpy arrays. The input that will be projected 36 | :param proj_all: boolean. True outputs XY, YZ, and XZ projections in a grid, False just outputs XY. False by default 37 | :param proj_method: string. Method by which to do projections. 'Max' by default 38 | :param colors: Can be either a string which corresponds to a cmap function in matplotlib, a function that 39 | takes in the channel index and returns a list of numbers, or a list of lists containing the color multipliers. 40 | :param global_adjust: boolean. If true, scales each color channel to set its max to be 255 41 | after combining all channels. False by default 42 | :param local_adjust: boolean. If true, performs contrast adjustment on each channel individually. False by default 43 | :return: a CYX numpy array containing the requested projections 44 | """ 45 | 46 | # turn list of 2d or 3d arrays into single 4d array if needed 47 | try: 48 | if isinstance(im1, (list, tuple)): 49 | # if only YX, add a single Z dimen 50 | if im1[0].ndim == 2: 51 | im1 = [np.expand_dims(c, axis=0) for c in im1] 52 | elif im1[0].ndim != 3: 53 | raise ValueError("im1 must be a list of 2d or 3d arrays") 54 | # combine list into 4d array 55 | im = np.stack(im1) 56 | else: 57 | if im1.ndim != 4: 58 | raise ValueError("Invalid dimensions for im1") 59 | im = im1 60 | 61 | except (AttributeError, IndexError): 62 | # its not a list of np arrays 63 | raise ValueError( 64 | "im1 must be either a 4d numpy array or a list of numpy arrays" 65 | ) 66 | 67 | # color processing code 68 | if isinstance(colors, str): 69 | # pass it in to matplotlib 70 | try: 71 | colors = pplot.get_cmap(colors)(np.linspace(0, 1, im.shape[0])) 72 | except ValueError: 73 | # thrown when string is not valid function 74 | raise ValueError("Invalid cmap string") 75 | elif callable(colors): 76 | # if its a function 77 | try: 78 | colors = [colors(i) for i in range(im.shape[0])] 79 | except ValueError: 80 | raise ValueError("Invalid color function") 81 | 82 | # else, were assuming it's a list 83 | # scale colors down to 0-1 range if they're bigger than 1 84 | if any(v > 1 for v in np.array(colors).flatten()): 85 | colors = [[v / 255.0 for v in c] for c in colors] 86 | 87 | # create final image 88 | if not proj_all: 89 | img_final = np.zeros((3, im.shape[2], im.shape[3])) 90 | else: 91 | # y + z, x + z 92 | img_final = np.zeros((3, im.shape[2] + im.shape[1], im.shape[3] + im.shape[1])) 93 | img_piece = np.zeros(img_final.shape) 94 | # loop through all channels 95 | for i, img_c in enumerate(im): 96 | try: 97 | proj_z = matproj(img_c, 0, proj_method, img_c.shape[0] // 2) 98 | if proj_all: 99 | proj_y, proj_x = ( 100 | matproj(img_c, axis, proj_method, img_c.shape[axis] // 2) 101 | for axis in range(1, 3) 102 | ) 103 | # flipping to get them facing the right way 104 | proj_x = np.transpose(proj_x, (1, 0)) 105 | proj_y = np.flipud(proj_y) 106 | _, sy, sz = proj_z.shape[1], proj_z.shape[0], proj_y.shape[0] # noqa 107 | img_piece[:, :sy, :sz] = proj_x 108 | img_piece[:, :sy, sz:] = proj_z 109 | img_piece[:, sy:, sz:] = proj_y 110 | else: 111 | img_piece[:] = proj_z 112 | except ValueError: 113 | raise ValueError("Invalid projection function") 114 | 115 | for c in range(3): 116 | img_piece[c] *= colors[i][c] 117 | 118 | # local contrast adjustment, minus the min, divide the max 119 | if local_adjust: 120 | img_piece -= np.min(img_piece) 121 | img_max = np.max(img_piece) 122 | if img_max > 0: 123 | img_piece /= img_max 124 | # img_final += img_piece 125 | img_final += (1 - img_final) * img_piece 126 | 127 | # color range adjustment, ensure that max value is 255 128 | if global_adjust: 129 | # scale color channels independently 130 | for c in range(3): 131 | max_val = np.max(img_final[c].flatten()) 132 | if max_val > 0: 133 | img_final[c] *= 255.0 / max_val 134 | 135 | return img_final 136 | -------------------------------------------------------------------------------- /integrated_cell/utils/reference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/utils/reference/__init__.py -------------------------------------------------------------------------------- /integrated_cell/utils/spectral_norm.py: -------------------------------------------------------------------------------- 1 | # stolen from pytorch 0.5 2 | # https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm 3 | 4 | """ 5 | Spectral Normalization from https://arxiv.org/abs/1802.05957 6 | """ 7 | import torch 8 | from torch.nn.functional import normalize 9 | 10 | 11 | class SpectralNorm(object): 12 | def __init__(self, name="weight", n_power_iterations=1, eps=1e-12): 13 | self.name = name 14 | if n_power_iterations <= 0: 15 | raise ValueError( 16 | "Expected n_power_iterations to be positive, but " 17 | "got n_power_iterations={}".format(n_power_iterations) 18 | ) 19 | self.n_power_iterations = n_power_iterations 20 | self.eps = eps 21 | 22 | def compute_weight(self, module): 23 | weight = getattr(module, self.name + "_org") 24 | u = getattr(module, self.name + "_u") 25 | height = weight.size(0) 26 | weight_mat = weight.view(height, -1) 27 | with torch.no_grad(): 28 | for _ in range(self.n_power_iterations): 29 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 30 | # are the first left and right singular vectors. 31 | # This power iteration produces approximations of `u` and `v`. 32 | v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 33 | u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 34 | 35 | # import pdb 36 | # pdb.set_trace() 37 | 38 | sigma = torch.dot( 39 | u.float(), torch.matmul(weight_mat.float(), v.float()).float() 40 | ) 41 | weight = weight / sigma 42 | return weight, u 43 | 44 | def remove(self, module): 45 | weight = module._parameters[self.name + "_org"] 46 | delattr(module, self.name) 47 | delattr(module, self.name + "_u") 48 | delattr(module, self.name + "_org") 49 | module.register_parameter(self.name, weight) 50 | 51 | def __call__(self, module, inputs): 52 | weight, u = self.compute_weight(module) 53 | setattr(module, self.name, weight) 54 | setattr(module, self.name + "_u", u) 55 | 56 | @staticmethod 57 | def apply(module, name, n_power_iterations, eps): 58 | fn = SpectralNorm(name, n_power_iterations, eps) 59 | weight = module._parameters[name] 60 | height = weight.size(0) 61 | 62 | u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 63 | delattr(module, fn.name) 64 | module.register_parameter(fn.name + "_org", weight) 65 | module.register_buffer(fn.name, weight) 66 | module.register_buffer(fn.name + "_u", u) 67 | 68 | module.register_forward_pre_hook(fn) 69 | return fn 70 | 71 | 72 | def spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12): 73 | r"""Applies spectral normalization to a parameter in the given module. 74 | 75 | .. math:: 76 | \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ 77 | \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 78 | 79 | Spectral normalization stabilizes the training of discriminators (critics) 80 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor 81 | with spectral norm :math:`\sigma` of the weight matrix calculated using 82 | power iteration method. If the dimension of the weight tensor is greater 83 | than 2, it is reshaped to 2D in power iteration method to get spectral 84 | norm. This is implemented via a hook that calculates spectral norm and 85 | rescales weight before every :meth:`~Module.forward` call. 86 | 87 | See `Spectral Normalization for Generative Adversarial Networks`_ . 88 | 89 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 90 | 91 | Args: 92 | module (nn.Module): containing module 93 | name (str, optional): name of weight parameter 94 | n_power_iterations (int, optional): number of power iterations to 95 | calculate spectal norm 96 | eps (float, optional): epsilon for numerical stability in 97 | calculating norms 98 | 99 | Returns: 100 | The original module with the spectal norm hook 101 | 102 | Example:: 103 | 104 | >>> m = spectral_norm(nn.Linear(20, 40)) 105 | Linear (20 -> 40) 106 | >>> m.weight_u.size() 107 | torch.Size([20]) 108 | 109 | """ 110 | SpectralNorm.apply(module, name, n_power_iterations, eps) 111 | return module 112 | 113 | 114 | def remove_spectral_norm(module, name="weight"): 115 | r"""Removes the spectral normalization reparameterization from a module. 116 | 117 | Args: 118 | module (nn.Module): containing module 119 | name (str, optional): name of weight parameter 120 | 121 | Example: 122 | >>> m = spectral_norm(nn.Linear(40, 10)) 123 | >>> remove_spectral_norm(m) 124 | """ 125 | for k, hook in module._forward_pre_hooks.items(): 126 | if isinstance(hook, SpectralNorm) and hook.name == name: 127 | hook.remove(module) 128 | del module._forward_pre_hooks[k] 129 | return module 130 | 131 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 132 | -------------------------------------------------------------------------------- /integrated_cell/utils/target/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * # noqa 2 | -------------------------------------------------------------------------------- /integrated_cell/utils/target/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sample2im(x, ref): 5 | n_channels = x.shape[1] + ref.shape[1] 6 | 7 | im = torch.zeros([x.shape[0], n_channels, x.shape[2], x.shape[3], x.shape[4]]) 8 | 9 | im[:, [0, 2]] = ref 10 | im[:, 1] = x 11 | 12 | return im 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/requirements.txt -------------------------------------------------------------------------------- /scripts/aaegan_short.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python /root/projects/pytorch_integrated_cell/train_model.py \ 3 | --gpu_ids $1 \ 4 | --save_dir ./results/aaegan_short \ 5 | --data_save_path ./results/data.pyt \ 6 | --lrEnc 2E-4 --lrDec 2E-4 \ 7 | --lrEncD 2E-2 --lrDecD 1E-4 \ 8 | --model_name aaegan3Dv6-relu-exp \ 9 | --train_module aaegan_trainv8 \ 10 | --kwargs_encD '{"noise_std": 0}' \ 11 | --kwargs_decD '{"noise_std": 2E-1}' \ 12 | --kwargs_optim '{"betas": [0, 0.9]}' \ 13 | --kwargs_model '{"lambda_encD_loss": 10, "lambda_decD_loss": 1, "lambda_class_loss": 1000, "lambda_ref_loss": 1000, "provide_decoder_vars": 1}' \ 14 | --imdir /root/results/ipp/ipp_17_10_25 \ 15 | --dataProvider DataProvider3Dh5 \ 16 | --saveStateIter 1 --saveProgressIter 1 \ 17 | --channels_pt1 0 2 --channels_pt2 0 1 2 \ 18 | --batch_size 16 \ 19 | --nlatentdim 128 \ 20 | --nepochs 1 \ 21 | --nepochs_pt2 1 \ 22 | --ndat 32 \ 23 | --overwrite_opts True \ 24 | -------------------------------------------------------------------------------- /scripts/bvae.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python ./train_model.py \ 3 | --gpu_ids $1 \ 4 | --save_parent ./results/test_bvae/ \ 5 | --lrEnc 1E-4 --lrDec 1E-4 \ 6 | --data_save_path ./results/data.pyt \ 7 | --critRecon BCELoss \ 8 | --model_name vaaegan3D \ 9 | --kwargs_model '{"beta": 1}' \ 10 | --train_module bvae \ 11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \ 12 | --imdir /root/results/ipp/ipp_17_10_25 \ 13 | --dataProvider DataProvider3Dh5 \ 14 | --saveStateIter 1 --saveProgressIter 1 \ 15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \ 16 | --batch_size 16 \ 17 | --nlatentdim 128 \ 18 | --nepochs 50 \ 19 | --nepochs_pt2 50 \ 20 | -------------------------------------------------------------------------------- /scripts/bvae_short.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | 4 | python /root/projects/pytorch_integrated_cell/train_model.py \ 5 | --gpu_ids $1 \ 6 | --save_dir ./results/bvae_short \ 7 | --data_save_path ./results/data.pyt \ 8 | --lrEnc 2E-4 --lrDec 2E-4 \ 9 | --model_name vaaegan3D \ 10 | --train_module bvae \ 11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \ 12 | --imdir /root/results/ipp/ipp_17_10_25 \ 13 | --dataProvider DataProvider3Dh5 \ 14 | --saveStateIter 1 --saveProgressIter 1 \ 15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \ 16 | --batch_size 16 \ 17 | --nlatentdim 128 \ 18 | --nepochs 5 \ 19 | --nepochs_pt2 5 \ 20 | --ndat 32 \ 21 | --overwrite_opts True \ 22 | -------------------------------------------------------------------------------- /scripts/bvaegan_short.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python /root/projects/pytorch_integrated_cell/train_model.py \ 3 | --gpu_ids $1 \ 4 | --save_dir ./results/bvaegan_short \ 5 | --data_save_path ./results/data.pyt \ 6 | --lrEnc 2E-4 --lrDec 2E-4 \ 7 | --lrDecD 2E-4 \ 8 | --model_name vaaegan3D \ 9 | --train_module bvaegan \ 10 | --kwargs_decD '{"noise_std": 1E-1}' \ 11 | --kwargs_optim '{"betas": [0, 0.9]}' \ 12 | --imdir /root/results/ipp/ipp_17_10_25 \ 13 | --dataProvider DataProvider3Dh5 \ 14 | --saveStateIter 1 --saveProgressIter 1 \ 15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \ 16 | --batch_size 16 \ 17 | --nlatentdim 128 \ 18 | --nepochs 1 \ 19 | --nepochs_pt2 1 \ 20 | --ndat 32 \ 21 | --overwrite_opts True \ 22 | -------------------------------------------------------------------------------- /scripts/multi_pred.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | 4 | python /root/projects/pytorch_integrated_cell/train_model.py \ 5 | --gpu_ids $1 \ 6 | --save_dir ./results/multi_pred \ 7 | --data_save_path ./results/data.pyt \ 8 | --lrEnc 2E-4 --lrDec 2E-4 \ 9 | --model_name multi_pred \ 10 | --train_module multi_pred \ 11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \ 12 | --imdir /root/results/ipp/ipp_17_10_25 \ 13 | --dataProvider DataProvider3Dh5 \ 14 | --saveStateIter 1 --saveProgressIter 1 \ 15 | --channels_pt1 0 2 3 4 5 1 --channels_pt2 0 \ 16 | --batch_size 16 \ 17 | --nlatentdim 128 \ 18 | --nepochs 100 \ 19 | --nepochs_pt2 0 \ 20 | --overwrite_opts True \ 21 | -------------------------------------------------------------------------------- /scripts/multi_pred_short.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | 4 | python /root/projects/pytorch_integrated_cell/train_model.py \ 5 | --gpu_ids $1 \ 6 | --save_dir ./results/multi_pred_short \ 7 | --data_save_path ./results/data.pyt \ 8 | --lrEnc 2E-4 --lrDec 2E-4 \ 9 | --model_name multi_pred \ 10 | --train_module multi_pred \ 11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \ 12 | --imdir /root/results/ipp/ipp_17_10_25 \ 13 | --dataProvider DataProvider3Dh5 \ 14 | --saveStateIter 1 --saveProgressIter 1 \ 15 | --channels_pt1 0 2 3 4 5 1 --channels_pt2 0 \ 16 | --batch_size 16 \ 17 | --nlatentdim 128 \ 18 | --nepochs 5 \ 19 | --nepochs_pt2 0 \ 20 | --ndat 32 \ 21 | --overwrite_opts True \ 22 | -------------------------------------------------------------------------------- /scripts/test_build_control_images.py: -------------------------------------------------------------------------------- 1 | from integrated_cell.utils.build_control_images import build_control_images 2 | 3 | save_dir = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/controls" 4 | csv_path = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/data_jobs_out.csv" 5 | image_parent = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/" 6 | 7 | build_control_images(save_dir, csv_path, image_parent) 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | exclude_dirs = ["exmples", "doc", "scripts"] 4 | 5 | PACKAGES = find_packages(exclude=exclude_dirs) 6 | 7 | setup( 8 | name="pytorch_integrated_cell", 9 | version="0.1", 10 | packages=PACKAGES, 11 | entry_points={ 12 | "console_scripts": ["ic_train_model=integrated_cell.bin.train_model:main"] 13 | }, 14 | install_requires=[ 15 | "torch==1.2.0", 16 | "torchvision==0.2.1", 17 | "matplotlib==2.2.2", 18 | "numpy>=1.15.0", 19 | "pandas>=0.23.4", 20 | "pip", 21 | "pillow==5.2.0", 22 | "scikit-image==0.15.0", 23 | "scipy==1.1.0", 24 | "tqdm>=4.28.1", 25 | "natsort==5.3.3", 26 | "ipykernel", 27 | "aicsimageio==3.0.7", 28 | "msgpack<0.6.0,>=0.5.6", 29 | "imageio==2.6.0", 30 | "quilt3==3.1.1", 31 | "seaborn", 32 | "brokenaxes", 33 | "lkaccess", 34 | "aicsfeature==0.2.1", 35 | ], 36 | ) 37 | --------------------------------------------------------------------------------