├── .flake8
├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── README.rst
├── pyproject.toml
├── src
└── cryo_sbi
│ ├── __init__.py
│ ├── inference
│ ├── __init__.py
│ ├── command_line_tools.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── build_models.py
│ │ ├── embedding_nets.py
│ │ └── estimator_models.py
│ ├── priors.py
│ ├── train_npe_model.py
│ └── validate_train_config.py
│ ├── utils
│ ├── __init__.py
│ ├── command_line_tools.py
│ ├── estimator_utils.py
│ ├── generate_models.py
│ ├── image_utils.py
│ ├── micrograph_utils.py
│ └── visualize_models.py
│ └── wpa_simulator
│ ├── __init__.py
│ ├── cryo_em_simulator.py
│ ├── ctf.py
│ ├── image_generation.py
│ ├── noise.py
│ ├── normalization.py
│ └── validate_image_config.py
├── tests
├── config_files
│ ├── image_params_testing.json
│ └── training_params_npe_testing.json
├── data
│ └── test.mrc
├── models
│ └── hsp90_models.pt
├── test_embeddings.py
├── test_estimator_utils.py
├── test_image_utils.py
├── test_micrograph_utils.py
├── test_posterior_models.py
├── test_simulator.py
└── test_visualize_models.py
└── tutorials
├── .gitignore
├── simulation_parameters.json
├── training_parameters.json
└── tutorial.ipynb
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # GHA workflow for running tests.
2 | #
3 | # Largely taken from
4 | # https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
5 | # Please check the link for more detailed instructions
6 |
7 | name: Run tests
8 |
9 | on: [push]
10 |
11 | jobs:
12 | build:
13 |
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.9", "3.10"]
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install .
29 | pip install pytest
30 | - name: Test with pytest
31 | run: |
32 | pytest tests/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project specific files
2 | messing_around/
3 | tests/test_simulator.ipynb
4 | *.estimator
5 | *.h5
6 | .vscode/
7 | .DS_Store
8 | #SBI related folders
9 | sbi-logs/
10 | *.png
11 | *.pdf
12 | *.png
13 |
14 | # Training data, models and other datafiles
15 | results/
16 | production/
17 | *.estimator
18 | *epoch=*
19 | *.loss
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 | db.sqlite3-journal
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | ================================================
2 | cryoSBI - Simulation-based Inference for Cryo-EM
3 | ================================================
4 |
5 | .. start-badges
6 |
7 | .. list-table::
8 | :stub-columns: 1
9 |
10 | * - tests
11 | - | |githubactions|
12 |
13 |
14 | .. |githubactions| image:: https://github.com/DSilva27/cryo_em_SBI/actions/workflows/python-package.yml/badge.svg?branch=main
15 | :alt: Testing Status
16 | :target: https://github.com/DSilva27/cryo_em_SBI/actions
17 |
18 | Summary
19 | -------
20 | cryoSBI is a Python module for simulation-based inference in cryo-electron microscopy. The module provides tools for simulating cryo-EM particles, training an amortized posterior model, and sampling from the posterior distribution.
21 | The code is based on the SBI libary `Lampe `_, which is using Pytorch.
22 |
23 | Installing
24 | ----------
25 | To install the module you will have to dowload the repository and create a virtual environment with the required dependencies.
26 | You can create an environment for example with conda using the following command:
27 |
28 | .. code:: bash
29 |
30 | conda create -n cryoSBI python=3.10
31 |
32 | After creating the virtual environment, you should install the required dependencies and the module.
33 |
34 | Dependencies
35 | ------------
36 |
37 | 1. `Lampe `_.
38 | 2. `SciPy `_.
39 | 3. `Numpy `_.
40 | 4. `PyTorch `_.
41 | 5. json
42 | 6. `mrcfile `_.
43 |
44 | Download this repository
45 | ------------------------
46 | .. code:: bash
47 |
48 | git clone https://github.com/flatironinstitute/cryoSBI.git
49 |
50 | Navigate to the cloned repository and install the module
51 | --------------------------------------------------------
52 | .. code:: bash
53 |
54 | cd cryo_em_SBI
55 |
56 | .. code:: bash
57 |
58 | pip install .
59 |
60 | Tutorial
61 | --------
62 | An introduction tutorial can be found at `tutorials/tutorial.ipynb`. In this tutorial, we go through the whole process of making models for cryoSBI, training an amortized posterior, and analyzing the results.
63 | In the following section, I highlighted cryoSBI key features.
64 |
65 | Generate model file to simulate cryo-EM particles
66 | -------------------------------------------------
67 | To generate a model file for simulating cryo-EM particles with the simulator provided in this module, you can use the command line tool `models_to_tensor`.
68 | You will need either a set of pdbs which are indexd or a trr trejectory file which contians all models. The tool will generate a model file that can be used to simulate cryo-EM particles.
69 |
70 | .. code:: bash
71 |
72 | models_to_tensor \
73 | --model_file path_to_models/pdb_{}.pdb \
74 | --output_file path_to_output_file/output.pt \
75 | --n_pdbs 100
76 |
77 | The output file will be a Pytorch tensor with the shape (number of models, 3, number of pseudo atoms).
78 |
79 | Simulating cryo-EM particles
80 | -----------------------------
81 | To simulate cryo-EM particles, you can use the CryoEmSimulator class. The class takes in a simulation config file and simulates cryo-EM particles based on the parameters specified in the config file.
82 |
83 | .. code:: python
84 |
85 | from cryo_sbi import CryoEmSimulator
86 | simulator = CryoEmSimulator("path_to_simulation_config_file.json")
87 | images, parameters = simulator.simulate(num_sim=10, return_parameters=True)
88 |
89 | The simulation config file should be a json file with the following structure:
90 |
91 | .. code:: json
92 |
93 | {
94 | "N_PIXELS": 128,
95 | "PIXEL_SIZE": 1.5,
96 | "SIGMA": [0.5, 5.0],
97 | "MODEL_FILE": "path_to_models/models.pt",
98 | "SHIFT": 25.0,
99 | "DEFOCUS": [0.5, 2.0],
100 | "SNR": [0.001, 0.5],
101 | "AMP": 0.1,
102 | "B_FACTOR": [1.0, 100.0]
103 | }
104 |
105 | The pixel size is defined in Angström (Å). The atom sigma defines the size of the Gaussians used to approximate the protein's electron density. Here, each Gaussian represents one amino acid, and while all Gaussians have the same sigma, the value is made to vary in the simulations. The shift is the offset of the protein from the image centre and is given in Angström (Å). The defocus of the microscope is given in units of micrometres (μm). The SNR (Signal-to-noise ratio) is unitless and defines the amount of noise in the simulated images. The Amplitude is a unitless parameter which ranges between 0 and 1. The B-factor is given in units of Angström squared (Å^2) and defines the decay rate of the CTF envelope function.
106 |
107 | Training an amortized posterior model
108 | --------------------------------------
109 | Training of an amortized posterior can be done using the train_npe_model command line utility. The utility takes in an image config file, a train config file, and other training parameters. The utility trains a neural network to approximate the posterior distribution of the parameters given the images.
110 |
111 | .. code:: bash
112 |
113 | train_npe_model \
114 | --image_config_file path_to_simulation_config_file.json \
115 | --train_config_file path_to_train_config_file.json\
116 | --epochs 150 \
117 | --estimator_file posterior.estimator \
118 | --loss_file posterior.loss \
119 | --n_workers 4 \
120 | --simulation_batch_size 5120 \
121 | --train_device cuda
122 |
123 | The training config file should be a json file with the following structure:
124 |
125 | .. code:: json
126 |
127 | {
128 | "EMBEDDING": "RESNET18",
129 | "OUT_DIM": 256,
130 | "NUM_TRANSFORM": 5,
131 | "NUM_HIDDEN_FLOW": 10,
132 | "HIDDEN_DIM_FLOW": 256,
133 | "MODEL": "NSF",
134 | "LEARNING_RATE": 0.0003,
135 | "CLIP_GRADIENT": 5.0,
136 | "THETA_SHIFT": 25,
137 | "THETA_SCALE": 25,
138 | "BATCH_SIZE": 256
139 | }
140 |
141 | When training posterior for your own system, it's important to change THETA_SCALE and THETA_SHIFT. These two parameters normalize the conformational variable in cryoSBI.
142 | THETA_SHIFT and THETA_SCALE need to be adjusted according to the number of structures used in the prior. A good option is to set THETA_SHIFT and THETA_SCALE to the number of structures in the prior divided by two.
143 |
144 | Loading the posterior after training
145 | ------------------------------------
146 | After training the estimator, loading it in Python can be done with the load_estimator in the estimator_utils module.
147 |
148 | .. code:: python
149 |
150 | import cryo_sbi.utils.estimator_utils as est_utils
151 | posterior = est_utils.load_estimator(
152 | config_file_path="path_to_config_file",
153 | estimator_path="path_to_estimator_file",
154 | device="cuda"
155 | )
156 |
157 | Inference
158 | ---------
159 | Sampling from the posterior distribution can be done using the sample_posterior function in the estimator_utils module. The function takes in an estimator, images, and other parameters and returns samples from the posterior distribution.
160 |
161 | .. code:: python
162 |
163 | import cryo_sbi.utils.estimator_utils as est_utils
164 | samples = est_utils.sample_posterior(
165 | estimator=posterior,
166 | images=images,
167 | num_samples=20000,
168 | batch_size=100,
169 | device="cuda",
170 | )
171 |
172 | The Pytorch tensor containing the samples will have the shape (number of samples, number of images). In order to visualize the posterior for each image you can use `matplotlib`.
173 | We can quickly generate a histogram with 50 bins with the following piece of code.
174 |
175 | .. code:: python
176 |
177 | import matplotlib.pyplot as plt
178 | idx_image = 0 # posterior for image with index 0
179 | plt.hist(samples[:, idx_image].flatten(), np.linspace(0, simulator.max_index, 50))
180 |
181 | In this case the x-axis is just the index of the structures in increasing order.
182 |
183 | Latent space
184 | ------------
185 |
186 | Computing the latent features for simulated or experimental particles can be done using the compute_latent_repr function in the estimator_utils module. The function needs a trained posterior estimator and images and computes the latent representation for each image.
187 |
188 | .. code:: python
189 |
190 | import cryo_sbi.utils.estimator_utils as est_utils
191 | latent_vecs = est_utils.compute_latent_repr(
192 | compute_latent_repr(
193 | estimator=posterior,
194 | images=images,
195 | batch_size=100,
196 | device="cuda",
197 | )
198 |
199 | After we computed the latent representation for the images, one possible way to visualize the latent space is to use `UMAP `_ . UMAP generates a two-dimensional representation of the latent space, which should allow us to analyze its important features.
200 |
201 | .. code:: python
202 |
203 | import umap
204 | reducer = umap.UMAP(metric="euclidian", n_components=2, n_neighbors=50)
205 | embedding = reducer.fit_transform(latent_vecs.numpy())
206 |
207 | We can quickly visualize the 2d latent space with matplotlib.
208 |
209 | .. code:: python
210 |
211 | import matplotlib.pyplot as plt
212 | plt.scatter(
213 | embedding[:, 0],
214 | embedding[:, 1],
215 | )
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "cryosbi"
8 | authors = [
9 | { name = "David Silva-Sanchez", email = "david.silva@yale.edu"},
10 | { name = "Lars Dingeldein"},
11 | { name = "Pilar Cossio"},
12 | { name = "Roberto Covino"}
13 | ]
14 |
15 |
16 | version = "0.2"
17 | dependencies = [
18 | "lampe",
19 | "zuko",
20 | "torch",
21 | "numpy",
22 | "matplotlib",
23 | "scipy",
24 | "torchvision",
25 | "mrcfile"
26 | ]
27 |
28 |
29 | [project.scripts]
30 | train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving"
31 | model_to_tensor = "cryo_sbi.utils.command_line_tools:cl_models_to_tensor"
32 |
--------------------------------------------------------------------------------
/src/cryo_sbi/__init__.py:
--------------------------------------------------------------------------------
1 | from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator
2 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/command_line_tools.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from cryo_sbi.inference.train_npe_model import (
3 | npe_train_no_saving,
4 | )
5 |
6 |
7 | def cl_npe_train_no_saving():
8 | cl_parser = argparse.ArgumentParser()
9 |
10 | cl_parser.add_argument(
11 | "--image_config_file", action="store", type=str, required=True
12 | )
13 | cl_parser.add_argument(
14 | "--train_config_file", action="store", type=str, required=True
15 | )
16 | cl_parser.add_argument("--epochs", action="store", type=int, required=True)
17 | cl_parser.add_argument("--estimator_file", action="store", type=str, required=True)
18 | cl_parser.add_argument("--loss_file", action="store", type=str, required=True)
19 | cl_parser.add_argument(
20 | "--train_from_checkpoint",
21 | action="store",
22 | type=bool,
23 | nargs="?",
24 | required=False,
25 | const=True,
26 | default=False,
27 | )
28 | cl_parser.add_argument(
29 | "--state_dict_file", action="store", type=str, required=False, default=False
30 | )
31 | cl_parser.add_argument(
32 | "--n_workers", action="store", type=int, required=False, default=1
33 | )
34 | cl_parser.add_argument(
35 | "--train_device", action="store", type=str, required=False, default="cpu"
36 | )
37 | cl_parser.add_argument(
38 | "--saving_freq", action="store", type=int, required=False, default=20
39 | )
40 | cl_parser.add_argument(
41 | "--simulation_batch_size",
42 | action="store",
43 | type=int,
44 | required=False,
45 | default=1024,
46 | )
47 |
48 | args = cl_parser.parse_args()
49 |
50 | npe_train_no_saving(
51 | image_config=args.image_config_file,
52 | train_config=args.train_config_file,
53 | epochs=args.epochs,
54 | estimator_file=args.estimator_file,
55 | loss_file=args.loss_file,
56 | train_from_checkpoint=args.train_from_checkpoint,
57 | model_state_dict=args.state_dict_file,
58 | n_workers=args.n_workers,
59 | device=args.train_device,
60 | saving_frequency=args.saving_freq,
61 | simulation_batch_size=args.simulation_batch_size,
62 | )
63 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/models/build_models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from functools import partial
3 | import zuko
4 | import lampe
5 | import cryo_sbi.inference.models.estimator_models as estimator_models
6 | from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS
7 |
8 |
9 | def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
10 | """
11 | Function to build NPE estimator with embedding net
12 | from config_file
13 |
14 | Args:
15 | config (dict): config file
16 | embedding_kwargs (dict): kwargs for embedding net
17 |
18 | Returns:
19 | estimator (nn.Module): NPE estimator
20 | """
21 |
22 | if config["MODEL"] == "MAF":
23 | model = zuko.flows.MAF
24 | elif config["MODEL"] == "NSF":
25 | model = zuko.flows.NSF
26 | elif config["MODEL"] == "SOSPF":
27 | model = zuko.flows.SOSPF
28 | else:
29 | raise NotImplementedError(
30 | f"Model : {config['MODEL']} has not been implemented yet!"
31 | )
32 |
33 | try:
34 | embedding = partial(
35 | EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"], **embedding_kwargs
36 | )
37 | except KeyError:
38 | raise NotImplementedError(
39 | f"Model : {config['EMBEDDING']} has not been implemented yet! \
40 | The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
41 | )
42 |
43 | estimator = estimator_models.NPEWithEmbedding(
44 | embedding_net=embedding,
45 | output_embedding_dim=config["OUT_DIM"],
46 | num_transforms=config["NUM_TRANSFORM"],
47 | num_hidden_flow=config["NUM_HIDDEN_FLOW"],
48 | hidden_flow_dim=config["HIDDEN_DIM_FLOW"],
49 | flow=model,
50 | theta_shift=config["THETA_SHIFT"],
51 | theta_scale=config["THETA_SCALE"],
52 | **{"activation": partial(nn.LeakyReLU, 0.1)},
53 | )
54 |
55 | return estimator
56 |
57 |
58 | def build_nre_classifier_model(config: dict, **embedding_kwargs) -> nn.Module:
59 | raise NotImplementedError("NRE classifier model has not been implemented yet!")
60 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/models/embedding_nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import torchvision.transforms as transforms
5 |
6 | from cryo_sbi.utils.image_utils import LowPassFilter, Mask
7 |
8 |
9 | EMBEDDING_NETS = {}
10 |
11 |
12 | def add_embedding(name):
13 | """
14 | Add embedding net to EMBEDDING_NETS dict
15 |
16 | Args:
17 | name (str): name of embedding net
18 |
19 | Returns:
20 | add (function): function to add embedding net to EMBEDDING_NETS dict
21 | """
22 |
23 | def add(class_):
24 | EMBEDDING_NETS[name] = class_
25 | return class_
26 |
27 | return add
28 |
29 |
30 | @add_embedding("RESNET18")
31 | class ResNet18_Encoder(nn.Module):
32 | def __init__(self, output_dimension: int):
33 | super(ResNet18_Encoder, self).__init__()
34 | self.resnet = models.resnet18()
35 | self.resnet.conv1 = nn.Conv2d(
36 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
37 | )
38 | self.resnet.fc = nn.Linear(
39 | in_features=512, out_features=output_dimension, bias=True
40 | )
41 |
42 | def forward(self, x):
43 | x = x.unsqueeze(1)
44 | x = self.resnet(x)
45 | return x
46 |
47 |
48 | @add_embedding("RESNET50")
49 | class ResNet50_Encoder(nn.Module):
50 | def __init__(self, output_dimension: int):
51 | super(ResNet50_Encoder, self).__init__()
52 |
53 | self.resnet = models.resnet50()
54 | self.resnet.conv1 = nn.Conv2d(
55 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
56 | )
57 | self.linear = nn.Linear(1000, output_dimension)
58 |
59 | def forward(self, x):
60 | x = x.unsqueeze(1)
61 | x = self.resnet(x)
62 | x = self.linear(nn.functional.relu(x))
63 | return x
64 |
65 |
66 | @add_embedding("RESNET101")
67 | class ResNet101_Encoder(nn.Module):
68 | def __init__(self, output_dimension: int):
69 | super(ResNet101_Encoder, self).__init__()
70 |
71 | self.resnet = models.resnet101()
72 | self.resnet.conv1 = nn.Conv2d(
73 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
74 | )
75 | self.linear = nn.Linear(1000, output_dimension)
76 |
77 | def forward(self, x):
78 | x = x.unsqueeze(1)
79 | x = self.resnet(x)
80 | x = self.linear(nn.functional.relu(x))
81 | return x
82 |
83 |
84 | @add_embedding("CONVNET")
85 | class ConvNet_Encoder(nn.Module):
86 | def __init__(self, output_dimension: int):
87 | super(ConvNet_Encoder, self).__init__()
88 |
89 | self.convnet = models.convnext_tiny()
90 | self.convnet.features[0][0] = nn.Conv2d(
91 | 1, 96, kernel_size=(4, 4), stride=(4, 4)
92 | )
93 | self.convnet.classifier[2] = nn.Linear(
94 | in_features=768, out_features=output_dimension, bias=True
95 | )
96 |
97 | def forward(self, x):
98 | x = x.unsqueeze(1)
99 | x = self.convnet(x)
100 | return x
101 |
102 |
103 | @add_embedding("CONVNET")
104 | class RegNetX_Encoder(nn.Module):
105 | def __init__(self, output_dimension: int):
106 | super(RegNetX_Encoder, self).__init__()
107 |
108 | self.regnetx = models.regnet_x_3_2gf()
109 | self.regnetx.stem[0] = nn.Conv2d(
110 | 1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
111 | )
112 | self.regnetx.fc = nn.Linear(
113 | in_features=1008, out_features=output_dimension, bias=True
114 | )
115 |
116 | def forward(self, x):
117 | x = x.unsqueeze(1)
118 | x = self.regnetx(x)
119 | return x
120 |
121 |
122 | @add_embedding("EFFICIENT")
123 | class EfficientNet_Encoder(nn.Module):
124 | def __init__(self, output_dimension: int):
125 | super(EfficientNet_Encoder, self).__init__()
126 |
127 | self.efficient_net = models.efficientnet_b3().features
128 | self.efficient_net[0][0] = nn.Conv2d(
129 | 1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
130 | )
131 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
132 | self.leakyrelu = nn.LeakyReLU()
133 | self.linear = nn.Linear(1536, output_dimension)
134 |
135 | def forward(self, x):
136 | x = x.unsqueeze(1)
137 | x = self.efficient_net(x)
138 | x = self.avg_pool(x).flatten(start_dim=1)
139 | x = self.leakyrelu(self.linear(x))
140 | return x
141 |
142 |
143 | @add_embedding("SWINS")
144 | class SwinTransformerS_Encoder(nn.Module):
145 | def __init__(self, output_dimension: int):
146 | super(SwinTransformerS_Encoder, self).__init__()
147 |
148 | self.swin_transformer = models.swin_t()
149 | self.swin_transformer.features[0][0] = nn.Conv2d(
150 | 1, 96, kernel_size=(4, 4), stride=(4, 4)
151 | )
152 | self.swin_transformer.head = nn.Linear(
153 | in_features=768, out_features=output_dimension, bias=True
154 | )
155 |
156 | def forward(self, x):
157 | x = x.unsqueeze(1)
158 | x = self.swin_transformer(x)
159 | return x
160 |
161 |
162 | @add_embedding("WIDERES50")
163 | class WideResnet50_Encoder(nn.Module):
164 | def __init__(self, output_dimension: int):
165 | super(WideResnet50_Encoder, self).__init__()
166 |
167 | self.wideresnet = models.wide_resnet50_2()
168 | self.wideresnet.conv1 = nn.Conv2d(
169 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
170 | )
171 | self.linear = nn.Linear(1000, output_dimension)
172 |
173 | def forward(self, x):
174 | x = x.unsqueeze(1)
175 | x = self.wideresnet(x)
176 | x = self.linear(nn.functional.relu(x))
177 | return x
178 |
179 |
180 | @add_embedding("WIDERES101")
181 | class WideResnet101_Encoder(nn.Module):
182 | def __init__(self, output_dimension: int):
183 | super(WideResnet101_Encoder, self).__init__()
184 |
185 | self.wideresnet = models.wide_resnet101_2()
186 | self.wideresnet.conv1 = nn.Conv2d(
187 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
188 | )
189 | self.linear = nn.Linear(1000, output_dimension)
190 |
191 | def forward(self, x):
192 | x = x.unsqueeze(1)
193 | x = self.wideresnet(x)
194 | x = self.linear(nn.functional.relu(x))
195 | return x
196 |
197 |
198 | @add_embedding("REGNETY")
199 | class RegNetY_Encoder(nn.Module):
200 | def __init__(self, output_dimension: int):
201 | super(RegNetY_Encoder, self).__init__()
202 |
203 | self.regnety = models.regnet_y_1_6gf()
204 | self.regnety.stem[0] = nn.Conv2d(
205 | 1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
206 | )
207 | self.regnety.fc = nn.Linear(
208 | in_features=888, out_features=output_dimension, bias=True
209 | )
210 |
211 | def forward(self, x):
212 | x = x.unsqueeze(1)
213 | x = self.regnety(x)
214 | return x
215 |
216 |
217 | @add_embedding("SHUFFLENET")
218 | class ShuffleNet_Encoder(nn.Module):
219 | def __init__(self, output_dimension: int):
220 | super(ShuffleNet_Encoder, self).__init__()
221 |
222 | self.shuffle_net = models.shufflenet_v2_x0_5()
223 | self.shuffle_net.conv1[0] = nn.Conv2d(
224 | 1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
225 | )
226 | self.shuffle_net.fc = nn.Linear(
227 | in_features=1024, out_features=output_dimension, bias=True
228 | )
229 |
230 | def forward(self, x):
231 | x = x.unsqueeze(1)
232 | x = self.shuffle_net(x)
233 | return x
234 |
235 |
236 | @add_embedding("RESNET18_FFT_FILTER")
237 | class ResNet18_FFT_Encoder(nn.Module):
238 | def __init__(self, output_dimension: int):
239 | super(ResNet18_FFT_Encoder, self).__init__()
240 | self.resnet = models.resnet18()
241 | self.resnet.conv1 = nn.Conv2d(
242 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
243 | )
244 | self.resnet.fc = nn.Linear(
245 | in_features=512, out_features=output_dimension, bias=True
246 | )
247 |
248 | self._fft_filter = LowPassFilter(128, 25)
249 |
250 | def forward(self, x):
251 | # Low pass filter images
252 | x = self._fft_filter(x)
253 | # Proceed as normal
254 | x = x.unsqueeze(1)
255 | x = self.resnet(x)
256 | return x
257 |
258 |
259 | @add_embedding("RESNET18_FFT_FILTER_132")
260 | class ResNet18_FFT_Encoder_132(nn.Module):
261 | def __init__(self, output_dimension: int):
262 | super(ResNet18_FFT_Encoder_132, self).__init__()
263 | self.resnet = models.resnet18()
264 | self.resnet.conv1 = nn.Conv2d(
265 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
266 | )
267 | self.resnet.fc = nn.Linear(
268 | in_features=512, out_features=output_dimension, bias=True
269 | )
270 |
271 | self._fft_filter = LowPassFilter(132, 25)
272 |
273 | def forward(self, x):
274 | # Low pass filter images
275 | x = self._fft_filter(x)
276 | # Proceed as normal
277 | x = x.unsqueeze(1)
278 | x = self.resnet(x)
279 | return x
280 |
281 |
282 | @add_embedding("RESNET18_FFT_FILTER_224")
283 | class ResNet18_FFT_Encoder_224(nn.Module):
284 | def __init__(self, output_dimension: int):
285 | super(ResNet18_FFT_Encoder_224, self).__init__()
286 | self.resnet = models.resnet18()
287 | self.resnet.conv1 = nn.Conv2d(
288 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
289 | )
290 | self.resnet.fc = nn.Linear(
291 | in_features=512, out_features=output_dimension, bias=True
292 | )
293 |
294 | self._fft_filter = LowPassFilter(224, 25)
295 |
296 | def forward(self, x):
297 | # Low pass filter images
298 | x = self._fft_filter(x)
299 | # Proceed as normal
300 | x = x.unsqueeze(1)
301 | x = self.resnet(x)
302 | return x
303 |
304 |
305 | @add_embedding("RESNET18_FFT_FILTER_256")
306 | class ResNet18_FFT_Encoder_256(nn.Module):
307 | def __init__(self, output_dimension: int):
308 | super(ResNet18_FFT_Encoder_256, self).__init__()
309 | self.resnet = models.resnet18()
310 | self.resnet.conv1 = nn.Conv2d(
311 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
312 | )
313 | self.resnet.fc = nn.Linear(
314 | in_features=512, out_features=output_dimension, bias=True
315 | )
316 |
317 | self._fft_filter = LowPassFilter(256, 10)
318 |
319 | def forward(self, x):
320 | # Low pass filter images
321 | x = self._fft_filter(x)
322 | # Proceed as normal
323 | x = x.unsqueeze(1)
324 | x = self.resnet(x)
325 | return x
326 |
327 |
328 | @add_embedding("RESNET34")
329 | class ResNet34_Encoder(nn.Module):
330 | def __init__(self, output_dimension: int):
331 | super(ResNet34_Encoder, self).__init__()
332 | self.resnet = models.resnet34()
333 | self.resnet.conv1 = nn.Conv2d(
334 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
335 | )
336 | self.resnet.fc = nn.Linear(
337 | in_features=512, out_features=output_dimension, bias=True
338 | )
339 |
340 | def forward(self, x):
341 | x = x.unsqueeze(1)
342 | x = self.resnet(x)
343 | return x
344 |
345 |
346 | @add_embedding("RESNET34_FFT_FILTER_256")
347 | class ResNet34_Encoder_FFT_FILTER_256(nn.Module):
348 | def __init__(self, output_dimension: int):
349 | super(ResNet34_Encoder_FFT_FILTER_256, self).__init__()
350 | self.resnet = models.resnet34()
351 | self.resnet.conv1 = nn.Conv2d(
352 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
353 | )
354 | self.resnet.fc = nn.Linear(
355 | in_features=512, out_features=output_dimension, bias=True
356 | )
357 | self._fft_filter = LowPassFilter(256, 50)
358 |
359 | def forward(self, x):
360 | # Low pass filter images
361 | x = self._fft_filter(x)
362 | # Proceed as normal
363 | x = x.unsqueeze(1)
364 | x = self.resnet(x)
365 | return x
366 |
367 |
368 | @add_embedding("VGG19")
369 | class VGG19_Encoder(nn.Module):
370 | def __init__(self, output_dimension: int):
371 | super(VGG19_Encoder, self).__init__()
372 |
373 | self.vgg19 = models.vgg19_bn().features
374 | self.vgg19[0] = nn.Conv2d(
375 | 1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
376 | )
377 |
378 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))
379 |
380 | self.feedforward = nn.Sequential(
381 | *[
382 | nn.Linear(in_features=25088, out_features=4096),
383 | nn.ReLU(inplace=True),
384 | nn.Linear(in_features=4096, out_features=output_dimension, bias=True),
385 | nn.ReLU(inplace=True),
386 | ]
387 | )
388 |
389 | def forward(self, x):
390 | # Low pass filter images
391 | # x = self._fft_filter(x)
392 | # Proceed as normal
393 | x = x.unsqueeze(1)
394 | x = self.vgg19(x)
395 | x = self.avgpool(x).flatten(start_dim=1)
396 | x = self.feedforward(x)
397 | return x
398 |
399 |
400 | @add_embedding("ConvEncoder_Tutorial")
401 | class ConvEncoder(nn.Module):
402 | def __init__(self, output_dimension: int):
403 | super(ConvEncoder, self).__init__()
404 | ndf = 16 # fixed for the tutorial
405 | self.main = nn.Sequential(
406 | # input is 1 x 64 x 64
407 | nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
408 | nn.LeakyReLU(0.2, inplace=True),
409 | # state size. (ndf) x 32 x 32
410 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
411 | # nn.BatchNorm2d(ndf * 2),
412 | nn.LeakyReLU(0.2, inplace=True),
413 | # state size. (ndf*2) x 16 x 16
414 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
415 | # nn.BatchNorm2d(ndf * 4),
416 | nn.LeakyReLU(0.2, inplace=True),
417 | # state size. (ndf*4) x 8 x 8
418 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
419 | # nn.BatchNorm2d(ndf * 8),
420 | nn.LeakyReLU(0.2, inplace=True),
421 | # state size. (ndf*8) x 4 x 4
422 | nn.Conv2d(ndf * 8, output_dimension, 4, 1, 0, bias=False),
423 | # state size. out_dims x 1 x 1
424 | )
425 |
426 | def forward(self, x):
427 | x = x.view(-1, 1, 64, 64)
428 | x = self.main(x)
429 | return x.view(x.size(0), -1) # flatten
430 |
431 |
432 | if __name__ == "__main__":
433 | pass
434 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/models/estimator_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import zuko
4 | from lampe.inference import NPE, NRE
5 |
6 |
7 | class Standardize(nn.Module):
8 | """
9 | Module to standardize inputs and retransform them to the original space
10 |
11 | Args:
12 | mean (torch.Tensor): mean of the data
13 | std (torch.Tensor): standard deviation of the data
14 |
15 | Returns:
16 | standardized (torch.Tensor): standardized data
17 | """
18 |
19 | # Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py
20 | def __init__(self, mean: float, std: float) -> None:
21 | super(Standardize, self).__init__()
22 | mean, std = map(torch.as_tensor, (mean, std))
23 | self.mean = mean
24 | self.std = std
25 | self.register_buffer("_mean", mean)
26 | self.register_buffer("_std", std)
27 |
28 | def forward(self, tensor: torch.Tensor) -> torch.Tensor:
29 | """
30 | Standardize the input tensor
31 |
32 | Args:
33 | tensor (torch.Tensor): input tensor
34 |
35 | Returns:
36 | standardized (torch.Tensor): standardized tensor
37 | """
38 |
39 | return (tensor - self._mean) / self._std
40 |
41 | def transform(self, tensor: torch.Tensor) -> torch.Tensor:
42 | """
43 | Transform the standardized tensor back to the original space
44 |
45 | Args:
46 | tensor (torch.Tensor): input tensor
47 |
48 | Returns:
49 | retransformed (torch.Tensor): retransformed tensor
50 | """
51 |
52 | return (tensor * self._std) + self._mean
53 |
54 |
55 | class NPEWithEmbedding(nn.Module):
56 | """Neural Posterior Estimation with embedding net
57 |
58 | Attributes:
59 | npe (NPE): NPE model
60 | embedding (nn.Module): embedding net
61 | standardize (Standardize): standardization module
62 | """
63 |
64 | def __init__(
65 | self,
66 | embedding_net: nn.Module,
67 | output_embedding_dim: int,
68 | num_transforms: int = 4,
69 | num_hidden_flow: int = 2,
70 | hidden_flow_dim: int = 128,
71 | flow: nn.Module = zuko.flows.MAF,
72 | theta_shift: float = 0.0,
73 | theta_scale: float = 1.0,
74 | **kwargs,
75 | ) -> None:
76 | """
77 | Neural Posterior Estimation with embedding net.
78 |
79 | Args:
80 | embedding_net (nn.Module): embedding net
81 | output_embedding_dim (int): output embedding dimension
82 | num_transforms (int, optional): number of transforms. Defaults to 4.
83 | num_hidden_flow (int, optional): number of hidden layers in flow. Defaults to 2.
84 | hidden_flow_dim (int, optional): hidden dimension in flow. Defaults to 128.
85 | flow (nn.Module, optional): flow. Defaults to zuko.flows.MAF.
86 | theta_shift (float, optional): Shift of the theta for standardization. Defaults to 0.0.
87 | theta_scale (float, optional): Scale of the theta for standardization. Defaults to 1.0.
88 | kwargs: additional arguments for the flow
89 |
90 | Returns:
91 | None
92 | """
93 |
94 | super().__init__()
95 |
96 | self.npe = NPE(
97 | 1,
98 | output_embedding_dim,
99 | transforms=num_transforms,
100 | build=flow,
101 | hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64],
102 | **kwargs,
103 | )
104 |
105 | self.embedding = embedding_net()
106 | self.standardize = Standardize(theta_shift, theta_scale)
107 |
108 | def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
109 | """
110 | Forward pass of the NPE model
111 |
112 | Args:
113 | theta (torch.Tensor): Conformational parameters.
114 | x (torch.Tensor): Image to condition the posterior on.
115 |
116 | Returns:
117 | torch.Tensor: Log probability of the posterior.
118 | """
119 |
120 | return self.npe(self.standardize(theta), self.embedding(x))
121 |
122 | def flow(self, x: torch.Tensor):
123 | """
124 | Conditions the posterior on an image.
125 |
126 | Args:
127 | x (torch.Tensor): Image to condition the posterior on.
128 |
129 | Returns:
130 | zuko.flows.Flow: The posterior distribution.
131 | """
132 | return self.npe.flow(self.embedding(x))
133 |
134 | def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
135 | """
136 | Generate samples from the posterior distribution.
137 |
138 | Args:
139 | x (torch.Tensor): Image to condition the posterior on.
140 | shape (tuple, optional): Shape of the samples. Defaults to (1,).
141 |
142 | Returns:
143 | torch.Tensor: Samples from the posterior distribution.
144 | """
145 |
146 | samples_standardized = self.flow(x).sample(shape)
147 | return self.standardize.transform(samples_standardized)
148 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/priors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import zuko
3 | from torch.distributions.distribution import Distribution
4 | from torch.utils.data import DataLoader, Dataset, IterableDataset
5 |
6 |
7 | def gen_quat() -> torch.Tensor:
8 | """
9 | Generate a random quaternion.
10 |
11 | Returns:
12 | quat (np.ndarray): Random quaternion
13 |
14 | """
15 | count = 0
16 | while count < 1:
17 | quat = 2 * torch.rand(size=(4,)) - 1
18 | norm = torch.sqrt(torch.sum(quat**2))
19 | if 0.2 <= norm <= 1.0:
20 | quat /= norm
21 | count += 1
22 |
23 | return quat
24 |
25 |
26 | def get_image_priors(
27 | max_index, image_config: dict, device="cuda"
28 | ) -> zuko.distributions.BoxUniform:
29 | """
30 | Return uniform prior in 1d from 0 to 19
31 |
32 | Args:
33 | max_index (int): max index of the 1d prior
34 |
35 | Returns:
36 | zuko.distributions.BoxUniform: prior
37 | """
38 | if isinstance(image_config["SIGMA"], list) and len(image_config["SIGMA"]) == 2:
39 | lower = torch.tensor(
40 | [[image_config["SIGMA"][0]]], dtype=torch.float32, device=device
41 | )
42 | upper = torch.tensor(
43 | [[image_config["SIGMA"][1]]], dtype=torch.float32, device=device
44 | )
45 |
46 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound."
47 |
48 | sigma_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
49 |
50 | shift_prior = zuko.distributions.BoxUniform(
51 | lower=torch.tensor(
52 | [-image_config["SHIFT"], -image_config["SHIFT"]],
53 | dtype=torch.float32,
54 | device=device,
55 | ),
56 | upper=torch.tensor(
57 | [image_config["SHIFT"], image_config["SHIFT"]],
58 | dtype=torch.float32,
59 | device=device,
60 | ),
61 | ndims=1,
62 | )
63 |
64 | if isinstance(image_config["DEFOCUS"], list) and len(image_config["DEFOCUS"]) == 2:
65 | lower = torch.tensor(
66 | [[image_config["DEFOCUS"][0]]], dtype=torch.float32, device=device
67 | )
68 | upper = torch.tensor(
69 | [[image_config["DEFOCUS"][1]]], dtype=torch.float32, device=device
70 | )
71 |
72 | assert lower > 0.0, "The lower bound for DEFOCUS must be positive."
73 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound."
74 |
75 | defocus_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
76 |
77 | if (
78 | isinstance(image_config["B_FACTOR"], list)
79 | and len(image_config["B_FACTOR"]) == 2
80 | ):
81 | lower = torch.tensor(
82 | [[image_config["B_FACTOR"][0]]], dtype=torch.float32, device=device
83 | )
84 | upper = torch.tensor(
85 | [[image_config["B_FACTOR"][1]]], dtype=torch.float32, device=device
86 | )
87 |
88 | assert lower > 0.0, "The lower bound for B_FACTOR must be positive."
89 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound."
90 |
91 | b_factor_prior = zuko.distributions.BoxUniform(
92 | lower=lower, upper=upper, ndims=1
93 | )
94 |
95 | if isinstance(image_config["SNR"], list) and len(image_config["SNR"]) == 2:
96 | lower = torch.tensor(
97 | [[image_config["SNR"][0]]], dtype=torch.float32, device=device
98 | ).log10()
99 | upper = torch.tensor(
100 | [[image_config["SNR"][1]]], dtype=torch.float32, device=device
101 | ).log10()
102 |
103 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound."
104 |
105 | snr_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
106 |
107 | amp_prior = zuko.distributions.BoxUniform(
108 | lower=torch.tensor([[image_config["AMP"]]], dtype=torch.float32, device=device),
109 | upper=torch.tensor([[image_config["AMP"]]], dtype=torch.float32, device=device),
110 | ndims=1,
111 | )
112 |
113 | index_prior = zuko.distributions.BoxUniform(
114 | lower=torch.tensor([0], dtype=torch.float32, device=device),
115 | upper=torch.tensor([max_index], dtype=torch.float32, device=device),
116 | )
117 | quaternion_prior = QuaternionPrior(device)
118 | if (
119 | image_config.get("ROTATIONS")
120 | and isinstance(image_config["ROTATIONS"], list)
121 | and len(image_config["ROTATIONS"]) == 4
122 | ):
123 | test_quat = image_config["ROTATIONS"]
124 | quaternion_prior = QuaternionTestPrior(test_quat, device)
125 |
126 | return ImagePrior(
127 | index_prior,
128 | quaternion_prior,
129 | sigma_prior,
130 | shift_prior,
131 | defocus_prior,
132 | b_factor_prior,
133 | amp_prior,
134 | snr_prior,
135 | device=device,
136 | )
137 |
138 |
139 | class QuaternionPrior:
140 | def __init__(self, device) -> None:
141 | self.device = device
142 |
143 | def sample(self, shape) -> torch.Tensor:
144 | quats = torch.stack(
145 | [gen_quat().to(self.device) for _ in range(shape[0])], dim=0
146 | )
147 | return quats
148 |
149 |
150 | class QuaternionTestPrior:
151 | def __init__(self, quat, device) -> None:
152 | self.device = device
153 | self.quat = torch.tensor(quat, device=device)
154 |
155 | def sample(self, shape) -> torch.Tensor:
156 | quats = torch.stack([self.quat for _ in range(shape[0])], dim=0)
157 | return quats
158 |
159 |
160 | class ImagePrior:
161 | def __init__(
162 | self,
163 | index_prior,
164 | quaternion_prior,
165 | sigma_prior,
166 | shift_prior,
167 | defocus_prior,
168 | b_factor_prior,
169 | amp_prior,
170 | snr_prior,
171 | device,
172 | ) -> None:
173 | self.priors = [
174 | index_prior,
175 | quaternion_prior,
176 | sigma_prior,
177 | shift_prior,
178 | defocus_prior,
179 | b_factor_prior,
180 | amp_prior,
181 | snr_prior,
182 | ]
183 |
184 | def sample(self, shape) -> torch.Tensor:
185 | samples = [prior.sample(shape) for prior in self.priors]
186 | return samples
187 |
188 |
189 | class PriorDataset(IterableDataset):
190 | def __init__(
191 | self,
192 | prior: Distribution,
193 | batch_shape: torch.Size = (),
194 | ):
195 | super().__init__()
196 |
197 | self.prior = prior
198 | self.batch_shape = batch_shape
199 |
200 | def __iter__(self):
201 | while True:
202 | theta = self.prior.sample(self.batch_shape)
203 | yield theta
204 |
205 |
206 | class PriorLoader(DataLoader):
207 | def __init__(
208 | self,
209 | prior: Distribution,
210 | batch_size: int = 2**8, # 256
211 | **kwargs,
212 | ):
213 | super().__init__(
214 | PriorDataset(prior, batch_shape=(batch_size,)),
215 | batch_size=None,
216 | **kwargs,
217 | )
218 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/train_npe_model.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import json
3 | import torch
4 | import numpy as np
5 | import torch.optim as optim
6 | from torch.utils.data import TensorDataset
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 | from lampe.data import JointLoader, H5Dataset
10 | from lampe.inference import NPELoss
11 | from lampe.utils import GDStep
12 | from itertools import islice
13 |
14 | from cryo_sbi.inference.priors import get_image_priors, PriorLoader
15 | from cryo_sbi.inference.models.build_models import build_npe_flow_model
16 | from cryo_sbi.inference.validate_train_config import check_train_params
17 | from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator
18 | from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
19 | from cryo_sbi.inference.validate_train_config import check_train_params
20 | import cryo_sbi.utils.image_utils as img_utils
21 |
22 |
23 | def load_model(
24 | train_config: str, model_state_dict: str, device: str, train_from_checkpoint: bool
25 | ) -> torch.nn.Module:
26 | """
27 | Load model from checkpoint or from scratch.
28 |
29 | Args:
30 | train_config (str): path to train config file
31 | model_state_dict (str): path to model state dict
32 | device (str): device to load model to
33 | train_from_checkpoint (bool): whether to load model from checkpoint or from scratch
34 | """
35 |
36 | check_train_params(train_config)
37 | estimator = build_npe_flow_model(train_config)
38 | if train_from_checkpoint:
39 | if not isinstance(model_state_dict, str):
40 | raise Warning("No model state dict specified! --model_state_dict is empty")
41 | print(f"Loading model parameters from {model_state_dict}")
42 | estimator.load_state_dict(torch.load(model_state_dict))
43 | estimator.to(device=device)
44 | return estimator
45 |
46 |
47 | def npe_train_no_saving(
48 | image_config: str,
49 | train_config: str,
50 | epochs: int,
51 | estimator_file: str,
52 | loss_file: str,
53 | train_from_checkpoint: bool = False,
54 | model_state_dict: Union[str, None] = None,
55 | n_workers: int = 1,
56 | device: str = "cpu",
57 | saving_frequency: int = 20,
58 | simulation_batch_size: int = 1024,
59 | ) -> None:
60 | """
61 | Train NPE model by simulating training data on the fly.
62 | Saves model and loss to disk.
63 |
64 | Args:
65 | image_config (str): path to image config file
66 | train_config (str): path to train config file
67 | epochs (int): number of epochs
68 | estimator_file (str): path to estimator file
69 | loss_file (str): path to loss file
70 | train_from_checkpoint (bool, optional): train from checkpoint. Defaults to False.
71 | model_state_dict (str, optional): path to pretrained model state dict. Defaults to None.
72 | n_workers (int, optional): number of workers. Defaults to 1.
73 | device (str, optional): training device. Defaults to "cpu".
74 | saving_frequency (int, optional): frequency of saving model. Defaults to 20.
75 | whiten_filter (Union[None, str], optional): path to whiten filter. Defaults to None.
76 |
77 | Raises:
78 | Warning: No model state dict specified! --model_state_dict is empty
79 |
80 | Returns:
81 | None
82 | """
83 |
84 | train_config = json.load(open(train_config))
85 | check_train_params(train_config)
86 | image_config = json.load(open(image_config))
87 |
88 | assert simulation_batch_size >= train_config["BATCH_SIZE"]
89 | assert simulation_batch_size % train_config["BATCH_SIZE"] == 0
90 |
91 | if image_config["MODEL_FILE"].endswith("npy"):
92 | models = (
93 | torch.from_numpy(
94 | np.load(image_config["MODEL_FILE"]),
95 | )
96 | .to(device)
97 | .to(torch.float32)
98 | )
99 | else:
100 | models = torch.load(image_config["MODEL_FILE"]).to(device).to(torch.float32)
101 |
102 | image_prior = get_image_priors(len(models) - 1, image_config, device="cpu")
103 | prior_loader = PriorLoader(
104 | image_prior, batch_size=simulation_batch_size, num_workers=n_workers
105 | )
106 |
107 | num_pixels = torch.tensor(
108 | image_config["N_PIXELS"], dtype=torch.float32, device=device
109 | )
110 | pixel_size = torch.tensor(
111 | image_config["PIXEL_SIZE"], dtype=torch.float32, device=device
112 | )
113 |
114 | estimator = load_model(
115 | train_config, model_state_dict, device, train_from_checkpoint
116 | )
117 |
118 | loss = NPELoss(estimator)
119 | optimizer = optim.AdamW(
120 | estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
121 | )
122 | step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"])
123 | mean_loss = []
124 |
125 | print("Training neural netowrk:")
126 | estimator.train()
127 | with tqdm(range(epochs), unit="epoch") as tq:
128 | for epoch in tq:
129 | losses = []
130 | for parameters in islice(prior_loader, 100):
131 | (
132 | indices,
133 | quaternions,
134 | res,
135 | shift,
136 | defocus,
137 | b_factor,
138 | amp,
139 | snr,
140 | ) = parameters
141 | images = cryo_em_simulator(
142 | models,
143 | indices.to(device, non_blocking=True),
144 | quaternions.to(device, non_blocking=True),
145 | res.to(device, non_blocking=True),
146 | shift.to(device, non_blocking=True),
147 | defocus.to(device, non_blocking=True),
148 | b_factor.to(device, non_blocking=True),
149 | amp.to(device, non_blocking=True),
150 | snr.to(device, non_blocking=True),
151 | num_pixels,
152 | pixel_size,
153 | )
154 | for _indices, _images in zip(
155 | indices.split(train_config["BATCH_SIZE"]),
156 | images.split(train_config["BATCH_SIZE"]),
157 | ):
158 | losses.append(
159 | step(
160 | loss(
161 | _indices.to(device, non_blocking=True),
162 | _images.to(device, non_blocking=True),
163 | )
164 | )
165 | )
166 | losses = torch.stack(losses)
167 |
168 | tq.set_postfix(loss=losses.mean().item())
169 | mean_loss.append(losses.mean().item())
170 | if epoch % saving_frequency == 0:
171 | torch.save(estimator.state_dict(), estimator_file + f"_epoch={epoch}")
172 |
173 | torch.save(estimator.state_dict(), estimator_file)
174 | torch.save(torch.tensor(mean_loss), loss_file)
175 |
--------------------------------------------------------------------------------
/src/cryo_sbi/inference/validate_train_config.py:
--------------------------------------------------------------------------------
1 | def check_train_params(config: dict) -> None:
2 | """
3 | Checks if all necessary parameters are provided.
4 |
5 | Args:
6 | config (dict): Dictionary containing training parameters.
7 |
8 | Returns:
9 | None
10 | """
11 |
12 | needed_keys = [
13 | "EMBEDDING",
14 | "OUT_DIM",
15 | "NUM_TRANSFORM",
16 | "NUM_HIDDEN_FLOW",
17 | "HIDDEN_DIM_FLOW",
18 | "MODEL",
19 | "LEARNING_RATE",
20 | "CLIP_GRADIENT",
21 | "BATCH_SIZE",
22 | "THETA_SHIFT",
23 | "THETA_SCALE",
24 | ]
25 |
26 | for key in needed_keys:
27 | assert key in config.keys(), f"Please provide a value for {key}"
28 |
29 | return
30 |
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/src/cryo_sbi/utils/__init__.py
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/command_line_tools.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from cryo_sbi.utils.generate_models import models_to_tensor
3 |
4 |
5 | def cl_models_to_tensor():
6 | cl_parser = argparse.ArgumentParser(
7 | description="Convert models to tensor for cryoSBI",
8 | epilog="pdb-files: The name for the pdbs must contain a {} to be replaced by the index of the pdb file. The index starts at 0. \
9 | For example protein_{}.pdb. trr-files: For .trr files you must provide a topology file."
10 | )
11 | cl_parser.add_argument(
12 | "--model_files", action="store", type=str, required=True
13 | )
14 | cl_parser.add_argument(
15 | "--output_file", action="store", type=str, required=True
16 | )
17 | cl_parser.add_argument(
18 | "--n_pdbs", action="store", type=int, required=False, default=None
19 | )
20 | cl_parser.add_argument(
21 | "--top_file", action="store", type=str, required=False, default=None
22 | )
23 | args = cl_parser.parse_args()
24 | models_to_tensor(
25 | model_files=args.model_files,
26 | output_file=args.output_file,
27 | n_pdbs=args.n_pdbs,
28 | top_file=args.top_file
29 | )
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/estimator_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | from cryo_sbi.inference.models import build_models
4 |
5 |
6 | @torch.no_grad()
7 | def evaluate_log_prob(
8 | estimator: torch.nn.Module,
9 | images: torch.Tensor,
10 | theta: torch.Tensor,
11 | batch_size: int = 0,
12 | device: str = "cpu",
13 | ) -> torch.Tensor:
14 | """
15 | Evaluates the log probability of a given set of images under a given estimator.
16 |
17 | Args:
18 | estimator (torch.nn.Module): The posterior model to use for evaluation.
19 | images (torch.Tensor): The input images used to condition the posterior.
20 | theta (torch.Tensor): The parameter values at which to evaluate the log probability.
21 | batch_size (int, optional): The batch size for batching the images. Defaults to 0.
22 | device (str, optional): The device to use for computation. Defaults to "cpu".
23 |
24 | Returns:
25 | torch.Tensor: The log probabilities of the images under the estimator.
26 | """
27 |
28 | # batching images if necessary
29 | if images.shape[0] > batch_size and batch_size > 0:
30 | images = torch.split(images, split_size_or_sections=batch_size, dim=0)
31 | else:
32 | batch_size = images.shape[0]
33 | images = [images]
34 |
35 | # theta dimensions [num_eval, num_images, 1]
36 | if theta.ndim == 3:
37 | num_eval = theta.shape[0]
38 | num_images = images.shape[0]
39 | assert theta.shape == torch.Size([num_eval, num_images, 1])
40 |
41 | elif theta.ndim == 2:
42 | raise IndexError("theta must have 3 dimensions [num_eval, num_images, 1]")
43 |
44 | elif theta.ndim == 1:
45 | theta = theta.reshape(-1, 1, 1).repeat(1, batch_size, 1)
46 |
47 | log_probs = []
48 | for image_batch in images:
49 | posterior = estimator.flow(image_batch.to(device))
50 | log_probs.append(posterior.log_prob(estimator.standardize(theta.to(device))))
51 |
52 | log_probs = torch.cat(log_probs, dim=1)
53 | return log_probs
54 |
55 |
56 | @torch.no_grad()
57 | def sample_posterior(
58 | estimator: torch.nn.Module,
59 | images: torch.Tensor,
60 | num_samples: int,
61 | batch_size: int = 100,
62 | device: str = "cpu",
63 | ) -> torch.Tensor:
64 | """
65 | Samples from the posterior distribution
66 |
67 | Args:
68 | estimator (torch.nn.Module): The posterior to use for sampling.
69 | images (torch.Tensor): The images used to condition the posterior.
70 | num_samples (int): The number of samples to draw
71 | batch_size (int, optional): The batch size for sampling. Defaults to 100.
72 | device (str, optional): The device to use. Defaults to "cpu".
73 |
74 | Returns:
75 | torch.Tensor: The posterior samples
76 | """
77 |
78 | theta_samples = []
79 |
80 | if images.shape[0] > batch_size and batch_size > 0:
81 | images = torch.split(images, split_size_or_sections=batch_size, dim=0)
82 | else:
83 | batch_size = images.shape[0]
84 | images = [images]
85 |
86 | for image_batch in images:
87 | samples = estimator.sample(
88 | image_batch.to(device, non_blocking=True), shape=(num_samples,)
89 | ).cpu()
90 | theta_samples.append(samples.reshape(-1, image_batch.shape[0]))
91 |
92 | return torch.cat(theta_samples, dim=1)
93 |
94 |
95 | @torch.no_grad()
96 | def compute_latent_repr(
97 | estimator: torch.nn.Module,
98 | images: torch.Tensor,
99 | batch_size: int = 100,
100 | device: str = "cpu",
101 | ) -> torch.Tensor:
102 | """
103 | Computes the latent representation of images.
104 |
105 | Args:
106 | estimator (torch.nn.Module): Posterior model for which to compute the latent representation.
107 | images (torch.Tensor): The images to compute the latent representation for.
108 | batch_size (int, optional): The batch size to use. Defaults to 100.
109 | device (str, optional): The device to use. Defaults to "cpu".
110 |
111 | Returns:
112 | torch.Tensor: The latent representation of the images.
113 | """
114 |
115 | latent_space_samples = []
116 |
117 | if images.shape[0] > batch_size and batch_size > 0:
118 | images = torch.split(images, split_size_or_sections=batch_size, dim=0)
119 | else:
120 | batch_size = images.shape[0]
121 | images = [images]
122 |
123 | for image_batch in images:
124 | samples = estimator.embedding(image_batch.to(device, non_blocking=True)).cpu()
125 | latent_space_samples.append(samples.reshape(image_batch.shape[0], -1))
126 |
127 | return torch.cat(latent_space_samples, dim=0)
128 |
129 |
130 | def load_estimator(
131 | config_file_path: str, estimator_path: str, device: str = "cpu"
132 | ) -> torch.nn.Module:
133 | """
134 | Loads a trained estimator.
135 |
136 | Args:
137 | config_file_path (str): Path to the config file used to train the estimator.
138 | estimator_path (str): Path to the estimator.
139 | device (str, optional): The device to use. Defaults to "cpu".
140 |
141 | Returns:
142 | torch.nn.Module: The loaded estimator.
143 | """
144 |
145 | train_config = json.load(open(config_file_path))
146 | estimator = build_models.build_npe_flow_model(train_config)
147 | estimator.load_state_dict(
148 | torch.load(estimator_path, map_location=torch.device(device))
149 | )
150 | estimator.to(device)
151 | estimator.eval()
152 |
153 | return estimator
154 |
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/generate_models.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import MDAnalysis as mda
3 | from MDAnalysis.analysis import align
4 | import torch
5 |
6 |
7 | def pdb_parser_(fname: str, atom_selection: str = "name CA") -> torch.tensor:
8 | """
9 | Parses a pdb file and returns a coarsed grained atomic model of the protein.
10 | The atomic model is a 5xN array, where N is the number of residues in the protein.
11 | The first three rows are the x, y, z coordinates of the alpha carbons.
12 |
13 | Parameters
14 | ----------
15 | fname : str
16 | The path to the pdb file.
17 |
18 | Returns
19 | -------
20 | atomic_model : torch.tensor
21 | The coarse grained atomic model of the protein.
22 | """
23 |
24 | univ = mda.Universe(fname)
25 | univ.atoms.translate(-univ.atoms.center_of_mass())
26 |
27 | model = torch.from_numpy(univ.select_atoms(atom_selection).positions.T)
28 |
29 | return model
30 |
31 |
32 | def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1, **kwargs):
33 | """
34 | Parses multiple pdb files and returns an coarsed grained model of the protein. The atomic model is a 5xN array, where N is the number of atoms or residues in the protein. The first three rows are the x, y, z coordinates of the atoms or residues. The fourth row is the atomic number of the atoms or the density of the residues. The fifth row is the variance of the atoms or residues, which is the resolution of the cryo-EM map divided by pi squared.
35 |
36 | Parameters
37 | ----------
38 | file_formatter : str
39 | The path to the pdb file. The path must contain the placeholder {} for the pdb index. For example, if the path is "data/pdb/{}.pdb", then the placeholder is {}.
40 | n_pdbs : int
41 | The number of pdb files to parse.
42 | output_file : str
43 | The path to the output file. The output file must be a .pt file.
44 | mode : str
45 | The mode of the atomic model. Either "resid" or "all atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein.
46 | """
47 |
48 | models = pdb_parser_(file_formatter.format(start_index), **kwargs)
49 | models = torch.zeros((n_pdbs, *models.shape))
50 |
51 | for i in range(0, n_pdbs):
52 | models[i] = pdb_parser_(file_formatter.format(start_index + i), **kwargs)
53 |
54 | if output_file.endswith("pt"):
55 | torch.save(models, output_file)
56 |
57 | else:
58 | raise ValueError("Model file format not supported. Please use .pt.")
59 |
60 | return
61 |
62 |
63 | def traj_parser_(top_file: str, traj_file: str) -> torch.tensor:
64 | """
65 | Parses a traj file and returns a coarsed grained atomic model of the protein.
66 | The atomic model is a Mx3xN array, where M is the number of frames in the trajectory,
67 | and N is the number of residues in the protein. The first three rows in axis 1 are the x, y, z coordinates of the alpha carbons.
68 |
69 | Parameters
70 | ----------
71 | top_file : str
72 | The path to the traj file.
73 |
74 | Returns
75 | -------
76 | atomic_model : torch.tensor
77 | The coarse grained atomic model of the protein.
78 | """
79 |
80 | ref = mda.Universe(top_file)
81 | ref.atoms.translate(-ref.atoms.center_of_mass())
82 |
83 | mobile = mda.Universe(top_file, traj_file)
84 | align.AlignTraj(mobile, ref, select="name CA", in_memory=True).run()
85 |
86 | atomic_models = torch.zeros(
87 | (mobile.trajectory.n_frames, 3, mobile.select_atoms("name CA").n_atoms)
88 | )
89 |
90 | for i in range(mobile.trajectory.n_frames):
91 | mobile.trajectory[i]
92 |
93 | atomic_models[i, 0:3, :] = torch.from_numpy(
94 | mobile.select_atoms("name CA").positions.T
95 | )
96 |
97 | return atomic_models
98 |
99 |
100 | def traj_parser(top_file: str, traj_file: str, output_file: str) -> None:
101 | """
102 | Parses a traj file and returns an atomic model of the protein. The atomic model is a Mx5xN array, where M is the number of frames in the trajectory, and N is the number of atoms in the protein. The first three rows in axis 1 are the x, y, z coordinates of the atoms. The fourth row is the atomic number of the atoms. The fifth row is the variance of the atoms before the resolution is applied.
103 |
104 | Parameters
105 | ----------
106 | top_file : str
107 | The path to the topology file.
108 | traj_file : str
109 | The path to the trajectory file.
110 | output_file : str
111 | The path to the output file. Must be a .pt file.
112 | mode : str
113 | The mode of the atomic model. Either "resid" or "all-atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein.
114 |
115 | Returns
116 | -------
117 | None
118 | """
119 |
120 | atomic_models = traj_parser_(top_file, traj_file)
121 |
122 | if output_file.endswith("pt"):
123 | torch.save(atomic_models, output_file)
124 |
125 | else:
126 | raise ValueError("Model file format not supported. Please use .pt.")
127 |
128 | return
129 |
130 |
131 | def models_to_tensor(
132 | model_files,
133 | output_file,
134 | n_pdbs: Union[int, None] = None,
135 | top_file: Union[str, None] = None,
136 | ):
137 | """
138 | Converts different model files to a torch tensor.
139 |
140 | Parameters
141 | ----------
142 | model_files : list
143 | A list of model files to convert to a torch tensor.
144 |
145 | output_file : str
146 | The path to the output file. Must be a .pt file.
147 |
148 | n_models : int
149 | The number of models to convert to a torch tensor. Just needed for models in pdb files.
150 |
151 | top_file : str
152 | The path to the topology file. Just needed for models in trr files.
153 |
154 | Returns
155 | -------
156 | None
157 | """
158 | assert output_file.endswith("pt"), "The output file must be a .pt file."
159 | if model_files.endswith("trr"):
160 | assert top_file is not None, "Please provide a topology file."
161 | assert n_pdbs is None, "The number of pdb files is not needed for trr files."
162 | traj_parser(top_file, model_files, output_file)
163 | elif model_files.endswith("pdb"):
164 | assert n_pdbs is not None, "Please provide the number of pdb files."
165 | assert top_file is None, "The topology file is not needed for pdb files."
166 | pdb_parser(model_files, n_pdbs, output_file)
167 |
168 |
169 |
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Union
3 | import numpy as np
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torch.distributions as d
7 | import mrcfile
8 | from tqdm import tqdm
9 |
10 |
11 | def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tensor:
12 | """
13 | Create a circular mask for a given image size and radius.
14 |
15 | Args:
16 | n_pixels (int): Side length of the image in pixels.
17 | radius (int): Radius of the circle.
18 | inside (bool, optional): If True, the mask will be True inside the circle. Defaults to True.
19 |
20 | Returns:
21 | mask (torch.Tensor): Mask of shape (n_pixels, n_pixels).
22 | """
23 |
24 | grid = torch.linspace(-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels)
25 | r_2d = grid[None, :] ** 2 + grid[:, None] ** 2
26 |
27 | if inside is True:
28 | mask = r_2d < radius**2
29 | else:
30 | mask = r_2d > radius**2
31 |
32 | return mask
33 |
34 |
35 | class Mask:
36 | """
37 | Mask a circular region in an image.
38 |
39 | Args:
40 | image_size (int): Number of pixels in the image.
41 | radius (int): Radius of the circle.
42 | inside (bool, optional): If True, the mask will be True inside the circle. Defaults to True.
43 | """
44 |
45 | def __init__(self, image_size: int, radius: int, inside: bool = False) -> None:
46 | self.image_size = image_size
47 | self.n_pixels = radius
48 | self.mask = circular_mask(image_size, radius, inside=inside)
49 |
50 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
51 | """Mask a circular region in an image.
52 |
53 | Args:
54 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
55 |
56 | Returns:
57 | image (torch.Tensor): Image with masked region equal to zero.
58 | """
59 |
60 | if len(image.shape) == 2:
61 | image[self.mask] = 0
62 | elif len(image.shape) == 3:
63 | image[:, self.mask] = 0
64 | else:
65 | raise NotImplementedError
66 |
67 | return image
68 |
69 |
70 | def fourier_down_sample(
71 | image: torch.Tensor, image_size: int, n_pixels: int
72 | ) -> torch.Tensor:
73 | """
74 | Downsample an image by removing the outer frequencies.
75 |
76 | Args:
77 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
78 | image_size (int): Side length of the image in pixels.
79 | n_pixels (int): Number of pixels to remove from each side.
80 |
81 | Returns:
82 | reconstructed (torch.Tensor): Downsampled image.
83 | """
84 |
85 | fft_image = torch.fft.fft2(image)
86 | fft_image = torch.fft.fftshift(fft_image)
87 |
88 | if len(image.shape) == 2:
89 | fft_image = fft_image[
90 | n_pixels : image_size - n_pixels,
91 | n_pixels : image_size - n_pixels,
92 | ]
93 | elif len(image.shape) == 3:
94 | fft_image = fft_image[
95 | :,
96 | n_pixels : image_size - n_pixels,
97 | n_pixels : image_size - n_pixels,
98 | ]
99 | else:
100 | raise NotImplementedError
101 |
102 | fft_image = torch.fft.fftshift(fft_image)
103 | reconstructed = torch.fft.ifft2(fft_image).real
104 | return reconstructed
105 |
106 |
107 | class FourierDownSample:
108 | """
109 | Downsample an image by removing the outer frequencies.
110 |
111 | Args:
112 | image_size (int): Size of image in pixels.
113 | down_sampled_size (int): Size of downsampled image in pixels.
114 | """
115 |
116 | def __init__(self, image_size: int, down_sampled_size: int) -> None:
117 | self._image_size = image_size
118 | self._n_pixels = (image_size - down_sampled_size) // 2
119 |
120 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
121 | """Downsample an image by removing the outer frequencies.
122 |
123 | Args:
124 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
125 |
126 | Returns:
127 | down_sampled (torch.Tensor): Downsampled image.
128 | """
129 |
130 | down_sampled = fourier_down_sample(
131 | image, image_size=self._image_size, n_pixels=self._n_pixels
132 | )
133 |
134 | return down_sampled
135 |
136 |
137 | class LowPassFilter:
138 | """
139 | Low pass filter an image by removing the outer frequencies.
140 |
141 | Args:
142 | image_size (int): Side length of the image in pixels.
143 | frequency_cutoff (int): Frequency cutoff.
144 | """
145 |
146 | def __init__(self, image_size: int, frequency_cutoff: int):
147 | self.mask = circular_mask(image_size, frequency_cutoff, inside=False)
148 |
149 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
150 | """
151 | Low pass filter an image by removing the outer frequencies.
152 |
153 | Args:
154 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
155 |
156 | Returns:
157 | reconstructed (torch.Tensor): Low pass filtered image.
158 | """
159 | fft_image = torch.fft.fft2(image)
160 | fft_image = torch.fft.fftshift(fft_image)
161 |
162 | if len(image.shape) == 2:
163 | fft_image[self.mask] = 0 + 0j
164 | elif len(image.shape) == 3:
165 | fft_image[:, self.mask] = 0 + 0j
166 | else:
167 | raise NotImplementedError
168 |
169 | fft_image = torch.fft.fftshift(fft_image)
170 | reconstructed = torch.fft.ifft2(fft_image).real
171 | return reconstructed
172 |
173 |
174 | class GaussianLowPassFilter:
175 | """
176 | Low pass filter by dampening the outer frequencies with a Gaussian.
177 | """
178 |
179 | def __init__(self, image_size: int, sigma: int):
180 | self._image_size = image_size
181 | self._sigma = sigma
182 | self._grid = torch.linspace(
183 | -0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size
184 | )
185 | self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2
186 | self._mask = torch.exp(-self._r_2d / (2 * sigma**2))
187 |
188 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
189 | """
190 | Low pass filter an image by dampening the outer frequencies with a Gaussian.
191 |
192 | Args:
193 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
194 |
195 | Returns:
196 | reconstructed (torch.Tensor): Low pass filtered image.
197 | """
198 |
199 | fft_image = torch.fft.fft2(image)
200 | fft_image = torch.fft.fftshift(fft_image)
201 |
202 | if len(image.shape) == 2:
203 | fft_image = fft_image * self._mask
204 | elif len(image.shape) == 3:
205 | fft_image = fft_image * self._mask.unsqueeze(0)
206 | else:
207 | raise NotImplementedError
208 |
209 | fft_image = torch.fft.fftshift(fft_image)
210 | reconstructed = torch.fft.ifft2(fft_image).real
211 | return reconstructed
212 |
213 |
214 | class NormalizeIndividual:
215 | """
216 | Normalize an image by subtracting the mean and dividing by the standard deviation.
217 | """
218 |
219 | def __init__(self) -> None:
220 | pass
221 |
222 | def __call__(self, images: torch.Tensor) -> torch.Tensor:
223 | """
224 | Normalize an image by subtracting the mean and dividing by the standard deviation.
225 |
226 | Args:
227 | images (torch.Tensor): Image of shape (n_channels, n_pixels, n_pixels).
228 |
229 | Returns:
230 | normalized (torch.Tensor): Normalized image.
231 | """
232 | if len(images.shape) == 2:
233 | mean = images.mean()
234 | std = images.std()
235 | images = images.unsqueeze(0)
236 | elif len(images.shape) == 3:
237 | mean = images.mean(dim=[1, 2])
238 | std = images.std(dim=[1, 2])
239 | else:
240 | raise NotImplementedError
241 |
242 | return transforms.functional.normalize(images, mean=mean, std=std)
243 |
244 |
245 | def mrc_to_tensor(image_path: str) -> torch.Tensor:
246 | """
247 | Convert an MRC file to a tensor.
248 |
249 | Args:
250 | image_path (str): Path to the MRC file.
251 |
252 | Returns:
253 | image (torch.Tensor): Image of shape (n_pixels, n_pixels).
254 | """
255 |
256 | assert isinstance(image_path, str), "image path needs to be a string"
257 | with mrcfile.open(image_path) as mrc:
258 | image = mrc.data
259 | return torch.from_numpy(image)
260 |
261 |
262 | class MRCtoTensor:
263 | """
264 | Convert an MRC file to a tensor.
265 | """
266 |
267 | def __init__(self) -> None:
268 | pass
269 |
270 | def __call__(self, image_path: str) -> torch.Tensor:
271 | """
272 | Convert an MRC file to a tensor.
273 |
274 | Args:
275 | image_path (str): Path to the MRC file.
276 |
277 | Returns:
278 | image (torch.Tensor): Image of shape (n_pixels, n_pixels).
279 | """
280 |
281 | return mrc_to_tensor(image_path)
282 |
283 |
284 | def estimate_noise_psd(images: torch.Tensor, image_size: int, mask_radius : Union[int, None] = None) -> torch.Tensor:
285 | """
286 | Estimates the power spectral density (PSD) of the noise in a set of images.
287 |
288 | Args:
289 | images (torch.Tensor): A tensor containing the input images. The shape of the tensor should be (N, H, W),
290 | where N is the number of images, H is the height, and W is the width.
291 |
292 | Returns:
293 | torch.Tensor: A tensor containing the estimated PSD of the noise. The shape of the tensor is (H, W), where H is the height
294 | and W is the width of the images.
295 |
296 | """
297 | if mask_radius is None:
298 | mask_radius = image_size // 2
299 | mask = circular_mask(image_size, mask_radius, inside=False)
300 | denominator = mask.sum() * images.shape[0]
301 | images_masked = images * mask
302 | mean_est = images_masked.sum() / denominator
303 | image_masked_fft = torch.fft.fft2(images_masked)
304 | noise_psd_est = torch.sum(torch.abs(image_masked_fft)**2, dim=[0]) / denominator
305 | noise_psd_est[image_size // 2, image_size // 2] -= mean_est
306 |
307 | return noise_psd_est
308 |
309 |
310 | class WhitenImage:
311 | """
312 | Whiten an image by dividing by the square root of the noise PSD.
313 |
314 | Args:
315 | image_size (int): Size of image in pixels.
316 | mask_radius (int, optional): Radius of the mask. Defaults to None.
317 |
318 | """
319 |
320 | def __init__(self, image_size: int, mask_radius: Union[int, None] = None) -> None:
321 | self.image_size = image_size
322 | self.mask_radius = mask_radius
323 |
324 | def _estimate_noise_psd(self, images: torch.Tensor) -> torch.Tensor:
325 | """
326 | Estimates the power spectral density (PSD) of the noise in a set of images.
327 | """
328 | noise_psd = estimate_noise_psd(images, self.image_size, self.mask_radius)
329 | return noise_psd
330 |
331 | def __call__(self, images: torch.Tensor) -> torch.Tensor:
332 | """
333 | Whiten an image by dividing by the square root of the noise PSD.
334 |
335 | Args:
336 | image (torch.Tensor): Image of shape (n_pixels, n_pixels).
337 |
338 | Returns:
339 | image (torch.Tensor): Whitened image.
340 | """
341 |
342 | assert images.ndim == 3, "Image should have shape (num_images , n_pixels, n_pixels)"
343 | noise_psd = self._estimate_noise_psd(images) ** -0.5
344 | images_fft = torch.fft.fft2(images)
345 | images_fft = images_fft * noise_psd
346 | images = torch.fft.ifft2(images_fft).real
347 | return images
348 |
349 |
350 | class MRCdataset:
351 | """
352 | Creates a dataset of MRC files.
353 | Each MRC file is converted to a tensor and has a unique index.
354 |
355 | Args:
356 | image_paths (list[str]): List of paths to MRC files.
357 |
358 | Methods:
359 | build_index_map: Builds a map of indices to file paths and file indices.
360 | getitem: Returns a at the given global index.
361 | __getitem__: Returns tensor of the MRC file at the given index.
362 | """
363 |
364 | def __init__(self, image_paths: List[str]):
365 | super().__init__()
366 | self.paths = image_paths
367 | self._num_paths = len(image_paths)
368 | self._index_map = None
369 |
370 | def __len__(self):
371 | return self._num_paths
372 |
373 | def __getitem__(self, idx):
374 | return idx, mrc_to_tensor(self.paths[idx])
375 |
376 | def _extract_num_particles(self, path):
377 | future_mrc = mrcfile.open_async(path)
378 | mrc = future_mrc.result()
379 | data_shape = mrc.data.shape
380 | # img_stack = mrc.is_image_stack()
381 | num_images = data_shape[0] if len(data_shape) > 2 else 1
382 | return num_images
383 |
384 | def build_index_map(self):
385 | """
386 | Builds a map of image indices to file paths and file indices.
387 | """
388 | if self._index_map is not None:
389 | print("Index map already built.")
390 | return
391 |
392 | self._path_index = []
393 | self._file_index = []
394 | print("Initalizing indexing...")
395 | for idx, path in tqdm(enumerate(self.paths), total=self._num_paths):
396 | num_images = self._extract_num_particles(path)
397 | self._path_index += [idx] * num_images
398 | self._file_index += list(range(num_images))
399 | self._index_map = True
400 |
401 | def save_index_map(self, path: str):
402 | """
403 | Saves the index map to a file.
404 |
405 | Args:
406 | path (str): Path to save the index map.
407 | """
408 | assert (
409 | self._index_map is not None
410 | ), "Index map not built. First call build_index_map()"
411 | np.savez(
412 | path,
413 | path_index=self._path_index,
414 | file_index=self._file_index,
415 | paths=self.paths,
416 | )
417 |
418 | def load_index_map(self, path: str):
419 | """
420 | Loads the index map from a file.
421 |
422 | Args:
423 | path (str): Path to load the index map.
424 | """
425 | index_map = np.load(path)
426 | assert len(self.paths) == len(index_map["paths"]), "Number of paths do not match the index map."
427 | for path1, path2 in zip(self.paths, index_map["paths"]):
428 | assert path1 == path2, "Paths do not match the index map."
429 | self._path_index = index_map["path_index"]
430 | self._file_index = index_map["file_index"]
431 | self._index_map = True
432 |
433 | def get_image(self, idx: Union[int, list]):
434 | """
435 | Returns the image at the given global index.
436 |
437 | Args:
438 | idx (int, List): Global index of the image.
439 | """
440 | assert (
441 | self._index_map is not None
442 | ), "Index map not built. First call build_index_map() or load_index_map()"
443 | if isinstance(idx, int):
444 | image = mrc_to_tensor(self.paths[self._path_index[idx]])
445 | if image.ndim > 2:
446 | return image[self._file_index[idx]]
447 | if isinstance(idx, (list, np.ndarray, torch.Tensor)):
448 | return [
449 | mrc_to_tensor(self.paths[self._path_index[i]])[self._file_index[i]]
450 | for i in idx
451 | ]
452 |
453 |
454 | class MRCloader(torch.utils.data.DataLoader):
455 | """
456 | Creates a dataloader of MRC files.
457 |
458 | Args:
459 | image_paths (list[str]): List of paths to MRC files.
460 | **kwargs: Keyword arguments passed to torch.utils.data.DataLoader.
461 | """
462 |
463 | def __init__(self, image_paths: List[str], **kwargs):
464 | super().__init__(MRCdataset(image_paths), batch_size=None, **kwargs)
465 |
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/micrograph_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Optional, Union, List
3 | from cryo_sbi.utils.image_utils import mrc_to_tensor
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torchvision.transforms.functional as TF
7 |
8 |
9 | class RandomMicrographPatches:
10 | """
11 | Iterator that returns random patches from a list of micrographs.
12 |
13 | Args:
14 | micro_graphs (List[Union[str, torch.Tensor]]): List of micrographs.
15 | transform (Union[None, transforms.Compose]): Transform to apply to the patches.
16 | patch_size (int): Size of the patches.
17 | batch_size (int, optional): Batch size. Defaults to 1.
18 | """
19 |
20 | def __init__(
21 | self,
22 | micro_graphs: List[Union[str, torch.Tensor]],
23 | transform: Union[None, transforms.Compose],
24 | patch_size: int,
25 | max_iter: Optional[int] = 1000,
26 | ) -> None:
27 | if all(map(isinstance, micro_graphs, [str] * len(micro_graphs))):
28 | self._micro_graphs = [mrc_to_tensor(path) for path in micro_graphs]
29 | else:
30 | self._micro_graphs = micro_graphs
31 |
32 | self._transform = transform
33 | self._patch_size = patch_size
34 | self._max_iter = max_iter
35 | self._current_iter = 0
36 |
37 | def __iter__(self) -> "RandomMicrographPatches":
38 | return self
39 |
40 | def __next__(self) -> torch.Tensor:
41 | if self._current_iter == self._max_iter:
42 | self._current_iter = 0
43 | raise StopIteration
44 | random_micrograph = random.choice(self._micro_graphs)
45 | assert random_micrograph.ndim == 2, "Micrograph should be 2D"
46 | x = random.randint(0, random_micrograph.shape[0] - self._patch_size)
47 | y = random.randint(0, random_micrograph.shape[1] - self._patch_size)
48 | patch = TF.crop(
49 | random_micrograph,
50 | top=y,
51 | left=x,
52 | height=self._patch_size,
53 | width=self._patch_size,
54 | )
55 | if self._transform is not None:
56 | patch = self._transform(patch)
57 | else:
58 | patch = patch.unsqueeze(0)
59 | self._current_iter += 1
60 | return patch
61 |
62 | def __len__(self) -> int:
63 | return self._max_iter
64 |
65 | @property
66 | def shape(self) -> torch.Size:
67 | """
68 | Shape of the transformed patches.
69 |
70 | Returns:
71 | torch.Size: Shape of the transformed patches.
72 | """
73 | return self.__next__().shape
74 |
75 |
76 | def compute_average_psd(
77 | images: Union[torch.Tensor, RandomMicrographPatches],
78 | device: str = "cpu",
79 | ) -> torch.Tensor:
80 | """
81 | Compute the average PSD of a set of images.
82 |
83 | Args:
84 | images (Union[torch.Tensor, RandomMicrographPatches]): Images to compute the average PSD of.
85 | device (str, optional): Device to compute the PSD on. Defaults to "cpu".
86 |
87 | Returns:
88 | avg_psd (torch.Tensor): Average PSD of the images.
89 | """
90 |
91 | if isinstance(images, RandomMicrographPatches):
92 | avg_psd = torch.zeros(images.shape[1:], device=device)
93 | for image in images: # TODO add progress bar with tqdm
94 | fft_image = torch.fft.fft2(image[0].to(device, non_blocking=True))
95 | psd = torch.abs(fft_image) ** 2
96 | avg_psd += psd / len(images)
97 | # add convergence check
98 | elif isinstance(images, torch.Tensor):
99 | fft_images = torch.fft.fft2(images.to(device=device), dim=(-2, -1))
100 | avg_psd = torch.mean(torch.abs(fft_images) ** 2, dim=0)
101 | return avg_psd.cpu()
102 |
--------------------------------------------------------------------------------
/src/cryo_sbi/utils/visualize_models.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def _scatter_plot_models(
7 | model: torch.Tensor, view_angles: tuple = (30, 45), **plot_kwargs: dict
8 | ) -> None:
9 | fig = plt.figure()
10 | ax = fig.add_subplot(111, projection="3d")
11 | ax.view_init(*view_angles)
12 |
13 | ax.scatter(*model, **plot_kwargs)
14 |
15 | ax.set_xlabel("X")
16 | ax.set_ylabel("Y")
17 | ax.set_zlabel("Z")
18 |
19 |
20 | def _sphere_plot_models(
21 | model: torch.Tensor,
22 | radius: float = 4,
23 | view_angles: tuple = (30, 45),
24 | **plot_kwargs: dict,
25 | ) -> None:
26 | fig = plt.figure()
27 | ax = fig.add_subplot(111, projection="3d")
28 | ax.view_init(30, 45)
29 |
30 | spheres = []
31 | for x, y, z in zip(model[0], model[1], model[2]):
32 | spheres.append((x.item(), y.item(), z.item(), radius))
33 |
34 | for idx, sphere in enumerate(spheres):
35 | x, y, z, r = sphere
36 |
37 | u = np.linspace(0, 2 * np.pi, 100)
38 | v = np.linspace(0, np.pi, 100)
39 | x = r * np.outer(np.cos(u), np.sin(v)) + x
40 | y = r * np.outer(np.sin(u), np.sin(v)) + y
41 | z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + z
42 |
43 | ax.plot_surface(x, y, z, **plot_kwargs)
44 |
45 | ax.set_xlabel("X")
46 | ax.set_ylabel("Y")
47 | ax.set_zlabel("Z")
48 |
49 |
50 | def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None:
51 | """
52 | Plot a model from the tensor.
53 |
54 | Args:
55 | model (torch.Tensor): Model to plot, should be a 2D tensor with shape (3, num_atoms)
56 | method (str, optional): Method to use for plotting. Defaults to "scatter". Can be "scatter" or "sphere".
57 | "scatter" is fast and simple, "sphere" is a proper 3D representation (Take long to render).
58 | **kwargs: Additional keyword arguments to pass to the plotting function.
59 |
60 | Returns:
61 | None
62 |
63 | Raises:
64 | AssertionError: If the model is not a 2D tensor with shape (3, num_atoms).
65 | ValueError: If the method is not "scatter" or "sphere".
66 |
67 | """
68 |
69 | assert model.ndim == 2, "Model should be 2D tensor"
70 | assert model.shape[0] == 3, "Model should have 3 rows"
71 |
72 | if method == "scatter":
73 | _scatter_plot_models(model, **kwargs)
74 |
75 | elif method == "sphere":
76 | _sphere_plot_models(model, **kwargs)
77 |
78 | else:
79 | raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.")
80 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable
2 | import json
3 | import numpy as np
4 | import torch
5 |
6 | from cryo_sbi.wpa_simulator.ctf import apply_ctf
7 | from cryo_sbi.wpa_simulator.image_generation import project_density
8 | from cryo_sbi.wpa_simulator.noise import add_noise
9 | from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
10 | from cryo_sbi.inference.priors import get_image_priors
11 | from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
12 |
13 |
14 | def cryo_em_simulator(
15 | models,
16 | index,
17 | quaternion,
18 | sigma,
19 | shift,
20 | defocus,
21 | b_factor,
22 | amp,
23 | snr,
24 | num_pixels,
25 | pixel_size,
26 | ):
27 | """
28 | Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models.
29 |
30 | Args:
31 | models (torch.Tensor): A tensor of coars grained models (num_models, 3, num_beads).
32 | index (torch.Tensor): A tensor of indices to select the models to simulate.
33 | quaternion (torch.Tensor): A tensor of quaternions to rotate the models.
34 | sigma (float): The standard deviation of the Gaussian kernel used to project the density.
35 | shift (torch.Tensor): A tensor of shifts to apply to the models.
36 | defocus (float): The defocus value of the contrast transfer function (CTF).
37 | b_factor (float): The B-factor of the CTF.
38 | amp (float): The amplitude contrast of the CTF.
39 | snr (float): The signal-to-noise ratio of the simulated image.
40 | num_pixels (int): The number of pixels in the simulated image.
41 | pixel_size (float): The size of each pixel in the simulated image.
42 |
43 | Returns:
44 | torch.Tensor: A tensor of the simulated cryo-EM image.
45 | """
46 | models_selected = models[index.round().long().flatten()]
47 | image = project_density(
48 | models_selected,
49 | quaternion,
50 | sigma,
51 | shift,
52 | num_pixels,
53 | pixel_size,
54 | )
55 | image = apply_ctf(image, defocus, b_factor, amp, pixel_size)
56 | image = add_noise(image, snr)
57 | image = gaussian_normalize_image(image)
58 | return image
59 |
60 |
61 | class CryoEmSimulator:
62 | def __init__(self, config_fname: str, device: str = "cpu"):
63 | self._device = device
64 | self._load_params(config_fname)
65 | self._load_models()
66 | self._priors = get_image_priors(self.max_index, self._config, device=device)
67 | self._num_pixels = torch.tensor(
68 | self._config["N_PIXELS"], dtype=torch.float32, device=device
69 | )
70 | self._pixel_size = torch.tensor(
71 | self._config["PIXEL_SIZE"], dtype=torch.float32, device=device
72 | )
73 |
74 | def _load_params(self, config_fname: str) -> None:
75 | """
76 | Loads the parameters from the config file into a dictionary.
77 |
78 | Args:
79 | config_fname (str): Path to the configuration file.
80 |
81 | Returns:
82 | None
83 | """
84 |
85 | config = json.load(open(config_fname))
86 | check_image_params(config)
87 | self._config = config
88 |
89 | def _load_models(self) -> None:
90 | """
91 | Loads the models from the model file specified in the config file.
92 |
93 | Returns:
94 | None
95 |
96 | """
97 | if self._config["MODEL_FILE"].endswith("npy"):
98 | models = (
99 | torch.from_numpy(
100 | np.load(self._config["MODEL_FILE"]),
101 | )
102 | .to(self._device)
103 | .to(torch.float32)
104 | )
105 | elif self._config["MODEL_FILE"].endswith("pt"):
106 | models = (
107 | torch.load(self._config["MODEL_FILE"])
108 | .to(self._device)
109 | .to(torch.float32)
110 | )
111 |
112 | else:
113 | raise NotImplementedError(
114 | "Model file format not supported. Please use .npy or .pt."
115 | )
116 |
117 | self._models = models
118 |
119 | assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)."
120 | assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)."
121 |
122 | @property
123 | def max_index(self) -> int:
124 | """
125 | Returns the maximum index of the model file.
126 |
127 | Returns:
128 | int: Maximum index of the model file.
129 | """
130 | return len(self._models) - 1
131 |
132 | def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None):
133 | """
134 | Simulate cryo-EM images using the specified models and prior distributions.
135 |
136 | Args:
137 | num_sim (int): The number of images to simulate.
138 | indices (torch.Tensor, optional): The indices of the images to simulate. If None, all images are simulated.
139 | return_parameters (bool, optional): Whether to return the sampled parameters used for simulation.
140 | batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch.
141 |
142 | Returns:
143 | torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels),
144 | and optionally the sampled parameters as a tuple of tensors.
145 | """
146 |
147 | parameters = self._priors.sample((num_sim,))
148 | indices = parameters[0] if indices is None else indices
149 | if indices is not None:
150 | assert isinstance(
151 | indices, torch.Tensor
152 | ), "Indices are not a torch.tensor, converting to torch.tensor."
153 | assert (
154 | indices.dtype == torch.float32
155 | ), "Indices are not a torch.float32, converting to torch.float32."
156 | assert (
157 | indices.ndim == 2
158 | ), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
159 | parameters[0] = indices
160 |
161 | images = []
162 | if batch_size is None:
163 | batch_size = num_sim
164 | for i in range(0, num_sim, batch_size):
165 | batch_indices = indices[i : i + batch_size]
166 | batch_parameters = [param[i : i + batch_size] for param in parameters[1:]]
167 | batch_images = cryo_em_simulator(
168 | self._models,
169 | batch_indices,
170 | *batch_parameters,
171 | self._num_pixels,
172 | self._pixel_size,
173 | )
174 | images.append(batch_images.cpu())
175 |
176 | images = torch.cat(images, dim=0)
177 |
178 | if return_parameters:
179 | return images.cpu(), parameters
180 | else:
181 | return images.cpu()
182 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/ctf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.Tensor:
6 | """
7 | Applies the CTF to the image.
8 |
9 | Args:
10 | image (torch.Tensor): The image to apply the CTF to.
11 | defocus (torch.Tensor): The defocus value.
12 | b_factor (torch.Tensor): The B-factor value.
13 | amp (torch.Tensor): The amplitude value.
14 | pixel_size (torch.Tensor): The pixel size value.
15 |
16 | Returns:
17 | torch.Tensor: The image with the CTF applied.
18 | """
19 |
20 | num_batch, num_pixels, _ = image.shape
21 | freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device)
22 | x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij")
23 |
24 | freq2_2d = x**2 + y**2
25 | freq2_2d = freq2_2d.expand(num_batch, -1, -1)
26 | imag = torch.zeros_like(freq2_2d, device=image.device) * 1j
27 |
28 | env = torch.exp(-b_factor * freq2_2d * 0.5)
29 | phase = defocus * torch.pi * 2.0 * 10000 * 0.019866 # hardcoded 0.019866 for 300kV
30 |
31 | ctf = (
32 | -amp * torch.cos(phase * freq2_2d * 0.5)
33 | - torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5)
34 | + imag
35 | )
36 | ctf = ctf * env / amp
37 |
38 | conv_image_ctf = torch.fft.fft2(image) * ctf
39 | image_ctf = torch.fft.ifft2(conv_image_ctf).real
40 |
41 | return image_ctf
42 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/image_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def gen_quat() -> torch.Tensor:
6 | """
7 | Generate a random quaternion.
8 |
9 | Returns:
10 | quat (np.ndarray): Random quaternion
11 |
12 | """
13 | count = 0
14 | while count < 1:
15 | quat = 2 * torch.rand(size=(4,)) - 1
16 | norm = torch.sqrt(torch.sum(quat**2))
17 | if 0.2 <= norm <= 1.0:
18 | quat /= norm
19 | count += 1
20 |
21 | return quat
22 |
23 |
24 | def gen_rot_matrix(quats: torch.Tensor) -> torch.Tensor:
25 | # TODO add docstring explaining the quaternion convention qr, qx, qy, qz
26 | """
27 | Generate a rotation matrix from a quaternion.
28 |
29 | Args:
30 | quat (torch.Tensor): Quaternion (n_batch, 4)
31 |
32 | Returns:
33 | rot_matrix (torch.Tensor): Rotation matrix
34 | """
35 |
36 | rot_matrix = torch.zeros((quats.shape[0], 3, 3), device=quats.device)
37 |
38 | rot_matrix[:, 0, 0] = 1 - 2 * (quats[:, 2] ** 2 + quats[:, 3] ** 2)
39 | rot_matrix[:, 0, 1] = 2 * (quats[:, 1] * quats[:, 2] - quats[:, 3] * quats[:, 0])
40 | rot_matrix[:, 0, 2] = 2 * (quats[:, 1] * quats[:, 3] + quats[:, 2] * quats[:, 0])
41 |
42 | rot_matrix[:, 1, 0] = 2 * (quats[:, 1] * quats[:, 2] + quats[:, 3] * quats[:, 0])
43 | rot_matrix[:, 1, 1] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 3] ** 2)
44 | rot_matrix[:, 1, 2] = 2 * (quats[:, 2] * quats[:, 3] - quats[:, 1] * quats[:, 0])
45 |
46 | rot_matrix[:, 2, 0] = 2 * (quats[:, 1] * quats[:, 3] - quats[:, 2] * quats[:, 0])
47 | rot_matrix[:, 2, 1] = 2 * (quats[:, 2] * quats[:, 3] + quats[:, 1] * quats[:, 0])
48 | rot_matrix[:, 2, 2] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 2] ** 2)
49 |
50 | return rot_matrix
51 |
52 |
53 | def project_density(
54 | coords: torch.Tensor,
55 | quats: torch.Tensor,
56 | sigma: torch.Tensor,
57 | shift: torch.Tensor,
58 | num_pixels: int,
59 | pixel_size: float,
60 | ) -> torch.Tensor:
61 | """
62 | Generate a 2D projections from a set of coordinates.
63 |
64 | Args:
65 | coords (torch.Tensor): Coordinates of the atoms in the images
66 | sigma (float): Standard deviation of the Gaussian function used to model electron density.
67 | num_pixels (int): Number of pixels along one image size.
68 | pixel_size (float): Pixel size in Angstrom
69 |
70 | Returns:
71 | image (torch.Tensor): Images generated from the coordinates
72 | """
73 |
74 | num_batch, _, num_atoms = coords.shape
75 | norm = 1 / (2 * torch.pi * sigma**2 * num_atoms)
76 |
77 | grid_min = -pixel_size * num_pixels * 0.5
78 | grid_max = pixel_size * num_pixels * 0.5
79 |
80 | rot_matrix = gen_rot_matrix(quats)
81 | grid = torch.arange(grid_min, grid_max, pixel_size, device=coords.device)[
82 | 0 : num_pixels.long()
83 | ].repeat(
84 | num_batch, 1
85 | ) # [0: num_pixels.long()] is needed due to single precision error in some cases
86 |
87 | coords_rot = torch.bmm(rot_matrix, coords)
88 | coords_rot[:, :2, :] += shift.unsqueeze(-1)
89 |
90 | gauss_x = torch.exp_(
91 | -0.5 * (((grid.unsqueeze(-1) - coords_rot[:, 0, :].unsqueeze(1)) / sigma) ** 2)
92 | )
93 | gauss_y = torch.exp_(
94 | -0.5 * (((grid.unsqueeze(-1) - coords_rot[:, 1, :].unsqueeze(1)) / sigma) ** 2)
95 | ).transpose(1, 2)
96 |
97 | image = torch.bmm(gauss_x, gauss_y) * norm.reshape(-1, 1, 1)
98 |
99 | return image
100 |
101 |
102 | '''def project_density(
103 | atomic_model: torch.Tensor,
104 | quats: torch.Tensor,
105 | delta_sigma: torch.Tensor,
106 | shift: torch.Tensor,
107 | num_pixels: int,
108 | pixel_size: float,
109 | ) -> torch.Tensor:
110 | """
111 | Generate a 2D projections from a set of coordinates.
112 |
113 | Args:
114 | atomic_model (torch.Tensor): Coordinates of the atoms in the images
115 | res (float): resolution of the images in Angstrom
116 | num_pixels (int): Number of pixels along one image size.
117 | pixel_size (float): Pixel size in Angstrom
118 |
119 | Returns:
120 | image (torch.Tensor): Images generated from the coordinates
121 | """
122 |
123 | num_batch, _, num_atoms = atomic_model.shape
124 |
125 | variances = atomic_model[:, 4, :] * delta_sigma[:, 0]
126 | amplitudes = atomic_model[:, 3, :] / torch.sqrt((2 * torch.pi * variances))
127 |
128 | grid_min = -pixel_size * num_pixels * 0.5
129 | grid_max = pixel_size * num_pixels * 0.5
130 |
131 | rot_matrix = gen_rot_matrix(quats)
132 | grid = torch.arange(grid_min, grid_max, pixel_size, device=atomic_model.device)[
133 | 0 : num_pixels.long()
134 | ].repeat(
135 | num_batch, 1
136 | ) # [0: num_pixels.long()] is needed due to single precision error in some cases
137 |
138 | coords_rot = torch.bmm(rot_matrix, atomic_model[:, :3, :])
139 | coords_rot[:, :2, :] += shift.unsqueeze(-1)
140 |
141 | gauss_x = torch.exp_(
142 | -((grid.unsqueeze(-1) - coords_rot[:, 0, :].unsqueeze(1)) ** 2)
143 | / variances.unsqueeze(1)
144 | ) * amplitudes.unsqueeze(1)
145 |
146 | gauss_y = torch.exp(
147 | -((grid.unsqueeze(-1) - coords_rot[:, 1, :].unsqueeze(1)) ** 2)
148 | / variances.unsqueeze(1)
149 | ) * amplitudes.unsqueeze(1)
150 |
151 | image = torch.bmm(gauss_x, gauss_y.transpose(1, 2)) # * norms
152 | image /= torch.norm(image, dim=[-2, -1]).reshape(-1, 1, 1) # do we need this normalization?
153 |
154 | return image'''
155 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/noise.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tensor:
7 | """
8 | Creates a circular mask of radius RADIUS_MASK centered in the image
9 |
10 | Args:
11 | n_pixels (int): Number of pixels along image side.
12 | radius (int): Radius of the mask.
13 |
14 | Returns:
15 | mask (torch.Tensor): Mask of shape (n_pixels, n_pixels).
16 | """
17 |
18 | grid = torch.linspace(
19 | -0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device
20 | )
21 | r_2d = grid[None, :] ** 2 + grid[:, None] ** 2
22 | mask = r_2d < radius**2
23 |
24 | return mask
25 |
26 |
27 | def get_snr(images, snr):
28 | """
29 | Computes the SNR of the images
30 | """
31 | mask = circular_mask(
32 | n_pixels=images.shape[-1],
33 | radius=images.shape[-1] // 2, # TODO: make this a parameter
34 | device=images.device,
35 | )
36 | signal_power = torch.std(
37 | images[:, mask], dim=[-1]
38 | ) # images are not centered at 0, so std is not the same as power
39 | assert signal_power.shape[0] == images.shape[0]
40 | noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(
41 | torch.pow(torch.tensor(10), snr)
42 | )
43 |
44 | return noise_power
45 |
46 |
47 | def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor:
48 | """
49 | Adds noise to image.
50 |
51 | Args:
52 | image (torch.Tensor): Image of shape (n_pixels, n_pixels).
53 | image_params (dict): Dictionary with image parameters.
54 | seed (int, optional): Seed for random number generator. Defaults to None.
55 |
56 | Returns:
57 | image_noise (torch.Tensor): Image with noise of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
58 | """
59 |
60 | if seed is not None:
61 | torch.manual_seed(seed)
62 |
63 | noise_power = get_snr(image, snr)
64 | noise = torch.randn_like(image, device=image.device)
65 |
66 | noise = noise * noise_power.reshape(-1, 1, 1)
67 |
68 | image_noise = image + noise
69 |
70 | return image_noise
71 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 |
4 |
5 | def gaussian_normalize_image(images: torch.Tensor) -> torch.Tensor:
6 | """
7 | Normalize an images by subtracting the mean and dividing by the standard deviation.
8 |
9 | Args:
10 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
11 |
12 | Returns:
13 | normalized (torch.Tensor): Normalized image.
14 | """
15 |
16 | mean = images.mean(dim=[1, 2])
17 | std = images.std(dim=[1, 2])
18 |
19 | return transforms.functional.normalize(images, mean=mean, std=std)
20 |
--------------------------------------------------------------------------------
/src/cryo_sbi/wpa_simulator/validate_image_config.py:
--------------------------------------------------------------------------------
1 | def check_image_params(config: dict) -> None:
2 | """
3 | Checks if all necessary parameters are provided.
4 |
5 | Args:
6 | config (dict): Dictionary containing image parameters.
7 |
8 | Returns:
9 | None
10 | """
11 |
12 | needed_keys = [
13 | "N_PIXELS",
14 | "PIXEL_SIZE",
15 | "SIGMA",
16 | "SHIFT",
17 | "DEFOCUS",
18 | "SNR",
19 | "MODEL_FILE",
20 | "AMP",
21 | "B_FACTOR",
22 | ]
23 |
24 | for key in needed_keys:
25 | assert key in config.keys(), f"Please provide a value for {key}"
26 |
27 | return None
28 |
--------------------------------------------------------------------------------
/tests/config_files/image_params_testing.json:
--------------------------------------------------------------------------------
1 | {
2 | "N_PIXELS": 64,
3 | "PIXEL_SIZE": 2.06,
4 | "SIGMA": [0.5, 5.0],
5 | "MODEL_FILE": "tests/models/hsp90_models.pt",
6 | "SHIFT": 20.0,
7 | "DEFOCUS": [1.5, 3.5],
8 | "SNR": [0.05, 0.05],
9 | "AMP": 0.1,
10 | "B_FACTOR": [1.0, 100.0]
11 | }
12 |
--------------------------------------------------------------------------------
/tests/config_files/training_params_npe_testing.json:
--------------------------------------------------------------------------------
1 | {
2 | "EMBEDDING": "RESNET18",
3 | "OUT_DIM": 256,
4 | "NUM_TRANSFORM": 5,
5 | "NUM_HIDDEN_FLOW": 10,
6 | "HIDDEN_DIM_FLOW": 256,
7 | "MODEL": "NSF",
8 | "LEARNING_RATE": 0.0003,
9 | "CLIP_GRADIENT": 5.0,
10 | "THETA_SHIFT": 25,
11 | "THETA_SCALE": 25,
12 | "BATCH_SIZE": 256
13 | }
14 |
--------------------------------------------------------------------------------
/tests/data/test.mrc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/tests/data/test.mrc
--------------------------------------------------------------------------------
/tests/models/hsp90_models.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/tests/models/hsp90_models.pt
--------------------------------------------------------------------------------
/tests/test_embeddings.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from itertools import product
4 |
5 | from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS
6 |
7 | embedding_networks = list(EMBEDDING_NETS.keys())
8 | num_images_to_test = [1, 5]
9 | out_dims_to_test = [10, 100]
10 | cases_to_test = list(product(embedding_networks, num_images_to_test, out_dims_to_test))
11 |
12 |
13 | @pytest.mark.parametrize(("embedding_name", "num_images", "out_dim"), cases_to_test)
14 | def test_embedding(embedding_name, num_images, out_dim):
15 |
16 | if "FFT_FILTER_" in embedding_name:
17 | size = embedding_name.split("FFT_FILTER_")[1]
18 | test_images = torch.randn(num_images, int(size), int(size))
19 | elif "Tutorial" in embedding_name:
20 | test_images = torch.randn(num_images, 64, 64)
21 | else:
22 | test_images = torch.randn(num_images, 128, 128)
23 |
24 | embedding = EMBEDDING_NETS[embedding_name](out_dim)
25 | out = embedding(test_images).shape
26 | assert out == torch.Size([num_images, out_dim]), embedding_name
27 |
--------------------------------------------------------------------------------
/tests/test_estimator_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 | import torch
4 | import numpy as np
5 | import json
6 |
7 | from cryo_sbi.inference.models import build_models
8 | from cryo_sbi.inference.models.estimator_models import NPEWithEmbedding
9 | from cryo_sbi.inference.validate_train_config import check_train_params
10 | from cryo_sbi.utils.estimator_utils import (
11 | sample_posterior,
12 | compute_latent_repr,
13 | evaluate_log_prob,
14 | load_estimator,
15 | )
16 |
17 |
18 | @pytest.fixture
19 | def train_params():
20 | config = json.load(open("tests/config_files/training_params_npe_testing.json"))
21 | check_train_params(config)
22 | return config
23 |
24 |
25 | @pytest.fixture
26 | def train_config_path():
27 | return "tests/config_files/training_params_npe_testing.json"
28 |
29 |
30 | @pytest.mark.parametrize(
31 | ("num_images", "num_samples", "batch_size"),
32 | [(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)],
33 | )
34 | def test_sampling(train_params, num_images, num_samples, batch_size):
35 | estimator = build_models.build_npe_flow_model(train_params)
36 | estimator.eval()
37 | images = torch.randn((num_images, 128, 128))
38 | samples = sample_posterior(
39 | estimator, images, num_samples=num_samples, batch_size=batch_size
40 | )
41 | assert samples.shape == torch.Size(
42 | [num_samples, num_images]
43 | ), f"Failed with: num_images: {num_images}, num_samles:{num_samples}, batch_size:{batch_size}"
44 |
45 |
46 | @pytest.mark.parametrize(
47 | ("num_images", "batch_size"), [(1, 1), (2, 2), (1, 5), (100, 10)]
48 | )
49 | def test_latent_extraction(train_params, num_images, batch_size):
50 | estimator = build_models.build_npe_flow_model(train_params)
51 | estimator.eval()
52 |
53 | latent_dim = train_params["OUT_DIM"]
54 | images = torch.randn((num_images, 128, 128))
55 | samples = compute_latent_repr(estimator, images, batch_size=batch_size)
56 | assert samples.shape == torch.Size(
57 | [num_images, latent_dim]
58 | ), f"Failed with: num_images: {num_images}, batch_size:{batch_size}"
59 |
60 |
61 | @pytest.mark.parametrize(
62 | ("num_images", "num_eval", "batch_size"),
63 | [(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)],
64 | )
65 | def test_logprob_eval(train_params, num_images, num_eval, batch_size):
66 | estimator = build_models.build_npe_flow_model(train_params)
67 | estimator.eval()
68 | images = torch.randn((num_images, 128, 128))
69 | theta = torch.linspace(0, 25, num_eval)
70 | samples = evaluate_log_prob(estimator, images, theta, batch_size=batch_size)
71 | assert samples.shape == torch.Size(
72 | [num_eval, num_images]
73 | ), f"Failed with: num_images: {num_images}, num_eval:{num_eval}, batch_size:{batch_size}"
74 |
75 |
76 | def test_load_estimator(train_params, train_config_path):
77 | estimator = build_models.build_npe_flow_model(train_params)
78 | torch.save(estimator.state_dict(), "tests/config_files/test_estimator.estimator")
79 | estimator = load_estimator(
80 | train_config_path, "tests/config_files/test_estimator.estimator"
81 | )
82 | assert isinstance(estimator, NPEWithEmbedding)
83 | os.remove("tests/config_files/test_estimator.estimator")
84 |
--------------------------------------------------------------------------------
/tests/test_image_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from cryo_sbi.utils import image_utils as iu
4 |
5 |
6 | def test_circular_mask():
7 | n_pixels = 100
8 | radius = 30
9 | inside_mask = iu.circular_mask(n_pixels, radius, inside=True)
10 | outside_mask = iu.circular_mask(n_pixels, radius, inside=False)
11 |
12 | assert inside_mask.shape == (n_pixels, n_pixels)
13 | assert outside_mask.shape == (n_pixels, n_pixels)
14 | assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10)
15 | assert outside_mask.sum().item() == pytest.approx(
16 | n_pixels**2 - radius**2 * 3.14159, abs=10
17 | )
18 |
19 |
20 | def test_mask_class():
21 | image_size = 100
22 | radius = 30
23 | inside = True
24 | mask = iu.Mask(image_size, radius, inside=inside)
25 | image = torch.ones((image_size, image_size))
26 |
27 | masked_image = mask(image)
28 | assert masked_image.shape == (image_size, image_size)
29 | assert masked_image[inside].sum().item() == pytest.approx(
30 | image_size**2 - radius**2 * 3.14159, abs=10
31 | )
32 |
33 |
34 | def test_fourier_down_sample():
35 | image_size = 100
36 | n_pixels = 30
37 | image = torch.ones((image_size, image_size))
38 |
39 | downsampled_image = iu.fourier_down_sample(image, image_size, n_pixels)
40 | assert downsampled_image.shape == (
41 | image_size - 2 * n_pixels,
42 | image_size - 2 * n_pixels,
43 | )
44 |
45 |
46 | def test_fourier_down_sample_class():
47 | image_size = 100
48 | down_sampled_size = 40
49 | down_sampler = iu.FourierDownSample(image_size, down_sampled_size)
50 | image = torch.ones((image_size, image_size))
51 |
52 | down_sampled_image = down_sampler(image)
53 | assert down_sampled_image.shape == (
54 | image_size - 2 * down_sampler._n_pixels,
55 | image_size - 2 * down_sampler._n_pixels,
56 | )
57 |
58 |
59 | def test_low_pass_filter():
60 | image_size = 100
61 | frequency_cutoff = 30
62 | low_pass_filter = iu.LowPassFilter(image_size, frequency_cutoff)
63 | image = torch.ones((image_size, image_size))
64 |
65 | filtered_image = low_pass_filter(image)
66 | assert filtered_image.shape == (image_size, image_size)
67 |
68 |
69 | def test_gaussian_low_pass_filter():
70 | image_size = 100
71 | frequency_cutoff = 30
72 | low_pass_filter = iu.GaussianLowPassFilter(image_size, frequency_cutoff)
73 | image = torch.ones((image_size, image_size))
74 |
75 | filtered_image = low_pass_filter(image)
76 | assert filtered_image.shape == (image_size, image_size)
77 |
78 |
79 | def test_normalize_individual():
80 | normalize_individual = iu.NormalizeIndividual()
81 | image = torch.randn((3, 100, 100))
82 |
83 | normalized_image = normalize_individual(image)
84 | assert normalized_image.shape == (3, 100, 100)
85 | assert normalized_image.mean().item() == pytest.approx(0.0, abs=1e-1)
86 | assert normalized_image.std().item() == pytest.approx(1.0, abs=1e-1)
87 |
88 |
89 | def test_mrc_to_tensor():
90 | image_path = "tests/data/test.mrc"
91 | image = iu.mrc_to_tensor(image_path)
92 |
93 | assert isinstance(image, torch.Tensor)
94 | assert image.shape == (5, 5)
95 |
96 |
97 | def test_image_whithening():
98 | whitening_transform = iu.WhitenImage(100)
99 | images = torch.randn((1, 100, 100))
100 | images_whitened = whitening_transform(images)
101 | assert images_whitened.shape == (1, 100, 100)
102 |
103 |
104 | def test_image_whithening_batched():
105 | whitening_transform = iu.WhitenImage(100)
106 | images = torch.randn((10, 100, 100))
107 | images_whitened = whitening_transform(images)
108 | assert images_whitened.shape == (10, 100, 100)
109 |
--------------------------------------------------------------------------------
/tests/test_micrograph_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torchvision.transforms as transforms
4 | from cryo_sbi.utils.image_utils import NormalizeIndividual
5 | from cryo_sbi.utils import micrograph_utils as mu
6 |
7 |
8 | def test_random_micrograph_patches_fail():
9 | micrograph = torch.randn(2, 128, 128)
10 | random_patches = mu.RandomMicrographPatches(
11 | micro_graphs=[micrograph], patch_size=10, transform=None
12 | )
13 | with pytest.raises(AssertionError):
14 | patch = next(random_patches)
15 |
16 |
17 | @pytest.mark.parametrize(
18 | ("micrograph_size", "patch_size", "max_iter"), [(128, 12, 100), (128, 90, 1000)]
19 | )
20 | def test_random_micrograph_patches(micrograph_size, patch_size, max_iter):
21 | micrograph = torch.randn(micrograph_size, micrograph_size)
22 | random_patches = mu.RandomMicrographPatches(
23 | micro_graphs=[micrograph],
24 | patch_size=patch_size,
25 | transform=None,
26 | max_iter=max_iter,
27 | )
28 | patch = next(random_patches)
29 | assert patch.shape == torch.Size([1, patch_size, patch_size])
30 | assert len(random_patches) == max_iter
31 |
32 |
33 | def test_compute_average_psd():
34 | micrograph = torch.randn(128, 128)
35 | transform = transforms.Compose([NormalizeIndividual()])
36 | random_patches = mu.RandomMicrographPatches(
37 | micro_graphs=[micrograph], patch_size=10, transform=transform, max_iter=100
38 | )
39 | avg_psd = mu.compute_average_psd(random_patches)
40 | assert avg_psd.shape == torch.Size([10, 10])
41 |
--------------------------------------------------------------------------------
/tests/test_posterior_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import json
3 | import torch
4 | from cryo_sbi.inference.models import build_models
5 | from cryo_sbi.inference.models import estimator_models
6 | from cryo_sbi.inference.validate_train_config import check_train_params
7 |
8 |
9 | @pytest.fixture
10 | def train_params():
11 | config = json.load(open("tests/config_files/training_params_npe_testing.json"))
12 | check_train_params(config)
13 | return config
14 |
15 |
16 | def test_build_npe_model(train_params):
17 | posterior_model = build_models.build_npe_flow_model(train_params)
18 | assert isinstance(posterior_model, estimator_models.NPEWithEmbedding)
19 |
20 |
21 | @pytest.mark.parametrize(
22 | ("batch_size", "sample_size"), [(1, 1), (2, 10), (5, 1000), (100, 2)]
23 | )
24 | def test_sample_npe_model(train_params, batch_size, sample_size):
25 | posterior_model = build_models.build_npe_flow_model(train_params)
26 | test_image = torch.randn((batch_size, 128, 128))
27 | samples = posterior_model.sample(test_image, shape=(sample_size,))
28 | assert samples.shape == torch.Size([sample_size, batch_size, 1])
29 |
--------------------------------------------------------------------------------
/tests/test_simulator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | import json
5 |
6 | from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator, CryoEmSimulator
7 | from cryo_sbi.wpa_simulator.ctf import apply_ctf
8 | from cryo_sbi.wpa_simulator.image_generation import (
9 | project_density,
10 | gen_quat,
11 | gen_rot_matrix,
12 | )
13 | from cryo_sbi.wpa_simulator.noise import add_noise, circular_mask, get_snr
14 | from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
15 | from cryo_sbi.inference.priors import get_image_priors
16 |
17 |
18 | def test_apply_ctf():
19 | # Create a test image
20 | image = torch.randn(1, 64, 64)
21 |
22 | # Set test parameters
23 | defocus = torch.tensor([1.0])
24 | b_factor = torch.tensor([100.0])
25 | amp = torch.tensor([0.5])
26 | pixel_size = torch.tensor(1.0)
27 |
28 | # Apply CTF to the test image
29 | image_ctf = apply_ctf(image, defocus, b_factor, amp, pixel_size)
30 |
31 | assert image_ctf.shape == image.shape
32 | assert isinstance(image_ctf, torch.Tensor)
33 | assert not torch.allclose(image_ctf, image)
34 |
35 |
36 | def test_gen_rot_matrix():
37 | # Create a test quaternion
38 | quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]])
39 |
40 | # Generate a rotation matrix from the quaternion
41 | rot_matrix = gen_rot_matrix(quat)
42 |
43 | assert rot_matrix.shape == torch.Size([1, 3, 3])
44 | assert isinstance(rot_matrix, torch.Tensor)
45 | assert torch.allclose(rot_matrix, torch.eye(3).unsqueeze(0))
46 |
47 |
48 | def test_gen_rot_matrix_batched():
49 | # Create a test quaternions with batche size 3
50 | quat = torch.tensor(
51 | [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]
52 | )
53 |
54 | # Generate a rotation matrix from the quaternion
55 | rot_matrix = gen_rot_matrix(quat)
56 |
57 | assert rot_matrix.shape == torch.Size([3, 3, 3])
58 | assert isinstance(rot_matrix, torch.Tensor)
59 | assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1))
60 |
61 |
62 | @pytest.mark.parametrize(
63 | ("noise_std", "num_images"),
64 | [
65 | (torch.tensor([1.5, 1]), 2),
66 | (torch.tensor([1.0, 2.0, 3.0]), 3),
67 | (torch.tensor([0.1]), 10),
68 | ],
69 | )
70 | def test_get_snr(noise_std, num_images):
71 | # Create a test image
72 | images = noise_std.reshape(-1, 1, 1) * torch.randn(num_images, 128, 128)
73 |
74 | # Compute the SNR of the test image
75 | snr = get_snr(images, torch.tensor([0.0]))
76 |
77 | assert snr.shape == torch.Size([images.shape[0], 1, 1]), "SNR has wrong shape"
78 | assert isinstance(snr, torch.Tensor)
79 | assert torch.allclose(
80 | snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01
81 | ), "SNR is not correct"
82 |
83 |
84 | @pytest.mark.parametrize(("num_images"), [1, 5])
85 | def test_simulator_default_settings(num_images):
86 | sim = CryoEmSimulator("tests/config_files/image_params_testing.json")
87 | images = sim.simulate(num_images)
88 | assert images.shape == torch.Size([num_images, 64, 64])
89 |
90 |
91 | @pytest.mark.parametrize(("num_images"), [1, 5])
92 | def test_simulator_custom_indices(num_images):
93 | sim = CryoEmSimulator("tests/config_files/image_params_testing.json")
94 | test_indices = torch.arange(num_images, dtype=torch.float32).reshape(-1, 1)
95 | images, parameters = sim.simulate(
96 | num_images, indices=test_indices, return_parameters=True
97 | )
98 |
99 | assert (parameters[0] == test_indices).all().item()
100 | assert images.shape == torch.Size([num_images, 64, 64])
101 |
--------------------------------------------------------------------------------
/tests/test_visualize_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 | from cryo_sbi.utils.visualize_models import plot_model
4 |
5 |
6 | def test_plot_model_scatter():
7 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
8 | plot_model(
9 | model, method="scatter"
10 | ) # No assertion, just checking if it runs without errors
11 |
12 |
13 | def test_plot_model_sphere():
14 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
15 | plot_model(
16 | model, method="sphere"
17 | ) # No assertion, just checking if it runs without errors
18 |
19 |
20 | def test_plot_model_invalid_model():
21 | model = torch.tensor(
22 | [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
23 | ) # Invalid shape, should have 3 rows
24 | with pytest.raises(AssertionError):
25 | plot_model(model, method="scatter")
26 |
27 |
28 | def test_plot_model_invalid_method():
29 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
30 | with pytest.raises(ValueError):
31 | plot_model(model, method="invalid_method")
32 |
--------------------------------------------------------------------------------
/tutorials/.gitignore:
--------------------------------------------------------------------------------
1 | *.h5
2 | *.tp
3 | *.pt
--------------------------------------------------------------------------------
/tutorials/simulation_parameters.json:
--------------------------------------------------------------------------------
1 | {
2 | "N_PIXELS": 64,
3 | "PIXEL_SIZE": 2.5,
4 | "SIGMA": [2.0, 2.0],
5 | "MODEL_FILE": "models.pt",
6 | "SHIFT": 0.0,
7 | "DEFOCUS": [2.0, 2.0],
8 | "SNR": [0.01, 0.5],
9 | "AMP": 0.1,
10 | "B_FACTOR": [1.0, 1.0]
11 | }
--------------------------------------------------------------------------------
/tutorials/training_parameters.json:
--------------------------------------------------------------------------------
1 | {
2 | "EMBEDDING": "ConvEncoder_Tutorial",
3 | "OUT_DIM": 128,
4 | "NUM_TRANSFORM": 3,
5 | "NUM_HIDDEN_FLOW": 3,
6 | "HIDDEN_DIM_FLOW": 128,
7 | "MODEL": "NSF",
8 | "LEARNING_RATE": 0.0003,
9 | "CLIP_GRADIENT": 5.0,
10 | "THETA_SHIFT": 50,
11 | "THETA_SCALE": 50,
12 | "BATCH_SIZE": 32
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/tutorials/tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "import numpy as np\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "from scipy.spatial.transform import Rotation\n",
14 | "\n",
15 | "from cryo_sbi import CryoEmSimulator\n",
16 | "import cryo_sbi.inference.train_npe_model as train_npe_model\n",
17 | "import cryo_sbi.utils.estimator_utils as est_utils\n",
18 | "import cryo_sbi.utils.image_utils as img_utils\n",
19 | "from cryo_sbi.utils.visualize_models import plot_model"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "#### If you want to run the latent space analysis with UMAP, you need to install the following package:\n",
27 | "You can find the installation instructions [here](https://umap-learn.readthedocs.io/en/latest/). \n"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# If you installed umap you can import it here \n",
37 | "import umap"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "### Make the models \n",
45 | "\n",
46 | "Here we are creating a simple molecular model. \n",
47 | "In our model we will have four pseudo atoms arranged in a rectangle. The model differs between the sidelength we are using. The atoms are placed at the corners of the rectangle. The distance between the atoms is the same for all atoms.\n",
48 | "The goal is then to simulate cryo-EM images with these models and infer the distance between the two atoms from the images.\n",
49 | "\n",
50 | "The first step is to create the models. We start by crating an array with the side length `side_length` between the atoms. We will use this array to create the models.\n",
51 | "The models are created by placing pseudo atoms at the corners of the rectangle.\n",
52 | "\n",
53 | "The models are saved into teh file `models.pt`.\n",
54 | "\n"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "side_lengths = torch.linspace(0, 50, 100)\n",
64 | "\n",
65 | "models = []\n",
66 | "for side_length in side_lengths:\n",
67 | " model = [\n",
68 | " [side_length, -side_length, side_length, -side_length],\n",
69 | " [side_length, side_length, -side_length, -side_length],\n",
70 | " [0.0, 0.0, 0.0, 0.0],\n",
71 | " ]\n",
72 | " models.append(model)\n",
73 | "models = torch.tensor(models)"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "##### Visualize the models"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "To visulaize the model in the x-y plane, as we do not have a z-dimension in our model."
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "fig = plt.figure()\n",
97 | "ax = fig.add_subplot()\n",
98 | "for i, c in zip([0, 25, 50, 75, 99], [\"red\", \"orange\", \"green\", \"blue\", \"purple\"]):\n",
99 | " ax.scatter(\n",
100 | " models[i, 0, 0],\n",
101 | " models[i, 1, 0],\n",
102 | " s=60,\n",
103 | " color=c,\n",
104 | " label=f\"Model with side length : {side_lengths[i]:.2f}\",\n",
105 | " )\n",
106 | " ax.scatter(models[i, 0, 1], models[i, 1, 1], s=60, color=c)\n",
107 | " ax.scatter(models[i, 0, 2], models[i, 1, 2], s=60, color=c)\n",
108 | " ax.scatter(models[i, 0, 3], models[i, 1, 3], s=60, color=c)\n",
109 | "\n",
110 | "ax.set_xlabel(\"X\")\n",
111 | "ax.set_ylabel(\"Y\")\n",
112 | "plt.legend()\n",
113 | "\n"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "# substract of center of mass\n",
123 | "for i in range(100):\n",
124 | " models[i] = models[i] - models[i].mean(dim=1, keepdim=True)\n",
125 | "\n",
126 | "torch.save(models, \"models.pt\")"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "### Run first simulation\n",
134 | "\n",
135 | "We will now simulate the cryo-EM images with our generated models.\n",
136 | "The simulation is done by the class `CryoEmSimulator`. And the simulation is run by the function `simulate` function.\n",
137 | "The class `CryoEmSimulator` takes as input a config file with the simulation parameters. The config file used here is `simulation_parameters.json`.\n",
138 | "\n",
139 | "The following parameters are used in the simulation:"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "```\n",
147 | "simulation_parameters.json\n",
148 | "\n",
149 | "{\n",
150 | " \"N_PIXELS\": 64, --> size of the image\n",
151 | " \"PIXEL_SIZE\": 2.0, --> pixel size in angstroms\n",
152 | " \"SIGMA\": [2.0, 2.0], --> standard deviation of the gaussian\n",
153 | " \"MODEL_FILE\": \"models.pt\", --> file which contains the models\n",
154 | " \"SHIFT\": 0.0, --> shift of model center \n",
155 | " \"DEFOCUS\": [2.0, 2.0], --> defocus range for the simulation\n",
156 | " \"SNR\": [0.01, 0.5], --> signal to noise ratio for the simulation\n",
157 | " \"AMP\": 0.1, --> amplitude for the ctf \n",
158 | " \"B_FACTOR\": [1.0, 1.0] --> b factor for the ctf\n",
159 | "} \n",
160 | "```"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "simulator = CryoEmSimulator(\n",
170 | " \"simulation_parameters.json\"\n",
171 | ") # creating simulator with simulation parameters"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "images, parameters = simulator.simulate(\n",
181 | " num_sim=5000, return_parameters=True\n",
182 | ") # simulating images and save parameters"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "side_length = parameters[0] # extracting side_length from parameters\n",
192 | "snr = parameters[-1] # extracting snr from parameters"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "#### Visualize the simulated images"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "fig, axes = plt.subplots(4, 4, figsize=(6, 6))\n",
209 | "for idx, ax in enumerate(axes.flatten()):\n",
210 | " ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n",
211 | " ax.set_title(\n",
212 | " f\"Side: {side_lengths[side_length[idx].round().long()].item():.2f}\", fontsize=10\n",
213 | " )\n",
214 | " ax.axis(\"off\")"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {},
220 | "source": [
221 | "### Train cryoSBI posterior\n",
222 | "\n",
223 | "We will now train the cryoSBI posterior to infer the distance between the atoms from the simulated images.\n",
224 | "The training is done with the function `npe_train_no_saving` which simulates images and simultaneously trains the posterior.\n",
225 | "The function takes as input the config file `training_parameters.json` which contains the training and neural network parameters.\n",
226 | "The function also takes as input the config file `simulation_parameters.json` which contains the simulation parameters used to simulate the images.\n"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {},
232 | "source": [
233 | "```\n",
234 | "training_parameters.json\n",
235 | "```\n",
236 | "\n",
237 | "```\n",
238 | "{\n",
239 | " \"EMBEDDING\": \"ConvEncoder_Tutorial\", --> embedding network for the images\n",
240 | " \"OUT_DIM\": 128, --> dimension of the embedding\n",
241 | " \"NUM_TRANSFORM\": 5, --> number of transformations\n",
242 | " \"NUM_HIDDEN_FLOW\": 5, --> number of hidden layers in the flow\n",
243 | " \"HIDDEN_DIM_FLOW\": 128, --> dimension of the hidden layers in the flow\n",
244 | " \"MODEL\": \"NSF\", --> type of flow\n",
245 | " \"LEARNING_RATE\": 0.0003, --> learning rate\n",
246 | " \"CLIP_GRADIENT\": 5.0, --> gradient clipping\n",
247 | " \"THETA_SHIFT\": 50, --> shift of the model center\n",
248 | " \"THETA_SCALE\": 50, --> scale of the model\n",
249 | " \"BATCH_SIZE\": 32 --> batch size\n",
250 | "}\n",
251 | "```\n"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "metadata": {},
258 | "outputs": [],
259 | "source": [
260 | "train_npe_model.npe_train_no_saving(\n",
261 | " \"simulation_parameters.json\",\n",
262 | " \"training_parameters.json\",\n",
263 | " 150,\n",
264 | " \"tutorial_estimator.pt\", # name of the estimator file\n",
265 | " \"tutorial.loss\", # name of the loss file\n",
266 | " n_workers=4, # number of workers for data loading\n",
267 | " device=\"cuda\", # device to use for training and simulation\n",
268 | " saving_frequency=100, # frequency of saving the model\n",
269 | " simulation_batch_size=160, # batch size for simulation\n",
270 | ")"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | "##### Visualize the loss after training"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "plt.plot(torch.load(\"tutorial.loss\"))\n",
287 | "plt.xlabel(\"Epoch\")\n",
288 | "plt.ylabel(\"Loss\")"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "metadata": {},
294 | "source": [
295 | "### Evaluate the posterior on our simulated images\n",
296 | "\n",
297 | "We will now evaluate the trained posterior on our simulated images.\n",
298 | "For each simulated image we will infer the distance between the atoms and compare it to the true distance, by sampling from the posterior."
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "posterior = est_utils.load_estimator(\n",
308 | " \"training_parameters.json\",\n",
309 | " \"tutorial_estimator.pt\",\n",
310 | " device=\"cuda\",\n",
311 | ")"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "samples = est_utils.sample_posterior(\n",
321 | " estimator=posterior,\n",
322 | " images=images,\n",
323 | " num_samples=15000,\n",
324 | " batch_size=1000,\n",
325 | " device=\"cuda\",\n",
326 | ")"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "fig, axes = plt.subplots(4, 4, figsize=(10, 10))\n",
336 | "for idx, ax in enumerate(axes.flatten()):\n",
337 | " ax.hist(samples[:, idx].flatten(), bins=np.linspace(0, simulator.max_index, 60))\n",
338 | " ax.axvline(side_length[idx], ymax=1, ymin=0, color=\"red\")\n",
339 | " ax.set_yticks([])\n",
340 | " ax.set_xticks([])"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {},
346 | "source": [
347 | "### Plot latent space"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "latent_representations = est_utils.compute_latent_repr(\n",
357 | " estimator=posterior,\n",
358 | " images=images,\n",
359 | " batch_size=1000,\n",
360 | " device=\"cuda\",\n",
361 | ")"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": null,
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "reducer = umap.UMAP(metric=\"euclidean\", n_components=2, n_neighbors=50)\n",
371 | "latent_vecs_transformed = reducer.fit_transform(latent_representations.numpy())"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "plt.scatter(\n",
381 | " latent_vecs_transformed[:, 0],\n",
382 | " latent_vecs_transformed[:, 1],\n",
383 | " c=side_length,\n",
384 | " cmap=\"viridis\",\n",
385 | " s=10,\n",
386 | ")\n",
387 | "plt.colorbar(label=\"Side length\")\n",
388 | "plt.xlabel(\"UMAP 1\")\n",
389 | "plt.ylabel(\"UMAP 2\")"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": null,
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "plt.scatter(\n",
399 | " latent_vecs_transformed[:, 0],\n",
400 | " latent_vecs_transformed[:, 1],\n",
401 | " c=snr,\n",
402 | " cmap=\"viridis\",\n",
403 | " s=10,\n",
404 | ")\n",
405 | "plt.colorbar(label=\"SNR\")\n",
406 | "plt.xlabel(\"UMAP 1\")\n",
407 | "plt.ylabel(\"UMAP 2\")"
408 | ]
409 | }
410 | ],
411 | "metadata": {
412 | "kernelspec": {
413 | "display_name": "cryo_sbi",
414 | "language": "python",
415 | "name": "cryo_sbi"
416 | },
417 | "language_info": {
418 | "codemirror_mode": {
419 | "name": "ipython",
420 | "version": 3
421 | },
422 | "file_extension": ".py",
423 | "mimetype": "text/x-python",
424 | "name": "python",
425 | "nbconvert_exporter": "python",
426 | "pygments_lexer": "ipython3",
427 | "version": "3.10.0"
428 | }
429 | },
430 | "nbformat": 4,
431 | "nbformat_minor": 2
432 | }
433 |
--------------------------------------------------------------------------------