├── .dockerignore
├── .flake8
├── .github
├── ISSUE_TEMPLATE
│ └── manuscript.md
└── workflows
│ └── build.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── Dockerfile
├── README.md
├── doc
├── benchmarks.md
├── getting_started.md
├── images
│ ├── model_arch.png
│ └── stats_apex_vs_not_apex.png
├── training_examples.md
└── training_notes.md
├── download_data.py
├── examples
├── 3D_benchmark_tests
│ ├── run_3D.sh
│ ├── run_3D_big.sh
│ ├── run_3D_small.sh
│ ├── run_apex_tests.py
│ ├── run_benchmark_tests.py
│ ├── run_docker.sh
│ └── run_docker_apex.sh
├── latent_variation.ipynb
├── loading_models.ipynb
├── paper_plos_2020
│ ├── 1_model_compare_2D.ipynb
│ ├── 2_model_compare_3D.ipynb
│ ├── 3_ELBO_demo 3D_single_channel.ipynb
│ ├── 4_ELBO_ablation.ipynb
│ ├── 5_drug_embedding.ipynb
│ ├── 6_3D_cell_mage_printing.ipynb
│ └── 7_latent_space_visualization.ipynb
├── plot_error.ipynb
└── training_scripts
│ ├── ae2D.sh
│ ├── ae3D.sh
│ ├── bvae2D.sh
│ ├── bvae3D.sh
│ ├── cbvae2D.sh
│ ├── cbvae3D.sh
│ ├── cbvaegan2D_target.sh
│ └── cbvaegan2D_target2.sh
├── integrated_cell
├── __init__.py
├── arc_walk.py
├── bin
│ ├── __init__.py
│ ├── train_model.py
│ └── train_multi.py
├── corr_stats.py
├── data_providers
│ ├── DataProvider.py
│ ├── DataProvider3D_label_free.py
│ ├── DataProviderABC.py
│ ├── GraphDataProvider.py
│ ├── MaskedChannelsDataProvider.py
│ ├── MultiScaleDataProvider.py
│ ├── PatchDataProvider.py
│ ├── ProjectedDataProvider.py
│ ├── RefDataProvider.py
│ ├── RescaledIntensityDataProvider.py
│ ├── RescaledIntensityRefDataProvider.py
│ ├── RescaledIntensityTargetDataProvider.py
│ ├── TargetDataProvider.py
│ └── __init__.py
├── external
│ ├── __init__.py
│ └── pytorch_fid
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── fid_score.py
│ │ └── inception.py
├── imgToProjection.py
├── layers.py
├── losses.py
├── metrics
│ ├── __init__.py
│ ├── embeddings.py
│ ├── embeddings_reference.py
│ ├── embeddings_target.py
│ ├── features2D.py
│ └── precision_recall.py
├── model_utils.py
├── models
│ ├── __init__.py
│ ├── ae.py
│ ├── base_model.py
│ ├── bvae.py
│ ├── bvae_cond.py
│ ├── bvae_cond_apex.py
│ ├── cbvaae.py
│ ├── cbvae.py
│ ├── cbvae2.py
│ ├── cbvae2_gan2.py
│ ├── cbvae2_gan_apex.py
│ ├── cbvae2_gan_model_parallel.py
│ ├── cbvae2_gan_ref.py
│ ├── cbvae2_progressive.py
│ ├── cbvae2_ref.py
│ ├── cbvae2_target.py
│ ├── cbvae_apex.py
│ ├── cbvae_gan.py
│ ├── cbvae_multi.py
│ ├── cbvaegan_target.py
│ └── cbvaegan_target2.py
├── networks
│ ├── __init__.py
│ ├── ae2D_residual.py
│ ├── ae3D_residual.py
│ ├── common.py
│ ├── cvaegan2D_residual.py
│ ├── cvaegan3D_residual.py
│ ├── old
│ │ ├── aaegan.py
│ │ ├── aaegan3D.py
│ │ ├── aaegan3D_resid.py
│ │ ├── aaegan3Dv2.py
│ │ ├── aaegan3Dv3.py
│ │ ├── aaegan3Dv4-elu.py
│ │ ├── aaegan3Dv4-relu.py
│ │ ├── aaegan3Dv4.py
│ │ ├── aaegan3Dv5.py
│ │ ├── aaegan_256v2.py
│ │ ├── aaegan_256v3.py
│ │ ├── aaegan_compact.py
│ │ ├── aaegan_compact_lrelu.py
│ │ ├── aaegan_compact_lrelu2.py
│ │ ├── aaegan_v2.py
│ │ ├── aaegan_v3.py
│ │ ├── waaegan_256v2.py
│ │ ├── waaegan_v2.py
│ │ └── waaegan_v3_improved.py
│ ├── proto
│ │ └── __init__.py
│ ├── ref_target_autoencoder.py
│ ├── vaegan2D_cgan_target.py
│ ├── vaegan3D_cgan_target2.py
│ └── vaegan3D_cgan_tsm.py
├── param_search.py
├── simplelogger.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── resources
│ │ ├── img2D.png
│ │ ├── img3D.tiff
│ │ └── test_data.csv
│ ├── test_dataprovider.py
│ ├── test_dummy.py
│ └── test_kld.py
└── utils
│ ├── __init__.py
│ ├── build_control_images.py
│ ├── features.py
│ ├── image_transforms.py
│ ├── imgToProjection.py
│ ├── plots.py
│ ├── reference
│ └── __init__.py
│ ├── spectral_norm.py
│ ├── target
│ ├── __init__.py
│ ├── plots.py
│ └── utils.py
│ └── utils.py
├── requirements.txt
├── scripts
├── aaegan_short.sh
├── bvae.sh
├── bvae_short.sh
├── bvaegan_short.sh
├── multi_pred.sh
├── multi_pred_short.sh
└── test_build_control_images.py
└── setup.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | data/
2 | examples/
3 | .git/
4 | .*
5 | *.egg-info
6 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 |
3 | exclude =
4 | .git,
5 | __pycache__,
6 | docs/source/conf.py,
7 | build
8 | ignore =
9 | E501, # max line length
10 | W503, # breaking before binary operators
11 | E302 # multi-line empty space / comments bug
12 | builtins =
13 | _,
14 | NAME,
15 | MAINTAINER,
16 | MAINTAINER_EMAIL,
17 | DESCRIPTION,
18 | LONG_DESCRIPTION,
19 | URL,
20 | DOWNLOAD_URL,
21 | LICENSE,
22 | CLASSIFIERS,
23 | AUTHOR,
24 | AUTHOR_EMAIL,
25 | PLATFORMS,
26 | VERSION,
27 | PACKAGES,
28 | PACKAGE_DATA,
29 | REQUIRES
30 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/manuscript.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Manuscript
3 | about: Issue to address in manuscript
4 | title: ''
5 | labels: manuscript
6 | assignees: donovanr
7 |
8 | ---
9 |
10 | ## Issue summary
11 |
12 |
13 | ## Details
14 |
15 |
16 | ## TODO
17 | - [ ] code
18 | - [ ] manuscript text
19 | - [ ] figure
20 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v1
10 | - name: Build Docker Image
11 | run: |
12 | docker build -t aics/pytorch_integrated_cell -f Dockerfile .
13 | - name: Run tests
14 | run: |
15 | docker run aics/pytorch_integrated_cell pytest integrated_cell/tests/
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | ._.DS_Store
3 | **.pyc
4 | **.pyt
5 | .ipynb_checkpoints
6 | .vscode
7 | .bash_history
8 | train_modules/__pycache__/*
9 | *.egg-info/
10 | examples/3D_benchmark_tests/*/*
11 | results/
12 | data/
13 | integrated_cell/networks/old
14 | integrated_cell/networks/proto
15 | integrated_cell/models/proto
16 | examples/paper_plos_2020/images
17 | examples/paper_plos_2020/demo_imgs
18 | examples/paper_plos_2020/*.tiff
19 | examples/training_scripts/*/
20 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: stable
4 | hooks:
5 | - id: black
6 | language_version: python3.6
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v1.3.0
9 | hooks:
10 | - id: flake8
11 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Allen Institute Contribution Agreement
2 |
3 | This document describes the terms under which you may make "Contributions" — which may include without limitation, software additions, revisions, bug fixes, configuration changes, documentation, or any other materials — to any of the projects owned or managed by the Allen Institute. If you have questions about these terms, please contact us at terms@alleninstitute.org.
4 |
5 | You certify that:
6 |
7 | * Your Contributions are either:
8 | 1. Created in whole or in part by you and you have the right to submit them under the designated license (described below); or
9 | 2. Based upon previous work that, to the best of your knowledge, is covered under an appropriate open source license and you have the right under that license to submit that work with modifications, whether created in whole or in part by you, under the designated license; or
10 | 3. Provided directly to you by some other person who certified (1) or (2) and you have not modified them.
11 |
12 | * You are granting your Contributions to the Allen Institute under the terms of the [2-Clause BSD license](https://opensource.org/licenses/BSD-2-Clause) (the “designated license”).
13 |
14 | * You understand and agree that the Allen Institute projects and your Contributions are public and that a record of the Contributions (including all metadata and personal information you submit with them) is maintained indefinitely and may be redistributed consistent with the Allen Institute’s mission and the 2-Clause BSD license.
15 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # From pytorch compiled from source
2 | FROM nvcr.io/nvidia/pytorch:19.09-py3
3 |
4 | COPY ./ /root/projects/pytorch_integrated_cell
5 |
6 | WORKDIR /root/projects/pytorch_integrated_cell
7 |
8 | RUN pip install -e .
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Pytorch 3D Integrated Cell
2 | ===============================
3 |
4 | 
5 |
6 | Building a 3D Integrated Cell: https://www.biorxiv.org/content/early/2017/12/21/238378
7 |
8 | ### For the 2D manuscript and software:
9 |
10 | **Generative Modeling with Conditional Autoencoders: Building an Integrated Cell**
11 | Manuscript: https://arxiv.org/abs/1705.00092
12 | GitHub: https://github.com/AllenCellModeling/torch_integrated_cell
13 |
14 | ## Todo Items
15 | - GIT
16 | - [x] Remove old big files from git
17 |
18 | - Jupyter notebooks
19 | - [x] Check-in current state to git
20 | - [x] Make sure notebooks all run and can produce figures
21 | - [x] Annotate notebooks (notebook purpose)
22 | - [x] Clear outputs
23 | - [x] Check-in final state to git
24 |
25 | - Data
26 | - [x] Make sure current Quilt data works
27 | - [ ] Check-in manuscript data to Quilt
28 |
29 | - Code
30 | - [x] Check-in current state to git
31 | - [x] Clear unused code
32 | - [ ] Clean up and annotate main functions
33 | - [ ] Check-in final state to git
34 |
35 | - Demos/Docs
36 | - [x] Installation instructions
37 | - [x] Getting Started doc
38 | - [ ] Demos for different training methods
39 | - [ ] Update doc figures
40 |
41 | ## Support
42 |
43 | This code is in active development and is used within our organization. We are currently not supporting this code for external use and are simply releasing the code to the community AS IS. The community is welcome to submit issues, but you should not expect an active response.
44 |
45 |
46 | ## System requirements
47 |
48 | We recommend installation on Linux and an NVIDIA graphics card with 10+ GB of RAM (e.g., NVIDIA Titan X Pascal) with the latest drivers installed.
49 |
50 | ## Installation
51 |
52 | Installing on linux is recommended.
53 |
54 | - Install Python 3.6+/Docker/etc if necessary.
55 | - All commands listed below assume the bash shell.
56 |
57 | ### **Installation method (A): In Existing Workspace**
58 | (Optional) Make a fresh conda repo. (This will mess up some libraries if inside a some of Nvidia's Docker images)
59 | ```shell
60 | conda create --name pytorch_integrated_cell python=3.7
61 | conda activate pytorch_integrated_cell
62 | ```
63 | Clone and install the repo
64 | ```shell
65 | git clone https://github.com/AllenCellModeling/pytorch_integrated_cell
66 | cd pytorch_integrated_cell
67 | pip install -e .
68 | ```
69 |
70 | If you want to do some development, install the pre-commit hooks:
71 | ```shell
72 | pip install pre-commit
73 | pre-commit install
74 | ```
75 |
76 | (Optional) Clone and install Nvidia Apex for half-precision computation
77 | Please follow the instructions on the Nvidia Apex github page:
78 | https://github.com/NVIDIA/apex
79 |
80 | ### **Installation method (B): Docker**
81 | We build on Nvidia Docker images. In our tests this runs very slightly slower than (A) although your mileage may vary. This comes with Nvidia Apex.
82 | ```shell
83 | git clone https://github.com/AllenCellModeling/pytorch_integrated_cell
84 | cd pytorch_integrated_cell
85 | docker build -t aics/pytorch_integrated_cell -f Dockerfile .
86 | ```
87 |
88 | ## Data
89 | Data can be downloaded via Quilt T3. The following script will dump the complete 2D and 3D dataset into `./data/`. This may take a long time depending on your connection.
90 | ```shell
91 | python download_data.py
92 | ```
93 | The dataset is about 250gb.
94 |
95 | ## Training Models
96 | Models are trained by via command line argument. A typical training call looks something like:
97 | ```shell
98 | ic_train_model \
99 | --gpu_ids 0 \
100 | --model_type ae \
101 | --save_parent ./ \
102 | --lr_enc 2E-4 --lr_dec 2E-4 \
103 | --data_save_path ./data.pyt \
104 | --imdir ./data/ \
105 | --crit_recon integrated_cell.losses.BatchMSELoss \
106 | --kwargs_crit_recon '{}' \
107 | --network_name vaegan2D_cgan \
108 | --kwargs_enc '{"n_classes": 24, "ch_ref": [0, 2], "ch_target": [1], "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \
109 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
110 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "ch_ref": [0, 2], "ch_target": [1], "n_latent_dim": 512, "n_ref": 512, "output_padding": [1,1], "activation_last": "softplus"}' \
111 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
112 | --kwargs_model '{"kld_reduction": "mean_batch", "objective": "H", "beta": 1E-2}' \
113 | --train_module cbvae2 \
114 | --dataProvider DataProvider \
115 | --kwargs_dp '{"crop_to": [160,96], "return2D": 1, "check_files": 0, "make_controls": 0, "csv_name": "controls/data_plus_controls.csv", "normalize_intensity": "avg_intensity"}' \
116 | --saveStateIter 1 --saveProgressIter 1 \
117 | --channels_pt1 0 1 2 \
118 | --batch_size 64 \
119 | --nepochs 300 \
120 | ```
121 |
122 | This automatically creates a timestamped directory in the current directory `./`.
123 |
124 | For details on how to modify training options, please see [the training documentation](doc/training.md)
125 |
126 | ## Loading Modes
127 | Models are loaded via python API. A typical loading call looks something like:
128 | ```python
129 | from integrated_cell import utils
130 |
131 | model_dir = '/my_parent_directory/model_type/date_time/'
132 | parent_dir = '/my_parent_directory/'
133 |
134 | networks, data_provider, args = utils.load_network_from_dir(model_dir, parent_dir)
135 |
136 | target_enc = networks['enc']
137 | target_dec = networks['dec']
138 |
139 | ```
140 |
141 | `networks` is a dictionary of the model subcomponents.
142 | `data_provider` is the an object that contains train, validate, and test data.
143 | `args` is a dictionary of the list of aguments passed to the model
144 |
145 | For details on how to modify training options, please see [the loading documentation](doc/loading.md)
146 |
147 |
148 |
149 | ## Project website
150 | Example outputs of this model can be viewed at http://www.allencell.org.
151 |
152 | ## Examples ##
153 | Examples of how to run the code can be found in the [3D benchmarks section](doc/benchmarks.md).
154 |
155 | ## Citation
156 | If you find this code useful in your research, please consider citing the following paper:
157 |
158 | @article {Johnson238378,
159 | author = {Johnson, Gregory R. and Donovan-Maiye, Rory M. and Maleckar, Mary M.},
160 | title = {Building a 3D Integrated Cell},
161 | year = {2017},
162 | doi = {10.1101/238378},
163 | publisher = {Cold Spring Harbor Laboratory},
164 | URL = {https://www.biorxiv.org/content/early/2017/12/21/238378},
165 | eprint = {https://www.biorxiv.org/content/early/2017/12/21/238378.full.pdf},
166 | journal = {bioRxiv}
167 | }
168 |
169 | ## Contact
170 | Gregory Johnson
171 | E-mail: gregj@alleninstitute.org
172 |
173 | ## License
174 | This program is free software: you can redistribute it and/or modify
175 | it under the terms of the GNU General Public License as published by
176 | the Free Software Foundation, either version 3 of the License, or
177 | (at your option) any later version.
178 |
179 | This program is distributed in the hope that it will be useful,
180 | but WITHOUT ANY WARRANTY; without even the implied warranty of
181 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
182 | GNU General Public License for more details.
183 |
184 | You should have received a copy of the GNU General Public License
185 | along with this program. If not, see .
186 |
--------------------------------------------------------------------------------
/doc/benchmarks.md:
--------------------------------------------------------------------------------
1 | Running 3D Benchmarks
2 | ===============================
3 | ## About
4 | In the directory
5 | `examples/3D_benchmark_tests` there are some scripts to run benchmarks on your system. You can configure which type of model you want to run, which gpus you want to run on, etc.
6 |
7 | ## How to Run
8 |
9 | ### System requirements
10 |
11 | Benchmarks are set to run on a box with 8 GPUs. Instructions to reconfigure your code are at the bottom of this page.
12 |
13 | ### Installation
14 | Install the Integrated Cell code via _both_ method (A) (with Nvidia Apex) and method (B). Download the data too.
15 |
16 | ### Running scripts
17 | From the root directory of this repo, cd into the 3D_benchmark_tests directory, and run `run_benchmark_tests.py`
18 |
19 | ```shell
20 | cd examples/3D_benchmark_tests/
21 | python run_benchmark_tests.py
22 | ```
23 |
24 | This will run the model configuration in `run_3D.sh` on different GPUs, with and without Apex, and with and without Docker. Plots showing the iteration time vs batch size will appear in this directory, e.g.
25 | 
26 |
27 | This figure shows the relationship between the largest models we can run with and without Nvidia Apex.
28 |
29 | ### Changing the configuration
30 | The primary configuration section is in this block in the `run_benchmark_tests.py` file:
31 |
32 | ```python
33 | experiment_dict = {}
34 | experiment_dict["function_call"] = ["bash run_docker.sh", "bash run_3D.sh"]
35 | experiment_dict["trainer_type"] = ["cbvae_apex", "cbvae"]
36 | experiment_dict["gpu_id"] = [
37 | [2],
38 | [2, 3],
39 | [3, 4],
40 | [0, 1, 2, 3],
41 | [2, 3, 4, 5],
42 | [0, 1, 2, 3, 4, 5, 6, 7],
43 | ]
44 | experiment_dict["batch_size"] = [8, 16, 32, 64, 128, 256]
45 | ```
46 |
47 | If you want to change the number or subsets of GPUs to try, change the `"gpu_id"` list, batch size with the `"batch_size"` list, etc.
--------------------------------------------------------------------------------
/doc/getting_started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | The Integrated Cell is a collection of tools for training deep generative models of cell organization.
4 | This document serves as a basic guide for how the code is organized and how to train and interface with trained networks.
5 |
6 | **This is code is NOT for production. It is for research. It still needs a lot of work.**
7 |
8 | ## Prerequesites
9 | This assumes you have a medium-level handle of Python as well as the PyTorch framework.
10 |
11 | You will need a CUDA-capable GPU as well. If it is installed with up-to-date drivers, insta
12 |
13 | ## Code and Model Organization
14 | The major components to train a model are as follows:
15 | ### 1) A **data provider**
16 | (`integrated_cell/data_providers`)
17 |
18 | These are objects that provide data via a `get_sample` method. The general means by which to interface with them is via this command:
19 |
20 | ```x, classes, ref = data_provider.get_sample()```
21 |
22 | where `x` is a (`batch_size` by `channel_dimension` by `spatial_dimensions`) batch of images, `classes` is a list of integer values corresponding to the class label of `x` (usually corresponding to the so-called "structure channel"), and `ref` is reserved for some reference information, this is sometimes the "reference channels", somtimes it is empty.
23 |
24 | The "base" data provider is in `integrated_cell/data_providers/DataProvider.py`. There are also child data providers that do different types of augmentation, etc.
25 |
26 | ### 2) A **network** to optimize
27 | (`integrated_cell/networks`)
28 |
29 | These are the networks that are being optimized. Each network type (e.g. `integrated_cell/networks/vaegan3D_cgan_target2.py`) generally contain an encoder and decoder object (`Enc` and `Dec`) and depending on the type, `Enc` may produce multiple outputs (in the case of a Variational Autoencoder).
30 | Other objects may be present in the file as well; a decoder discriminator (`DecD`) or encoder discriminator (`EncD`) may be present for advarsarial models.
31 |
32 | The constructors of these sub-networks may have parameters that define the shape or attributes of the network. These are typical Pytorch `torch.nn.Module` objects.
33 |
34 | ### 3) A **loss** function
35 | (`integrated_cell/losses.py` or [loss functions from Pytorch](https://pytorch.org/docs/stable/nn.html#loss-functions))
36 |
37 | These are custom losses used for model training.
38 | There may be multiple loss objects for a **network** that may be passed into a **model**.
39 | Typically we use pixel-wise mean squared error, binary cross entropy, or L1-loss for images.
40 |
41 | ### 2) A **model**-type trainer
42 | (`integrated_cell/models`)
43 |
44 | These are objects that train specific types of models.
45 | They intake all of the above components and perform backprop steps until some number of iterations or epochs have been completed. They also control saving model state and doing checks on validation data, etc.
46 |
47 | ## Model training
48 | Kicking off model training is usually performed via the command-line with `ic_train_model`.
49 | This command allows you to pass in ALL of the parameters necessary to define a training session.
50 |
51 | A command looks like:
52 | ```shell
53 | ic_train_model \
54 | --gpu_ids 0 \
55 | --model_type ae \
56 | --save_parent ./ \
57 | --lr_enc 2E-4 --lr_dec 2E-4 \
58 | --data_save_path ./data_target.pyt \
59 | --crit_recon integrated_cell.losses.BatchMSELoss \
60 | --kwargs_crit_recon '{}' \
61 | --network_name vaegan3D_cgan_target2 \
62 | --kwargs_enc '{"n_classes": 24, "n_latent_dim": 512}' \
63 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
64 | --kwargs_dec '{"n_classes": 24, "n_latent_dim": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \
65 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
66 | --train_module cbvae2_target \
67 | --kwargs_model '{"beta_min": 1E-6, "beta_start": -1, "beta_step": 3E-5, "objective": "A", "kld_reduction": "batch"}' \
68 | --imdir /raid/shared/ipp/scp_19_04_10/ \
69 | --dataProvider TargetDataProvider \
70 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "make_controls": 0, "csv_name": "controls/data_plus_controls.csv"}' \
71 | --saveStateIter 1 --saveProgressIter 1 \
72 | --channels_pt1 0 1 2 \
73 | --batch_size 32 \
74 | --nepochs 300 \
75 | ```
76 |
77 | You can see that there are a lot of settions that may be configured.
78 | This command is usually run from a `.sh` file, so one can pass in arguments via command line.
79 |
80 | It should be noted that the parameter `--save_parent` is the directory in which the model save directory is constructed.
81 | Running the above command will create a date-timestamped directory for that particular run.
82 | If you would like to specify the explicit save directory, use the `--save_dir` parameter instead.
83 |
84 | ## Viewing model status
85 | Viewing model status is usually performed with the jupyter notebook `examples/plot_error.ipynb`.
86 | It is meant to be interactive.
87 | There is a variable `dirs_to_search` that allows one to specify the specifc model directories view.
88 |
89 | For each model directory, the current checkpoint results are displayed. These generally include a print-off of images for train and validation, as well as loss curves on the train set and latent space embeddings.
90 |
91 | ## Loading and using a trained model
92 |
93 | To programmatically access a model, the function `integrated_cell/utils/utils.py::load_network_from_dir` may be used. Typical usage is as follows:
94 |
95 | ```python
96 | from integrated_cell import utils
97 |
98 | model_dir = "my/model/save/dir/"
99 | networks, data_provider, args = utils.load_network_from_dir(model_dir)
100 |
101 | encoder = networks['Enc']
102 | decoder = networks['Dec']
103 | ```
104 |
105 | If one moves the model save directory, or one uses a relative path in the `ic_train_model` call for the `--data_save_path`, relative components may be overwritten with the parameter `parent_dir`.
106 | Also if you would like to load a specific checkpoint, one can use the parameter `suffix`. An example is as follows:
107 |
108 | ```python
109 | model_dir = "/my/model/save/dir/"
110 | parent_dir = "/my/model"
111 | suffix = "_93300" # this is the iteration at which the model was saved
112 |
113 | networks, data_provider, args = utils.load_network_from_dir(model_dir, parent_dir=parent_dir, suffix=suffix)
114 | ```
115 |
116 | ## More:
117 | [Training different types of models](./training_examples.md)
--------------------------------------------------------------------------------
/doc/images/model_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/doc/images/model_arch.png
--------------------------------------------------------------------------------
/doc/images/stats_apex_vs_not_apex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/doc/images/stats_apex_vs_not_apex.png
--------------------------------------------------------------------------------
/doc/training_examples.md:
--------------------------------------------------------------------------------
1 | # Training Examples
2 |
3 | The integrated cell code base supports training of many types of models.
4 | We use a simple taxonomy to identify how to load different network components:
5 |
6 | Autoencoders (`ae`) contain an encoder and a decoder. Examples of this are:
7 | - vanilla autoencoders
8 | - [variational autoencoders](https://arxiv.org/abs/1312.6114)
9 | - [conditional variational autoencoders](https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models)
10 |
11 | Advarsarial Autoencoders (`aae`) contain an advarsary on the encoder in addition to the decoder.
12 | - [advarsarial autoencoders](https://arxiv.org/abs/1511.05644)
13 |
14 | Autoencoder GANs (`aegan`) contain a GAN attached to the decoder.
15 | - [Autoencoder GANs](https://arxiv.org/abs/1512.09300)
16 |
17 | Advarsarial Autoencoder GANs (`aaegan`) contain a advarsary attached to the encoder and a different advarsary attached to the decoder.
18 |
19 | Examples of how to configure training code for different models is as follows:
20 |
21 |
22 |
--------------------------------------------------------------------------------
/doc/training_notes.md:
--------------------------------------------------------------------------------
1 | # Training Tips and Thoughts
2 |
3 | ### Compute
4 |
5 | Generally need a machine with a GPU
6 | `dgx-aics-dcp-001.corp.alleninstitute.org` is a good place to start
7 |
8 | ### General
9 | - For short training sessions (for debugging) utilize the cli command `--ndat`. This overwrites the number of data in the `data_provider` object for shorter epoch times. When used, it's generally set to 2x batch size.
10 |
11 | - Launch jobs in a screen session in case you lose connection to the host machine
12 |
13 | - From `integrated_cell.models.base_model.Model`, an "epoch" is defined as `np.ceil(len(data_provider) / data_provider.batch_size)` training-step iterations. `data_provider.get_sample()` controls how samples are returned (usually sampling with replacement).
14 |
15 | - To specify a specific directory to save in, use the CLI command `--save_dir`. To not have to specify a new dir every run, use the CLI command `--save_parent`, and the code will dump everything in a date-timestamped directory.
16 |
17 | ### Recommended Usage
18 | - `integrated_cell` installed in a fresh Conda env.
19 | - Run training from a screen session that is running on the host machine.
20 | - Interrogate training models with `examples/plot_error.ipynb`.
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/download_data.py:
--------------------------------------------------------------------------------
1 | import quilt3
2 |
3 | pkg = quilt3.Package.browse(
4 | "aics/pipeline_integrated_single_cell",
5 | registry="s3://allencell",
6 | top_hash="7fd488f05ec41968607c7263cb13b3e70812972a24e832ef6f72195bdd35f1b2",
7 | )
8 |
9 | pkg.fetch("./data/")
10 |
--------------------------------------------------------------------------------
/examples/3D_benchmark_tests/run_3D.sh:
--------------------------------------------------------------------------------
1 | gpu_ids=$1
2 | save_dir=$2
3 | trainer_type=$3
4 | batch_size=$4
5 | ndat=$5
6 |
7 | image_dir=../../data/
8 | csv_name=metadata.csv
9 |
10 |
11 | ic_train_model \
12 | --gpu_ids $gpu_ids \
13 | --model_type ae \
14 | --save_dir $save_dir \
15 | --lr_enc 2E-4 --lr_dec 2E-4 \
16 | --data_save_path $save_dir/data.pyt \
17 | --crit_recon integrated_cell.losses.BatchMSELoss \
18 | --kwargs_crit_recon '{}' \
19 | --network_name vaegan3D_cgan_p \
20 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \
21 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
22 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \
23 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
24 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1E9, "save_progress_iter": 1E9}' \
25 | --train_module $trainer_type \
26 | --imdir $image_dir \
27 | --dataProvider DataProvider \
28 | --kwargs_dp '{"crop_to": [160, 96, 64], "check_files": 0, "csv_name": "'$csv_name'"}' \
29 | --saveStateIter 1 --saveProgressIter 1 \
30 | --channels_pt1 0 1 2 \
31 | --batch_size $batch_size \
32 | --nepochs 1 \
33 | --ndat $ndat
34 |
--------------------------------------------------------------------------------
/examples/3D_benchmark_tests/run_3D_big.sh:
--------------------------------------------------------------------------------
1 | gpu_ids=$1
2 | save_dir=$2
3 | opt_level=$3
4 | batch_size=16
5 |
6 | # image_dir=../../data/
7 | # csv_name=metadata.csv
8 |
9 | image_dir=/raid/shared/ipp/scp_19_04_10/
10 | csv_name=controls/data_plus_controls.csv
11 |
12 | ic_train_model \
13 | --gpu_ids $gpu_ids \
14 | --model_type ae \
15 | --save_dir $save_dir \
16 | --lr_enc 2E-4 --lr_dec 2E-4 \
17 | --data_save_path $save_dir/data.pyt \
18 | --crit_recon torch.nn.MSELoss \
19 | --kwargs_crit_recon '{"reduction": "sum"}' \
20 | --network_name vaegan3D_cgan_p \
21 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512}' \
22 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
23 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus"}' \
24 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
25 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1, "save_progress_iter": 1, "opt_level": "'$opt_level'"}' \
26 | --train_module cbvae_apex \
27 | --imdir $image_dir \
28 | --dataProvider DataProvider \
29 | --kwargs_dp '{"crop_to": [160, 96, 64], "check_files": 0, "csv_name": "'$csv_name'"}' \
30 | --saveStateIter 1 --saveProgressIter 1 \
31 | --channels_pt1 0 1 2 \
32 | --batch_size $batch_size \
33 | --nepochs 3 \
34 |
--------------------------------------------------------------------------------
/examples/3D_benchmark_tests/run_3D_small.sh:
--------------------------------------------------------------------------------
1 | gpu_ids=$1
2 | save_dir=$2
3 | opt_level=$3
4 | batch_size=16
5 |
6 | # image_dir=../../data/
7 | # csv_name=metadata.csv
8 |
9 | image_dir=/raid/shared/ipp/scp_19_04_10/
10 | csv_name=controls/data_plus_controls.csv
11 |
12 | ic_train_model \
13 | --gpu_ids $gpu_ids \
14 | --model_type ae \
15 | --save_dir $save_dir \
16 | --lr_enc 2E-4 --lr_dec 2E-4 \
17 | --data_save_path $save_dir/data.pyt \
18 | --crit_recon torch.nn.MSELoss \
19 | --kwargs_crit_recon '{"reduction": "sum"}' \
20 | --network_name vaegan3D_cgan_p \
21 | --kwargs_enc '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "conv_channels_list": [32, 64, 128]}' \
22 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
23 | --kwargs_dec '{"n_classes": 24, "n_channels": 2, "n_channels_target": 1, "n_latent_dim": 512, "n_ref": 512, "proj_z": 0, "proj_z_ref_to_target": 0, "activation_last": "softplus", "conv_channels_list": [128, 64, 32]}' \
24 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
25 | --kwargs_model '{"beta": 1, "objective": "H", "save_state_iter": 1, "save_progress_iter": 1, "opt_level": "'$opt_level'"}' \
26 | --train_module cbvae_apex \
27 | --imdir $image_dir \
28 | --dataProvider DataProvider \
29 | --kwargs_dp '{"rescale_to": 0.25, "crop_to": [40, 24, 16], "check_files": 0, "csv_name": "'$csv_name'"}' \
30 | --saveStateIter 1 --saveProgressIter 1 \
31 | --channels_pt1 0 1 2 \
32 | --batch_size $batch_size \
33 | --nepochs 5 \
34 |
--------------------------------------------------------------------------------
/examples/3D_benchmark_tests/run_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker run --runtime=nvidia \
4 | -v $PWD/../../:$PWD/../../ \
5 | -v $PWD/../../:/root/projects/pytorch_integrated_cell \
6 | -v /raid/shared:/raid/shared \
7 | aics/pytorch_integrated_cell \
8 | /bin/bash -c " cd $PWD; bash run_3D.sh '$1' $2 $3 $4 $5"
9 |
--------------------------------------------------------------------------------
/examples/3D_benchmark_tests/run_docker_apex.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker run --detach --runtime=nvidia \
4 | -v $PWD/../../:$PWD/../../ \
5 | -v $PWD/../../:/root/projects/pytorch_integrated_cell \
6 | -v /raid/shared:/raid/shared \
7 | aics/pytorch_integrated_cell \
8 | /bin/bash -c " cd $PWD; bash $1 '$2' $3 '$4'"
9 |
--------------------------------------------------------------------------------
/examples/loading_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Loading Models Demo\n",
8 | "\n",
9 | "This is a demo for how to load and use serialized models from a directory. \n",
10 | "\n",
11 | "Run a model in the `./training_scripts` directory, then use this demo. Here we assume you ran the `ae2D.sh` "
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import os\n",
21 | "from integrated_cell import utils\n",
22 | "\n",
23 | "#If you need to, use this to change the GPU with which to load a model\n",
24 | "gpu_ids = [0]\n",
25 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(ID) for ID in gpu_ids])\n",
26 | "\n",
27 | "#Specify the model directory to use\n",
28 | "model_dir = \"./training_scripts/ae2D/\"\n",
29 | "\n",
30 | "#Load the model. These variables shouldn't change as a function of model type\n",
31 | "networks, data_provider, args = utils.load_network_from_dir(model_dir)\n",
32 | "\n",
33 | "encoder = networks['enc'].cuda()\n",
34 | "decoder = networks['dec'].cuda()\n",
35 | "\n",
36 | "encoder.train(False)\n",
37 | "decoder.train(False)\n",
38 | "\n",
39 | "#If you're not clear on how to use the network, then the best thing to do is to training model and check out how it is used:\n",
40 | "f\"integrated_cell/models/{args['train_module']}.py\""
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "from integrated_cell.utils.plots import imshow\n",
50 | "import integrated_cell.utils as utils\n",
51 | "import torch \n",
52 | "import matplotlib\n",
53 | "%matplotlib inline\n",
54 | "\n",
55 | "x, class_labels, ref = data_provider.get_sample('test', [0])\n",
56 | "\n",
57 | "x = x.cuda()\n",
58 | "class_labels = class_labels.cuda().long()\n",
59 | "class_labels_onehot = utils.index_to_onehot(class_labels, data_provider.get_n_classes())\n",
60 | "ref = ref.cuda()\n",
61 | "\n",
62 | "# Typical autoencoder\n",
63 | "if args['train_module'] == 'ae':\n",
64 | " \n",
65 | " # z is the latent representation of the image x\n",
66 | " z = encoder(x)\n",
67 | " xHat = decoder(z)\n",
68 | "\n",
69 | "# Beta variational autoencoder\n",
70 | "elif args['train_module'] == 'bvae':\n",
71 | "\n",
72 | " # z is the latent representation of the image x\n",
73 | " z_mu, z_sigma = encoder(x)\n",
74 | " \n",
75 | " # here we can either the \"average\" by passing z_mu into the decoder or\n",
76 | " # we can resample from N(z_mu, z_sigma)\n",
77 | " z_sampled = utils.reparameterize(z_mu, z_sigma)\n",
78 | " xHat = decoder(z_sampled)\n",
79 | " \n",
80 | "# Conditional beta variational autoencoders, sometimes with advarsarial loss\n",
81 | "elif args['train_module'] in ['cbvae2_target', 'cbvaegan_target2', 'cbvaegan_target']:\n",
82 | " # These are the conditional models: We provide the class labels and reference structures to both the encoder and decoder\n",
83 | " # and the model (hopefully) doesn't \n",
84 | " \n",
85 | " # z is the latent representation of the image x\n",
86 | " z_mu, z_sigma = encoder(x, ref, class_labels_onehot)\n",
87 | " \n",
88 | " # here we can either the \"average\" by passing z_mu into the decoder or\n",
89 | " # we can resample from N(z_mu, z_sigma)\n",
90 | " z_sampled = utils.reparameterize(z_mu, z_sigma)\n",
91 | " \n",
92 | " xHat = decoder(z_sampled, ref, class_labels_onehot)\n",
93 | " \n",
94 | "im_out = torch.cat([x, xHat], axis = 3)\n",
95 | "imshow(im_out)\n"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "\n"
105 | ]
106 | }
107 | ],
108 | "metadata": {
109 | "kernelspec": {
110 | "display_name": "Python [conda env:pytorch_integrated_cell]",
111 | "language": "python",
112 | "name": "conda-env-pytorch_integrated_cell-py"
113 | },
114 | "language_info": {
115 | "codemirror_mode": {
116 | "name": "ipython",
117 | "version": 3
118 | },
119 | "file_extension": ".py",
120 | "mimetype": "text/x-python",
121 | "name": "python",
122 | "nbconvert_exporter": "python",
123 | "pygments_lexer": "ipython3",
124 | "version": "3.6.8"
125 | }
126 | },
127 | "nbformat": 4,
128 | "nbformat_minor": 4
129 | }
130 |
--------------------------------------------------------------------------------
/examples/plot_error.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# This is the main notebook for interrogating the results of models as they train"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "## import integrated_cell.SimpleLogger as SimpleLogger\n",
17 | "import torch\n",
18 | "from IPython.core.display import display\n",
19 | "import pickle\n",
20 | "import os\n",
21 | "import PIL.Image\n",
22 | "import numpy as np\n",
23 | "import scipy.misc as misc\n",
24 | "import glob\n",
25 | "import json\n",
26 | "from natsort import natsorted\n",
27 | "\n",
28 | "%matplotlib inline\n",
29 | "\n",
30 | "class RenamingUnpickler(pickle.Unpickler):\n",
31 | " def find_class(self, module, name):\n",
32 | " if module == 'integrated_cell.SimpleLogger':\n",
33 | " module = 'integrated_cell.simplelogger'\n",
34 | " return super().find_class(module, name)\n",
35 | "\n",
36 | "\n",
37 | "import warnings\n",
38 | "warnings.filterwarnings(\"ignore\")\n",
39 | "\n",
40 | "dirs_to_search = [\n",
41 | " \n",
42 | "# '/allen/aics/modeling/gregj/results/integrated_cell/test_cbvae_3D_avg_inten/*',\n",
43 | "# '/allen/aics/modeling/gregj/results/integrated_cell/test_cbvae_beta_ref/*',\n",
44 | " './training_scripts/*/'\n",
45 | " \n",
46 | " ]\n",
47 | "\n",
48 | "model_dirs = list()\n",
49 | "\n",
50 | "for my_dir in dirs_to_search:\n",
51 | " model_dirs += glob.glob(my_dir)\n",
52 | "\n",
53 | "model_dirs = natsorted(model_dirs)[::-1] \n",
54 | " \n",
55 | "\n",
56 | "def show_dir(model_dir): \n",
57 | " logger_file = '{0}/logger_tmp.pkl'.format(model_dir)\n",
58 | " \n",
59 | " if not os.path.exists(logger_file):\n",
60 | " print('Could not find logger at ' + logger_file)\n",
61 | " return\n",
62 | "\n",
63 | " print(model_dir)\n",
64 | " \n",
65 | " try:\n",
66 | " logger = RenamingUnpickler(open( logger_file, \"rb\" )).load()\n",
67 | "# logger = pickle.load( open( logger_file, \"rb\" ) )\n",
68 | " epoch = max(logger.log['epoch'])\n",
69 | " except:\n",
70 | " pass\n",
71 | " \n",
72 | " try:\n",
73 | " opt = pickle.load( open( '{0}/opt.pkl'.format(model_dir), \"rb\" ))\n",
74 | " except:\n",
75 | " opt = json.load(open( '{0}/../args.json'.format(model_dir), \"rb\" ))\n",
76 | "\n",
77 | " print(opt)\n",
78 | " print('Epoch: ' + str(epoch))\n",
79 | "\n",
80 | " im_progress_path = '{0}/progress_{1}.png'.format(model_dir, int(epoch))\n",
81 | " im_progress = misc.imread(im_progress_path)\n",
82 | "\n",
83 | " im_history_path = '{0}/history.png'.format(model_dir)\n",
84 | " im_history = misc.imread(im_history_path) \n",
85 | "\n",
86 | " im_history_short_path = '{0}/history_short.png'.format(model_dir)\n",
87 | " im_history_short = misc.imread(im_history_short_path) \n",
88 | "\n",
89 | " im_embedding_path = '{0}/embedding.png'.format(model_dir)\n",
90 | " im_embedding = misc.imread(im_embedding_path) \n",
91 | "\n",
92 | " display(PIL.Image.fromarray(im_progress))\n",
93 | " display(PIL.Image.fromarray(np.concatenate((im_history, im_history_short, im_embedding), 1)))\n",
94 | " print(' ')\n",
95 | " \n",
96 | "for model_dir in model_dirs:\n",
97 | " \n",
98 | " sub_dirs = glob.glob(os.path.join(model_dir, 'ref_model/'))\n",
99 | " \n",
100 | " for sub_dir in sub_dirs:\n",
101 | " try:\n",
102 | " show_dir(sub_dir)\n",
103 | " except:\n",
104 | " print('could not load ' + model_dir)\n",
105 | " "
106 | ]
107 | }
108 | ],
109 | "metadata": {
110 | "kernelspec": {
111 | "display_name": "Python [conda env:pytorch_integrated_cell]",
112 | "language": "python",
113 | "name": "conda-env-pytorch_integrated_cell-py"
114 | },
115 | "language_info": {
116 | "codemirror_mode": {
117 | "name": "ipython",
118 | "version": 3
119 | },
120 | "file_extension": ".py",
121 | "mimetype": "text/x-python",
122 | "name": "python",
123 | "nbconvert_exporter": "python",
124 | "pygments_lexer": "ipython3",
125 | "version": "3.6.8"
126 | }
127 | },
128 | "nbformat": 4,
129 | "nbformat_minor": 4
130 | }
131 |
--------------------------------------------------------------------------------
/examples/training_scripts/ae2D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/ae2D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/ae2D/data_ae.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name ae2D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch": 3}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch": 3}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{}' \
15 | --train_module ae \
16 | --imdir $PWD/../../data/ \
17 | --dataProvider DataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/ae3D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/ae3D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/ae3D/data_ae.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name ae3D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch": 3}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch": 3}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{}' \
15 | --train_module ae \
16 | --imdir $PWD/../../data/ \
17 | --dataProvider DataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/bvae2D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/bvae2D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/bvae2D/data_ae.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name cvaegan2D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{"beta": 1}' \
15 | --train_module bvae \
16 | --imdir $PWD/../../data/ \
17 | --dataProvider DataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/bvae3D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/bvae3D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/bvae3D/data_ae.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name cvaegan3D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 3, "n_ch_ref": 0, "n_classes": 0}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{"beta": 1}' \
15 | --train_module bvae \
16 | --imdir $PWD/../../data/ \
17 | --dataProvider DataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/cbvae2D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/cbvae2D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/cbvae2D/data.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name cvaegan2D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{"beta": 1}' \
15 | --train_module cbvae2_target \
16 | --imdir ../../data/ \
17 | --dataProvider TargetDataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/cbvae3D.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type ae \
4 | --save_dir $PWD/cbvae3D \
5 | --lr_enc 2E-4 --lr_dec 2E-4 \
6 | --data_save_path $PWD/cbvae3D/data.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --network_name cvaegan3D_residual \
10 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
11 | --kwargs_enc_optim '{"betas": [0.9, 0.999]}' \
12 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
13 | --kwargs_dec_optim '{"betas": [0.9, 0.999]}' \
14 | --kwargs_model '{"beta": 1}' \
15 | --train_module cbvae2_target \
16 | --imdir $PWD/../../data/ \
17 | --dataProvider TargetDataProvider \
18 | --kwargs_dp '{"crop_to": [160, 96, 64], "return2D": 0, "check_files": 0, "csv_name": "metadata.csv"}' \
19 | --saveStateIter 1 --saveProgressIter 1 \
20 | --channels_pt1 0 1 2 \
21 | --batch_size 32 \
22 | --nepochs 300 \
23 |
--------------------------------------------------------------------------------
/examples/training_scripts/cbvaegan2D_target.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type aegan \
4 | --save_dir $PWD/cbvaegan2D_target \
5 | --lr_enc 2E-4 --lr_dec 2E-4 --lr_decD 5E-4 \
6 | --data_save_path $PWD/cbvaegan2D_target/data.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --crit_decD torch.nn.BCEWithLogitsLoss \
10 | --kwargs_crit_decD '{"reduction": "sum"}' \
11 | --network_name cvaegan2D_residual \
12 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
13 | --kwargs_enc_optim '{"betas": [0.5, 0.9]}' \
14 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
15 | --kwargs_dec_optim '{"betas": [0.5, 0.9]}' \
16 | --kwargs_decD '{"n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
17 | --kwargs_decD_optim '{"betas": [0.5, 0.9]}' \
18 | --kwargs_model '{"beta": 1, "lambda_decD_loss": 1}' \
19 | --train_module cbvaegan_target \
20 | --imdir $PWD/../../data/ \
21 | --dataProvider TargetDataProvider \
22 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \
23 | --saveStateIter 1 --saveProgressIter 1 \
24 | --channels_pt1 0 1 2 \
25 | --batch_size 64 \
26 | --nepochs 300 \
27 |
--------------------------------------------------------------------------------
/examples/training_scripts/cbvaegan2D_target2.sh:
--------------------------------------------------------------------------------
1 | ic_train_model \
2 | --gpu_ids $1 \
3 | --model_type aegan \
4 | --save_dir $PWD/cbvaegan2D_target2 \
5 | --lr_enc 2E-4 --lr_dec 2E-4 --lr_decD 5E-4 \
6 | --data_save_path $PWD/cbvaegan2D_target2/data.pyt \
7 | --crit_recon integrated_cell.losses.BatchMSELoss \
8 | --kwargs_crit_recon '{}' \
9 | --crit_decD torch.nn.CrossEntropyLoss \
10 | --kwargs_crit_decD '{"reduction": "sum"}' \
11 | --network_name cvaegan2D_residual \
12 | --kwargs_enc '{"n_latent_dim": 512, "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
13 | --kwargs_enc_optim '{"betas": [0.5, 0.9]}' \
14 | --kwargs_dec '{"n_latent_dim": 512, "activation_last": "softplus", "n_ch_target": 1, "n_ch_ref": 2, "n_classes": 24}' \
15 | --kwargs_dec_optim '{"betas": [0.5, 0.9]}' \
16 | --kwargs_decD '{"n_ch_target": 1, "n_ch_ref": 2, "n_classes": 0, "n_classes_out": 25}' \
17 | --kwargs_decD_optim '{"betas": [0.5, 0.9]}' \
18 | --kwargs_model '{"beta": 1, "lambda_decD_loss": 1}' \
19 | --train_module cbvaegan_target2 \
20 | --imdir $PWD/../../data/ \
21 | --dataProvider TargetDataProvider \
22 | --kwargs_dp '{"crop_to": [160, 96], "return2D": 1, "check_files": 0, "csv_name": "metadata.csv"}' \
23 | --saveStateIter 1 --saveProgressIter 1 \
24 | --channels_pt1 0 1 2 \
25 | --batch_size 64 \
26 | --nepochs 300 \
27 |
--------------------------------------------------------------------------------
/integrated_cell/__init__.py:
--------------------------------------------------------------------------------
1 | from .imgToProjection import imgtoprojection # noqa
2 | from .simplelogger import SimpleLogger # noqa
3 |
--------------------------------------------------------------------------------
/integrated_cell/arc_walk.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def cart_from_pol(r,theta):
4 |
5 | x = np.zeros(len(theta)+1)
6 | sin = np.sin(theta)
7 | cos = np.cos(theta)
8 |
9 | x[0] = r * cos[0]
10 |
11 | for i in range(1,len(theta)):
12 | x[i] = r * cos[i] * np.prod(sin[:i])
13 |
14 | x[-1] = r * np.prod(sin)
15 |
16 | return x
17 |
18 |
19 | def pol_from_cart(x):
20 |
21 | x_sq=x**2
22 | r = np.sqrt(np.sum(x_sq))
23 |
24 | if len(x) > 1:
25 | theta = np.zeros(len(x)-1)
26 |
27 | for i,_ in enumerate(theta):
28 | theta[i] = np.arccos( x[i]/np.sqrt(np.sum(x_sq[i:])) )
29 |
30 | if x[-1] < 0:
31 | theta[-1] *= -1.0
32 |
33 | return (r,theta)
34 |
35 |
36 | def shortest_angular_path(theta_start, theta_end, N_points):
37 |
38 | theta_end %= (2*np.pi)
39 | theta_start %= (2*np.pi)
40 |
41 | # print('theta_end = ', theta_end)
42 | # print('theta_start = ', theta_start)
43 |
44 | swap = theta_end < theta_start
45 | if swap:
46 | theta_end, theta_start = theta_start, theta_end
47 |
48 | theta = theta_end - theta_start
49 | # print('theta = ', theta)
50 |
51 | if theta <= np.pi:
52 | path = np.linspace(0, theta, N_points)
53 | else:
54 | path = np.linspace(theta, 2*np.pi, N_points)
55 | path %= (2*np.pi)
56 | path = np.flipud(path)
57 |
58 | path += theta_start
59 | path %= (2*np.pi)
60 |
61 | if swap:
62 | path = np.flipud(path)
63 |
64 | return(path)
65 |
66 |
67 | def linspace_sph_pol(x, y, N_points):
68 |
69 | assert len(x) == len(y)
70 |
71 | r1,theta1 = pol_from_cart(x)
72 | r2,theta2 = pol_from_cart(y)
73 |
74 | r_path = np.linspace(r1, r2, N_points)
75 |
76 | theta_path = np.zeros([N_points, len(theta1)])
77 | for i in range(len(theta1)-1):
78 | theta_path[:,i] = np.linspace(theta1[i], theta2[i], N_points)
79 |
80 | theta_path[:,-1] = shortest_angular_path(theta1[-1],theta2[-1],N_points)
81 |
82 | cart_path = np.zeros([N_points, len(x)])
83 | for i in range(N_points):
84 | x_i = cart_from_pol(r_path[i],theta_path[i,:])
85 | cart_path[i,:] = x_i
86 |
87 | return cart_path
88 |
--------------------------------------------------------------------------------
/integrated_cell/bin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/bin/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/corr_stats.py:
--------------------------------------------------------------------------------
1 | # copy and pasted from https://github.com/pytorch/pytorch/issues/1254
2 |
3 | import torch
4 |
5 | import pdb
6 |
7 | def pearsonr(x, y):
8 | """
9 | Mimics `scipy.stats.pearsonr`
10 |
11 | Arguments
12 | ---------
13 | x : 1D torch.Tensor
14 | y : 1D torch.Tensor
15 |
16 | Returns
17 | -------
18 | r_val : float
19 | pearsonr correlation coefficient between x and y
20 |
21 | Scipy docs ref:
22 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html
23 |
24 | Scipy code ref:
25 | https://github.com/scipy/scipy/blob/v0.19.0/scipy/stats/stats.py#L2975-L3033
26 | Example:
27 | >>> x = np.random.randn(100)
28 | >>> y = np.random.randn(100)
29 | >>> sp_corr = scipy.stats.pearsonr(x, y)[0]
30 | >>> th_corr = pearsonr(torch.from_numpy(x), torch.from_numpy(y))
31 | >>> np.allclose(sp_corr, th_corr)
32 | """
33 | mean_x = torch.mean(x)
34 | mean_y = torch.mean(y)
35 | xm = x.sub(mean_x.data[0])
36 | ym = y.sub(mean_y.data[0])
37 | r_num = xm.dot(ym)
38 | r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
39 | r_val = r_num / r_den
40 | return r_val
41 |
42 | def corrcoef(x):
43 | """
44 | Mimics `np.corrcoef`
45 |
46 | Arguments
47 | ---------
48 | x : 2D torch.Tensor
49 |
50 | Returns
51 | -------
52 | c : torch.Tensor
53 | if x.size() = (5, 100), then return val will be of size (5,5)
54 |
55 | Numpy docs ref:
56 | https://docs.scipy.org/doc/numpy/reference/generated/numpy.corrcoef.html
57 | Numpy code ref:
58 | https://github.com/numpy/numpy/blob/v1.12.0/numpy/lib/function_base.py#L2933-L3013
59 |
60 | Example:
61 | >>> x = np.random.randn(5,120)
62 | # result is a (5,5) matrix of correlations between rows
63 | >>> np_corr = np.corrcoef(x)
64 | >>> th_corr = corrcoef(torch.from_numpy(x))
65 | >>> np.allclose(np_corr, th_corr.numpy())
66 | # [out]: True
67 | """
68 | # calculate covariance matrix of rows
69 | mean_x = torch.mean(x, 1).unsqueeze(1)
70 | xm = x.sub(mean_x.expand_as(x))
71 | c = xm.mm(xm.t())
72 | c = c / (x.size(1) - 1)
73 |
74 | # normalize covariance matrix
75 | d = torch.diag(c)
76 | stddev = torch.pow(d, 0.5)
77 | c = c.div(stddev.expand_as(c))
78 | c = c.div(stddev.expand_as(c).t())
79 |
80 | # clamp between -1 and 1
81 | # probably not necessary but numpy does it
82 | c = torch.clamp(c, -1.0, 1.0)
83 |
84 | return c
--------------------------------------------------------------------------------
/integrated_cell/data_providers/DataProviderABC.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class DataProviderABC(ABC):
4 |
5 | @abstractmethod
6 | def get_n_dat(self, train_or_test):
7 | pass
8 |
9 | @abstractmethod
10 | def get_n_classes(self):
11 | pass
12 |
13 | @abstractmethod
14 | def get_image_paths(self, inds, train_or_test):
15 | pass
16 |
17 | @abstractmethod
18 | def get_images(self, inds, train_or_test):
19 | pass
20 |
21 | @abstractmethod
22 | def get_classes(self, inds, train_or_test):
23 | pass
24 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/GraphDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from integrated_cell.data_providers.DataProvider import (
3 | DataProvider as ParentDataProvider,
4 | ) # ugh im sorry
5 |
6 | from .. import utils
7 |
8 |
9 | class DataProvider(ParentDataProvider):
10 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
11 | def __init__(self, masked_channels=[], **kwargs):
12 |
13 | super().__init__(**kwargs)
14 |
15 | self.masked_channels = masked_channels
16 |
17 | def get_sample(self, train_or_test="train", inds=None):
18 | # returns
19 | # x is b by c by y by x
20 | # x_class is b by c by #classes
21 | # graph is b by c by c - a random dag over channels
22 |
23 | x, classes, _ = super().get_sample(train_or_test=train_or_test, inds=inds)
24 |
25 | n_b = x.shape[0]
26 | n_c = x.shape[1]
27 |
28 | classes = classes.type_as(x).long()
29 | classes = utils.index_to_onehot(classes, self.get_n_classes())
30 |
31 | classes_mem = torch.ones(n_b) * self.get_n_classes() - 2
32 | classes_mem = utils.index_to_onehot(classes_mem, self.get_n_classes())
33 |
34 | classes_dna = torch.ones(n_b) * self.get_n_classes() - 1
35 | classes_dna = utils.index_to_onehot(classes_dna, self.get_n_classes())
36 |
37 | classes = [torch.unsqueeze(c, 1) for c in [classes_mem, classes, classes_dna]]
38 | classes = torch.cat(classes, 1)
39 |
40 | graph = torch.zeros(x.shape[0], x.shape[1], x.shape[1])
41 |
42 | graphs = list()
43 | for i in range(n_b):
44 | graph = (
45 | torch.ones([n_c, n_c]).triu(1) * torch.zeros([n_c, n_c]).bernoulli_()
46 | )
47 | rand_inds = torch.randperm(n_c)
48 |
49 | graph = graph[:, rand_inds][rand_inds, :]
50 | graphs.append(graph)
51 |
52 | graph = torch.stack(graphs, 0).long()
53 |
54 | return x, classes, graph
55 |
56 | def get_n_classes(self):
57 | return super().get_n_classes() + 2
58 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/MaskedChannelsDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from integrated_cell.data_providers.DataProvider import (
3 | DataProvider as ParentDataProvider,
4 | ) # ugh im sorry
5 |
6 |
7 | class DataProvider(ParentDataProvider):
8 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
9 | def __init__(self, masked_channels=[], **kwargs):
10 |
11 | super().__init__(**kwargs)
12 |
13 | self.masked_channels = masked_channels
14 |
15 | def get_sample(self, train_or_test="train", inds=None):
16 |
17 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
18 |
19 | # build a mask for channels
20 | n_channels = x.shape[1]
21 | channel_mask = torch.zeros(n_channels).byte()
22 | channel_mask[self.masked_channels] = 1
23 |
24 | x[:, channel_mask] = 0
25 |
26 | return x, classes, ref
27 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/MultiScaleDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from integrated_cell.data_providers.DataProvider import (
3 | DataProvider as ParentDataProvider,
4 | ) # ugh im sorry
5 |
6 |
7 | class DataProvider(ParentDataProvider):
8 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
9 | def __init__(self, n_x_sub=4, **kwargs):
10 |
11 | super().__init__(**kwargs)
12 |
13 | self.n_x_sub = n_x_sub
14 |
15 | def get_sample(self, train_or_test="train", inds=None, patched=True):
16 | # returns
17 | # x is b by c by y by x
18 | # x_class is b by c by #classes
19 | # graph is b by c by c - a random dag over channels
20 |
21 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
22 |
23 | # this will be faster
24 | x = x.cuda()
25 |
26 | scales = 1 / (2 ** torch.arange(0, self.n_x_sub).float())
27 |
28 | x = [x] + [
29 | torch.nn.functional.interpolate(x, scale_factor=scale.item())
30 | for scale in scales[1:]
31 | ]
32 |
33 | return x, classes, ref
34 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/PatchDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from integrated_cell.data_providers.DataProvider import (
3 | DataProvider as ParentDataProvider,
4 | ) # ugh im sorry
5 |
6 |
7 | def get_patch_slice(shape, patch_size):
8 | patch_size = torch.tensor(patch_size)
9 |
10 | shape_dims = torch.tensor(shape) - patch_size
11 |
12 | starts = torch.cat(
13 | [torch.randint(i, [1]) if i > 0 else torch.tensor([0]) for i in shape_dims], 0
14 | )
15 | ends = starts + patch_size
16 |
17 | slices = tuple(slice(s, e) for s, e in zip(starts, ends))
18 |
19 | # x = x[slices]
20 |
21 | return slices
22 |
23 |
24 | class DataProvider(ParentDataProvider):
25 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
26 | def __init__(
27 | self,
28 | patch_size=[3, 64, 64, 32],
29 | default_return_mesh=False,
30 | default_return_patch=True,
31 | **kwargs
32 | ):
33 |
34 | super().__init__(**kwargs)
35 |
36 | self.patch_size = patch_size
37 | self.default_return_mesh = default_return_mesh
38 | self.default_return_patch = default_return_patch
39 |
40 | def get_sample(
41 | self, train_or_test="train", inds=None, patched=None, return_mesh=None
42 | ):
43 | # returns
44 | # x is b by c by y by x
45 | # x_class is b by c by #classes
46 | # graph is b by c by c - a random dag over channels
47 |
48 | if return_mesh is None:
49 | return_mesh = self.default_return_mesh
50 |
51 | if patched is None:
52 | patched = self.default_return_patch
53 |
54 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
55 |
56 | shape = x[0].shape
57 |
58 | if return_mesh:
59 |
60 | meshes = [torch.arange(0, i) - (i // 2) for i in list(shape[1:])]
61 | mesh = torch.meshgrid(meshes)
62 | mesh = torch.stack([m.float() for m in mesh], 0)
63 |
64 | if not patched:
65 | ref = torch.stack([mesh for i in range(x.shape[0])], 0)
66 |
67 | if patched:
68 |
69 | slices = [
70 | get_patch_slice(shape, self.patch_size) for i in range(x.shape[0])
71 | ]
72 |
73 | x = torch.stack(
74 | [x_sub[slices_sub] for x_sub, slices_sub in zip(x, slices)], 0
75 | )
76 |
77 | if return_mesh:
78 | ref = torch.stack(
79 | [
80 | mesh[[slice(0, mesh.shape[0])] + list(slices[0][1:])]
81 | for slices_sub in slices
82 | ],
83 | 0,
84 | )
85 |
86 | return x, classes, ref
87 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/ProjectedDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from integrated_cell.data_providers.DataProvider import (
4 | DataProvider as ParentDataProvider,
5 | ) # ugh im sorry
6 |
7 |
8 | class DataProvider(ParentDataProvider):
9 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
10 | def __init__(self, channel_intensity_values=None, slice_or_proj="slice", **kwargs):
11 |
12 | super().__init__(**kwargs)
13 |
14 | self.channel_intensity_values = channel_intensity_values
15 | self.slice_or_proj = slice_or_proj
16 |
17 | def get_sample(self, train_or_test="train", inds=None):
18 | # returns
19 | # x is b by c by y by x
20 | # x_class is b by c by #classes
21 | # graph is b by c by c - a random dag over channels
22 |
23 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
24 |
25 | if self.channel_intensity_values:
26 | for i in range(x.shape[0]):
27 | for j in range(x.shape[1]):
28 | if torch.sum(x[i, j]) > 0:
29 | x[i, j] = x[i, j] * (
30 | self.channel_intensity_values[j] / torch.sum(x[i, j])
31 | )
32 |
33 | if self.slice_or_proj == "slice":
34 | center_of_image = torch.tensor(x.shape[2:]) / 2
35 |
36 | zx = x[:, :, center_of_image[0], :, :]
37 | zy = x[:, :, :, center_of_image[1], :]
38 | xy = x[:, :, :, :, center_of_image[2]]
39 |
40 | x = [zx, zy, xy]
41 |
42 | elif self.slice_or_proj == "proj":
43 | zx = torch.max(x, 2)[0]
44 | zy = torch.max(x, 3)[0]
45 | xy = torch.max(x, 4)[0]
46 |
47 | x = [zx, zy, xy]
48 |
49 | return x, classes, ref
50 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/RefDataProvider.py:
--------------------------------------------------------------------------------
1 | from integrated_cell.data_providers.DataProvider import (
2 | DataProvider as ParentDataProvider,
3 | ) # ugh im sorry
4 |
5 |
6 | class DataProvider(ParentDataProvider):
7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
8 | def __init__(self, **kwargs):
9 |
10 | super().__init__(**kwargs)
11 |
12 | def get_sample(self, train_or_test="train", inds=None):
13 | # returns
14 | # x is b by c by y by x
15 | # x_class is b by c by #classes
16 | # graph is b by c by c - a random dag over channels
17 |
18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
19 |
20 | x = x[:, [0, 2]]
21 |
22 | return x
23 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/RescaledIntensityDataProvider.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from integrated_cell.data_providers.DataProvider import (
4 | DataProvider as ParentDataProvider,
5 | ) # ugh im sorry
6 |
7 |
8 | class DataProvider(ParentDataProvider):
9 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
10 | def __init__(
11 | self, channel_intensity_values=[2596.9521, 2596.9521, 2596.9521], **kwargs
12 | ):
13 |
14 | super().__init__(**kwargs)
15 |
16 | self.channel_intensity_values = channel_intensity_values
17 |
18 | def get_sample(self, train_or_test="train", inds=None):
19 | # returns
20 | # x is b by c by y by x
21 | # x_class is b by c by #classes
22 | # graph is b by c by c - a random dag over channels
23 |
24 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
25 |
26 | for i in range(x.shape[0]):
27 | for j in range(x.shape[1]):
28 | if torch.sum(x[i, j]) > 0:
29 | x[i, j] = x[i, j] * (
30 | self.channel_intensity_values[j] / torch.sum(x[i, j])
31 | )
32 |
33 | return x, classes, ref
34 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/RescaledIntensityRefDataProvider.py:
--------------------------------------------------------------------------------
1 | from integrated_cell.data_providers.RescaledIntensityDataProvider import (
2 | DataProvider as ParentDataProvider,
3 | ) # ugh im sorry
4 |
5 |
6 | class DataProvider(ParentDataProvider):
7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
8 | def __init__(self, **kwargs):
9 |
10 | super().__init__(**kwargs)
11 |
12 | def get_sample(self, train_or_test="train", inds=None):
13 | # returns
14 | # x is b by c by y by x
15 | # x_class is b by c by #classes
16 | # graph is b by c by c - a random dag over channels
17 |
18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
19 |
20 | x = x[:, [0, 2]]
21 |
22 | return x
23 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/RescaledIntensityTargetDataProvider.py:
--------------------------------------------------------------------------------
1 | from integrated_cell.data_providers.RescaledIntensityDataProvider import (
2 | DataProvider as ParentDataProvider,
3 | ) # ugh im sorry
4 |
5 |
6 | class DataProvider(ParentDataProvider):
7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
8 | def __init__(self, target_inds=[1], ref_inds=[0, 2], **kwargs):
9 |
10 | super().__init__(**kwargs)
11 |
12 | def get_sample(self, train_or_test="train", inds=None):
13 | # returns
14 | # x is b by c by y by x
15 | # x_class is b by c by #classes
16 | # graph is b by c by c - a random dag over channels
17 |
18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
19 |
20 | ref = x[:, [0, 2]]
21 | x = x[:, [1]]
22 |
23 | return x, classes, ref
24 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/TargetDataProvider.py:
--------------------------------------------------------------------------------
1 | from integrated_cell.data_providers.DataProvider import (
2 | DataProvider as ParentDataProvider,
3 | ) # ugh im sorry
4 |
5 |
6 | class DataProvider(ParentDataProvider):
7 | # Same as DataProvider but zeros out channels indicated by the variable 'masked_channels'
8 | def __init__(self, **kwargs):
9 |
10 | super().__init__(**kwargs)
11 |
12 | def get_sample(self, train_or_test="train", inds=None):
13 | # returns
14 | # x is b by c by y by x
15 | # x_class is b by c by #classes
16 | # graph is b by c by c - a random dag over channels
17 |
18 | x, classes, ref = super().get_sample(train_or_test=train_or_test, inds=inds)
19 |
20 | ref = x[:, [0, 2]]
21 | x = x[:, [1]]
22 |
23 | return x, classes, ref
24 |
--------------------------------------------------------------------------------
/integrated_cell/data_providers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/data_providers/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/external/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/external/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/external/pytorch_fid/README.md:
--------------------------------------------------------------------------------
1 | Original Repository: [pytorch-fid](https://github.com/mseitzer/pytorch-fid)
2 |
3 | # Fréchet Inception Distance (FID score) in PyTorch
4 |
5 | This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
6 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow.
7 |
8 | FID is a measure of similarity between two datasets of images.
9 | It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
10 | FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
11 |
12 | Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337).
13 |
14 | **Note that the official implementation gives slightly different scores.** If you report FID scores in your paper, and you want them to be exactly comparable to FID scores reported in other papers, you should use [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR).
15 | You can still use this version if you want a quick FID estimate without installing Tensorflow.
16 |
17 | **Update:** The weights and the model are now exactly the same as in the official Tensorflow implementation, and I verified them to give the same results (around `1e-8` mean absolute error) on single inputs on my platform. However, due to differences in the image interpolation implementation and library backends, FID results might still differ slightly from the original implementation. A test I ran (details are to come) resulted in `.08` absolute error and `0.0009` relative error.
18 |
19 | ## Usage
20 |
21 | Requirements:
22 | - python3
23 | - pytorch
24 | - torchvision
25 | - numpy
26 | - scipy
27 |
28 | To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
29 | ```
30 | ./fid_score.py path/to/dataset1 path/to/dataset2
31 | ```
32 |
33 | To run the evaluation on GPU, use the flag `--gpu N`, where `N` is the index of the GPU to use.
34 |
35 | ### Using different layers for feature maps
36 |
37 | In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
38 | As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance.
39 |
40 | This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
41 | Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
42 | The resulting scores might also no longer correlate with visual quality.
43 |
44 | You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
45 | The choices are:
46 | - 64: first max pooling features
47 | - 192: second max pooling featurs
48 | - 768: pre-aux classifier features
49 | - 2048: final average pooling features (this is the default)
50 |
51 | ## License
52 |
53 | This implementation is licensed under the Apache License 2.0.
54 |
55 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500)
56 |
57 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0.
58 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR).
59 |
--------------------------------------------------------------------------------
/integrated_cell/imgToProjection.py:
--------------------------------------------------------------------------------
1 | # Author: Evan Wiederspan
2 |
3 | import numpy as np
4 | import matplotlib as mpl
5 |
6 | mpl.use("Agg") # noqa
7 | import matplotlib.pyplot as pplot
8 |
9 |
10 | def matproj(im, dim, method="max", slice_index=0):
11 | if method == "max":
12 | im = np.max(im, dim)
13 | elif method == "mean":
14 | im = np.mean(im, dim)
15 | elif method == "sum":
16 | im = np.sum(im, dim)
17 | elif method == "slice":
18 | im = im[slice_index]
19 | else:
20 | raise ValueError("Invalid projection method")
21 | return im
22 |
23 |
24 | def imgtoprojection(
25 | im1,
26 | proj_all=False,
27 | proj_method="max",
28 | colors=lambda i: [1, 1, 1],
29 | global_adjust=False,
30 | local_adjust=False,
31 | ):
32 | """
33 | Outputs projections of a 4d CZYX numpy array into a CYX numpy array, allowing for color masks for each input channel
34 | as well as adjustment options
35 | :param im1: Either a 4d numpy array or a list of 3D or 2D numpy arrays. The input that will be projected
36 | :param proj_all: boolean. True outputs XY, YZ, and XZ projections in a grid, False just outputs XY. False by default
37 | :param proj_method: string. Method by which to do projections. 'Max' by default
38 | :param colors: Can be either a string which corresponds to a cmap function in matplotlib, a function that
39 | takes in the channel index and returns a list of numbers, or a list of lists containing the color multipliers.
40 | :param global_adjust: boolean. If true, scales each color channel to set its max to be 255
41 | after combining all channels. False by default
42 | :param local_adjust: boolean. If true, performs contrast adjustment on each channel individually. False by default
43 | :return: a CYX numpy array containing the requested projections
44 | """
45 |
46 | # turn list of 2d or 3d arrays into single 4d array if needed
47 | try:
48 | if isinstance(im1, (type([]), type(()))):
49 | # if only YX, add a single Z dimen
50 | if im1[0].ndim == 2:
51 | im1 = [np.expand_dims(c, axis=0) for c in im1]
52 | elif im1[0].ndim != 3:
53 | raise ValueError("im1 must be a list of 2d or 3d arrays")
54 | # combine list into 4d array
55 | im = np.stack(im1)
56 | else:
57 | if im1.ndim != 4:
58 | raise ValueError("Invalid dimensions for im1")
59 | im = im1
60 |
61 | except (AttributeError, IndexError):
62 | # its not a list of np arrays
63 | raise ValueError(
64 | "im1 must be either a 4d numpy array or a list of numpy arrays"
65 | )
66 |
67 | # color processing code
68 | if isinstance(colors, str):
69 | # pass it in to matplotlib
70 | try:
71 | colors = pplot.get_cmap(colors)(np.linspace(0, 1, im.shape[0]))
72 | except ValueError:
73 | # thrown when string is not valid function
74 | raise ValueError("Invalid cmap string")
75 | elif callable(colors):
76 | # if its a function
77 | try:
78 | colors = [colors(i) for i in range(im.shape[0])]
79 | except: # noqa
80 | raise ValueError("Invalid color function")
81 |
82 | # else, were assuming it's a list
83 | # scale colors down to 0-1 range if they're bigger than 1
84 | if any(v > 1 for v in np.array(colors).flatten()):
85 | colors = [[v / 255 for v in c] for c in colors]
86 |
87 | # create final image
88 | if not proj_all:
89 | img_final = np.zeros((3, im.shape[2], im.shape[3]))
90 | else:
91 | # y + z, x + z
92 | img_final = np.zeros((3, im.shape[2] + im.shape[1], im.shape[3] + im.shape[1]))
93 | img_piece = np.zeros(img_final.shape)
94 | # loop through all channels
95 | for i, img_c in enumerate(im):
96 | try:
97 | proj_z = matproj(img_c, 0, proj_method, img_c.shape[0] // 2)
98 | if proj_all:
99 | proj_y, proj_x = (
100 | matproj(img_c, axis, proj_method, img_c.shape[axis] // 2)
101 | for axis in range(1, 3)
102 | )
103 | # flipping to get them facing the right way
104 | # proj_x = np.fliplr(np.transpose(proj_x, (1, 0)))
105 | # proj_y = np.flipud(proj_y)
106 | proj_x = np.transpose(proj_x, (1, 0))
107 | proj_y = np.flipud(proj_y)
108 |
109 | _, sy, sz = proj_z.shape[1], proj_z.shape[0], proj_y.shape[0] # noqa
110 | img_piece[:, :sy, :sz] = proj_x
111 | img_piece[:, :sy, sz:] = proj_z
112 | img_piece[:, sy:, sz:] = proj_y
113 | else:
114 | img_piece[:] = proj_z
115 | except ValueError:
116 | raise ValueError("Invalid projection function")
117 |
118 | for c in range(3):
119 | img_piece[c] *= colors[i][c]
120 |
121 | # local contrast adjustment, minus the min, divide the max
122 | if local_adjust:
123 | img_piece -= np.min(img_piece)
124 | img_piece /= np.max(img_piece)
125 | img_final += img_piece
126 |
127 | # color range adjustment, ensure that max value is 255
128 | if global_adjust:
129 | # scale color channels independently
130 | img_final /= np.max(img_final)
131 |
132 | return img_final
133 |
--------------------------------------------------------------------------------
/integrated_cell/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ChannelSoftmax(nn.Module):
6 | # normalizes each channel to sum to one
7 | def __init__(self):
8 | super(ChannelSoftmax, self).__init__()
9 |
10 | def forward(self, x):
11 |
12 | # https://stackoverflow.com/questions/44081007/logsoftmax-stability
13 | # or even better
14 | # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
15 |
16 | n_dims = len(x.shape) - 2
17 |
18 | b = torch.max(x.view(x.shape[0], x.shape[1], -1), dim=2)[0]
19 |
20 | for i in range(n_dims):
21 | b = torch.unsqueeze(b, -1)
22 |
23 | x_exp = torch.exp(x - b)
24 |
25 | normalize = x_exp.clone()
26 | for i in range(n_dims):
27 | normalize = torch.sum(normalize, -1)
28 |
29 | for i in range(n_dims):
30 | normalize = torch.unsqueeze(normalize, -1)
31 |
32 | softmaxed = x_exp / normalize
33 |
34 | return softmaxed
35 |
--------------------------------------------------------------------------------
/integrated_cell/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/metrics/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/metrics/embeddings.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 |
4 | from ..utils import utils, reparameterize
5 |
6 | from ..losses import KLDLoss
7 |
8 |
9 | def get_latent_embeddings(
10 | enc,
11 | dec,
12 | dp,
13 | recon_loss,
14 | modes=["test"],
15 | batch_size=256,
16 | n_recon_samples=10,
17 | sampler=None,
18 | beta=1,
19 | channels_ref=[0, 2],
20 | channels_target=[1],
21 | ):
22 |
23 | if sampler is None:
24 |
25 | def sampler(mode, inds):
26 | return dp.get_sample(mode, inds)
27 |
28 | enc.eval()
29 | dec.eval()
30 |
31 | embedding = dict()
32 |
33 | for mode in modes:
34 | ndat = dp.get_n_dat(mode)
35 |
36 | x, classes, ref = sampler(mode, [0])
37 |
38 | x = x.cuda()
39 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda()
40 |
41 | with torch.no_grad():
42 | zAll = enc(x, classes_onehot)
43 |
44 | ndims = torch.prod(torch.tensor(zAll[0][0].shape[1:]))
45 |
46 | embeddings_ref_mu = torch.zeros(ndat, ndims)
47 | embeddings_ref_sigma = torch.zeros(ndat, ndims)
48 |
49 | embeddings_ref_recons = torch.zeros(ndat, n_recon_samples)
50 | embeddings_ref_kld = torch.zeros(ndat)
51 |
52 | embeddings_target_mu = torch.zeros(ndat, ndims)
53 | embeddings_target_sigma = torch.zeros(ndat, ndims)
54 |
55 | embeddings_target_recons = torch.zeros(ndat, n_recon_samples)
56 | embeddings_target_kld = torch.zeros(ndat)
57 |
58 | embeddings_classes = torch.zeros(ndat).long()
59 |
60 | inds = list(range(0, ndat))
61 | data_iter = [
62 | torch.LongTensor(inds[i : i + batch_size]) # noqa
63 | for i in range(0, len(inds), batch_size) # noqa
64 | ]
65 |
66 | for i in tqdm(range(0, len(data_iter))):
67 | batch_size = len(data_iter[i])
68 |
69 | x, classes, ref = sampler(mode, data_iter[i])
70 |
71 | x = x.cuda()
72 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda()
73 |
74 | with torch.no_grad():
75 | zAll = enc(x, classes_onehot)
76 |
77 | embeddings_ref_mu.index_copy_(
78 | 0, data_iter[i], zAll[0][0].cpu().view([batch_size, -1])
79 | )
80 | embeddings_ref_sigma.index_copy_(
81 | 0, data_iter[i], zAll[0][1].cpu().view([batch_size, -1])
82 | )
83 |
84 | embeddings_target_mu.index_copy_(
85 | 0, data_iter[i], zAll[1][0].cpu().view([batch_size, -1])
86 | )
87 | embeddings_target_sigma.index_copy_(
88 | 0, data_iter[i], zAll[1][1].cpu().view([batch_size, -1])
89 | )
90 |
91 | embeddings_classes.index_copy_(0, data_iter[i], classes)
92 |
93 | recons_ref, recons_target = get_recons(
94 | x,
95 | classes_onehot,
96 | dec,
97 | zAll,
98 | n_recon_samples,
99 | recon_loss=recon_loss,
100 | channels_ref=channels_ref,
101 | channels_target=channels_target,
102 | )
103 |
104 | embeddings_ref_kld.index_copy_(
105 | 0,
106 | torch.LongTensor(data_iter[i]),
107 | get_klds(zAll[0][0], zAll[0][1]).data[:].cpu(),
108 | )
109 | embeddings_ref_recons.index_copy_(
110 | 0, torch.LongTensor(data_iter[i]), recons_ref
111 | )
112 |
113 | embeddings_target_kld.index_copy_(
114 | 0, torch.LongTensor(data_iter[i]), get_klds(zAll[1][0], zAll[1][1])
115 | )
116 | embeddings_target_recons.index_copy_(
117 | 0, torch.LongTensor(data_iter[i]), recons_target
118 | )
119 |
120 | embedding[mode] = {}
121 | embedding[mode]["ref"] = {}
122 | embedding[mode]["ref"]["mu"] = embeddings_ref_mu
123 | embedding[mode]["ref"]["sigma"] = embeddings_ref_sigma
124 |
125 | embedding[mode]["ref"]["kld"] = embeddings_ref_kld
126 | embedding[mode]["ref"]["recon"] = embeddings_ref_recons
127 | embedding[mode]["ref"]["elbo"] = -(
128 | torch.mean(embeddings_ref_recons, 1) + beta * embeddings_ref_kld
129 | )
130 |
131 | embedding[mode]["target"] = {}
132 | embedding[mode]["target"]["mu"] = embeddings_target_mu
133 | embedding[mode]["target"]["sigma"] = embeddings_target_sigma
134 | embedding[mode]["target"]["class"] = embeddings_classes
135 |
136 | embedding[mode]["target"]["kld"] = embeddings_target_kld
137 | embedding[mode]["target"]["recon"] = embeddings_target_recons
138 |
139 | embedding[mode]["target"]["elbo"] = -(
140 | torch.mean(embeddings_target_recons, 1) + beta * embeddings_target_kld
141 | )
142 |
143 | for mode in embedding:
144 | for component in embedding[mode]:
145 | for thing in embedding[mode][component]:
146 | embedding[mode][component][thing] = (
147 | embedding[mode][component][thing].cpu().detach()
148 | )
149 |
150 | return embedding
151 |
152 |
153 | def get_klds(mus, sigmas):
154 |
155 | kld_loss = KLDLoss(reduction="sum")
156 |
157 | klds = torch.zeros(mus.shape[0])
158 |
159 | for i, (mu, sigma) in enumerate(zip(mus, sigmas)):
160 |
161 | kld, _, _ = kld_loss(mu.unsqueeze(0), sigma.unsqueeze(0))
162 |
163 | klds[i] = kld[0]
164 |
165 | return klds
166 |
167 |
168 | def get_recons(
169 | x,
170 | classes_onehot,
171 | dec,
172 | zAll,
173 | n_recon_samples,
174 | recon_loss,
175 | channels_ref=[0, 2],
176 | channels_target=[1],
177 | ):
178 |
179 | recons_ref = torch.zeros(x.shape[0], n_recon_samples)
180 | recons_target = torch.zeros(x.shape[0], n_recon_samples)
181 |
182 | for i in range(n_recon_samples):
183 | zOut = [reparameterize(z[0], z[1]) for z in zAll]
184 |
185 | with torch.no_grad():
186 | xHat = dec([classes_onehot] + zOut)
187 |
188 | recons_ref[:, i] = torch.stack(
189 | [
190 | recon_loss(xHat[[ind], [channels_ref]], x[[ind], [channels_ref]])
191 | for ind in range(len(x))
192 | ]
193 | )
194 |
195 | recons_target[:, i] = torch.stack(
196 | [
197 | recon_loss(xHat[[ind], [channels_target]], x[[ind], [channels_target]])
198 | for ind in range(len(x))
199 | ]
200 | )
201 |
202 | return recons_ref, recons_target
203 |
--------------------------------------------------------------------------------
/integrated_cell/metrics/embeddings_reference.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 |
4 | from ..utils import reparameterize
5 |
6 | from ..losses import KLDLoss
7 |
8 |
9 | def get_latent_embeddings(
10 | enc,
11 | dec,
12 | dp,
13 | recon_loss,
14 | modes=["test"],
15 | batch_size=256,
16 | n_recon_samples=10,
17 | sampler=None,
18 | beta=1,
19 | channels_ref=[0, 2],
20 | channels_target=[1],
21 | ):
22 |
23 | kld_loss = KLDLoss(reduction="sum")
24 |
25 | if sampler is None:
26 |
27 | def sampler(mode, inds):
28 | return dp.get_sample(mode, inds)
29 |
30 | enc.eval()
31 | dec.eval()
32 |
33 | embedding = dict()
34 |
35 | for mode in modes:
36 | ndat = dp.get_n_dat(mode)
37 |
38 | x = sampler(mode, [0])
39 | x = x.cuda()
40 |
41 | with torch.no_grad():
42 | zAll = enc(x)
43 |
44 | ndims = torch.prod(torch.tensor(zAll[0].shape[1:]))
45 |
46 | embeddings_ref_mu = torch.zeros(ndat, ndims)
47 | embeddings_ref_sigma = torch.zeros(ndat, ndims)
48 |
49 | embeddings_ref_recons = torch.zeros(ndat, n_recon_samples)
50 | embeddings_ref_kld = torch.zeros(ndat)
51 |
52 | inds = list(range(0, ndat))
53 | data_iter = [
54 | torch.LongTensor(inds[i : i + batch_size]) # noqa
55 | for i in range(0, len(inds), batch_size) # noqa
56 | ]
57 |
58 | for i in tqdm(range(0, len(data_iter))):
59 | batch_size = len(data_iter[i])
60 |
61 | x = sampler(mode, data_iter[i])
62 | x = x.cuda()
63 |
64 | with torch.no_grad():
65 | zAll = enc(x)
66 |
67 | recons_ref = get_recons(
68 | x, dec, zAll, n_recon_samples, recon_loss=recon_loss
69 | )
70 |
71 | klds = torch.stack(
72 | [kld_loss(mu, sigma) for mu, sigma in zip(zAll[0], zAll[1])]
73 | )
74 |
75 | embeddings_ref_mu.index_copy_(
76 | 0, torch.LongTensor(data_iter[i]), zAll[0].cpu()
77 | )
78 |
79 | embeddings_ref_sigma.index_copy_(
80 | 0, torch.LongTensor(data_iter[i]), zAll[1].cpu()
81 | )
82 |
83 | embeddings_ref_kld.index_copy_(
84 | 0, torch.LongTensor(data_iter[i]), klds.cpu()
85 | )
86 |
87 | embeddings_ref_recons.index_copy_(
88 | 0, torch.LongTensor(data_iter[i]), recons_ref
89 | )
90 |
91 | embedding[mode] = {}
92 | embedding[mode]["ref"] = {}
93 | embedding[mode]["ref"]["mu"] = embeddings_ref_mu
94 | embedding[mode]["ref"]["sigma"] = embeddings_ref_sigma
95 |
96 | embedding[mode]["ref"]["kld"] = embeddings_ref_kld
97 | embedding[mode]["ref"]["recon"] = embeddings_ref_recons
98 |
99 | embedding[mode]["ref"]["elbo"] = -(
100 | torch.mean(embeddings_ref_recons, 1) + beta * embeddings_ref_kld
101 | )
102 |
103 | for mode in embedding:
104 | for component in embedding[mode]:
105 | for thing in embedding[mode][component]:
106 | embedding[mode][component][thing] = (
107 | embedding[mode][component][thing].cpu().detach()
108 | )
109 |
110 | return embedding
111 |
112 |
113 | def get_recons(x, dec, zAll, n_recon_samples, recon_loss):
114 |
115 | recons = torch.zeros(x.shape[0], n_recon_samples)
116 |
117 | for i in range(n_recon_samples):
118 | zOut = reparameterize(zAll[0], zAll[1])
119 |
120 | with torch.no_grad():
121 | xHat = dec(zOut)
122 |
123 | recons[:, i] = torch.stack(
124 | [recon_loss(xHat[[ind]], x[[ind]]) for ind in range(len(x))]
125 | )
126 |
127 | return recons
128 |
--------------------------------------------------------------------------------
/integrated_cell/metrics/embeddings_target.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 |
4 | from ..utils import utils, reparameterize
5 |
6 | from ..losses import KLDLoss
7 |
8 |
9 | def get_latent_embeddings(
10 | enc,
11 | dec,
12 | dp,
13 | recon_loss,
14 | modes=["test"],
15 | batch_size=256,
16 | n_recon_samples=10,
17 | sampler=None,
18 | beta=1,
19 | channels_ref=[0, 2],
20 | channels_target=[1],
21 | ):
22 |
23 | kld_loss = KLDLoss(reduction="sum")
24 |
25 | if sampler is None:
26 |
27 | def sampler(mode, inds):
28 | return dp.get_sample(mode, inds)
29 |
30 | enc.eval()
31 | dec.eval()
32 |
33 | embedding = dict()
34 |
35 | for mode in modes:
36 | ndat = dp.get_n_dat(mode)
37 |
38 | x, classes, ref = sampler(mode, [0])
39 |
40 | x = x.cuda()
41 | ref = ref.cuda()
42 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda()
43 |
44 | with torch.no_grad():
45 | zAll = enc(x, ref, classes_onehot)
46 |
47 | ndims = torch.prod(torch.tensor(zAll[0].shape[1:]))
48 |
49 | embeddings_target_mu = torch.zeros(ndat, ndims)
50 | embeddings_target_sigma = torch.zeros(ndat, ndims)
51 |
52 | embeddings_target_recons = torch.zeros(ndat, n_recon_samples)
53 | embeddings_target_kld = torch.zeros(ndat)
54 |
55 | embeddings_classes = torch.zeros(ndat).long()
56 |
57 | inds = list(range(0, ndat))
58 | data_iter = [
59 | torch.LongTensor(inds[i : i + batch_size]) # noqa
60 | for i in range(0, len(inds), batch_size) # noqa
61 | ]
62 |
63 | for i in tqdm(range(0, len(data_iter))):
64 | batch_size = len(data_iter[i])
65 |
66 | x, classes, ref = sampler(mode, data_iter[i])
67 |
68 | x = x.cuda()
69 | ref = ref.cuda()
70 | classes_onehot = utils.index_to_onehot(classes, dp.get_n_classes()).cuda()
71 |
72 | with torch.no_grad():
73 | zAll = enc(x, ref, classes_onehot)
74 |
75 | recons_target = get_recons(
76 | x,
77 | ref,
78 | classes_onehot,
79 | dec,
80 | zAll,
81 | n_recon_samples,
82 | recon_loss=recon_loss,
83 | )
84 |
85 | klds = torch.stack(
86 | [kld_loss(mu, sigma) for mu, sigma in zip(zAll[0], zAll[1])]
87 | )
88 |
89 | embeddings_target_mu.index_copy_(
90 | 0, torch.LongTensor(data_iter[i]), zAll[0].cpu()
91 | )
92 |
93 | embeddings_target_sigma.index_copy_(
94 | 0, torch.LongTensor(data_iter[i]), zAll[1].cpu()
95 | )
96 |
97 | embeddings_classes.index_copy_(0, data_iter[i], classes)
98 |
99 | embeddings_target_kld.index_copy_(
100 | 0, torch.LongTensor(data_iter[i]), klds.cpu()
101 | )
102 |
103 | embeddings_target_recons.index_copy_(
104 | 0, torch.LongTensor(data_iter[i]), recons_target
105 | )
106 |
107 | embedding[mode] = {}
108 | embedding[mode]["target"] = {}
109 | embedding[mode]["target"]["mu"] = embeddings_target_mu
110 | embedding[mode]["target"]["sigma"] = embeddings_target_sigma
111 | embedding[mode]["target"]["class"] = embeddings_classes
112 |
113 | embedding[mode]["target"]["kld"] = embeddings_target_kld
114 | embedding[mode]["target"]["recon"] = embeddings_target_recons
115 |
116 | embedding[mode]["target"]["elbo"] = -(
117 | torch.mean(embeddings_target_recons, 1) + beta * embeddings_target_kld
118 | )
119 |
120 | for mode in embedding:
121 | for component in embedding[mode]:
122 | for thing in embedding[mode][component]:
123 | embedding[mode][component][thing] = (
124 | embedding[mode][component][thing].cpu().detach()
125 | )
126 |
127 | return embedding
128 |
129 |
130 | def get_recons(x, ref, classes_onehot, dec, zAll, n_recon_samples, recon_loss):
131 |
132 | recons = torch.zeros(x.shape[0], n_recon_samples)
133 |
134 | for i in range(n_recon_samples):
135 | zOut = reparameterize(zAll[0], zAll[1])
136 |
137 | with torch.no_grad():
138 | xHat = dec(zOut, ref, classes_onehot)
139 |
140 | recons[:, i] = torch.stack(
141 | [recon_loss(xHat[[ind]], x[[ind]]) for ind in range(len(x))]
142 | )
143 |
144 | return recons
145 |
--------------------------------------------------------------------------------
/integrated_cell/metrics/features2D.py:
--------------------------------------------------------------------------------
1 | import skimage.measure as measure
2 | import skimage.filters as filters
3 | import numpy as np
4 |
5 |
6 | def props2summary(props):
7 |
8 | summary_props = [
9 | "area",
10 | "bbox_area",
11 | "convex_area",
12 | "centroid",
13 | "equivalent_diameter",
14 | "euler_number",
15 | "extent",
16 | "filled_area",
17 | "major_axis_length",
18 | "max_intensity",
19 | "mean_intensity",
20 | "min_intensity",
21 | "minor_axis_length",
22 | "moments",
23 | "moments_central",
24 | "moments_hu",
25 | "moments_normalized",
26 | "orientation",
27 | "perimeter",
28 | "solidity",
29 | "weighted_centroid",
30 | "weighted_moments_central",
31 | "weighted_moments_hu",
32 | "weighted_moments_normalized",
33 | ]
34 |
35 | prop_new = {}
36 |
37 | for k in summary_props:
38 | prop_list = list()
39 |
40 | for prop in props:
41 | p = np.array(prop[k])
42 | prop_list.append(p)
43 |
44 | if len(prop_list) > 1:
45 | prop_stack = np.stack(prop_list, -1)
46 |
47 | prop_mean = np.mean(prop_stack, -1)
48 | prop_std = np.std(prop_stack, -1)
49 |
50 | prop_total = np.sum(prop_stack, -1)
51 |
52 | else:
53 | prop_stack = prop_list[0]
54 |
55 | prop_total = prop_stack
56 | prop_mean = prop_stack
57 | prop_std = np.zeros(prop_stack.shape)
58 |
59 | prop_new[k + "_total"] = prop_total
60 | prop_new[k + "_mean"] = prop_mean
61 | prop_new[k + "_std"] = prop_std
62 |
63 | prop_new["num_objs"] = len(props)
64 |
65 | return prop_new
66 |
67 |
68 | def find_main_obj(im_bw):
69 | im_label = measure.label(im_bw > 0)
70 |
71 | ulabels = np.unique(im_label[im_label > 0])
72 |
73 | label_counts = np.zeros(len(ulabels))
74 |
75 | for i, label in enumerate(ulabels):
76 | label_counts[i] = np.sum(im_label == label)
77 |
78 | return im_label == ulabels[np.argmax(label_counts)]
79 |
80 |
81 | def ch_feats(im_bw, bg_thresh=1e-2):
82 |
83 | im_bg_sub = im_bw > bg_thresh
84 |
85 | props = {}
86 |
87 | im_bg_sub = find_main_obj(im_bg_sub)
88 |
89 | main_props = measure.regionprops(im_bg_sub.astype("uint8"), im_bw)
90 | main_props = props2summary(main_props)
91 |
92 | for k in main_props:
93 | props["main_" + k] = main_props[k]
94 |
95 | thresh = filters.threshold_otsu(im_bw[im_bg_sub])
96 | im_obj = measure.label(im_bw > thresh)
97 |
98 | thresh_props = measure.regionprops(im_obj, im_bw)
99 | thresh_props = props2summary(thresh_props)
100 |
101 | for k in thresh_props:
102 | props["obj_" + k] = thresh_props[k]
103 |
104 | return props
105 |
106 |
107 | def im2feats(im, ch_names, bg_thresh=1e-2):
108 |
109 | feats = {}
110 | for i, ch_name in enumerate(ch_names):
111 | im_ch = im[i]
112 |
113 | feats[ch_name] = ch_feats(im_ch, bg_thresh)
114 |
115 | return feats
116 |
--------------------------------------------------------------------------------
/integrated_cell/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ..external.pytorch_fid.inception import InceptionV3
3 | from ..external.pytorch_fid import fid_score
4 | import os
5 | import pathlib
6 |
7 | """
8 | Implements Improved Precision and Recall Metric for Assessing Generative Models
9 | https://arxiv.org/abs/1904.06991
10 | """
11 |
12 |
13 | def my_cdist(x1, x2):
14 | x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
15 | x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
16 | res = torch.addmm(
17 | x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2
18 | ).add_(x1_norm)
19 | res = res.clamp_min_(1e-30).sqrt_()
20 | return res
21 |
22 |
23 | def manifold_estimate(phi_a, phi_b, k=[3], batch_size=None):
24 | """
25 | phi_a and phi_b are N_a x D and N_b x D sets of features,
26 | k is the nearest neighbor setting for knn manifold estimation,
27 | batch_size is the number of pairwise distance computations done at once (None = do them all at once)
28 | """
29 | if batch_size is None:
30 | d_aa = my_cdist(phi_a, phi_a)
31 | r_phi = {kappa: torch.kthvalue(d_aa, kappa + 1, dim=1)[0] for kappa in k}
32 | d_ba = my_cdist(phi_b, phi_a)
33 | b_in_a = {
34 | kappa: torch.any(d_ba <= rp, dim=1).float() for kappa, rp in r_phi.items()
35 | }
36 | else:
37 | r_phi = {kappa: torch.empty_like(phi_a[:, 0]) for kappa in k}
38 | for batch_inds in torch.split(torch.arange(len(phi_a)), batch_size):
39 | d_aa = my_cdist(phi_a[batch_inds], phi_a)
40 | for kappa in k:
41 | r_phi[kappa][batch_inds] = torch.kthvalue(d_aa, kappa + 1, dim=1)[0]
42 | b_in_a = {kappa: torch.empty_like(phi_b[:, 0]) for kappa in k}
43 | for batch_inds in torch.split(torch.arange(len(phi_b)), batch_size):
44 | d_ba = my_cdist(phi_b[batch_inds], phi_a)
45 | for kappa in k:
46 | b_in_a[kappa][batch_inds] = torch.any(
47 | d_ba <= r_phi[kappa], dim=1
48 | ).float()
49 | return {key: value.mean().item() for key, value in b_in_a.items()}
50 |
51 |
52 | def precision_recall(phi_r, phi_g, k=[3], batch_size=None):
53 | """
54 | phi_r and phi_s are N_a x D and N_b x D sets of features
55 | from real and generated images respectively,
56 | k is a list of integers for the nearest neighbor setting for knn manifold estimation.
57 | returns a dict {k_1:{'precision':float, 'recall':float}, k_2:{'precision':float, 'recall':float}, ...}
58 | """
59 | d = {
60 | "precision": manifold_estimate(phi_r, phi_g, k=k, batch_size=batch_size),
61 | "recall": manifold_estimate(phi_g, phi_r, k=k, batch_size=batch_size),
62 | }
63 |
64 | result = {}
65 | for k1, subdict in d.items():
66 | for k2, v in subdict.items():
67 | result.setdefault(k2, {})[k1] = v
68 |
69 | return result
70 |
71 |
72 | def calculate_inception_pr_given_paths(
73 | paths, k=[3], batch_size=1, cuda=True, dims=2048, verbose=False
74 | ):
75 | # the is constructed to imitate `calculate_fid_given_paths` from integrated_cell.external.pytorch_fid.fidscore.calculate_fid_given_paths
76 |
77 | for p in paths:
78 | if not os.path.exists(p):
79 | raise RuntimeError("Invalid path: %s" % p)
80 |
81 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
82 |
83 | model = InceptionV3([block_idx])
84 | if cuda:
85 | model.cuda()
86 |
87 | path = pathlib.Path(paths[0])
88 | files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
89 | act_real = fid_score.get_activations(files, model, batch_size, dims, cuda, verbose)
90 |
91 | path = pathlib.Path(paths[1])
92 | files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
93 | act_gen = fid_score.get_activations(files, model, batch_size, dims, cuda, verbose)
94 |
95 | act_real = torch.tensor(act_real)
96 | act_gen = torch.tensor(act_gen)
97 |
98 | return precision_recall(act_real, act_gen, k, batch_size=None)
99 |
100 |
101 | def calculate_inception_pr_given_paired_paths(
102 | paths_real, paths_gen, k=[3], batch_size=1, cuda=True, dims=2048, verbose=False
103 | ):
104 | # the is constructed to imitate `calculate_fid_given_paths` from integrated_cell.external.pytorch_fid.fidscore.calculate_fid_given_paths
105 |
106 | for p in paths_real + paths_gen:
107 | if not os.path.exists(p):
108 | raise RuntimeError("Invalid path: %s" % p)
109 |
110 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
111 |
112 | model = InceptionV3([block_idx])
113 | if cuda:
114 | model.cuda()
115 |
116 | prs = list()
117 | for i, [path_real, path_gen] in enumerate(zip(paths_real, paths_gen)):
118 | path = pathlib.Path(path_real)
119 | files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
120 | act_real = fid_score.get_activations(
121 | files, model, batch_size, dims, cuda, verbose
122 | )
123 |
124 | path = pathlib.Path(path_gen)
125 | files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
126 | act_gen = fid_score.get_activations(
127 | files, model, batch_size, dims, cuda, verbose
128 | )
129 |
130 | act_real = torch.tensor(act_real)
131 | act_gen = torch.tensor(act_gen)
132 |
133 | prs.append(precision_recall(act_real, act_gen, k, batch_size=None))
134 |
135 | return prs
136 |
--------------------------------------------------------------------------------
/integrated_cell/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | # from .utils import plots
5 | import matplotlib as mpl
6 |
7 | mpl.use("Agg") # noqa
8 |
9 | import warnings
10 |
11 |
12 | def init_opts(opt, opt_default):
13 | vars_default = vars(opt_default)
14 | for var in vars_default:
15 | if not hasattr(opt, var):
16 | setattr(opt, var, getattr(opt_default, var))
17 | return opt
18 |
19 |
20 | def set_gpu_recursive(var, gpu_id):
21 | for key in var:
22 | if isinstance(var[key], dict):
23 | var[key] = set_gpu_recursive(var[key], gpu_id)
24 | else:
25 | try:
26 | if gpu_id != -1:
27 | var[key] = var[key].cuda(gpu_id)
28 | else:
29 | var[key] = var[key].cpu()
30 | except AttributeError:
31 | pass
32 | return var
33 |
34 |
35 | def sampleUniform(batsize, nlatentdim):
36 | return torch.Tensor(batsize, nlatentdim).uniform_(-1, 1)
37 |
38 |
39 | def sampleGaussian(batsize, nlatentdim):
40 | return torch.Tensor(batsize, nlatentdim).normal_()
41 |
42 |
43 | def tensor2img(img):
44 | warnings.warn(
45 | "integrated_cell.model_utils.tensor2img is depricated. Please use integrated_cell.utils.plots.tensor2im instead."
46 | )
47 | # return plots.tensor2im(img)
48 |
49 |
50 | def load_embeddings(embeddings_path, enc=None, dp=None):
51 |
52 | if os.path.exists(embeddings_path):
53 |
54 | embeddings = torch.load(embeddings_path)
55 | else:
56 | embeddings = get_latent_embeddings(enc, dp)
57 | torch.save(embeddings, embeddings_path)
58 |
59 | return embeddings
60 |
61 |
62 | def get_latent_embeddings(enc, dp):
63 | enc.eval()
64 | gpu_id = enc.gpu_ids[0]
65 |
66 | modes = ("test", "train")
67 |
68 | embedding = dict()
69 |
70 | for mode in modes:
71 | ndat = dp.get_n_dat(mode)
72 | embeddings = torch.zeros(ndat, enc.n_latent_dim)
73 |
74 | inds = list(range(0, ndat))
75 | data_iter = [
76 | inds[i : i + dp.batch_size] # noqa
77 | for i in range(0, len(inds), dp.batch_size)
78 | ]
79 |
80 | for i in range(0, len(data_iter)):
81 | print(str(i) + "/" + str(len(data_iter)))
82 | x = dp.get_images(data_iter[i], mode).cuda(gpu_id)
83 |
84 | with torch.no_grad():
85 | zAll = enc(x)
86 |
87 | embeddings.index_copy_(
88 | 0, torch.LongTensor(data_iter[i]), zAll[-1].data[:].cpu()
89 | )
90 |
91 | embedding[mode] = embeddings
92 |
93 | return embedding
94 |
95 |
96 | def load_state(model, optimizer, path, gpu_id):
97 | # device = torch.device('cpu')
98 |
99 | checkpoint = torch.load(path)
100 |
101 | model.load_state_dict(checkpoint["model"])
102 | optimizer.load_state_dict(checkpoint["optimizer"])
103 |
104 | # model.cuda(gpu_id)
105 |
106 | # optimizer.state = set_gpu_recursive(optimizer.state, gpu_id)
107 |
108 |
109 | def save_state(model, optimizer, path, gpu_id):
110 |
111 | # model = model.cpu()
112 | # optimizer.state = set_gpu_recursive(optimizer.state, -1)
113 |
114 | checkpoint = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
115 | torch.save(checkpoint, path)
116 |
117 | # model = model.cuda(gpu_id)
118 | # optimizer.state = set_gpu_recursive(optimizer.state, gpu_id)
119 |
--------------------------------------------------------------------------------
/integrated_cell/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/models/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/models/ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from . import base_model
4 | from .. import SimpleLogger
5 |
6 | import scipy
7 |
8 | from ..utils.plots import tensor2im
9 | from integrated_cell.utils import plots as plots
10 | from .. import model_utils
11 |
12 | import os
13 | import pickle
14 |
15 | import shutil
16 |
17 |
18 | class Model(base_model.Model):
19 | def __init__(
20 | self, enc, dec, opt_enc, opt_dec, crit_recon, n_display_imgs=10, **kwargs
21 | ):
22 | # Train an autoencoder
23 |
24 | super(Model, self).__init__(**kwargs)
25 |
26 | self.enc = enc
27 | self.dec = dec
28 | self.opt_enc = opt_enc
29 | self.opt_dec = opt_dec
30 |
31 | self.crit_recon = crit_recon
32 |
33 | self.n_display_imgs = n_display_imgs
34 |
35 | logger_path = "{}/logger.pkl".format(self.save_dir)
36 |
37 | if os.path.exists(logger_path):
38 | self.logger = pickle.load(open(logger_path, "rb"))
39 | else:
40 | columns = ("epoch", "iter", "reconLoss", "time")
41 | print_str = "[%d][%d] reconLoss: %.6f time: %.2f"
42 | self.logger = SimpleLogger(columns, print_str)
43 |
44 | def iteration(self):
45 |
46 | torch.cuda.empty_cache()
47 |
48 | enc, dec = self.enc, self.dec
49 | opt_enc, opt_dec = self.opt_enc, self.opt_dec
50 | crit_recon = self.crit_recon
51 |
52 | enc.train(True)
53 | dec.train(True)
54 |
55 | for p in enc.parameters():
56 | p.requires_grad = True
57 |
58 | for p in dec.parameters():
59 | p.requires_grad = True
60 |
61 | # Ignore class labels and reference information
62 | x, _, _ = self.data_provider.get_sample()
63 | x = x.cuda()
64 |
65 | opt_enc.zero_grad()
66 | opt_dec.zero_grad()
67 |
68 | #####################
69 | # train autoencoder
70 | #####################
71 |
72 | # Forward passes
73 | z = enc(x)
74 |
75 | xHat = dec(z)
76 |
77 | # Update the image reconstruction
78 | recon_loss = crit_recon(xHat, x)
79 |
80 | recon_loss.backward()
81 |
82 | opt_enc.step()
83 | opt_dec.step()
84 |
85 | zLatent = z.data.cpu()
86 | recon_loss = recon_loss.item()
87 |
88 | errors = [recon_loss]
89 |
90 | return errors, zLatent
91 |
92 | def save_progress(self):
93 | gpu_id = self.gpu_ids[0]
94 | epoch = self.get_current_epoch()
95 |
96 | data_provider = self.data_provider
97 | enc = self.enc
98 | dec = self.dec
99 |
100 | enc.train(False)
101 | dec.train(False)
102 |
103 | ###############
104 | # TRAINING DATA
105 | ###############
106 | img_inds = np.arange(self.n_display_imgs)
107 |
108 | x, _, _ = data_provider.get_sample("train", img_inds)
109 | x = x.cuda(gpu_id)
110 |
111 | with torch.no_grad():
112 | z = enc(x)
113 | xHat = dec(z)
114 |
115 | imgX = tensor2im(x.data.cpu())
116 | imgXHat = tensor2im(xHat.data.cpu())
117 | imgTrainOut = np.concatenate((imgX, imgXHat), 0)
118 |
119 | ###############
120 | # TESTING DATA
121 | ###############
122 | x, _, _ = data_provider.get_sample("validate", img_inds)
123 | x = x.cuda(gpu_id)
124 |
125 | with torch.no_grad():
126 | z = enc(x)
127 | xHat = dec(z)
128 |
129 | imgX = tensor2im(x.data.cpu())
130 | imgXHat = tensor2im(xHat.data.cpu())
131 |
132 | imgTestOut = np.concatenate((imgX, imgXHat), 0)
133 |
134 | imgOut = np.concatenate((imgTrainOut, imgTestOut))
135 |
136 | scipy.misc.imsave(
137 | "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)), imgOut
138 | )
139 |
140 | embeddings_train = np.concatenate(self.zAll, 0)
141 |
142 | pickle.dump(
143 | embeddings_train, open("{0}/embedding.pth".format(self.save_dir), "wb")
144 | )
145 | pickle.dump(
146 | embeddings_train,
147 | open(
148 | "{0}/embedding_{1}.pth".format(self.save_dir, self.get_current_iter()),
149 | "wb",
150 | ),
151 | )
152 |
153 | pickle.dump(self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))
154 |
155 | # History
156 | plots.history(self.logger, "{0}/history.png".format(self.save_dir))
157 |
158 | # Short History
159 | plots.short_history(self.logger, "{0}/history_short.png".format(self.save_dir))
160 |
161 | # Embedding figure
162 | plots.embeddings(embeddings_train, "{0}/embedding.png".format(self.save_dir))
163 |
164 | enc.train(True)
165 | dec.train(True)
166 |
167 | def save(self, save_dir):
168 | # for saving and loading see:
169 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718
170 |
171 | gpu_id = self.gpu_ids[0]
172 |
173 | n_iters = self.get_current_iter()
174 |
175 | img_embeddings = np.concatenate(self.zAll, 0)
176 | pickle.dump(img_embeddings, open("{0}/embedding.pth".format(save_dir), "wb"))
177 | pickle.dump(
178 | img_embeddings,
179 | open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb"),
180 | )
181 |
182 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir)
183 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters)
184 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir)
185 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters)
186 |
187 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id)
188 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final)
189 |
190 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id)
191 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final)
192 |
193 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb"))
194 | pickle.dump(
195 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb")
196 | )
197 |
--------------------------------------------------------------------------------
/integrated_cell/models/base_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import time
4 |
5 | # This is the base class for trainers
6 | class Model(object):
7 | def __init__(
8 | self,
9 | data_provider,
10 | n_epochs,
11 | gpu_ids,
12 | save_dir,
13 | save_state_iter=1,
14 | save_progress_iter=1,
15 | ):
16 |
17 | self.data_provider = data_provider
18 | self.n_epochs = n_epochs
19 |
20 | self.gpu_ids = gpu_ids
21 |
22 | self.save_dir = save_dir
23 |
24 | self.save_state_iter = save_state_iter
25 | self.save_progress_iter = save_progress_iter
26 |
27 | self.iters_per_epoch = np.ceil(len(data_provider) / data_provider.batch_size)
28 |
29 | self.zAll = list()
30 |
31 | def iteration(self):
32 | raise NotImplementedError
33 |
34 | def get_current_iter(self):
35 | return len(self.logger)
36 |
37 | def get_current_epoch(self, iteration=-1):
38 |
39 | if iteration == -1:
40 | iteration = self.get_current_iter()
41 |
42 | return np.floor(iteration / self.iters_per_epoch)
43 |
44 | def load(self):
45 | # This is where we load the model
46 | raise NotImplementedError
47 |
48 | def save(self, save_dir):
49 | # This is where we save the model
50 | raise NotImplementedError
51 |
52 | def maybe_save(self):
53 |
54 | epoch = self.get_current_epoch(self.get_current_iter() - 1)
55 | epoch_next = self.get_current_epoch(self.get_current_iter())
56 |
57 | saved = False
58 | if epoch != epoch_next:
59 | # save the logger every epoch
60 | pickle.dump(
61 | self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb")
62 | )
63 |
64 | if (epoch_next % self.save_progress_iter) == 0:
65 | print("saving progress")
66 | self.save_progress()
67 |
68 | if (epoch_next % self.save_state_iter) == 0:
69 | print("saving state")
70 | self.save(self.save_dir)
71 |
72 | saved = True
73 |
74 | return saved
75 |
76 | def save_progress(self):
77 | raise NotImplementedError
78 |
79 | def total_iters(self):
80 | return int(np.ceil(self.iters_per_epoch) * self.n_epochs)
81 |
82 | def train(self):
83 | start_iter = self.get_current_iter()
84 |
85 | for this_iter in range(int(start_iter), self.total_iters()):
86 |
87 | start = time.time()
88 |
89 | errors, zLatent = self.iteration()
90 |
91 | stop = time.time()
92 | deltaT = stop - start
93 |
94 | self.logger.add(
95 | [self.get_current_epoch(), self.get_current_iter()] + errors + [deltaT]
96 | )
97 | self.zAll.append(zLatent.data.cpu().detach().numpy())
98 |
99 | if self.maybe_save():
100 | self.zAll = list()
101 |
--------------------------------------------------------------------------------
/integrated_cell/models/cbvaae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import pickle
5 | import shutil
6 |
7 | from .. import utils
8 | from . import cbvae2
9 | from .. import SimpleLogger
10 | from .. import model_utils
11 |
12 |
13 | # conditional beta variational advarsarial auto encoder
14 | #
15 | # This is because sometimes the conditioning results in different latent space embeddings
16 | # and we try to fix that with the advarsary
17 |
18 |
19 | class Model(cbvae2.Model):
20 | def __init__(
21 | self, encD, opt_encD, crit_encD, lambda_encD_loss=1, n_encD_steps=1, **kwargs
22 | ):
23 |
24 | super(Model, self).__init__(**kwargs)
25 |
26 | self.encD = encD
27 | self.opt_encD = opt_encD
28 |
29 | self.crit_encD = crit_encD
30 |
31 | self.lambda_encD_loss = lambda_encD_loss
32 | self.n_encD_steps = n_encD_steps
33 |
34 | logger_path = "{}/logger.pkl".format(self.save_dir)
35 | if os.path.exists(logger_path):
36 | self.logger = pickle.load(open(logger_path, "rb"))
37 | else:
38 | columns = ("epoch", "iter", "reconLoss")
39 | print_str = "[%d][%d] reconLoss: %.6f"
40 |
41 | columns += (
42 | "kldLossRef",
43 | "kldLossStruct",
44 | "minimaxencDLoss",
45 | "encDLoss",
46 | "time",
47 | )
48 | print_str += (
49 | " kld ref: %.6f kld struct: %.6f mmEncD: %.6f encD: %.6f time: %.2f"
50 | )
51 | self.logger = SimpleLogger(columns, print_str)
52 |
53 | def iteration(self):
54 | gpu_id = self.gpu_ids[0]
55 |
56 | enc, dec, encD = self.enc, self.dec, self.encD
57 | opt_enc, opt_dec, opt_encD = self.opt_enc, self.opt_dec, self.opt_encD
58 | crit_recon, crit_encD = self.crit_recon, self.crit_encD
59 |
60 | # do this just incase anything upstream changes these values
61 | enc.train(True)
62 | dec.train(True)
63 | encD.train(True)
64 |
65 | # update the discriminator
66 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x)))
67 |
68 | for p in encD.parameters():
69 | p.requires_grad = True
70 |
71 | for p in enc.parameters():
72 | p.requires_grad = False
73 |
74 | for p in dec.parameters():
75 | p.requires_grad = False
76 |
77 | encD_loss = 0
78 | for step_num in range(self.n_encD_steps):
79 | x, classes, ref = self.data_provider.get_sample()
80 | x = x.cuda(gpu_id)
81 |
82 | classes = classes.type_as(x).long()
83 | classes_onehot = utils.index_to_onehot(
84 | classes, self.data_provider.get_n_classes()
85 | )
86 |
87 | y_xReal = classes
88 | y_xFake = (
89 | torch.zeros(classes.shape)
90 | .fill_(self.data_provider.get_n_classes())
91 | .type_as(x)
92 | .long()
93 | )
94 |
95 | y_zReal = y_xFake
96 | y_zFake = y_xReal
97 |
98 | with torch.no_grad():
99 | zAll = enc(x, classes_onehot)
100 |
101 | zFake = self.reparameterize(zAll[-1][0], zAll[-1][1])
102 | zReal = torch.zeros(zFake.shape).type_as(zFake).normal_()
103 |
104 | opt_enc.zero_grad()
105 | opt_dec.zero_grad()
106 | opt_encD.zero_grad()
107 |
108 | ###############
109 | # train encD
110 | ###############
111 |
112 | # train with real
113 | yHat_zReal = encD(zReal)
114 | errEncD_real = crit_encD(yHat_zReal, y_zReal)
115 |
116 | # train with fake
117 | yHat_zFake = encD(zFake)
118 | errEncD_fake = crit_encD(yHat_zFake, y_zFake)
119 |
120 | encD_loss_tmp = (errEncD_real + errEncD_fake) / 2
121 | encD_loss_tmp.backward()
122 |
123 | opt_encD.step()
124 |
125 | encD_loss += encD_loss_tmp.item()
126 |
127 | for p in enc.parameters():
128 | p.requires_grad = True
129 |
130 | for p in dec.parameters():
131 | p.requires_grad = True
132 |
133 | for p in encD.parameters():
134 | p.requires_grad = False
135 |
136 | opt_enc.zero_grad()
137 | opt_dec.zero_grad()
138 | opt_encD.zero_grad()
139 |
140 | #####################
141 | # train autoencoder
142 | #####################
143 |
144 | # Forward passes
145 | z_ref, z_struct = enc(x, classes_onehot)
146 |
147 | kld_ref = self.kld_loss(z_ref[0], z_ref[1])
148 | kld_struct = self.kld_loss(z_struct[0], z_struct[1])
149 |
150 | kld_loss = kld_ref + kld_struct
151 |
152 | zAll = [z_ref, z_struct]
153 | for i in range(len(zAll)):
154 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1])
155 |
156 | xHat = dec([classes_onehot] + zAll)
157 |
158 | recon_loss = crit_recon(xHat, x)
159 |
160 | if self.lambda_encD_loss > 0:
161 | yHat_zFake = encD(zAll[-1])
162 | minimax_encD_loss = crit_encD(yHat_zFake, y_xReal)
163 | else:
164 | minimax_encD_loss = 0
165 |
166 | beta_vae_loss = (
167 | self.vae_loss(recon_loss, kld_loss)
168 | + self.lambda_encD_loss * minimax_encD_loss
169 | )
170 |
171 | beta_vae_loss.backward()
172 | opt_enc.step()
173 | opt_dec.step()
174 |
175 | # Log a bunch of stuff
176 | recon_loss = recon_loss.item()
177 | minimax_encD_loss = minimax_encD_loss.item()
178 |
179 | kld_loss_ref = kld_ref.item()
180 | kld_loss_struct = kld_struct.item()
181 |
182 | zLatent = z_struct[0].data.cpu()
183 |
184 | errors = [
185 | recon_loss,
186 | kld_loss_ref,
187 | kld_loss_struct,
188 | minimax_encD_loss,
189 | encD_loss,
190 | ]
191 |
192 | return errors, zLatent
193 |
194 | def save(self, save_dir):
195 | # for saving and loading see:
196 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718
197 |
198 | gpu_id = self.gpu_ids[0]
199 |
200 | n_iters = self.get_current_iter()
201 |
202 | embeddings = np.concatenate(self.zAll, 0)
203 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb"))
204 | pickle.dump(
205 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb")
206 | )
207 |
208 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir)
209 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters)
210 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir)
211 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters)
212 |
213 | encD_save_path_tmp = "{0}/encD.pth".format(save_dir)
214 | encD_save_path_final = "{0}/encD_{1}.pth".format(save_dir, n_iters)
215 |
216 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id)
217 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final)
218 |
219 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id)
220 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final)
221 |
222 | model_utils.save_state(self.encD, self.opt_encD, encD_save_path_tmp, gpu_id)
223 | shutil.copyfile(encD_save_path_tmp, encD_save_path_final)
224 |
225 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb"))
226 | pickle.dump(
227 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb")
228 | )
229 |
--------------------------------------------------------------------------------
/integrated_cell/models/cbvae2_gan_model_parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | import shutil
5 |
6 | from .. import utils
7 | from . import cbvae2_gan
8 | from .. import model_utils
9 |
10 |
11 | class Model(cbvae2_gan.Model):
12 | def __init__(self, **kwargs):
13 |
14 | super(Model, self).__init__(**kwargs)
15 |
16 | def iteration(self):
17 |
18 | torch.cuda.empty_cache()
19 |
20 | enc, dec, decD = self.enc, self.dec, self.decD
21 | opt_enc, opt_dec, opt_decD = self.opt_enc, self.opt_dec, self.opt_decD
22 | crit_recon, crit_decD = (self.crit_recon, self.crit_decD)
23 |
24 | # do this just incase anything upstream changes these values
25 | enc.train(True)
26 | dec.train(True)
27 | decD.train(True)
28 |
29 | gpu_enc = self.enc.gpu_ids[0]
30 | gpu_dec = self.dec.gpu_ids[0]
31 | gpu_decD = self.decD.gpu_ids[0]
32 |
33 | # update the discriminator
34 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x)))
35 |
36 | for p in decD.parameters():
37 | p.requires_grad = True
38 |
39 | for p in enc.parameters():
40 | p.requires_grad = False
41 |
42 | for p in dec.parameters():
43 | p.requires_grad = False
44 |
45 | decDLoss = 0
46 | for step_num in range(self.n_decD_steps):
47 | x, classes, ref = self.data_provider.get_sample()
48 | x = x.cuda(gpu_enc)
49 |
50 | classes = classes.type_as(x).long()
51 | classes_onehot = utils.index_to_onehot(
52 | classes, self.data_provider.get_n_classes()
53 | )
54 |
55 | y_xReal = classes.clone().cuda(gpu_decD)
56 | y_xFake = (
57 | torch.zeros(classes.shape)
58 | .fill_(self.data_provider.get_n_classes())
59 | .type_as(x)
60 | .long()
61 | .cuda(gpu_decD)
62 | )
63 | with torch.no_grad():
64 | zAll = enc(x, classes_onehot)
65 |
66 | for i in range(len(zAll)):
67 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1]).cuda(gpu_dec)
68 | zAll[i].detach_()
69 |
70 | with torch.no_grad():
71 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll)
72 |
73 | opt_enc.zero_grad()
74 | opt_dec.zero_grad()
75 | opt_decD.zero_grad()
76 |
77 | ##############
78 | # Train decD
79 | ##############
80 |
81 | # train with real
82 | yHat_xReal = decD(x.cuda(gpu_decD))
83 | errDecD_real = crit_decD(yHat_xReal, y_xReal)
84 |
85 | # train with fake, reconstructed
86 | yHat_xFake = decD(xHat.cuda(gpu_decD))
87 | errDecD_fake = crit_decD(yHat_xFake, y_xFake)
88 |
89 | # train with fake, sampled and decoded
90 | for z in zAll:
91 | z.normal_()
92 |
93 | with torch.no_grad():
94 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll)
95 |
96 | yHat_xFake2 = decD(xHat.cuda(gpu_decD))
97 | errDecD_fake2 = crit_decD(yHat_xFake2, y_xFake)
98 |
99 | decDLoss_tmp = (errDecD_real + (errDecD_fake + errDecD_fake2) / 2) / 2
100 | decDLoss_tmp.backward()
101 |
102 | opt_decD.step()
103 |
104 | decDLoss += decDLoss_tmp.item()
105 |
106 | errDecD_real = None
107 | errDecD_fake = None
108 | errDecD_fake2 = None
109 |
110 | for p in enc.parameters():
111 | p.requires_grad = True
112 |
113 | for p in dec.parameters():
114 | p.requires_grad = True
115 |
116 | for p in decD.parameters():
117 | p.requires_grad = False
118 |
119 | opt_enc.zero_grad()
120 | opt_dec.zero_grad()
121 | opt_decD.zero_grad()
122 |
123 | #####################
124 | # train autoencoder
125 | #####################
126 | torch.cuda.empty_cache()
127 |
128 | # Forward passes
129 | z_ref, z_struct = enc(x, classes_onehot)
130 |
131 | kld_ref = self.kld_loss(z_ref[0], z_ref[1])
132 | kld_struct = self.kld_loss(z_struct[0], z_struct[1])
133 |
134 | kld_loss = kld_ref + kld_struct
135 |
136 | kld_loss_ref = kld_ref.item()
137 | kld_loss_struct = kld_struct.item()
138 |
139 | zLatent = z_struct[0].data.cpu()
140 |
141 | zAll = [z_ref, z_struct]
142 | for i in range(len(zAll)):
143 | zAll[i] = self.reparameterize(zAll[i][0], zAll[i][1]).cuda(gpu_dec)
144 |
145 | xHat = dec([classes_onehot.cuda(gpu_dec)] + zAll)
146 |
147 | # resample from the structure space and make sure that the reference channel stays the same
148 | # shuffle_inds = torch.randperm(x.shape[0])
149 | # zAll[-1].normal_()
150 | # xHat2 = dec([classes_onehot[shuffle_inds]] + zAll)
151 |
152 | # Update the image reconstruction
153 | recon_loss = crit_recon(xHat.cuda(gpu_enc), x)
154 |
155 | beta_vae_loss = self.vae_loss(recon_loss, kld_loss)
156 |
157 | beta_vae_loss.backward(retain_graph=True)
158 |
159 | recon_loss = recon_loss.item()
160 |
161 | opt_enc.step()
162 |
163 | if self.lambda_decD_loss > 0:
164 | for p in enc.parameters():
165 | p.requires_grad = False
166 |
167 | # update wrt decD(dec(enc(X)))
168 | yHat_xFake = decD(xHat.cuda(gpu_decD))
169 | minimaxDecDLoss = crit_decD(yHat_xFake, y_xReal)
170 |
171 | shuffle_inds = torch.randperm(x.shape[0])
172 | xHat = dec([classes_onehot[shuffle_inds].cuda(gpu_dec)] + zAll)
173 |
174 | yHat_xFake2 = decD(xHat.cuda(gpu_decD))
175 | minimaxDecDLoss2 = crit_decD(yHat_xFake2, y_xReal[shuffle_inds])
176 |
177 | minimaxDecLoss = (minimaxDecDLoss + minimaxDecDLoss2) / 2
178 | minimaxDecLoss.mul(self.lambda_decD_loss).backward()
179 | minimaxDecLoss = minimaxDecLoss.item()
180 | else:
181 | minimaxDecLoss = 0
182 |
183 | opt_dec.step()
184 |
185 | errors = [recon_loss, kld_loss_ref, kld_loss_struct, minimaxDecLoss, decDLoss]
186 |
187 | return errors, zLatent
188 |
189 | def save(self, save_dir):
190 | # for saving and loading see:
191 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718
192 |
193 | gpu_id = self.gpu_ids[0]
194 |
195 | n_iters = self.get_current_iter()
196 |
197 | embeddings = np.concatenate(self.zAll, 0)
198 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb"))
199 | pickle.dump(
200 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb")
201 | )
202 |
203 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir)
204 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters)
205 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir)
206 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters)
207 |
208 | decD_save_path_tmp = "{0}/decD.pth".format(save_dir)
209 | decD_save_path_final = "{0}/decD_{1}.pth".format(save_dir, n_iters)
210 |
211 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id)
212 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final)
213 |
214 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id)
215 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final)
216 |
217 | model_utils.save_state(self.decD, self.opt_decD, decD_save_path_tmp, gpu_id)
218 | shutil.copyfile(decD_save_path_tmp, decD_save_path_final)
219 |
220 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb"))
221 | pickle.dump(
222 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb")
223 | )
224 |
--------------------------------------------------------------------------------
/integrated_cell/models/cbvae_apex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .. import utils
3 | from . import cbvae
4 | from .bvae import reparameterize, kl_divergence
5 |
6 | from apex import amp
7 |
8 |
9 | class Model(cbvae.Model):
10 | def __init__(self, opt_level="O1", **kwargs):
11 |
12 | super(Model, self).__init__(**kwargs)
13 |
14 | is_data_parallel = False
15 | # do some hack if set up for DataParallel
16 | if (
17 | str(self.enc.__class__)
18 | == ""
19 | ):
20 | is_data_parallel = True
21 | device_ids_enc = self.enc.device_ids
22 | device_ids_dec = self.dec.device_ids
23 |
24 | self.enc = self.enc.module
25 | self.dec = self.dec.module
26 |
27 | [self.enc, self.dec], [self.opt_enc, self.opt_dec] = amp.initialize(
28 | [self.enc, self.dec], [self.opt_enc, self.opt_dec], opt_level=opt_level
29 | )
30 |
31 | if is_data_parallel:
32 | self.enc = torch.nn.DataParallel(self.enc, device_ids_enc)
33 | self.dec = torch.nn.DataParallel(self.dec, device_ids_dec)
34 |
35 | def iteration(self):
36 |
37 | torch.cuda.empty_cache()
38 |
39 | gpu_id = self.gpu_ids[0]
40 |
41 | enc, dec = self.enc, self.dec
42 | opt_enc, opt_dec = self.opt_enc, self.opt_dec
43 | crit_recon = self.crit_recon
44 |
45 | # do this just incase anything upstream changes these values
46 | enc.train(True)
47 | dec.train(True)
48 |
49 | x, classes, ref = self.data_provider.get_sample()
50 | x = x.cuda(gpu_id)
51 |
52 | classes = classes.type_as(x).long()
53 | classes_onehot = utils.index_to_onehot(
54 | classes, self.data_provider.get_n_classes()
55 | )
56 |
57 | for p in enc.parameters():
58 | p.requires_grad = True
59 |
60 | for p in dec.parameters():
61 | p.requires_grad = True
62 |
63 | opt_enc.zero_grad()
64 | opt_dec.zero_grad()
65 |
66 | #####################
67 | # train autoencoder
68 | #####################
69 |
70 | # Forward passes
71 | z_ref, z_struct = enc(x, classes_onehot)
72 |
73 | kld_ref, _, _ = kl_divergence(z_ref[0], z_ref[1])
74 | kld_struct, _, _ = kl_divergence(z_struct[0], z_struct[1])
75 |
76 | kld = kld_ref + kld_struct
77 |
78 | kld_loss_ref = kld_ref.item()
79 | kld_loss_struct = kld_struct.item()
80 |
81 | zLatent = z_struct[0].data.cpu()
82 |
83 | zAll = [z_ref, z_struct]
84 | for i in range(len(zAll)):
85 | zAll[i] = reparameterize(zAll[i][0], zAll[i][1])
86 |
87 | xHat = dec([classes_onehot] + zAll)
88 |
89 | # Update the image reconstruction
90 | recon_loss = crit_recon(xHat, x)
91 |
92 | if self.objective == "H":
93 | beta_vae_loss = recon_loss + self.beta * kld
94 | elif self.objective == "H_eps":
95 | beta_vae_loss = (
96 | recon_loss
97 | + torch.abs((self.beta * kld_ref) - x.shape[0] * 0.1)
98 | + torch.abs((self.beta * kld_struct) - x.shape[0] * 0.1)
99 | )
100 | elif self.objective == "B":
101 | C = torch.clamp(
102 | torch.Tensor(
103 | [self.c_max / self.c_iters_max * len(self.logger)]
104 | ).type_as(x),
105 | 0,
106 | self.c_max,
107 | )
108 | beta_vae_loss = recon_loss + self.gamma * (kld - C).abs()
109 |
110 | elif self.objective == "B_eps":
111 | C = torch.clamp(
112 | torch.Tensor(
113 | [self.c_max / self.c_iters_max * len(self.logger)]
114 | ).type_as(x),
115 | 0,
116 | self.c_max,
117 | )
118 | beta_vae_loss = recon_loss + self.gamma * (kld - C).abs()
119 |
120 | elif self.objective == "A":
121 | # warmup mode
122 | beta_mult = self.beta_start + self.beta_step * self.get_current_iter()
123 | if beta_mult > self.beta_max:
124 | beta_mult = self.beta_max
125 |
126 | if beta_mult < self.beta_min:
127 | beta_mult = self.beta_min
128 |
129 | beta_vae_loss = recon_loss + beta_mult * kld
130 |
131 | with amp.scale_loss(beta_vae_loss, [opt_enc, opt_dec]) as scaled_loss:
132 | scaled_loss.backward()
133 |
134 | recon_loss = recon_loss.item()
135 |
136 | opt_enc.step()
137 | opt_dec.step()
138 |
139 | errors = [recon_loss, kld_loss_ref, kld_loss_struct]
140 |
141 | return errors, zLatent
142 |
--------------------------------------------------------------------------------
/integrated_cell/models/cbvaegan_target.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import pickle
5 | import shutil
6 |
7 | from .. import utils
8 | from . import cbvae2_target
9 | from .. import SimpleLogger
10 | from .. import model_utils
11 |
12 |
13 | class Model(cbvae2_target.Model):
14 | def __init__(
15 | self, decD, opt_decD, crit_decD, lambda_decD_loss=1, n_decD_steps=1, **kwargs
16 | ):
17 |
18 | super(Model, self).__init__(**kwargs)
19 |
20 | self.decD = decD
21 | self.opt_decD = opt_decD
22 |
23 | self.crit_decD = crit_decD
24 |
25 | self.lambda_decD_loss = lambda_decD_loss
26 | self.n_decD_steps = n_decD_steps
27 |
28 | logger_path = "{}/logger.pkl".format(self.save_dir)
29 | if os.path.exists(logger_path):
30 | self.logger = pickle.load(open(logger_path, "rb"))
31 | else:
32 | columns = ("epoch", "iter", "reconLoss")
33 | print_str = "[%d][%d] reconLoss: %.6f"
34 |
35 | columns += ("kldLoss", "minimaxDecDLoss", "decDLoss", "time")
36 | print_str += " kld: %.6f mmDecD: %.6f decD: %.6f time: %.2f"
37 | self.logger = SimpleLogger(columns, print_str)
38 |
39 | def iteration(self):
40 |
41 | torch.cuda.empty_cache()
42 |
43 | gpu_id = self.gpu_ids[0]
44 |
45 | enc, dec, decD = self.enc, self.dec, self.decD
46 | opt_enc, opt_dec, opt_decD = self.opt_enc, self.opt_dec, self.opt_decD
47 | crit_recon, crit_decD = (self.crit_recon, self.crit_decD)
48 |
49 | # do this just incase anything upstream changes these values
50 | enc.train(True)
51 | dec.train(True)
52 | decD.train(True)
53 |
54 | # update the discriminator
55 | # maximize log(AdvZ(z)) + log(1 - AdvZ(Enc(x)))
56 |
57 | for p in decD.parameters():
58 | p.requires_grad = True
59 |
60 | for p in enc.parameters():
61 | p.requires_grad = False
62 |
63 | for p in dec.parameters():
64 | p.requires_grad = False
65 |
66 | decDLoss = 0
67 | for step_num in range(self.n_decD_steps):
68 | x, classes, ref = self.data_provider.get_sample()
69 | x = x.cuda(gpu_id)
70 |
71 | classes = classes.type_as(x).long()
72 | classes_onehot = utils.index_to_onehot(
73 | classes, self.data_provider.get_n_classes()
74 | )
75 |
76 | ref = ref.cuda(gpu_id)
77 |
78 | y_xReal = torch.ones(classes.shape[0], 1).type_as(x)
79 | y_xFake = torch.zeros(classes.shape[0], 1).type_as(x)
80 |
81 | with torch.no_grad():
82 | mu, logsigma = enc(x, ref, classes_onehot)
83 | z = self.reparameterize(mu, logsigma)
84 | xHat = dec(z, ref, classes_onehot)
85 |
86 | opt_enc.zero_grad()
87 | opt_dec.zero_grad()
88 | opt_decD.zero_grad()
89 |
90 | ##############
91 | # Train decD
92 | ##############
93 |
94 | # train with real
95 | yHat_xReal = decD(x, ref, classes_onehot)
96 | errDecD_real = crit_decD(yHat_xReal, y_xReal)
97 |
98 | # train with fake, reconstructed
99 | yHat_xFake = decD(xHat, ref, classes_onehot)
100 | errDecD_fake = crit_decD(yHat_xFake, y_xFake)
101 |
102 | # train with fake, sampled and decoded
103 | z.normal_()
104 | shuffle_inds = torch.randperm(x.shape[0])
105 |
106 | with torch.no_grad():
107 | xHat = dec(z, ref, classes_onehot[shuffle_inds])
108 |
109 | yHat_xFake2 = decD(x, ref, classes_onehot[shuffle_inds])
110 | errDecD_fake2 = crit_decD(yHat_xFake2, y_xFake)
111 |
112 | decDLoss_tmp = (errDecD_real + (errDecD_fake + errDecD_fake2) / 2) / 2
113 | decDLoss_tmp.backward()
114 |
115 | opt_decD.step()
116 |
117 | decDLoss += decDLoss_tmp.item()
118 |
119 | errDecD_real = None
120 | errDecD_fake = None
121 | errDecD_fake2 = None
122 |
123 | for p in enc.parameters():
124 | p.requires_grad = True
125 |
126 | for p in dec.parameters():
127 | p.requires_grad = True
128 |
129 | for p in decD.parameters():
130 | p.requires_grad = False
131 |
132 | opt_enc.zero_grad()
133 | opt_dec.zero_grad()
134 | opt_decD.zero_grad()
135 |
136 | #####################
137 | # train autoencoder
138 | #####################
139 | torch.cuda.empty_cache()
140 |
141 | # Forward passes
142 | mu, logsigma = enc(x, ref, classes_onehot)
143 |
144 | z = self.reparameterize(mu, logsigma)
145 |
146 | kld_loss = self.kld_loss(mu, logsigma)
147 |
148 | xHat = dec(z, ref, classes_onehot)
149 |
150 | recon_loss = crit_recon(xHat, x)
151 | beta_vae_loss = self.vae_loss(recon_loss, kld_loss)
152 | beta_vae_loss.backward(retain_graph=True)
153 |
154 | recon_loss = recon_loss.item()
155 | kld_loss = kld_loss.item()
156 | zLatent = mu.data.cpu()
157 |
158 | opt_enc.step()
159 |
160 | if self.lambda_decD_loss > 0:
161 | for p in enc.parameters():
162 | p.requires_grad = False
163 |
164 | # update wrt reconstructed images
165 | yHat_xFake = decD(xHat, ref, classes_onehot)
166 | minimaxDecDLoss = crit_decD(yHat_xFake, y_xReal)
167 |
168 | shuffle_inds = torch.randperm(x.shape[0])
169 |
170 | # update wrt generated imges
171 | z = torch.Tensor(z.shape).normal_().type_as(x)
172 | xHat = dec(z, ref, classes_onehot[shuffle_inds])
173 |
174 | yHat_xFake2 = decD(xHat, ref, classes_onehot[shuffle_inds])
175 | minimaxDecDLoss2 = crit_decD(yHat_xFake2, y_xReal)
176 |
177 | minimaxDecDLoss = (minimaxDecDLoss + minimaxDecDLoss2) / 2
178 | minimaxDecDLoss.mul(self.lambda_decD_loss).backward()
179 | minimaxDecDLoss = minimaxDecDLoss.item()
180 | else:
181 | minimaxDecDLoss = 0
182 |
183 | opt_dec.step()
184 |
185 | errors = [recon_loss, kld_loss, minimaxDecDLoss, decDLoss]
186 |
187 | return errors, zLatent
188 |
189 | def save(self, save_dir):
190 | # for saving and loading see:
191 | # https://discuss.pytorch.org/t/how-to-save-load-torch-models/718
192 |
193 | gpu_id = self.gpu_ids[0]
194 |
195 | n_iters = self.get_current_iter()
196 |
197 | embeddings = np.concatenate(self.zAll, 0)
198 | pickle.dump(embeddings, open("{0}/embedding.pth".format(save_dir), "wb"))
199 | pickle.dump(
200 | embeddings, open("{0}/embedding_{1}.pth".format(save_dir, n_iters), "wb")
201 | )
202 |
203 | enc_save_path_tmp = "{0}/enc.pth".format(save_dir)
204 | enc_save_path_final = "{0}/enc_{1}.pth".format(save_dir, n_iters)
205 | dec_save_path_tmp = "{0}/dec.pth".format(save_dir)
206 | dec_save_path_final = "{0}/dec_{1}.pth".format(save_dir, n_iters)
207 |
208 | decD_save_path_tmp = "{0}/decD.pth".format(save_dir)
209 | decD_save_path_final = "{0}/decD_{1}.pth".format(save_dir, n_iters)
210 |
211 | model_utils.save_state(self.enc, self.opt_enc, enc_save_path_tmp, gpu_id)
212 | shutil.copyfile(enc_save_path_tmp, enc_save_path_final)
213 |
214 | model_utils.save_state(self.dec, self.opt_dec, dec_save_path_tmp, gpu_id)
215 | shutil.copyfile(dec_save_path_tmp, dec_save_path_final)
216 |
217 | model_utils.save_state(self.decD, self.opt_decD, decD_save_path_tmp, gpu_id)
218 | shutil.copyfile(decD_save_path_tmp, decD_save_path_final)
219 |
220 | pickle.dump(self.logger, open("{0}/logger.pkl".format(save_dir), "wb"))
221 | pickle.dump(
222 | self.logger, open("{0}/logger_{1}.pkl".format(save_dir, n_iters), "wb")
223 | )
224 |
--------------------------------------------------------------------------------
/integrated_cell/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/networks/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/networks/ae2D_residual.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import numpy as np
3 |
4 | from ..utils import spectral_norm
5 | from ..utils import get_activation
6 |
7 | # An autoencoder for 2D images using modified residual layers
8 | # This is effectively a 2D simplified version of vaegan3D_cgan_target2.py
9 |
10 |
11 | class PadLayer(nn.Module):
12 | def __init__(self, pad_dims):
13 | super(PadLayer, self).__init__()
14 |
15 | self.pad_dims = pad_dims
16 |
17 | def forward(self, x):
18 | if np.sum(self.pad_dims) == 0:
19 | return x
20 | else:
21 | return nn.functional.pad(
22 | x,
23 | [0, self.pad_dims[2], 0, self.pad_dims[1], 0, self.pad_dims[0]],
24 | "constant",
25 | 0,
26 | )
27 |
28 |
29 | class DownLayerResidual(nn.Module):
30 | def __init__(self, ch_in, ch_out, activation="relu", activation_last=None):
31 | super(DownLayerResidual, self).__init__()
32 |
33 | if activation_last is None:
34 | activation_last = activation
35 |
36 | self.bypass = nn.Sequential(
37 | nn.AvgPool2d(2, stride=2, padding=0),
38 | spectral_norm(nn.Conv2d(ch_in, ch_out, 1, 1, padding=0, bias=True)),
39 | )
40 |
41 | self.resid = nn.Sequential(
42 | spectral_norm(nn.Conv2d(ch_in, ch_in, 4, 2, padding=1, bias=True)),
43 | nn.BatchNorm2d(ch_in),
44 | get_activation(activation),
45 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, padding=1, bias=True)),
46 | nn.BatchNorm2d(ch_out),
47 | )
48 |
49 | self.activation = get_activation(activation_last)
50 |
51 | def forward(self, x, x_proj=None, x_class=None):
52 |
53 | x = self.bypass(x) + self.resid(x)
54 |
55 | x = self.activation(x)
56 |
57 | return x
58 |
59 |
60 | class UpLayerResidual(nn.Module):
61 | def __init__(
62 | self, ch_in, ch_out, activation="relu", output_padding=0, activation_last=None
63 | ):
64 | super(UpLayerResidual, self).__init__()
65 |
66 | if activation_last is None:
67 | activation_last = activation
68 |
69 | self.bypass = nn.Sequential(
70 | spectral_norm(nn.Conv2d(ch_in, ch_out, 1, 1, padding=0, bias=True)),
71 | nn.Upsample(scale_factor=2),
72 | PadLayer(output_padding),
73 | )
74 |
75 | self.resid = nn.Sequential(
76 | spectral_norm(
77 | nn.ConvTranspose2d(
78 | ch_in,
79 | ch_in,
80 | 4,
81 | 2,
82 | padding=1,
83 | output_padding=output_padding,
84 | bias=True,
85 | )
86 | ),
87 | nn.BatchNorm2d(ch_in),
88 | get_activation(activation),
89 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, padding=1, bias=True)),
90 | nn.BatchNorm2d(ch_out),
91 | )
92 |
93 | self.activation = get_activation(activation_last)
94 |
95 | def forward(self, x):
96 | x = self.bypass(x) + self.resid(x)
97 |
98 | x = self.activation(x)
99 |
100 | return x
101 |
102 |
103 | class Enc(nn.Module):
104 | def __init__(
105 | self,
106 | n_latent_dim,
107 | gpu_ids,
108 | n_ch=3,
109 | conv_channels_list=[32, 64, 128, 256, 512],
110 | imsize_compressed=[5, 3],
111 | ):
112 | super(Enc, self).__init__()
113 |
114 | self.gpu_ids = gpu_ids
115 | self.n_latent_dim = n_latent_dim
116 |
117 | self.target_path = nn.ModuleList(
118 | [DownLayerResidual(n_ch, conv_channels_list[0])]
119 | )
120 |
121 | for ch_in, ch_out in zip(conv_channels_list[0:-1], conv_channels_list[1:]):
122 | self.target_path.append(DownLayerResidual(ch_in, ch_out))
123 |
124 | ch_in = ch_out
125 |
126 | self.latent_out = spectral_norm(
127 | nn.Linear(
128 | ch_in * int(np.prod(imsize_compressed)), self.n_latent_dim, bias=True
129 | )
130 | )
131 |
132 | def forward(self, x_target):
133 |
134 | for target_path in self.target_path:
135 | x_target = target_path(x_target)
136 |
137 | x_target = x_target.view(x_target.size()[0], -1)
138 |
139 | z = self.latent_out(x_target)
140 |
141 | return z
142 |
143 |
144 | class Dec(nn.Module):
145 | def __init__(
146 | self,
147 | n_latent_dim,
148 | gpu_ids,
149 | padding_latent=[0, 0],
150 | imsize_compressed=[5, 3],
151 | n_ch=3,
152 | conv_channels_list=[512, 256, 128, 64, 32],
153 | activation_last="sigmoid",
154 | ):
155 |
156 | super(Dec, self).__init__()
157 |
158 | self.gpu_ids = gpu_ids
159 | self.padding_latent = padding_latent
160 | self.imsize_compressed = imsize_compressed
161 |
162 | self.ch_first = conv_channels_list[0]
163 |
164 | self.n_latent_dim = n_latent_dim
165 |
166 | self.n_channels = n_ch
167 |
168 | self.target_fc = spectral_norm(
169 | nn.Linear(
170 | self.n_latent_dim,
171 | conv_channels_list[0] * int(np.prod(self.imsize_compressed)),
172 | bias=True,
173 | )
174 | )
175 |
176 | self.target_bn_relu = nn.Sequential(
177 | nn.BatchNorm2d(conv_channels_list[0]), nn.ReLU(inplace=True)
178 | )
179 |
180 | self.target_path = nn.ModuleList([])
181 |
182 | l_sizes = conv_channels_list
183 | for i in range(len(l_sizes) - 1):
184 | if i == 0:
185 | padding = padding_latent
186 | else:
187 | padding = 0
188 |
189 | self.target_path.append(
190 | UpLayerResidual(l_sizes[i], l_sizes[i + 1], output_padding=padding)
191 | )
192 |
193 | self.target_path.append(
194 | UpLayerResidual(l_sizes[i + 1], n_ch, activation_last=activation_last)
195 | )
196 |
197 | def forward(self, z_target):
198 | # gpu_ids = None
199 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
200 |
201 | x_target = self.target_fc(z_target).view(
202 | z_target.size()[0],
203 | self.ch_first,
204 | self.imsize_compressed[0],
205 | self.imsize_compressed[1],
206 | )
207 | x_target = self.target_bn_relu(x_target)
208 |
209 | for target_path in self.target_path:
210 |
211 | x_target = target_path(x_target)
212 |
213 | return x_target
214 |
--------------------------------------------------------------------------------
/integrated_cell/networks/ae3D_residual.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import numpy as np
3 |
4 | from ..utils import spectral_norm
5 | from ..utils import get_activation
6 |
7 | # An autoencoder for 3D images using modified residual layers
8 | # This is effectively a simplified version of vaegan3D_cgan_target2.py
9 |
10 |
11 | class PadLayer(nn.Module):
12 | def __init__(self, pad_dims):
13 | super(PadLayer, self).__init__()
14 |
15 | self.pad_dims = pad_dims
16 |
17 | def forward(self, x):
18 | if np.sum(self.pad_dims) == 0:
19 | return x
20 | else:
21 | return nn.functional.pad(
22 | x,
23 | [0, self.pad_dims[2], 0, self.pad_dims[1], 0, self.pad_dims[0]],
24 | "constant",
25 | 0,
26 | )
27 |
28 |
29 | class DownLayerResidual(nn.Module):
30 | def __init__(self, ch_in, ch_out, activation="relu", activation_last=None):
31 | super(DownLayerResidual, self).__init__()
32 |
33 | if activation_last is None:
34 | activation_last = activation
35 |
36 | self.bypass = nn.Sequential(
37 | nn.AvgPool3d(2, stride=2, padding=0),
38 | spectral_norm(nn.Conv3d(ch_in, ch_out, 1, 1, padding=0, bias=True)),
39 | )
40 |
41 | self.resid = nn.Sequential(
42 | spectral_norm(nn.Conv3d(ch_in, ch_in, 4, 2, padding=1, bias=True)),
43 | nn.BatchNorm3d(ch_in),
44 | get_activation(activation),
45 | spectral_norm(nn.Conv3d(ch_in, ch_out, 3, 1, padding=1, bias=True)),
46 | nn.BatchNorm3d(ch_out),
47 | )
48 |
49 | self.activation = get_activation(activation_last)
50 |
51 | def forward(self, x, x_proj=None, x_class=None):
52 |
53 | x = self.bypass(x) + self.resid(x)
54 |
55 | x = self.activation(x)
56 |
57 | return x
58 |
59 |
60 | class UpLayerResidual(nn.Module):
61 | def __init__(
62 | self, ch_in, ch_out, activation="relu", output_padding=0, activation_last=None
63 | ):
64 | super(UpLayerResidual, self).__init__()
65 |
66 | if activation_last is None:
67 | activation_last = activation
68 |
69 | self.bypass = nn.Sequential(
70 | spectral_norm(nn.Conv3d(ch_in, ch_out, 1, 1, padding=0, bias=True)),
71 | nn.Upsample(scale_factor=2),
72 | PadLayer(output_padding),
73 | )
74 |
75 | self.resid = nn.Sequential(
76 | spectral_norm(
77 | nn.ConvTranspose3d(
78 | ch_in,
79 | ch_in,
80 | 4,
81 | 2,
82 | padding=1,
83 | output_padding=output_padding,
84 | bias=True,
85 | )
86 | ),
87 | nn.BatchNorm3d(ch_in),
88 | get_activation(activation),
89 | spectral_norm(nn.Conv3d(ch_in, ch_out, 3, 1, padding=1, bias=True)),
90 | nn.BatchNorm3d(ch_out),
91 | )
92 |
93 | self.activation = get_activation(activation_last)
94 |
95 | def forward(self, x):
96 | x = self.bypass(x) + self.resid(x)
97 |
98 | x = self.activation(x)
99 |
100 | return x
101 |
102 |
103 | class Enc(nn.Module):
104 | def __init__(
105 | self,
106 | n_latent_dim,
107 | gpu_ids,
108 | n_ch=3,
109 | conv_channels_list=[32, 64, 128, 256, 512],
110 | imsize_compressed=[5, 3, 2],
111 | ):
112 | super(Enc, self).__init__()
113 |
114 | self.gpu_ids = gpu_ids
115 | self.n_latent_dim = n_latent_dim
116 |
117 | self.target_path = nn.ModuleList(
118 | [DownLayerResidual(n_ch, conv_channels_list[0])]
119 | )
120 |
121 | for ch_in, ch_out in zip(conv_channels_list[0:-1], conv_channels_list[1:]):
122 | self.target_path.append(DownLayerResidual(ch_in, ch_out))
123 |
124 | ch_in = ch_out
125 |
126 | self.latent_out = spectral_norm(
127 | nn.Linear(
128 | ch_in * int(np.prod(imsize_compressed)), self.n_latent_dim, bias=True
129 | )
130 | )
131 |
132 | def forward(self, x_target):
133 |
134 | for target_path in self.target_path:
135 | x_target = target_path(x_target)
136 |
137 | x_target = x_target.view(x_target.size()[0], -1)
138 |
139 | z = self.latent_out(x_target)
140 |
141 | return z
142 |
143 |
144 | class Dec(nn.Module):
145 | def __init__(
146 | self,
147 | n_latent_dim,
148 | gpu_ids,
149 | padding_latent=[0, 0, 0],
150 | imsize_compressed=[5, 3, 2],
151 | n_ch=3,
152 | conv_channels_list=[512, 256, 128, 64, 32],
153 | activation_last="sigmoid",
154 | ):
155 |
156 | super(Dec, self).__init__()
157 |
158 | self.gpu_ids = gpu_ids
159 | self.padding_latent = padding_latent
160 | self.imsize_compressed = imsize_compressed
161 |
162 | self.ch_first = conv_channels_list[0]
163 |
164 | self.n_latent_dim = n_latent_dim
165 |
166 | self.n_channels = n_ch
167 |
168 | self.target_fc = spectral_norm(
169 | nn.Linear(
170 | self.n_latent_dim,
171 | conv_channels_list[0] * int(np.prod(self.imsize_compressed)),
172 | bias=True,
173 | )
174 | )
175 |
176 | self.target_bn_relu = nn.Sequential(
177 | nn.BatchNorm3d(conv_channels_list[0]), nn.ReLU(inplace=True)
178 | )
179 |
180 | self.target_path = nn.ModuleList([])
181 |
182 | l_sizes = conv_channels_list
183 | for i in range(len(l_sizes) - 1):
184 | if i == 0:
185 | padding = padding_latent
186 | else:
187 | padding = 0
188 |
189 | self.target_path.append(
190 | UpLayerResidual(l_sizes[i], l_sizes[i + 1], output_padding=padding)
191 | )
192 |
193 | self.target_path.append(
194 | UpLayerResidual(l_sizes[i + 1], n_ch, activation_last=activation_last)
195 | )
196 |
197 | def forward(self, z_target):
198 | # gpu_ids = None
199 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
200 |
201 | x_target = self.target_fc(z_target).view(
202 | z_target.size()[0],
203 | self.ch_first,
204 | self.imsize_compressed[0],
205 | self.imsize_compressed[1],
206 | self.imsize_compressed[2],
207 | )
208 | x_target = self.target_bn_relu(x_target)
209 |
210 | for target_path in self.target_path:
211 |
212 | x_target = target_path(x_target)
213 |
214 | return x_target
215 |
--------------------------------------------------------------------------------
/integrated_cell/networks/old/aaegan_compact.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import pdb
4 | from model_utils import init_opts
5 |
6 | ksize = 4
7 | dstep = 2
8 |
9 | class Enc(nn.Module):
10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
11 | super(Enc, self).__init__()
12 |
13 | self.gpu_ids = gpu_ids
14 | self.fcsize = insize/64
15 |
16 | self.nLatentDim = nLatentDim
17 | self.nClasses = nClasses
18 | self.nRef = nRef
19 |
20 | self.main = nn.Sequential(
21 | nn.Conv2d(nch, 64, ksize, dstep, 1),
22 | nn.BatchNorm2d(64),
23 |
24 | nn.PReLU(),
25 | nn.Conv2d(64, 128, ksize, dstep, 1),
26 | nn.BatchNorm2d(128),
27 |
28 | nn.PReLU(),
29 | nn.Conv2d(128, 256, ksize, dstep, 1),
30 | nn.BatchNorm2d(256),
31 |
32 | nn.PReLU(),
33 | nn.Conv2d(256, 512, ksize, dstep, 1),
34 | nn.BatchNorm2d(512),
35 |
36 | nn.PReLU(),
37 | nn.Conv2d(512, 1024, ksize, dstep, 1),
38 | nn.BatchNorm2d(1024),
39 |
40 | nn.PReLU(),
41 | nn.Conv2d(1024, 1024, ksize, dstep, 1),
42 | nn.BatchNorm2d(1024),
43 |
44 | nn.PReLU()
45 | )
46 |
47 | # discriminator output
48 | self.discrim = nn.Sequential(
49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1),
50 | # nn.BatchNorm1d(self.nClasses+1),
51 | )
52 |
53 | if self.nClasses > 0:
54 | self.discrim.add_module('end', nn.LogSoftmax())
55 | else:
56 | self.discrim.add_module('end', nn.Sigmoid())
57 |
58 | if self.nClasses > 0:
59 | self.classOut = nn.Sequential(
60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses),
61 | nn.BatchNorm1d(self.nClasses),
62 | nn.LogSoftmax()
63 | )
64 |
65 | if self.nRef > 0:
66 | self.refOut = nn.Sequential(
67 | nn.Linear(1024*int(self.fcsize**2), self.nRef),
68 | nn.BatchNorm1d(self.nRef)
69 | )
70 |
71 | if self.nLatentDim > 0:
72 | self.latentOut = nn.Sequential(
73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim),
74 | nn.BatchNorm1d(self.nLatentDim)
75 | )
76 |
77 | def forward(self, x):
78 | # gpu_ids = None
79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
80 | gpu_ids = self.gpu_ids
81 |
82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2))
84 |
85 | xOut = list()
86 |
87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids)
88 | # xOut.append(xDiscrim)
89 |
90 | if self.nClasses > 0:
91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids)
92 | xOut.append(xClasses)
93 |
94 | if self.nRef > 0:
95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids)
96 | xOut.append(xRef)
97 |
98 | if self.nLatentDim > 0:
99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids)
100 | xOut.append(xLatent)
101 |
102 | return xOut, xDiscrim
103 |
104 | class Dec(nn.Module):
105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
106 | super(Dec, self).__init__()
107 |
108 | self.gpu_ids = gpu_ids
109 | self.fcsize = int(insize/64)
110 |
111 | self.nLatentDim = nLatentDim
112 | self.nClasses = nClasses
113 | self.nRef = nRef
114 |
115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2))
116 |
117 | self.main = nn.Sequential(
118 | # nn.BatchNorm2d(1024),
119 |
120 | nn.PReLU(),
121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1),
122 | nn.BatchNorm2d(1024),
123 |
124 | nn.PReLU(),
125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1),
126 | nn.BatchNorm2d(512),
127 |
128 | nn.PReLU(),
129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1),
130 | nn.BatchNorm2d(256),
131 |
132 | nn.PReLU(),
133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1),
134 | nn.BatchNorm2d(128),
135 |
136 | nn.PReLU(),
137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1),
138 | nn.BatchNorm2d(64),
139 |
140 | nn.PReLU(),
141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1),
142 | # nn.BatchNorm2d(nch),
143 | nn.Sigmoid()
144 | )
145 |
146 | def forward(self, xIn):
147 | # gpu_ids = None
148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
149 | gpu_ids = self.gpu_ids
150 |
151 | x = torch.cat(xIn, 1)
152 |
153 | x = self.fc(x)
154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize)
155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
156 |
157 | return x
158 |
159 | class EncD(nn.Module):
160 | def __init__(self, nlatentdim, gpu_ids, opt=None):
161 | super(EncD, self).__init__()
162 |
163 | nfc = 1024
164 |
165 | self.gpu_ids = gpu_ids
166 |
167 | self.main = nn.Sequential(
168 | nn.Linear(nlatentdim, nfc),
169 | nn.BatchNorm1d(nfc),
170 | nn.LeakyReLU(0.2, inplace=True),
171 |
172 | nn.Linear(nfc, nfc),
173 | nn.BatchNorm1d(nfc),
174 | nn.LeakyReLU(0.2, inplace=True),
175 |
176 | nn.Linear(nfc, 512),
177 | nn.BatchNorm1d(512),
178 | nn.LeakyReLU(0.2, inplace=True),
179 |
180 | nn.Linear(512, 1),
181 | nn.Sigmoid()
182 | )
183 |
184 | def forward(self, x):
185 | # gpu_ids = None
186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
187 | gpu_ids = self.gpu_ids
188 |
189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
190 |
191 | return x
192 |
193 | class DecD(nn.Module):
194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None):
195 | super(DecD, self).__init__()
196 |
197 | self.linear = nn.Linear(1,1)
198 |
199 | def forward(self, x):
200 | return x
201 |
202 |
203 |
--------------------------------------------------------------------------------
/integrated_cell/networks/old/aaegan_compact_lrelu.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import pdb
4 | from model_utils import init_opts
5 |
6 | ksize = 4
7 | dstep = 2
8 |
9 | class Enc(nn.Module):
10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
11 | super(Enc, self).__init__()
12 |
13 | self.gpu_ids = gpu_ids
14 | self.fcsize = insize/64
15 |
16 | self.nLatentDim = nLatentDim
17 | self.nClasses = nClasses
18 | self.nRef = nRef
19 |
20 | self.main = nn.Sequential(
21 | nn.Conv2d(nch, 64, ksize, dstep, 1),
22 | nn.BatchNorm2d(64),
23 |
24 | nn.LeakyReLU(0.2, inplace=True),
25 | nn.Conv2d(64, 128, ksize, dstep, 1),
26 | nn.BatchNorm2d(128),
27 |
28 | nn.LeakyReLU(0.2, inplace=True),
29 | nn.Conv2d(128, 256, ksize, dstep, 1),
30 | nn.BatchNorm2d(256),
31 |
32 | nn.LeakyReLU(0.2, inplace=True),
33 | nn.Conv2d(256, 512, ksize, dstep, 1),
34 | nn.BatchNorm2d(512),
35 |
36 | nn.LeakyReLU(0.2, inplace=True),
37 | nn.Conv2d(512, 1024, ksize, dstep, 1),
38 | nn.BatchNorm2d(1024),
39 |
40 | nn.LeakyReLU(0.2, inplace=True),
41 | nn.Conv2d(1024, 1024, ksize, dstep, 1),
42 | nn.BatchNorm2d(1024),
43 |
44 | nn.LeakyReLU(0.2, inplace=True)
45 | )
46 |
47 | # discriminator output
48 | self.discrim = nn.Sequential(
49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1),
50 | # nn.BatchNorm1d(self.nClasses+1),
51 | )
52 |
53 | if self.nClasses > 0:
54 | self.discrim.add_module('end', nn.LogSoftmax())
55 | else:
56 | self.discrim.add_module('end', nn.Sigmoid())
57 |
58 | if self.nClasses > 0:
59 | self.classOut = nn.Sequential(
60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses),
61 | nn.BatchNorm1d(self.nClasses),
62 | nn.LogSoftmax()
63 | )
64 |
65 | if self.nRef > 0:
66 | self.refOut = nn.Sequential(
67 | nn.Linear(1024*int(self.fcsize**2), self.nRef),
68 | nn.BatchNorm1d(self.nRef)
69 | )
70 |
71 | if self.nLatentDim > 0:
72 | self.latentOut = nn.Sequential(
73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim),
74 | nn.BatchNorm1d(self.nLatentDim)
75 | )
76 |
77 | def forward(self, x):
78 | # gpu_ids = None
79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
80 | gpu_ids = self.gpu_ids
81 |
82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2))
84 |
85 | xOut = list()
86 |
87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids)
88 | # xOut.append(xDiscrim)
89 |
90 | if self.nClasses > 0:
91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids)
92 | xOut.append(xClasses)
93 |
94 | if self.nRef > 0:
95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids)
96 | xOut.append(xRef)
97 |
98 | if self.nLatentDim > 0:
99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids)
100 | xOut.append(xLatent)
101 |
102 | return xOut, xDiscrim
103 |
104 | class Dec(nn.Module):
105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
106 | super(Dec, self).__init__()
107 |
108 | self.gpu_ids = gpu_ids
109 | self.fcsize = int(insize/64)
110 |
111 | self.nLatentDim = nLatentDim
112 | self.nClasses = nClasses
113 | self.nRef = nRef
114 |
115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2))
116 |
117 | self.main = nn.Sequential(
118 | # nn.BatchNorm2d(1024),
119 |
120 | nn.LeakyReLU(0.2, inplace=True),
121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1),
122 | nn.BatchNorm2d(1024),
123 |
124 | nn.LeakyReLU(0.2, inplace=True),
125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1),
126 | nn.BatchNorm2d(512),
127 |
128 | nn.LeakyReLU(0.2, inplace=True),
129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1),
130 | nn.BatchNorm2d(256),
131 |
132 | nn.LeakyReLU(0.2, inplace=True),
133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1),
134 | nn.BatchNorm2d(128),
135 |
136 | nn.LeakyReLU(0.2, inplace=True),
137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1),
138 | nn.BatchNorm2d(64),
139 |
140 | nn.LeakyReLU(0.2, inplace=True),
141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1),
142 | # nn.BatchNorm2d(nch),
143 | nn.Sigmoid()
144 | )
145 |
146 | def forward(self, xIn):
147 | # gpu_ids = None
148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
149 | gpu_ids = self.gpu_ids
150 |
151 | x = torch.cat(xIn, 1)
152 |
153 | x = self.fc(x)
154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize)
155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
156 |
157 | return x
158 |
159 | class EncD(nn.Module):
160 | def __init__(self, nlatentdim, gpu_ids, opt=None):
161 | super(EncD, self).__init__()
162 |
163 | nfc = 1024
164 |
165 | self.gpu_ids = gpu_ids
166 |
167 | self.main = nn.Sequential(
168 | nn.Linear(nlatentdim, nfc),
169 | nn.BatchNorm1d(nfc),
170 | nn.LeakyReLU(0.2, inplace=True),
171 |
172 | nn.Linear(nfc, nfc),
173 | nn.BatchNorm1d(nfc),
174 | nn.LeakyReLU(0.2, inplace=True),
175 |
176 | nn.Linear(nfc, 512),
177 | nn.BatchNorm1d(512),
178 | nn.LeakyReLU(0.2, inplace=True),
179 |
180 | nn.Linear(512, 1),
181 | nn.Sigmoid()
182 | )
183 |
184 | def forward(self, x):
185 | # gpu_ids = None
186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
187 | gpu_ids = self.gpu_ids
188 |
189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
190 |
191 | return x
192 |
193 | class DecD(nn.Module):
194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None):
195 | super(DecD, self).__init__()
196 |
197 | self.linear = nn.Linear(1,1)
198 |
199 | def forward(self, x):
200 | return x
201 |
202 |
203 |
--------------------------------------------------------------------------------
/integrated_cell/networks/old/aaegan_compact_lrelu2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import pdb
4 | from model_utils import init_opts
5 |
6 | ksize = 4
7 | dstep = 2
8 |
9 | class Enc(nn.Module):
10 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
11 | super(Enc, self).__init__()
12 |
13 | self.gpu_ids = gpu_ids
14 | self.fcsize = insize/64
15 |
16 | self.nLatentDim = nLatentDim
17 | self.nClasses = nClasses
18 | self.nRef = nRef
19 |
20 | self.main = nn.Sequential(
21 | nn.Conv2d(nch, 64, ksize, dstep, 1),
22 | nn.BatchNorm2d(64),
23 |
24 | nn.LeakyReLU(0.2, inplace=True),
25 | nn.Conv2d(64, 128, ksize, dstep, 1),
26 | nn.BatchNorm2d(128),
27 |
28 | nn.LeakyReLU(0.2, inplace=True),
29 | nn.Conv2d(128, 256, ksize, dstep, 1),
30 | nn.BatchNorm2d(256),
31 |
32 | nn.LeakyReLU(0.2, inplace=True),
33 | nn.Conv2d(256, 512, ksize, dstep, 1),
34 | nn.BatchNorm2d(512),
35 |
36 | nn.LeakyReLU(0.2, inplace=True),
37 | nn.Conv2d(512, 1024, ksize, dstep, 1),
38 | nn.BatchNorm2d(1024),
39 |
40 | nn.LeakyReLU(0.2, inplace=True),
41 | nn.Conv2d(1024, 1024, ksize, dstep, 1),
42 | nn.BatchNorm2d(1024),
43 |
44 | nn.LeakyReLU(0.2, inplace=True)
45 | )
46 |
47 | # discriminator output
48 | self.discrim = nn.Sequential(
49 | nn.Linear(1024*int(self.fcsize**2), self.nClasses+1),
50 | # nn.BatchNorm1d(self.nClasses+1),
51 | )
52 |
53 | if self.nClasses > 0:
54 | self.discrim.add_module('end', nn.LogSoftmax())
55 | else:
56 | self.discrim.add_module('end', nn.Sigmoid())
57 |
58 | if self.nClasses > 0:
59 | self.classOut = nn.Sequential(
60 | nn.Linear(1024*int(self.fcsize**2), self.nClasses),
61 | nn.BatchNorm1d(self.nClasses),
62 | nn.LogSoftmax()
63 | )
64 |
65 | if self.nRef > 0:
66 | self.refOut = nn.Sequential(
67 | nn.Linear(1024*int(self.fcsize**2), self.nRef),
68 | nn.BatchNorm1d(self.nRef)
69 | )
70 |
71 | if self.nLatentDim > 0:
72 | self.latentOut = nn.Sequential(
73 | nn.Linear(1024*int(self.fcsize**2), self.nLatentDim),
74 | nn.BatchNorm1d(self.nLatentDim)
75 | )
76 |
77 | def forward(self, x):
78 | # gpu_ids = None
79 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
80 | gpu_ids = self.gpu_ids
81 |
82 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
83 | x = x.view(x.size()[0], 1024*int(self.fcsize**2))
84 |
85 | xOut = list()
86 |
87 | xDiscrim = nn.parallel.data_parallel(self.discrim, x, gpu_ids)
88 | # xOut.append(xDiscrim)
89 |
90 | if self.nClasses > 0:
91 | xClasses = nn.parallel.data_parallel(self.classOut, x, gpu_ids)
92 | xOut.append(xClasses)
93 |
94 | if self.nRef > 0:
95 | xRef = nn.parallel.data_parallel(self.refOut, x, gpu_ids)
96 | xOut.append(xRef)
97 |
98 | if self.nLatentDim > 0:
99 | xLatent = nn.parallel.data_parallel(self.latentOut, x, gpu_ids)
100 | xOut.append(xLatent)
101 |
102 | return xOut, xDiscrim
103 |
104 | class Dec(nn.Module):
105 | def __init__(self, nLatentDim, nClasses, nRef, insize, nch, gpu_ids, opt=None):
106 | super(Dec, self).__init__()
107 |
108 | self.gpu_ids = gpu_ids
109 | self.fcsize = int(insize/64)
110 |
111 | self.nLatentDim = nLatentDim
112 | self.nClasses = nClasses
113 | self.nRef = nRef
114 |
115 | self.fc = nn.Linear(self.nLatentDim + self.nClasses + self.nRef, 1024*int(self.fcsize**2))
116 |
117 | self.main = nn.Sequential(
118 | # nn.BatchNorm2d(1024),
119 |
120 | nn.ReLU(inplace=True),
121 | nn.ConvTranspose2d(1024, 1024, ksize, dstep, 1),
122 | nn.BatchNorm2d(1024),
123 |
124 | nn.ReLU(inplace=True),
125 | nn.ConvTranspose2d(1024, 512, ksize, dstep, 1),
126 | nn.BatchNorm2d(512),
127 |
128 | nn.ReLU(inplace=True),
129 | nn.ConvTranspose2d(512, 256, ksize, dstep, 1),
130 | nn.BatchNorm2d(256),
131 |
132 | nn.ReLU(inplace=True),
133 | nn.ConvTranspose2d(256, 128, ksize, dstep, 1),
134 | nn.BatchNorm2d(128),
135 |
136 | nn.ReLU(inplace=True),
137 | nn.ConvTranspose2d(128, 64, ksize, dstep, 1),
138 | nn.BatchNorm2d(64),
139 |
140 | nn.ReLU(inplace=True),
141 | nn.ConvTranspose2d(64, nch, ksize, dstep, 1),
142 | # nn.BatchNorm2d(nch),
143 | nn.Sigmoid()
144 | )
145 |
146 | def forward(self, xIn):
147 | # gpu_ids = None
148 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
149 | gpu_ids = self.gpu_ids
150 |
151 | x = torch.cat(xIn, 1)
152 |
153 | x = self.fc(x)
154 | x = x.view(x.size()[0], 1024, self.fcsize, self.fcsize)
155 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
156 |
157 | return x
158 |
159 | class EncD(nn.Module):
160 | def __init__(self, nlatentdim, gpu_ids, opt=None):
161 | super(EncD, self).__init__()
162 |
163 | nfc = 1024
164 |
165 | self.gpu_ids = gpu_ids
166 |
167 | self.main = nn.Sequential(
168 | nn.Linear(nlatentdim, nfc),
169 | nn.BatchNorm1d(nfc),
170 | nn.LeakyReLU(0.2, inplace=True),
171 |
172 | nn.Linear(nfc, nfc),
173 | nn.BatchNorm1d(nfc),
174 | nn.LeakyReLU(0.2, inplace=True),
175 |
176 | nn.Linear(nfc, 512),
177 | nn.BatchNorm1d(512),
178 | nn.LeakyReLU(0.2, inplace=True),
179 |
180 | nn.Linear(512, 1),
181 | nn.Sigmoid()
182 | )
183 |
184 | def forward(self, x):
185 | # gpu_ids = None
186 | # if isinstance(x.data, torch.cuda.FloatTensor) and len(self.gpu_ids) > 1:
187 | gpu_ids = self.gpu_ids
188 |
189 | x = nn.parallel.data_parallel(self.main, x, gpu_ids)
190 |
191 | return x
192 |
193 | class DecD(nn.Module):
194 | def __init__(self, nout, insize, nch, gpu_ids, opt=None):
195 | super(DecD, self).__init__()
196 |
197 | self.linear = nn.Linear(1,1)
198 |
199 | def forward(self, x):
200 | return x
201 |
202 |
203 |
--------------------------------------------------------------------------------
/integrated_cell/networks/proto/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/networks/proto/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/networks/ref_target_autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .. import utils
4 |
5 |
6 | class Autoencoder(nn.Module):
7 | def __init__(self, ref_enc, ref_dec, target_enc, target_dec):
8 | super(Autoencoder, self).__init__()
9 |
10 | self.ref_enc = ref_enc
11 | self.ref_dec = ref_dec
12 | self.ref_n_latent_dim = ref_enc.n_latent_dim
13 |
14 | self.target_enc = target_enc
15 | self.target_dec = target_dec
16 | self.target_n_latent_dim = target_enc.n_latent_dim
17 |
18 | def encode(self, target, ref, labels):
19 | batch_size = labels.shape[0]
20 |
21 | if ref is None:
22 | # generate ref image
23 | self.z_ref = (
24 | torch.zeros([batch_size, self.ref_n_latent_dim])
25 | .type_as(labels)
26 | .normal_()
27 | )
28 |
29 | else:
30 | self.z_ref = self.ref_enc(ref)
31 |
32 | if target is None:
33 |
34 | self.z_target = (
35 | torch.zeros([batch_size, self.target_n_latent_dim])
36 | .type_as(labels)
37 | .normal_()
38 | )
39 |
40 | else:
41 | self.z_target = self.target_enc(target, ref, labels)
42 |
43 | return self.z_ref, self.z_target
44 |
45 | def forward(self, target, ref, labels):
46 |
47 | batch_size = labels.shape[0]
48 |
49 | if ref is None:
50 | # generate ref image
51 | self.z_ref = (
52 | torch.zeros([batch_size, self.ref_n_latent_dim])
53 | .type_as(labels)
54 | .normal_()
55 | )
56 |
57 | ref_hat = self.ref_dec(self.z_ref)
58 | ref = ref_hat
59 | else:
60 |
61 | self.z_ref = self.ref_enc(ref)
62 | self.z_ref_sampled = utils.reparameterize(self.z_ref[0], self.z_ref[1])
63 |
64 | ref_hat = self.ref_dec(self.z_ref_sampled)
65 |
66 | if target is None:
67 |
68 | self.z_target = (
69 | torch.zeros([batch_size, self.target_n_latent_dim])
70 | .type_as(labels)
71 | .normal_()
72 | )
73 |
74 | target_hat = self.target_dec(self.z_target, ref, labels)
75 | else:
76 | self.z_target = self.target_enc(target, ref, labels)
77 | self.z_target_sampled = utils.reparameterize(
78 | self.z_target[0], self.z_target[1]
79 | )
80 | target_hat = self.target_dec(self.z_target_sampled, ref, labels)
81 |
82 | return target_hat, ref_hat
83 |
--------------------------------------------------------------------------------
/integrated_cell/param_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import subprocess
4 | import time
5 | import tqdm
6 | import itertools
7 | import pandas as pd
8 |
9 | import pdb
10 |
11 | parent_dir = '/root/results/integrated_cell/param_search'
12 | if not os.path.exists(parent_dir):
13 | os.makedirs(parent_dir)
14 |
15 | scripts_dir = parent_dir + os.sep + 'scripts'
16 | if not os.path.exists(scripts_dir):
17 | os.makedirs(scripts_dir)
18 |
19 |
20 | job_list_path = parent_dir + 'job_list.csv'
21 |
22 | job_template = 'cd .. \n' \
23 | 'python /root/projects/pytorch_integrated_cell/train_model.py \\\n' \
24 | '\t--gpu_ids {0} \\\n' \
25 | '\t--save_dir {1} \\\n' \
26 | '\t--model_name aaegan3Dv6-exp-half \\\n' \
27 | '\t--lrEnc {2} --lrDec {2} --lrEncD {3} --lrDecD {2} \\\n' \
28 | '\t--encDRatio {4} --decDRatio {5} \\\n' \
29 | '\t--noise {6} \\\n' \
30 | '\t--nlatentdim 128 \\\n' \
31 | '\t--batch_size 30 \\\n' \
32 | '\t--nepochs 25 --nepochs_pt2 0 \\\n' \
33 | '\t--train_module aaegan_trainv6 \\\n' \
34 | '\t--imdir /root/data/ipp/ipp_17_10_25 \\\n' \
35 | '\t--dataProvider DataProvider3Dh5-half \\\n' \
36 | '\t--saveStateIter 1 --saveProgressIter 1 \\\n' \
37 | '\t--channels_pt1 0 2 --channels_pt2 0 1 2 \\\n'
38 |
39 | def format_job(string, gpu_id, row):
40 | return string.format(gpu_id, row['job_dir'], row['lr'], row['lrDec'], row['encDRatio'], row['decDRatio'], row['noise'])
41 |
42 |
43 | # job_template = 'sleep {0!s}'
44 | # def format_job(string):
45 | # return string.format(np.random.choice(10, 1)[0])
46 | job_path_template = scripts_dir + os.sep + 'param_job_{0!s}.sh'
47 | job_dir_template = parent_dir + os.sep + 'param_job_{0!s}'
48 |
49 | if os.path.exists(job_list_path):
50 | job_list = pd.read_csv(job_list_path)
51 | else:
52 | lr_range = [5E-4, 2E-4, 1E-4, 5E-5, 1E-5, 5E-6]
53 | lrDec_range = [1E-1, 5E-2, 1E-2, 5E-3, 1E-3, 1E-4, 5E-5, 1E-5]
54 | encDRatio_range = [1E-1, 5E-2, 1E-2, 5E-3, 1E-3, 5E-4, 1E-4, 5E-5, 1E-5, 5E-6, 1E-6]
55 | decDRatio_range = encDRatio_range
56 | noise_range = [0.1, 0.01, 0.001]
57 |
58 | param_opts = itertools.product(lr_range, lrDec_range, encDRatio_range, decDRatio_range, noise_range)
59 |
60 | opts_list = [list(i) for i in param_opts]
61 |
62 | job_scripts = [job_path_template.format(i) for i in range(0, len(opts_list))]
63 | job_dirs = [job_dir_template.format(i) for i in range(0, len(opts_list))]
64 |
65 | all_opts = [opt+[job_dir]+[job_scripts] for opt, job_dir, job_scripts in zip(opts_list, job_dirs, job_scripts)]
66 |
67 | job_list = pd.DataFrame(all_opts, columns = ['lr', 'lrDec', 'encDRatio', 'decDRatio', 'noise', 'job_dir', 'job_script'])
68 |
69 | job_list.to_csv(job_list_path)
70 |
71 |
72 | # randomly shuffle
73 | job_list = job_list.sample(frac=1)
74 |
75 |
76 |
77 |
78 | processes = range(0, 8)
79 | job_process_list = [-1] * len(processes)
80 |
81 | #for every job
82 | for index, row in job_list.iterrows():
83 |
84 | #check to see if the job exists
85 | job_script = row['job_script']
86 | if os.path.exists(job_script):
87 | continue
88 |
89 | #if it doesnt exist, wait for
90 | job_status = None
91 |
92 | #look for a process that isn't running
93 | while job_status is None:
94 | for job_process, process_id in zip(job_process_list, processes):
95 |
96 | #check process status
97 | if job_process == -1:
98 | job_status = True
99 | break
100 | else:
101 | job_status = job_process.poll()
102 |
103 | if job_status is not None:
104 | break
105 |
106 | if job_status is not None:
107 | break
108 |
109 | print('Waiting for next job')
110 | time.sleep(3)
111 |
112 | #start a job
113 | #get the job string
114 | job_str = format_job(job_template, process_id, row)
115 |
116 | #write the bash file to disk
117 |
118 | print('starting ' + job_script + ' on process ' + str(process_id))
119 | with open(job_script, 'w') as text_file:
120 | text_file.write(job_str)
121 |
122 | pdb.set_trace()
123 | job_process_list[process_id] = subprocess.Popen('bash ' + job_script, stdout=subprocess.PIPE, shell=True)
124 |
125 | print(job_template)
--------------------------------------------------------------------------------
/integrated_cell/simplelogger.py:
--------------------------------------------------------------------------------
1 | class SimpleLogger:
2 | def __init__(self, fields, print_format=""):
3 | if isinstance(print_format, str) and not print_format:
4 | printstr = ""
5 | for field in fields:
6 | printstr = printstr + field + ": %f "
7 | print_format = printstr
8 |
9 | self.print_format = print_format
10 |
11 | self.fields = fields
12 |
13 | self.log = dict()
14 | for field in fields:
15 | self.log[field] = []
16 |
17 | def add(self, inputs):
18 |
19 | assert len(inputs) == len(self.fields)
20 |
21 | for i in range(0, len(self.fields)):
22 | self.log[self.fields[i]].append(inputs[i])
23 |
24 | if isinstance(self.print_format, str):
25 | print(self.print_format % tuple(inputs))
26 |
27 | def __len__(self):
28 | return len(self.log[self.fields[0]])
29 |
--------------------------------------------------------------------------------
/integrated_cell/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/tests/conftest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from pathlib import Path
5 | import pytest
6 |
7 |
8 | @pytest.fixture
9 | def data_dir() -> Path:
10 | return Path(__file__).parent / "resources"
11 |
--------------------------------------------------------------------------------
/integrated_cell/tests/resources/img2D.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/resources/img2D.png
--------------------------------------------------------------------------------
/integrated_cell/tests/resources/img3D.tiff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/tests/resources/img3D.tiff
--------------------------------------------------------------------------------
/integrated_cell/tests/resources/test_data.csv:
--------------------------------------------------------------------------------
1 | CellId,StructureId/Name,CellLineId/Name,save_reg_path,save_reg_path_flat,ch_dna,ch_memb,ch_seg_cell,ch_seg_nuc,ch_struct,ch_trans
2 | 0,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5
3 | 1,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5
4 | 2,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5
5 | 3,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5
6 | 4,Lysotubules,AICS-Zero,./img3D.tiff,././img2D.png,2,3,1,0,4,5
7 | 5,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5
8 | 6,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5
9 | 7,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5
10 | 8,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5
11 | 9,Mitotubules,AICS-One,./img3D.tiff,././img2D.png,2,3,1,0,4,5
12 | 10,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5
13 | 11,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5
14 | 12,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5
15 | 13,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5
16 | 14,Riboplasmictubules,AICS-Three,./img3D.tiff,././img2D.png,2,3,1,0,4,5
17 | 15,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5
18 | 16,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5
19 | 17,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5
20 | 18,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5
21 | 19,FocalAdhesionTubules,AICS-Eleven,./img3D.tiff,././img2D.png,2,3,1,0,4,5
--------------------------------------------------------------------------------
/integrated_cell/tests/test_dataprovider.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from ..data_providers import DataProvider
4 |
5 | import pandas as pd
6 |
7 | # TODO: Add croping, rescaling, normalization tests
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "batch_size, n_dat, hold_out, channelInds, return2D, rescale_to, crop_to, normalize_intensity",
12 | [
13 | [1, -1, 0.1, [0, 1, 2], False, None, None, False],
14 | [2, -1, 0.1, [0, 1, 2], False, None, None, False],
15 | [2, -1, 0.1, [0, 1, 2], True, None, None, False],
16 | ],
17 | )
18 | def test_dataprovider(
19 | data_dir,
20 | batch_size,
21 | n_dat,
22 | hold_out,
23 | channelInds,
24 | return2D,
25 | rescale_to,
26 | crop_to,
27 | normalize_intensity,
28 | ):
29 | csv_name = "test_data.csv"
30 | test_data_csv = "{}/{}".format(data_dir, csv_name)
31 |
32 | df = pd.read_csv(test_data_csv)
33 |
34 | dp = DataProvider.DataProvider(
35 | image_parent=data_dir,
36 | csv_name=csv_name,
37 | batch_size=batch_size,
38 | n_dat=n_dat,
39 | hold_out=hold_out,
40 | channelInds=channelInds,
41 | return2D=return2D,
42 | rescale_to=rescale_to,
43 | crop_to=crop_to,
44 | normalize_intensity=normalize_intensity,
45 | )
46 |
47 | n_dat = 0
48 | for group in ["train", "test", "validate"]:
49 | n_dat += dp.get_n_dat(group)
50 |
51 | assert n_dat == len(df)
52 |
53 | x, classes, ref = dp.get_sample()
54 |
55 | assert x.shape[0] == batch_size
56 | assert x.shape[1] == len(channelInds)
57 |
58 | if return2D:
59 | assert len(x.shape) == 4
60 | else:
61 | assert len(x.shape) == 5
62 |
--------------------------------------------------------------------------------
/integrated_cell/tests/test_dummy.py:
--------------------------------------------------------------------------------
1 | def test_dummy():
2 | assert True
3 |
--------------------------------------------------------------------------------
/integrated_cell/tests/test_kld.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_kld():
5 | with pytest.mark.raises(NotImplementedError):
6 | raise NotImplementedError
7 |
--------------------------------------------------------------------------------
/integrated_cell/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .spectral_norm import spectral_norm, remove_spectral_norm # noqa
2 | from .utils import * # noqa
3 |
--------------------------------------------------------------------------------
/integrated_cell/utils/build_control_images.py:
--------------------------------------------------------------------------------
1 | import tifffile
2 | import pandas as pd
3 | import os
4 | from tqdm import tqdm
5 | import numpy as np
6 | import scipy
7 | from scipy.stats import norm
8 |
9 | from .utils import str2rand
10 | from .imgToProjection import imgtoprojection
11 |
12 |
13 | # I mostly copied over the parts from the IPP to clone the processing to make these control images.
14 | # This should be moved into the IPP later - GRJ 05/2019
15 |
16 |
17 | def build_control_images(
18 | save_dir,
19 | csv_path,
20 | image_parent,
21 | split_seed=615,
22 | verbose=True,
23 | target_col="StructureId/Name",
24 | ):
25 | # constructs "control" images that contain the following content as the "structure" channel:
26 | # DNA label
27 | # Membrane label
28 | # Random structure from a different cell
29 | # Blank (all zeros)
30 | # Poisson noise
31 |
32 | # make a dataframe out of the csv log file
33 | if verbose:
34 | print("reading csv manifest")
35 | csv_df = pd.read_csv(csv_path)
36 |
37 | image_col = "save_reg_path"
38 |
39 | csv_df["ch_memb"] = 3
40 | csv_df["ch_struct"] = 4
41 | csv_df["ch_dna"] = 2
42 | csv_df["ch_seg_cell"] = 1
43 | csv_df["ch_seg_nuc"] = 0
44 | csv_df["ch_trans"] = 5
45 |
46 | ch_names = ["nuc", "cell", "dna", "memb", "struct", "trans"]
47 |
48 | im_paths = list()
49 |
50 | for i, im_path in enumerate(csv_df[image_col]):
51 | splits = np.array(im_path.split("/"))
52 | lens = np.array([len(s) for s in splits])
53 | splits = splits[lens > 0]
54 | im_paths += ["/".join(splits[-2::])]
55 |
56 | csv_df[image_col] = im_paths
57 |
58 | # csv_df = check_files(csv_df, image_col, verbose)
59 |
60 | # pick some cells at random
61 | image_classes = list(csv_df[target_col])
62 | [label_names, labels] = np.unique(image_classes, return_inverse=True)
63 |
64 | n_classes = len(label_names)
65 | fract_sample_labels = 1 / n_classes
66 |
67 | rand = str2rand(csv_df["CellId"], split_seed)
68 | rand_struct = np.argsort(str2rand(csv_df["CellId"], split_seed + 1))
69 |
70 | duplicate_inds = rand <= fract_sample_labels
71 |
72 | df_controls = csv_df[duplicate_inds].copy().reset_index()
73 |
74 | random_scale = norm.ppf(0.999)
75 |
76 | channels_to_make = ["DNA", "Memb", "Blank", "Noise", "Random"]
77 |
78 | if not os.path.exists(save_dir):
79 | os.makedirs(save_dir)
80 |
81 | im = load_image(csv_df.iloc[0], image_parent, image_col)
82 | n_ch = im.shape[0]
83 |
84 | for index in tqdm(range(len(df_controls))):
85 | row = df_controls.iloc[index]
86 |
87 | save_path = "{}/control_{}.tiff".format(save_dir, index)
88 |
89 | save_path_flats = [
90 | "{}/control_{}_{}_flat.png".format(save_dir, index, ch_name)
91 | for ch_name in channels_to_make
92 | ]
93 | save_path_flat_projs = [
94 | "{}/control_{}_{}_flat_proj.png".format(save_dir, index, ch_name)
95 | for ch_name in channels_to_make
96 | ]
97 |
98 | try:
99 | im = load_image(row, image_parent, image_col)
100 | except: # noqa
101 | print("Skipping image {}".format(row[image_col]))
102 | continue
103 |
104 | n_ch = im.shape[0]
105 |
106 | ch_dna = im[row["ch_dna"]] * (im[row["ch_seg_nuc"]] > 0)
107 | ch_memb = im[row["ch_memb"]]
108 |
109 | ch_blank = np.zeros(im[0].shape)
110 |
111 | ch_noise = (np.random.normal(0, 1, im[0].shape) / random_scale) * 255
112 | ch_noise[ch_noise < 0] = 0
113 | ch_noise[ch_noise > 255] = 255
114 |
115 | random_image_row = csv_df.iloc[rand_struct[index]]
116 | ch_random = load_image(random_image_row, image_parent, image_col)[
117 | random_image_row["ch_struct"]
118 | ]
119 |
120 | im_out = np.concatenate(
121 | [im]
122 | + [
123 | np.expand_dims(ch, 0)
124 | for ch in [ch_dna, ch_memb, ch_blank, ch_noise, ch_random]
125 | ],
126 | 0,
127 | )
128 |
129 | im_tmp = im_out.astype("uint8")
130 | tifffile.imsave(save_path, im_tmp)
131 |
132 | im_crop = crop_cell_nuc(im_out, ch_names + channels_to_make)
133 |
134 | row_list = list()
135 | for i, channel in enumerate(channels_to_make):
136 | name = "Control - {}".format(channel)
137 |
138 | row = row.copy()
139 |
140 | row["StructureDisplayName"] = name
141 | row["StructureId"] = -i
142 | row["StructureShortName"] = name
143 | row["ProteinId/Name"] = name
144 | row["StructureId/Name"] = name
145 |
146 | row["save_reg_path"] = save_path
147 | row["save_reg_path_flat"] = save_path_flats[i]
148 | row["save_reg_path_flat_proj"] = save_path_flat_projs[i]
149 |
150 | row["ch_struct"] = n_ch + i
151 |
152 | row_list.append(row)
153 |
154 | im_tmp = im_crop[[2, 3, n_ch + i]]
155 | colors = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
156 |
157 | im_flat = imgtoprojection(
158 | im_tmp / 255,
159 | proj_all=True,
160 | proj_method="max",
161 | local_adjust=False,
162 | global_adjust=True,
163 | colors=colors,
164 | )
165 | scipy.misc.imsave(
166 | row["save_reg_path_flat_proj"], im_flat.transpose(1, 2, 0)
167 | )
168 |
169 | im_flat = imgtoprojection(
170 | im_tmp / 255,
171 | proj_all=False,
172 | proj_method="max",
173 | local_adjust=False,
174 | global_adjust=True,
175 | colors=colors,
176 | )
177 | scipy.misc.imsave(row["save_reg_path_flat"], im_flat.transpose(1, 2, 0))
178 |
179 | csv_df = csv_df.append(row_list)
180 |
181 | csv_df.to_csv("{}/data_plus_controls.csv".format(save_dir))
182 |
183 |
184 | def check_files(csv_df, image_col="save_reg_path", verbose=True):
185 | # TODO add checking to make sure number of keys in h5 file matches number of lines in csv file
186 | if verbose:
187 | print("Checking the existence of files")
188 |
189 | for index in tqdm(range(0, csv_df.shape[0]), desc="Checking files", ascii=True):
190 | is_good_row = True
191 |
192 | row = csv_df.loc[index]
193 |
194 | image_path = os.sep + row[image_col]
195 |
196 | try:
197 | load_image(row)
198 | except: # noqa
199 | print("Could not load from image. " + image_path)
200 | is_good_row = False
201 |
202 | csv_df.loc[index, "valid_row"] = is_good_row
203 |
204 | # only work with valid rows
205 | n_old_rows = len(csv_df)
206 | csv_df = csv_df.loc[csv_df["valid_row"] == True] # noqa
207 | csv_df = csv_df.drop("valid_row", 1)
208 | n_new_rows = len(csv_df)
209 | if verbose:
210 | print("{0}/{1} samples have all files present".format(n_new_rows, n_old_rows))
211 |
212 | return csv_df
213 |
214 |
215 | def load_image(df_row, image_parent, image_col):
216 |
217 | im_path = image_parent + os.sep + df_row[image_col]
218 |
219 | im = tifffile.imread(im_path)
220 |
221 | return im
222 |
223 |
224 | def crop_cell_nuc(im, channel_names):
225 |
226 | nuc_ind = np.array(channel_names) == "nuc"
227 | cell_ind = np.array(channel_names) == "cell"
228 |
229 | dna_ind = np.array(channel_names) == "dna"
230 | # trans_ind = np.array(channel_names) == "trans"
231 |
232 | other_channel_inds = np.ones(len(channel_names))
233 | other_channel_inds[nuc_ind | cell_ind | dna_ind] = 0
234 |
235 | im[dna_ind] = im[dna_ind] * im[nuc_ind]
236 |
237 | for i in np.where(other_channel_inds)[0]:
238 | im[i] = im[i] * im[cell_ind]
239 |
240 | return im
241 |
--------------------------------------------------------------------------------
/integrated_cell/utils/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from skimage.measure import label
4 | from skimage.filters import gaussian, threshold_otsu
5 | from scipy.ndimage.morphology import binary_fill_holes
6 | from aicsfeature.extractor import dna, cell
7 |
8 |
9 | def im2feats(
10 | im_cell,
11 | im_nuc,
12 | im_structures,
13 | extra_features=["io_intensity", "bright_spots", "intensity", "skeleton", "texture"],
14 | ):
15 | # im_cell = Y x X x Z binary numpy array
16 | # im_nuc = Y x X x Z binary numpy array
17 | # im_structures = C x Y x X x Z numpy array
18 |
19 | # im_cell, im_nuc, im_structures, im_structures_seg, seg_cell, seg_nuc = im_process(
20 | # im_cell, im_nuc, im_structures
21 | # )
22 |
23 | nuc_feats = dna.get_features(im_nuc, extra_features=extra_features)
24 | cell_feats = cell.get_features(im_cell, extra_features=extra_features)
25 |
26 | # feats_out = aicsfeature.kitchen_sink.kitchen_sink(
27 | # im_cell=im_cell,
28 | # im_nuc=im_nuc,
29 | # im_structures=im_structures,
30 | # seg_cell=seg_cell,
31 | # seg_nuc=seg_nuc,
32 | # extra_features=extra_features,
33 | # )
34 | # feats_out_seg = aicsfeature.kitchen_sink.kitchen_sink(
35 | # im_cell=im_cell,
36 | # im_nuc=im_nuc,
37 | # im_structures=im_structures_seg,
38 | # seg_cell=seg_cell,
39 | # seg_nuc=seg_nuc,
40 | # extra_features=extra_features,
41 | # )
42 |
43 | return [nuc_feats, cell_feats]
44 |
45 |
46 | def im_process(im_cell, im_nuc, im_structures):
47 | seg_nuc = binary_fill_holes(im_nuc)
48 | seg_cell = binary_fill_holes(im_cell)
49 | seg_cell[seg_nuc] = 1
50 |
51 | im_nuc = (im_nuc * (255)).astype("uint16") * seg_nuc
52 | im_cell = (im_cell * (255)).astype("uint16") * seg_cell
53 |
54 | im_structures = [(im_ch * (255)).astype("uint16") for im_ch in im_structures]
55 |
56 | im_structures_seg = list()
57 |
58 | for i, im_structure in enumerate(im_structures):
59 | im_blur = gaussian(im_structure, 1)
60 |
61 | im_pix = im_structure[im_cell > 0]
62 | if np.all(im_pix == 0):
63 | im_structures_seg.append(im_structure)
64 | continue
65 |
66 | im_structures_seg.append(
67 | im_structure * (im_blur > threshold_otsu(im_blur[im_cell > 0]))
68 | )
69 |
70 | return im_cell, im_nuc, im_structures, im_structures_seg, seg_cell, seg_nuc
71 |
72 |
73 | def find_main_obj(im_seg):
74 | im_label = label(im_seg)
75 |
76 | obj_index = -1
77 | max_cell_obj_size = -1
78 | for i in range(1, np.max(im_label) + 1):
79 | obj_size = np.sum(im_label == i)
80 | if obj_size > max_cell_obj_size:
81 | max_cell_obj_size = obj_size
82 | obj_index = i
83 |
84 | main_obj = im_label == obj_index
85 |
86 | return main_obj
87 |
--------------------------------------------------------------------------------
/integrated_cell/utils/image_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def crop_to(images, crop_to):
6 | crop = (np.array(images.shape[2:]) - np.array(crop_to)) / 2
7 | crop_pre = np.floor(crop).astype(int)
8 | crop_post = np.ceil(crop).astype(int)
9 |
10 | pad_pre = -crop_pre
11 | pad_pre[pad_pre < 0] = 0
12 |
13 | pad_post = -crop_post
14 | pad_post[pad_post < 0] = 0
15 |
16 | crop_pre[crop_pre < 0] = 0
17 |
18 | crop_post[crop_post < 0] = 0
19 | crop_post[crop_post == 0] = -np.array(images.shape[2:])[crop_post == 0]
20 |
21 | if len(crop_pre) == 2:
22 | images = images[
23 | :,
24 | :,
25 | crop_pre[0] : -crop_post[0], # noqa
26 | crop_pre[1] : -crop_post[1], # noqa
27 | ]
28 |
29 | elif len(crop_pre) == 3:
30 | images = images[
31 | :,
32 | :,
33 | crop_pre[0] : -crop_post[0], # noqa
34 | crop_pre[1] : -crop_post[1], # noqa
35 | crop_pre[2] : -crop_post[2], # noqa
36 | ]
37 |
38 | pad_pre = np.hstack([np.zeros(2), pad_pre])
39 | pad_post = np.hstack([np.zeros(2), pad_post])
40 |
41 | padding = np.vstack([pad_pre, pad_post]).transpose().astype("int")
42 |
43 | images = np.pad(images, padding, mode="constant", constant_values=0)
44 |
45 | images = torch.tensor(images)
46 |
--------------------------------------------------------------------------------
/integrated_cell/utils/imgToProjection.py:
--------------------------------------------------------------------------------
1 | # Author: Evan Wiederspan
2 |
3 | import numpy as np
4 | import matplotlib
5 |
6 | matplotlib.use("agg")
7 | import matplotlib.pyplot as pplot # noqa
8 |
9 |
10 | def matproj(im, dim, method="max", slice_index=0):
11 | if method == "max":
12 | im = np.max(im, dim)
13 | elif method == "mean":
14 | im = np.mean(im, dim)
15 | elif method == "sum":
16 | im = np.sum(im, dim)
17 | elif method == "slice":
18 | im = im[slice_index]
19 | else:
20 | raise ValueError("Invalid projection method")
21 | return im
22 |
23 |
24 | def imgtoprojection(
25 | im1,
26 | proj_all=False,
27 | proj_method="max",
28 | colors=lambda i: [1, 1, 1],
29 | global_adjust=False,
30 | local_adjust=False,
31 | ):
32 | """
33 | Outputs projections of a 4d CZYX numpy array into a CYX numpy array, allowing for color masks for each input channel
34 | as well as adjustment options
35 | :param im1: Either a 4d numpy array or a list of 3D or 2D numpy arrays. The input that will be projected
36 | :param proj_all: boolean. True outputs XY, YZ, and XZ projections in a grid, False just outputs XY. False by default
37 | :param proj_method: string. Method by which to do projections. 'Max' by default
38 | :param colors: Can be either a string which corresponds to a cmap function in matplotlib, a function that
39 | takes in the channel index and returns a list of numbers, or a list of lists containing the color multipliers.
40 | :param global_adjust: boolean. If true, scales each color channel to set its max to be 255
41 | after combining all channels. False by default
42 | :param local_adjust: boolean. If true, performs contrast adjustment on each channel individually. False by default
43 | :return: a CYX numpy array containing the requested projections
44 | """
45 |
46 | # turn list of 2d or 3d arrays into single 4d array if needed
47 | try:
48 | if isinstance(im1, (list, tuple)):
49 | # if only YX, add a single Z dimen
50 | if im1[0].ndim == 2:
51 | im1 = [np.expand_dims(c, axis=0) for c in im1]
52 | elif im1[0].ndim != 3:
53 | raise ValueError("im1 must be a list of 2d or 3d arrays")
54 | # combine list into 4d array
55 | im = np.stack(im1)
56 | else:
57 | if im1.ndim != 4:
58 | raise ValueError("Invalid dimensions for im1")
59 | im = im1
60 |
61 | except (AttributeError, IndexError):
62 | # its not a list of np arrays
63 | raise ValueError(
64 | "im1 must be either a 4d numpy array or a list of numpy arrays"
65 | )
66 |
67 | # color processing code
68 | if isinstance(colors, str):
69 | # pass it in to matplotlib
70 | try:
71 | colors = pplot.get_cmap(colors)(np.linspace(0, 1, im.shape[0]))
72 | except ValueError:
73 | # thrown when string is not valid function
74 | raise ValueError("Invalid cmap string")
75 | elif callable(colors):
76 | # if its a function
77 | try:
78 | colors = [colors(i) for i in range(im.shape[0])]
79 | except ValueError:
80 | raise ValueError("Invalid color function")
81 |
82 | # else, were assuming it's a list
83 | # scale colors down to 0-1 range if they're bigger than 1
84 | if any(v > 1 for v in np.array(colors).flatten()):
85 | colors = [[v / 255.0 for v in c] for c in colors]
86 |
87 | # create final image
88 | if not proj_all:
89 | img_final = np.zeros((3, im.shape[2], im.shape[3]))
90 | else:
91 | # y + z, x + z
92 | img_final = np.zeros((3, im.shape[2] + im.shape[1], im.shape[3] + im.shape[1]))
93 | img_piece = np.zeros(img_final.shape)
94 | # loop through all channels
95 | for i, img_c in enumerate(im):
96 | try:
97 | proj_z = matproj(img_c, 0, proj_method, img_c.shape[0] // 2)
98 | if proj_all:
99 | proj_y, proj_x = (
100 | matproj(img_c, axis, proj_method, img_c.shape[axis] // 2)
101 | for axis in range(1, 3)
102 | )
103 | # flipping to get them facing the right way
104 | proj_x = np.transpose(proj_x, (1, 0))
105 | proj_y = np.flipud(proj_y)
106 | _, sy, sz = proj_z.shape[1], proj_z.shape[0], proj_y.shape[0] # noqa
107 | img_piece[:, :sy, :sz] = proj_x
108 | img_piece[:, :sy, sz:] = proj_z
109 | img_piece[:, sy:, sz:] = proj_y
110 | else:
111 | img_piece[:] = proj_z
112 | except ValueError:
113 | raise ValueError("Invalid projection function")
114 |
115 | for c in range(3):
116 | img_piece[c] *= colors[i][c]
117 |
118 | # local contrast adjustment, minus the min, divide the max
119 | if local_adjust:
120 | img_piece -= np.min(img_piece)
121 | img_max = np.max(img_piece)
122 | if img_max > 0:
123 | img_piece /= img_max
124 | # img_final += img_piece
125 | img_final += (1 - img_final) * img_piece
126 |
127 | # color range adjustment, ensure that max value is 255
128 | if global_adjust:
129 | # scale color channels independently
130 | for c in range(3):
131 | max_val = np.max(img_final[c].flatten())
132 | if max_val > 0:
133 | img_final[c] *= 255.0 / max_val
134 |
135 | return img_final
136 |
--------------------------------------------------------------------------------
/integrated_cell/utils/reference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/integrated_cell/utils/reference/__init__.py
--------------------------------------------------------------------------------
/integrated_cell/utils/spectral_norm.py:
--------------------------------------------------------------------------------
1 | # stolen from pytorch 0.5
2 | # https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm
3 |
4 | """
5 | Spectral Normalization from https://arxiv.org/abs/1802.05957
6 | """
7 | import torch
8 | from torch.nn.functional import normalize
9 |
10 |
11 | class SpectralNorm(object):
12 | def __init__(self, name="weight", n_power_iterations=1, eps=1e-12):
13 | self.name = name
14 | if n_power_iterations <= 0:
15 | raise ValueError(
16 | "Expected n_power_iterations to be positive, but "
17 | "got n_power_iterations={}".format(n_power_iterations)
18 | )
19 | self.n_power_iterations = n_power_iterations
20 | self.eps = eps
21 |
22 | def compute_weight(self, module):
23 | weight = getattr(module, self.name + "_org")
24 | u = getattr(module, self.name + "_u")
25 | height = weight.size(0)
26 | weight_mat = weight.view(height, -1)
27 | with torch.no_grad():
28 | for _ in range(self.n_power_iterations):
29 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
30 | # are the first left and right singular vectors.
31 | # This power iteration produces approximations of `u` and `v`.
32 | v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
33 | u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
34 |
35 | # import pdb
36 | # pdb.set_trace()
37 |
38 | sigma = torch.dot(
39 | u.float(), torch.matmul(weight_mat.float(), v.float()).float()
40 | )
41 | weight = weight / sigma
42 | return weight, u
43 |
44 | def remove(self, module):
45 | weight = module._parameters[self.name + "_org"]
46 | delattr(module, self.name)
47 | delattr(module, self.name + "_u")
48 | delattr(module, self.name + "_org")
49 | module.register_parameter(self.name, weight)
50 |
51 | def __call__(self, module, inputs):
52 | weight, u = self.compute_weight(module)
53 | setattr(module, self.name, weight)
54 | setattr(module, self.name + "_u", u)
55 |
56 | @staticmethod
57 | def apply(module, name, n_power_iterations, eps):
58 | fn = SpectralNorm(name, n_power_iterations, eps)
59 | weight = module._parameters[name]
60 | height = weight.size(0)
61 |
62 | u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
63 | delattr(module, fn.name)
64 | module.register_parameter(fn.name + "_org", weight)
65 | module.register_buffer(fn.name, weight)
66 | module.register_buffer(fn.name + "_u", u)
67 |
68 | module.register_forward_pre_hook(fn)
69 | return fn
70 |
71 |
72 | def spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12):
73 | r"""Applies spectral normalization to a parameter in the given module.
74 |
75 | .. math::
76 | \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
77 | \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
78 |
79 | Spectral normalization stabilizes the training of discriminators (critics)
80 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
81 | with spectral norm :math:`\sigma` of the weight matrix calculated using
82 | power iteration method. If the dimension of the weight tensor is greater
83 | than 2, it is reshaped to 2D in power iteration method to get spectral
84 | norm. This is implemented via a hook that calculates spectral norm and
85 | rescales weight before every :meth:`~Module.forward` call.
86 |
87 | See `Spectral Normalization for Generative Adversarial Networks`_ .
88 |
89 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
90 |
91 | Args:
92 | module (nn.Module): containing module
93 | name (str, optional): name of weight parameter
94 | n_power_iterations (int, optional): number of power iterations to
95 | calculate spectal norm
96 | eps (float, optional): epsilon for numerical stability in
97 | calculating norms
98 |
99 | Returns:
100 | The original module with the spectal norm hook
101 |
102 | Example::
103 |
104 | >>> m = spectral_norm(nn.Linear(20, 40))
105 | Linear (20 -> 40)
106 | >>> m.weight_u.size()
107 | torch.Size([20])
108 |
109 | """
110 | SpectralNorm.apply(module, name, n_power_iterations, eps)
111 | return module
112 |
113 |
114 | def remove_spectral_norm(module, name="weight"):
115 | r"""Removes the spectral normalization reparameterization from a module.
116 |
117 | Args:
118 | module (nn.Module): containing module
119 | name (str, optional): name of weight parameter
120 |
121 | Example:
122 | >>> m = spectral_norm(nn.Linear(40, 10))
123 | >>> remove_spectral_norm(m)
124 | """
125 | for k, hook in module._forward_pre_hooks.items():
126 | if isinstance(hook, SpectralNorm) and hook.name == name:
127 | hook.remove(module)
128 | del module._forward_pre_hooks[k]
129 | return module
130 |
131 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
132 |
--------------------------------------------------------------------------------
/integrated_cell/utils/target/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import * # noqa
2 |
--------------------------------------------------------------------------------
/integrated_cell/utils/target/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def sample2im(x, ref):
5 | n_channels = x.shape[1] + ref.shape[1]
6 |
7 | im = torch.zeros([x.shape[0], n_channels, x.shape[2], x.shape[3], x.shape[4]])
8 |
9 | im[:, [0, 2]] = ref
10 | im[:, 1] = x
11 |
12 | return im
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllenCellModeling/pytorch_integrated_cell/8a83fc6f8dc79037f4b681d9d7ef0bc5b91e9948/requirements.txt
--------------------------------------------------------------------------------
/scripts/aaegan_short.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | python /root/projects/pytorch_integrated_cell/train_model.py \
3 | --gpu_ids $1 \
4 | --save_dir ./results/aaegan_short \
5 | --data_save_path ./results/data.pyt \
6 | --lrEnc 2E-4 --lrDec 2E-4 \
7 | --lrEncD 2E-2 --lrDecD 1E-4 \
8 | --model_name aaegan3Dv6-relu-exp \
9 | --train_module aaegan_trainv8 \
10 | --kwargs_encD '{"noise_std": 0}' \
11 | --kwargs_decD '{"noise_std": 2E-1}' \
12 | --kwargs_optim '{"betas": [0, 0.9]}' \
13 | --kwargs_model '{"lambda_encD_loss": 10, "lambda_decD_loss": 1, "lambda_class_loss": 1000, "lambda_ref_loss": 1000, "provide_decoder_vars": 1}' \
14 | --imdir /root/results/ipp/ipp_17_10_25 \
15 | --dataProvider DataProvider3Dh5 \
16 | --saveStateIter 1 --saveProgressIter 1 \
17 | --channels_pt1 0 2 --channels_pt2 0 1 2 \
18 | --batch_size 16 \
19 | --nlatentdim 128 \
20 | --nepochs 1 \
21 | --nepochs_pt2 1 \
22 | --ndat 32 \
23 | --overwrite_opts True \
24 |
--------------------------------------------------------------------------------
/scripts/bvae.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | python ./train_model.py \
3 | --gpu_ids $1 \
4 | --save_parent ./results/test_bvae/ \
5 | --lrEnc 1E-4 --lrDec 1E-4 \
6 | --data_save_path ./results/data.pyt \
7 | --critRecon BCELoss \
8 | --model_name vaaegan3D \
9 | --kwargs_model '{"beta": 1}' \
10 | --train_module bvae \
11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \
12 | --imdir /root/results/ipp/ipp_17_10_25 \
13 | --dataProvider DataProvider3Dh5 \
14 | --saveStateIter 1 --saveProgressIter 1 \
15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \
16 | --batch_size 16 \
17 | --nlatentdim 128 \
18 | --nepochs 50 \
19 | --nepochs_pt2 50 \
20 |
--------------------------------------------------------------------------------
/scripts/bvae_short.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 |
3 |
4 | python /root/projects/pytorch_integrated_cell/train_model.py \
5 | --gpu_ids $1 \
6 | --save_dir ./results/bvae_short \
7 | --data_save_path ./results/data.pyt \
8 | --lrEnc 2E-4 --lrDec 2E-4 \
9 | --model_name vaaegan3D \
10 | --train_module bvae \
11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \
12 | --imdir /root/results/ipp/ipp_17_10_25 \
13 | --dataProvider DataProvider3Dh5 \
14 | --saveStateIter 1 --saveProgressIter 1 \
15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \
16 | --batch_size 16 \
17 | --nlatentdim 128 \
18 | --nepochs 5 \
19 | --nepochs_pt2 5 \
20 | --ndat 32 \
21 | --overwrite_opts True \
22 |
--------------------------------------------------------------------------------
/scripts/bvaegan_short.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | python /root/projects/pytorch_integrated_cell/train_model.py \
3 | --gpu_ids $1 \
4 | --save_dir ./results/bvaegan_short \
5 | --data_save_path ./results/data.pyt \
6 | --lrEnc 2E-4 --lrDec 2E-4 \
7 | --lrDecD 2E-4 \
8 | --model_name vaaegan3D \
9 | --train_module bvaegan \
10 | --kwargs_decD '{"noise_std": 1E-1}' \
11 | --kwargs_optim '{"betas": [0, 0.9]}' \
12 | --imdir /root/results/ipp/ipp_17_10_25 \
13 | --dataProvider DataProvider3Dh5 \
14 | --saveStateIter 1 --saveProgressIter 1 \
15 | --channels_pt1 0 2 --channels_pt2 0 1 2 \
16 | --batch_size 16 \
17 | --nlatentdim 128 \
18 | --nepochs 1 \
19 | --nepochs_pt2 1 \
20 | --ndat 32 \
21 | --overwrite_opts True \
22 |
--------------------------------------------------------------------------------
/scripts/multi_pred.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 |
3 |
4 | python /root/projects/pytorch_integrated_cell/train_model.py \
5 | --gpu_ids $1 \
6 | --save_dir ./results/multi_pred \
7 | --data_save_path ./results/data.pyt \
8 | --lrEnc 2E-4 --lrDec 2E-4 \
9 | --model_name multi_pred \
10 | --train_module multi_pred \
11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \
12 | --imdir /root/results/ipp/ipp_17_10_25 \
13 | --dataProvider DataProvider3Dh5 \
14 | --saveStateIter 1 --saveProgressIter 1 \
15 | --channels_pt1 0 2 3 4 5 1 --channels_pt2 0 \
16 | --batch_size 16 \
17 | --nlatentdim 128 \
18 | --nepochs 100 \
19 | --nepochs_pt2 0 \
20 | --overwrite_opts True \
21 |
--------------------------------------------------------------------------------
/scripts/multi_pred_short.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 |
3 |
4 | python /root/projects/pytorch_integrated_cell/train_model.py \
5 | --gpu_ids $1 \
6 | --save_dir ./results/multi_pred_short \
7 | --data_save_path ./results/data.pyt \
8 | --lrEnc 2E-4 --lrDec 2E-4 \
9 | --model_name multi_pred \
10 | --train_module multi_pred \
11 | --kwargs_optim '{"betas": [0.9, 0.999]}' \
12 | --imdir /root/results/ipp/ipp_17_10_25 \
13 | --dataProvider DataProvider3Dh5 \
14 | --saveStateIter 1 --saveProgressIter 1 \
15 | --channels_pt1 0 2 3 4 5 1 --channels_pt2 0 \
16 | --batch_size 16 \
17 | --nlatentdim 128 \
18 | --nepochs 5 \
19 | --nepochs_pt2 0 \
20 | --ndat 32 \
21 | --overwrite_opts True \
22 |
--------------------------------------------------------------------------------
/scripts/test_build_control_images.py:
--------------------------------------------------------------------------------
1 | from integrated_cell.utils.build_control_images import build_control_images
2 |
3 | save_dir = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/controls"
4 | csv_path = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/data_jobs_out.csv"
5 | image_parent = "/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/"
6 |
7 | build_control_images(save_dir, csv_path, image_parent)
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | exclude_dirs = ["exmples", "doc", "scripts"]
4 |
5 | PACKAGES = find_packages(exclude=exclude_dirs)
6 |
7 | setup(
8 | name="pytorch_integrated_cell",
9 | version="0.1",
10 | packages=PACKAGES,
11 | entry_points={
12 | "console_scripts": ["ic_train_model=integrated_cell.bin.train_model:main"]
13 | },
14 | install_requires=[
15 | "torch==1.2.0",
16 | "torchvision==0.2.1",
17 | "matplotlib==2.2.2",
18 | "numpy>=1.15.0",
19 | "pandas>=0.23.4",
20 | "pip",
21 | "pillow==5.2.0",
22 | "scikit-image==0.15.0",
23 | "scipy==1.1.0",
24 | "tqdm>=4.28.1",
25 | "natsort==5.3.3",
26 | "ipykernel",
27 | "aicsimageio==3.0.7",
28 | "msgpack<0.6.0,>=0.5.6",
29 | "imageio==2.6.0",
30 | "quilt3==3.1.1",
31 | "seaborn",
32 | "brokenaxes",
33 | "lkaccess",
34 | "aicsfeature==0.2.1",
35 | ],
36 | )
37 |
--------------------------------------------------------------------------------