├── .flake8 ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── README.rst ├── pyproject.toml ├── src └── cryo_sbi │ ├── __init__.py │ ├── inference │ ├── __init__.py │ ├── command_line_tools.py │ ├── models │ │ ├── __init__.py │ │ ├── build_models.py │ │ ├── embedding_nets.py │ │ └── estimator_models.py │ ├── priors.py │ ├── train_npe_model.py │ └── validate_train_config.py │ ├── utils │ ├── __init__.py │ ├── command_line_tools.py │ ├── estimator_utils.py │ ├── generate_models.py │ ├── image_utils.py │ ├── micrograph_utils.py │ └── visualize_models.py │ └── wpa_simulator │ ├── __init__.py │ ├── cryo_em_simulator.py │ ├── ctf.py │ ├── image_generation.py │ ├── noise.py │ ├── normalization.py │ └── validate_image_config.py ├── tests ├── config_files │ ├── image_params_testing.json │ └── training_params_npe_testing.json ├── data │ └── test.mrc ├── models │ └── hsp90_models.pt ├── test_embeddings.py ├── test_estimator_utils.py ├── test_image_utils.py ├── test_micrograph_utils.py ├── test_posterior_models.py ├── test_simulator.py └── test_visualize_models.py └── tutorials ├── .gitignore ├── simulation_parameters.json ├── training_parameters.json └── tutorial.ipynb /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # GHA workflow for running tests. 2 | # 3 | # Largely taken from 4 | # https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 5 | # Please check the link for more detailed instructions 6 | 7 | name: Run tests 8 | 9 | on: [push] 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10"] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install . 29 | pip install pytest 30 | - name: Test with pytest 31 | run: | 32 | pytest tests/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific files 2 | messing_around/ 3 | tests/test_simulator.ipynb 4 | *.estimator 5 | *.h5 6 | .vscode/ 7 | .DS_Store 8 | #SBI related folders 9 | sbi-logs/ 10 | *.png 11 | *.pdf 12 | *.png 13 | 14 | # Training data, models and other datafiles 15 | results/ 16 | production/ 17 | *.estimator 18 | *epoch=* 19 | *.loss 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ================================================ 2 | cryoSBI - Simulation-based Inference for Cryo-EM 3 | ================================================ 4 | 5 | .. start-badges 6 | 7 | .. list-table:: 8 | :stub-columns: 1 9 | 10 | * - tests 11 | - | |githubactions| 12 | 13 | 14 | .. |githubactions| image:: https://github.com/DSilva27/cryo_em_SBI/actions/workflows/python-package.yml/badge.svg?branch=main 15 | :alt: Testing Status 16 | :target: https://github.com/DSilva27/cryo_em_SBI/actions 17 | 18 | Summary 19 | ------- 20 | cryoSBI is a Python module for simulation-based inference in cryo-electron microscopy. The module provides tools for simulating cryo-EM particles, training an amortized posterior model, and sampling from the posterior distribution. 21 | The code is based on the SBI libary `Lampe `_, which is using Pytorch. 22 | 23 | Installing 24 | ---------- 25 | To install the module you will have to dowload the repository and create a virtual environment with the required dependencies. 26 | You can create an environment for example with conda using the following command: 27 | 28 | .. code:: bash 29 | 30 | conda create -n cryoSBI python=3.10 31 | 32 | After creating the virtual environment, you should install the required dependencies and the module. 33 | 34 | Dependencies 35 | ------------ 36 | 37 | 1. `Lampe `_. 38 | 2. `SciPy `_. 39 | 3. `Numpy `_. 40 | 4. `PyTorch `_. 41 | 5. json 42 | 6. `mrcfile `_. 43 | 44 | Download this repository 45 | ------------------------ 46 | .. code:: bash 47 | 48 | git clone https://github.com/flatironinstitute/cryoSBI.git 49 | 50 | Navigate to the cloned repository and install the module 51 | -------------------------------------------------------- 52 | .. code:: bash 53 | 54 | cd cryo_em_SBI 55 | 56 | .. code:: bash 57 | 58 | pip install . 59 | 60 | Tutorial 61 | -------- 62 | An introduction tutorial can be found at `tutorials/tutorial.ipynb`. In this tutorial, we go through the whole process of making models for cryoSBI, training an amortized posterior, and analyzing the results. 63 | In the following section, I highlighted cryoSBI key features. 64 | 65 | Generate model file to simulate cryo-EM particles 66 | ------------------------------------------------- 67 | To generate a model file for simulating cryo-EM particles with the simulator provided in this module, you can use the command line tool `models_to_tensor`. 68 | You will need either a set of pdbs which are indexd or a trr trejectory file which contians all models. The tool will generate a model file that can be used to simulate cryo-EM particles. 69 | 70 | .. code:: bash 71 | 72 | models_to_tensor \ 73 | --model_file path_to_models/pdb_{}.pdb \ 74 | --output_file path_to_output_file/output.pt \ 75 | --n_pdbs 100 76 | 77 | The output file will be a Pytorch tensor with the shape (number of models, 3, number of pseudo atoms). 78 | 79 | Simulating cryo-EM particles 80 | ----------------------------- 81 | To simulate cryo-EM particles, you can use the CryoEmSimulator class. The class takes in a simulation config file and simulates cryo-EM particles based on the parameters specified in the config file. 82 | 83 | .. code:: python 84 | 85 | from cryo_sbi import CryoEmSimulator 86 | simulator = CryoEmSimulator("path_to_simulation_config_file.json") 87 | images, parameters = simulator.simulate(num_sim=10, return_parameters=True) 88 | 89 | The simulation config file should be a json file with the following structure: 90 | 91 | .. code:: json 92 | 93 | { 94 | "N_PIXELS": 128, 95 | "PIXEL_SIZE": 1.5, 96 | "SIGMA": [0.5, 5.0], 97 | "MODEL_FILE": "path_to_models/models.pt", 98 | "SHIFT": 25.0, 99 | "DEFOCUS": [0.5, 2.0], 100 | "SNR": [0.001, 0.5], 101 | "AMP": 0.1, 102 | "B_FACTOR": [1.0, 100.0] 103 | } 104 | 105 | The pixel size is defined in Angström (Å). The atom sigma defines the size of the Gaussians used to approximate the protein's electron density. Here, each Gaussian represents one amino acid, and while all Gaussians have the same sigma, the value is made to vary in the simulations. The shift is the offset of the protein from the image centre and is given in Angström (Å). The defocus of the microscope is given in units of micrometres (μm). The SNR (Signal-to-noise ratio) is unitless and defines the amount of noise in the simulated images. The Amplitude is a unitless parameter which ranges between 0 and 1. The B-factor is given in units of Angström squared (Å^2) and defines the decay rate of the CTF envelope function. 106 | 107 | Training an amortized posterior model 108 | -------------------------------------- 109 | Training of an amortized posterior can be done using the train_npe_model command line utility. The utility takes in an image config file, a train config file, and other training parameters. The utility trains a neural network to approximate the posterior distribution of the parameters given the images. 110 | 111 | .. code:: bash 112 | 113 | train_npe_model \ 114 | --image_config_file path_to_simulation_config_file.json \ 115 | --train_config_file path_to_train_config_file.json\ 116 | --epochs 150 \ 117 | --estimator_file posterior.estimator \ 118 | --loss_file posterior.loss \ 119 | --n_workers 4 \ 120 | --simulation_batch_size 5120 \ 121 | --train_device cuda 122 | 123 | The training config file should be a json file with the following structure: 124 | 125 | .. code:: json 126 | 127 | { 128 | "EMBEDDING": "RESNET18", 129 | "OUT_DIM": 256, 130 | "NUM_TRANSFORM": 5, 131 | "NUM_HIDDEN_FLOW": 10, 132 | "HIDDEN_DIM_FLOW": 256, 133 | "MODEL": "NSF", 134 | "LEARNING_RATE": 0.0003, 135 | "CLIP_GRADIENT": 5.0, 136 | "THETA_SHIFT": 25, 137 | "THETA_SCALE": 25, 138 | "BATCH_SIZE": 256 139 | } 140 | 141 | When training posterior for your own system, it's important to change THETA_SCALE and THETA_SHIFT. These two parameters normalize the conformational variable in cryoSBI. 142 | THETA_SHIFT and THETA_SCALE need to be adjusted according to the number of structures used in the prior. A good option is to set THETA_SHIFT and THETA_SCALE to the number of structures in the prior divided by two. 143 | 144 | Loading the posterior after training 145 | ------------------------------------ 146 | After training the estimator, loading it in Python can be done with the load_estimator in the estimator_utils module. 147 | 148 | .. code:: python 149 | 150 | import cryo_sbi.utils.estimator_utils as est_utils 151 | posterior = est_utils.load_estimator( 152 | config_file_path="path_to_config_file", 153 | estimator_path="path_to_estimator_file", 154 | device="cuda" 155 | ) 156 | 157 | Inference 158 | --------- 159 | Sampling from the posterior distribution can be done using the sample_posterior function in the estimator_utils module. The function takes in an estimator, images, and other parameters and returns samples from the posterior distribution. 160 | 161 | .. code:: python 162 | 163 | import cryo_sbi.utils.estimator_utils as est_utils 164 | samples = est_utils.sample_posterior( 165 | estimator=posterior, 166 | images=images, 167 | num_samples=20000, 168 | batch_size=100, 169 | device="cuda", 170 | ) 171 | 172 | The Pytorch tensor containing the samples will have the shape (number of samples, number of images). In order to visualize the posterior for each image you can use `matplotlib`. 173 | We can quickly generate a histogram with 50 bins with the following piece of code. 174 | 175 | .. code:: python 176 | 177 | import matplotlib.pyplot as plt 178 | idx_image = 0 # posterior for image with index 0 179 | plt.hist(samples[:, idx_image].flatten(), np.linspace(0, simulator.max_index, 50)) 180 | 181 | In this case the x-axis is just the index of the structures in increasing order. 182 | 183 | Latent space 184 | ------------ 185 | 186 | Computing the latent features for simulated or experimental particles can be done using the compute_latent_repr function in the estimator_utils module. The function needs a trained posterior estimator and images and computes the latent representation for each image. 187 | 188 | .. code:: python 189 | 190 | import cryo_sbi.utils.estimator_utils as est_utils 191 | latent_vecs = est_utils.compute_latent_repr( 192 | compute_latent_repr( 193 | estimator=posterior, 194 | images=images, 195 | batch_size=100, 196 | device="cuda", 197 | ) 198 | 199 | After we computed the latent representation for the images, one possible way to visualize the latent space is to use `UMAP `_ . UMAP generates a two-dimensional representation of the latent space, which should allow us to analyze its important features. 200 | 201 | .. code:: python 202 | 203 | import umap 204 | reducer = umap.UMAP(metric="euclidian", n_components=2, n_neighbors=50) 205 | embedding = reducer.fit_transform(latent_vecs.numpy()) 206 | 207 | We can quickly visualize the 2d latent space with matplotlib. 208 | 209 | .. code:: python 210 | 211 | import matplotlib.pyplot as plt 212 | plt.scatter( 213 | embedding[:, 0], 214 | embedding[:, 1], 215 | ) 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "cryosbi" 8 | authors = [ 9 | { name = "David Silva-Sanchez", email = "david.silva@yale.edu"}, 10 | { name = "Lars Dingeldein"}, 11 | { name = "Pilar Cossio"}, 12 | { name = "Roberto Covino"} 13 | ] 14 | 15 | 16 | version = "0.2" 17 | dependencies = [ 18 | "lampe", 19 | "zuko", 20 | "torch", 21 | "numpy", 22 | "matplotlib", 23 | "scipy", 24 | "torchvision", 25 | "mrcfile" 26 | ] 27 | 28 | 29 | [project.scripts] 30 | train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving" 31 | model_to_tensor = "cryo_sbi.utils.command_line_tools:cl_models_to_tensor" 32 | -------------------------------------------------------------------------------- /src/cryo_sbi/__init__.py: -------------------------------------------------------------------------------- 1 | from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator 2 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/command_line_tools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cryo_sbi.inference.train_npe_model import ( 3 | npe_train_no_saving, 4 | ) 5 | 6 | 7 | def cl_npe_train_no_saving(): 8 | cl_parser = argparse.ArgumentParser() 9 | 10 | cl_parser.add_argument( 11 | "--image_config_file", action="store", type=str, required=True 12 | ) 13 | cl_parser.add_argument( 14 | "--train_config_file", action="store", type=str, required=True 15 | ) 16 | cl_parser.add_argument("--epochs", action="store", type=int, required=True) 17 | cl_parser.add_argument("--estimator_file", action="store", type=str, required=True) 18 | cl_parser.add_argument("--loss_file", action="store", type=str, required=True) 19 | cl_parser.add_argument( 20 | "--train_from_checkpoint", 21 | action="store", 22 | type=bool, 23 | nargs="?", 24 | required=False, 25 | const=True, 26 | default=False, 27 | ) 28 | cl_parser.add_argument( 29 | "--state_dict_file", action="store", type=str, required=False, default=False 30 | ) 31 | cl_parser.add_argument( 32 | "--n_workers", action="store", type=int, required=False, default=1 33 | ) 34 | cl_parser.add_argument( 35 | "--train_device", action="store", type=str, required=False, default="cpu" 36 | ) 37 | cl_parser.add_argument( 38 | "--saving_freq", action="store", type=int, required=False, default=20 39 | ) 40 | cl_parser.add_argument( 41 | "--simulation_batch_size", 42 | action="store", 43 | type=int, 44 | required=False, 45 | default=1024, 46 | ) 47 | 48 | args = cl_parser.parse_args() 49 | 50 | npe_train_no_saving( 51 | image_config=args.image_config_file, 52 | train_config=args.train_config_file, 53 | epochs=args.epochs, 54 | estimator_file=args.estimator_file, 55 | loss_file=args.loss_file, 56 | train_from_checkpoint=args.train_from_checkpoint, 57 | model_state_dict=args.state_dict_file, 58 | n_workers=args.n_workers, 59 | device=args.train_device, 60 | saving_frequency=args.saving_freq, 61 | simulation_batch_size=args.simulation_batch_size, 62 | ) 63 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/models/build_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functools import partial 3 | import zuko 4 | import lampe 5 | import cryo_sbi.inference.models.estimator_models as estimator_models 6 | from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS 7 | 8 | 9 | def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: 10 | """ 11 | Function to build NPE estimator with embedding net 12 | from config_file 13 | 14 | Args: 15 | config (dict): config file 16 | embedding_kwargs (dict): kwargs for embedding net 17 | 18 | Returns: 19 | estimator (nn.Module): NPE estimator 20 | """ 21 | 22 | if config["MODEL"] == "MAF": 23 | model = zuko.flows.MAF 24 | elif config["MODEL"] == "NSF": 25 | model = zuko.flows.NSF 26 | elif config["MODEL"] == "SOSPF": 27 | model = zuko.flows.SOSPF 28 | else: 29 | raise NotImplementedError( 30 | f"Model : {config['MODEL']} has not been implemented yet!" 31 | ) 32 | 33 | try: 34 | embedding = partial( 35 | EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"], **embedding_kwargs 36 | ) 37 | except KeyError: 38 | raise NotImplementedError( 39 | f"Model : {config['EMBEDDING']} has not been implemented yet! \ 40 | The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}" 41 | ) 42 | 43 | estimator = estimator_models.NPEWithEmbedding( 44 | embedding_net=embedding, 45 | output_embedding_dim=config["OUT_DIM"], 46 | num_transforms=config["NUM_TRANSFORM"], 47 | num_hidden_flow=config["NUM_HIDDEN_FLOW"], 48 | hidden_flow_dim=config["HIDDEN_DIM_FLOW"], 49 | flow=model, 50 | theta_shift=config["THETA_SHIFT"], 51 | theta_scale=config["THETA_SCALE"], 52 | **{"activation": partial(nn.LeakyReLU, 0.1)}, 53 | ) 54 | 55 | return estimator 56 | 57 | 58 | def build_nre_classifier_model(config: dict, **embedding_kwargs) -> nn.Module: 59 | raise NotImplementedError("NRE classifier model has not been implemented yet!") 60 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/models/embedding_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | 6 | from cryo_sbi.utils.image_utils import LowPassFilter, Mask 7 | 8 | 9 | EMBEDDING_NETS = {} 10 | 11 | 12 | def add_embedding(name): 13 | """ 14 | Add embedding net to EMBEDDING_NETS dict 15 | 16 | Args: 17 | name (str): name of embedding net 18 | 19 | Returns: 20 | add (function): function to add embedding net to EMBEDDING_NETS dict 21 | """ 22 | 23 | def add(class_): 24 | EMBEDDING_NETS[name] = class_ 25 | return class_ 26 | 27 | return add 28 | 29 | 30 | @add_embedding("RESNET18") 31 | class ResNet18_Encoder(nn.Module): 32 | def __init__(self, output_dimension: int): 33 | super(ResNet18_Encoder, self).__init__() 34 | self.resnet = models.resnet18() 35 | self.resnet.conv1 = nn.Conv2d( 36 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 37 | ) 38 | self.resnet.fc = nn.Linear( 39 | in_features=512, out_features=output_dimension, bias=True 40 | ) 41 | 42 | def forward(self, x): 43 | x = x.unsqueeze(1) 44 | x = self.resnet(x) 45 | return x 46 | 47 | 48 | @add_embedding("RESNET50") 49 | class ResNet50_Encoder(nn.Module): 50 | def __init__(self, output_dimension: int): 51 | super(ResNet50_Encoder, self).__init__() 52 | 53 | self.resnet = models.resnet50() 54 | self.resnet.conv1 = nn.Conv2d( 55 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 56 | ) 57 | self.linear = nn.Linear(1000, output_dimension) 58 | 59 | def forward(self, x): 60 | x = x.unsqueeze(1) 61 | x = self.resnet(x) 62 | x = self.linear(nn.functional.relu(x)) 63 | return x 64 | 65 | 66 | @add_embedding("RESNET101") 67 | class ResNet101_Encoder(nn.Module): 68 | def __init__(self, output_dimension: int): 69 | super(ResNet101_Encoder, self).__init__() 70 | 71 | self.resnet = models.resnet101() 72 | self.resnet.conv1 = nn.Conv2d( 73 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 74 | ) 75 | self.linear = nn.Linear(1000, output_dimension) 76 | 77 | def forward(self, x): 78 | x = x.unsqueeze(1) 79 | x = self.resnet(x) 80 | x = self.linear(nn.functional.relu(x)) 81 | return x 82 | 83 | 84 | @add_embedding("CONVNET") 85 | class ConvNet_Encoder(nn.Module): 86 | def __init__(self, output_dimension: int): 87 | super(ConvNet_Encoder, self).__init__() 88 | 89 | self.convnet = models.convnext_tiny() 90 | self.convnet.features[0][0] = nn.Conv2d( 91 | 1, 96, kernel_size=(4, 4), stride=(4, 4) 92 | ) 93 | self.convnet.classifier[2] = nn.Linear( 94 | in_features=768, out_features=output_dimension, bias=True 95 | ) 96 | 97 | def forward(self, x): 98 | x = x.unsqueeze(1) 99 | x = self.convnet(x) 100 | return x 101 | 102 | 103 | @add_embedding("CONVNET") 104 | class RegNetX_Encoder(nn.Module): 105 | def __init__(self, output_dimension: int): 106 | super(RegNetX_Encoder, self).__init__() 107 | 108 | self.regnetx = models.regnet_x_3_2gf() 109 | self.regnetx.stem[0] = nn.Conv2d( 110 | 1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False 111 | ) 112 | self.regnetx.fc = nn.Linear( 113 | in_features=1008, out_features=output_dimension, bias=True 114 | ) 115 | 116 | def forward(self, x): 117 | x = x.unsqueeze(1) 118 | x = self.regnetx(x) 119 | return x 120 | 121 | 122 | @add_embedding("EFFICIENT") 123 | class EfficientNet_Encoder(nn.Module): 124 | def __init__(self, output_dimension: int): 125 | super(EfficientNet_Encoder, self).__init__() 126 | 127 | self.efficient_net = models.efficientnet_b3().features 128 | self.efficient_net[0][0] = nn.Conv2d( 129 | 1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False 130 | ) 131 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1) 132 | self.leakyrelu = nn.LeakyReLU() 133 | self.linear = nn.Linear(1536, output_dimension) 134 | 135 | def forward(self, x): 136 | x = x.unsqueeze(1) 137 | x = self.efficient_net(x) 138 | x = self.avg_pool(x).flatten(start_dim=1) 139 | x = self.leakyrelu(self.linear(x)) 140 | return x 141 | 142 | 143 | @add_embedding("SWINS") 144 | class SwinTransformerS_Encoder(nn.Module): 145 | def __init__(self, output_dimension: int): 146 | super(SwinTransformerS_Encoder, self).__init__() 147 | 148 | self.swin_transformer = models.swin_t() 149 | self.swin_transformer.features[0][0] = nn.Conv2d( 150 | 1, 96, kernel_size=(4, 4), stride=(4, 4) 151 | ) 152 | self.swin_transformer.head = nn.Linear( 153 | in_features=768, out_features=output_dimension, bias=True 154 | ) 155 | 156 | def forward(self, x): 157 | x = x.unsqueeze(1) 158 | x = self.swin_transformer(x) 159 | return x 160 | 161 | 162 | @add_embedding("WIDERES50") 163 | class WideResnet50_Encoder(nn.Module): 164 | def __init__(self, output_dimension: int): 165 | super(WideResnet50_Encoder, self).__init__() 166 | 167 | self.wideresnet = models.wide_resnet50_2() 168 | self.wideresnet.conv1 = nn.Conv2d( 169 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 170 | ) 171 | self.linear = nn.Linear(1000, output_dimension) 172 | 173 | def forward(self, x): 174 | x = x.unsqueeze(1) 175 | x = self.wideresnet(x) 176 | x = self.linear(nn.functional.relu(x)) 177 | return x 178 | 179 | 180 | @add_embedding("WIDERES101") 181 | class WideResnet101_Encoder(nn.Module): 182 | def __init__(self, output_dimension: int): 183 | super(WideResnet101_Encoder, self).__init__() 184 | 185 | self.wideresnet = models.wide_resnet101_2() 186 | self.wideresnet.conv1 = nn.Conv2d( 187 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 188 | ) 189 | self.linear = nn.Linear(1000, output_dimension) 190 | 191 | def forward(self, x): 192 | x = x.unsqueeze(1) 193 | x = self.wideresnet(x) 194 | x = self.linear(nn.functional.relu(x)) 195 | return x 196 | 197 | 198 | @add_embedding("REGNETY") 199 | class RegNetY_Encoder(nn.Module): 200 | def __init__(self, output_dimension: int): 201 | super(RegNetY_Encoder, self).__init__() 202 | 203 | self.regnety = models.regnet_y_1_6gf() 204 | self.regnety.stem[0] = nn.Conv2d( 205 | 1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False 206 | ) 207 | self.regnety.fc = nn.Linear( 208 | in_features=888, out_features=output_dimension, bias=True 209 | ) 210 | 211 | def forward(self, x): 212 | x = x.unsqueeze(1) 213 | x = self.regnety(x) 214 | return x 215 | 216 | 217 | @add_embedding("SHUFFLENET") 218 | class ShuffleNet_Encoder(nn.Module): 219 | def __init__(self, output_dimension: int): 220 | super(ShuffleNet_Encoder, self).__init__() 221 | 222 | self.shuffle_net = models.shufflenet_v2_x0_5() 223 | self.shuffle_net.conv1[0] = nn.Conv2d( 224 | 1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False 225 | ) 226 | self.shuffle_net.fc = nn.Linear( 227 | in_features=1024, out_features=output_dimension, bias=True 228 | ) 229 | 230 | def forward(self, x): 231 | x = x.unsqueeze(1) 232 | x = self.shuffle_net(x) 233 | return x 234 | 235 | 236 | @add_embedding("RESNET18_FFT_FILTER") 237 | class ResNet18_FFT_Encoder(nn.Module): 238 | def __init__(self, output_dimension: int): 239 | super(ResNet18_FFT_Encoder, self).__init__() 240 | self.resnet = models.resnet18() 241 | self.resnet.conv1 = nn.Conv2d( 242 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 243 | ) 244 | self.resnet.fc = nn.Linear( 245 | in_features=512, out_features=output_dimension, bias=True 246 | ) 247 | 248 | self._fft_filter = LowPassFilter(128, 25) 249 | 250 | def forward(self, x): 251 | # Low pass filter images 252 | x = self._fft_filter(x) 253 | # Proceed as normal 254 | x = x.unsqueeze(1) 255 | x = self.resnet(x) 256 | return x 257 | 258 | 259 | @add_embedding("RESNET18_FFT_FILTER_132") 260 | class ResNet18_FFT_Encoder_132(nn.Module): 261 | def __init__(self, output_dimension: int): 262 | super(ResNet18_FFT_Encoder_132, self).__init__() 263 | self.resnet = models.resnet18() 264 | self.resnet.conv1 = nn.Conv2d( 265 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 266 | ) 267 | self.resnet.fc = nn.Linear( 268 | in_features=512, out_features=output_dimension, bias=True 269 | ) 270 | 271 | self._fft_filter = LowPassFilter(132, 25) 272 | 273 | def forward(self, x): 274 | # Low pass filter images 275 | x = self._fft_filter(x) 276 | # Proceed as normal 277 | x = x.unsqueeze(1) 278 | x = self.resnet(x) 279 | return x 280 | 281 | 282 | @add_embedding("RESNET18_FFT_FILTER_224") 283 | class ResNet18_FFT_Encoder_224(nn.Module): 284 | def __init__(self, output_dimension: int): 285 | super(ResNet18_FFT_Encoder_224, self).__init__() 286 | self.resnet = models.resnet18() 287 | self.resnet.conv1 = nn.Conv2d( 288 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 289 | ) 290 | self.resnet.fc = nn.Linear( 291 | in_features=512, out_features=output_dimension, bias=True 292 | ) 293 | 294 | self._fft_filter = LowPassFilter(224, 25) 295 | 296 | def forward(self, x): 297 | # Low pass filter images 298 | x = self._fft_filter(x) 299 | # Proceed as normal 300 | x = x.unsqueeze(1) 301 | x = self.resnet(x) 302 | return x 303 | 304 | 305 | @add_embedding("RESNET18_FFT_FILTER_256") 306 | class ResNet18_FFT_Encoder_256(nn.Module): 307 | def __init__(self, output_dimension: int): 308 | super(ResNet18_FFT_Encoder_256, self).__init__() 309 | self.resnet = models.resnet18() 310 | self.resnet.conv1 = nn.Conv2d( 311 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 312 | ) 313 | self.resnet.fc = nn.Linear( 314 | in_features=512, out_features=output_dimension, bias=True 315 | ) 316 | 317 | self._fft_filter = LowPassFilter(256, 10) 318 | 319 | def forward(self, x): 320 | # Low pass filter images 321 | x = self._fft_filter(x) 322 | # Proceed as normal 323 | x = x.unsqueeze(1) 324 | x = self.resnet(x) 325 | return x 326 | 327 | 328 | @add_embedding("RESNET34") 329 | class ResNet34_Encoder(nn.Module): 330 | def __init__(self, output_dimension: int): 331 | super(ResNet34_Encoder, self).__init__() 332 | self.resnet = models.resnet34() 333 | self.resnet.conv1 = nn.Conv2d( 334 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 335 | ) 336 | self.resnet.fc = nn.Linear( 337 | in_features=512, out_features=output_dimension, bias=True 338 | ) 339 | 340 | def forward(self, x): 341 | x = x.unsqueeze(1) 342 | x = self.resnet(x) 343 | return x 344 | 345 | 346 | @add_embedding("RESNET34_FFT_FILTER_256") 347 | class ResNet34_Encoder_FFT_FILTER_256(nn.Module): 348 | def __init__(self, output_dimension: int): 349 | super(ResNet34_Encoder_FFT_FILTER_256, self).__init__() 350 | self.resnet = models.resnet34() 351 | self.resnet.conv1 = nn.Conv2d( 352 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 353 | ) 354 | self.resnet.fc = nn.Linear( 355 | in_features=512, out_features=output_dimension, bias=True 356 | ) 357 | self._fft_filter = LowPassFilter(256, 50) 358 | 359 | def forward(self, x): 360 | # Low pass filter images 361 | x = self._fft_filter(x) 362 | # Proceed as normal 363 | x = x.unsqueeze(1) 364 | x = self.resnet(x) 365 | return x 366 | 367 | 368 | @add_embedding("VGG19") 369 | class VGG19_Encoder(nn.Module): 370 | def __init__(self, output_dimension: int): 371 | super(VGG19_Encoder, self).__init__() 372 | 373 | self.vgg19 = models.vgg19_bn().features 374 | self.vgg19[0] = nn.Conv2d( 375 | 1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) 376 | ) 377 | 378 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7)) 379 | 380 | self.feedforward = nn.Sequential( 381 | *[ 382 | nn.Linear(in_features=25088, out_features=4096), 383 | nn.ReLU(inplace=True), 384 | nn.Linear(in_features=4096, out_features=output_dimension, bias=True), 385 | nn.ReLU(inplace=True), 386 | ] 387 | ) 388 | 389 | def forward(self, x): 390 | # Low pass filter images 391 | # x = self._fft_filter(x) 392 | # Proceed as normal 393 | x = x.unsqueeze(1) 394 | x = self.vgg19(x) 395 | x = self.avgpool(x).flatten(start_dim=1) 396 | x = self.feedforward(x) 397 | return x 398 | 399 | 400 | @add_embedding("ConvEncoder_Tutorial") 401 | class ConvEncoder(nn.Module): 402 | def __init__(self, output_dimension: int): 403 | super(ConvEncoder, self).__init__() 404 | ndf = 16 # fixed for the tutorial 405 | self.main = nn.Sequential( 406 | # input is 1 x 64 x 64 407 | nn.Conv2d(1, ndf, 4, 2, 1, bias=False), 408 | nn.LeakyReLU(0.2, inplace=True), 409 | # state size. (ndf) x 32 x 32 410 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 411 | # nn.BatchNorm2d(ndf * 2), 412 | nn.LeakyReLU(0.2, inplace=True), 413 | # state size. (ndf*2) x 16 x 16 414 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 415 | # nn.BatchNorm2d(ndf * 4), 416 | nn.LeakyReLU(0.2, inplace=True), 417 | # state size. (ndf*4) x 8 x 8 418 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 419 | # nn.BatchNorm2d(ndf * 8), 420 | nn.LeakyReLU(0.2, inplace=True), 421 | # state size. (ndf*8) x 4 x 4 422 | nn.Conv2d(ndf * 8, output_dimension, 4, 1, 0, bias=False), 423 | # state size. out_dims x 1 x 1 424 | ) 425 | 426 | def forward(self, x): 427 | x = x.view(-1, 1, 64, 64) 428 | x = self.main(x) 429 | return x.view(x.size(0), -1) # flatten 430 | 431 | 432 | if __name__ == "__main__": 433 | pass 434 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/models/estimator_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import zuko 4 | from lampe.inference import NPE, NRE 5 | 6 | 7 | class Standardize(nn.Module): 8 | """ 9 | Module to standardize inputs and retransform them to the original space 10 | 11 | Args: 12 | mean (torch.Tensor): mean of the data 13 | std (torch.Tensor): standard deviation of the data 14 | 15 | Returns: 16 | standardized (torch.Tensor): standardized data 17 | """ 18 | 19 | # Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py 20 | def __init__(self, mean: float, std: float) -> None: 21 | super(Standardize, self).__init__() 22 | mean, std = map(torch.as_tensor, (mean, std)) 23 | self.mean = mean 24 | self.std = std 25 | self.register_buffer("_mean", mean) 26 | self.register_buffer("_std", std) 27 | 28 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 29 | """ 30 | Standardize the input tensor 31 | 32 | Args: 33 | tensor (torch.Tensor): input tensor 34 | 35 | Returns: 36 | standardized (torch.Tensor): standardized tensor 37 | """ 38 | 39 | return (tensor - self._mean) / self._std 40 | 41 | def transform(self, tensor: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Transform the standardized tensor back to the original space 44 | 45 | Args: 46 | tensor (torch.Tensor): input tensor 47 | 48 | Returns: 49 | retransformed (torch.Tensor): retransformed tensor 50 | """ 51 | 52 | return (tensor * self._std) + self._mean 53 | 54 | 55 | class NPEWithEmbedding(nn.Module): 56 | """Neural Posterior Estimation with embedding net 57 | 58 | Attributes: 59 | npe (NPE): NPE model 60 | embedding (nn.Module): embedding net 61 | standardize (Standardize): standardization module 62 | """ 63 | 64 | def __init__( 65 | self, 66 | embedding_net: nn.Module, 67 | output_embedding_dim: int, 68 | num_transforms: int = 4, 69 | num_hidden_flow: int = 2, 70 | hidden_flow_dim: int = 128, 71 | flow: nn.Module = zuko.flows.MAF, 72 | theta_shift: float = 0.0, 73 | theta_scale: float = 1.0, 74 | **kwargs, 75 | ) -> None: 76 | """ 77 | Neural Posterior Estimation with embedding net. 78 | 79 | Args: 80 | embedding_net (nn.Module): embedding net 81 | output_embedding_dim (int): output embedding dimension 82 | num_transforms (int, optional): number of transforms. Defaults to 4. 83 | num_hidden_flow (int, optional): number of hidden layers in flow. Defaults to 2. 84 | hidden_flow_dim (int, optional): hidden dimension in flow. Defaults to 128. 85 | flow (nn.Module, optional): flow. Defaults to zuko.flows.MAF. 86 | theta_shift (float, optional): Shift of the theta for standardization. Defaults to 0.0. 87 | theta_scale (float, optional): Scale of the theta for standardization. Defaults to 1.0. 88 | kwargs: additional arguments for the flow 89 | 90 | Returns: 91 | None 92 | """ 93 | 94 | super().__init__() 95 | 96 | self.npe = NPE( 97 | 1, 98 | output_embedding_dim, 99 | transforms=num_transforms, 100 | build=flow, 101 | hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64], 102 | **kwargs, 103 | ) 104 | 105 | self.embedding = embedding_net() 106 | self.standardize = Standardize(theta_shift, theta_scale) 107 | 108 | def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 109 | """ 110 | Forward pass of the NPE model 111 | 112 | Args: 113 | theta (torch.Tensor): Conformational parameters. 114 | x (torch.Tensor): Image to condition the posterior on. 115 | 116 | Returns: 117 | torch.Tensor: Log probability of the posterior. 118 | """ 119 | 120 | return self.npe(self.standardize(theta), self.embedding(x)) 121 | 122 | def flow(self, x: torch.Tensor): 123 | """ 124 | Conditions the posterior on an image. 125 | 126 | Args: 127 | x (torch.Tensor): Image to condition the posterior on. 128 | 129 | Returns: 130 | zuko.flows.Flow: The posterior distribution. 131 | """ 132 | return self.npe.flow(self.embedding(x)) 133 | 134 | def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor: 135 | """ 136 | Generate samples from the posterior distribution. 137 | 138 | Args: 139 | x (torch.Tensor): Image to condition the posterior on. 140 | shape (tuple, optional): Shape of the samples. Defaults to (1,). 141 | 142 | Returns: 143 | torch.Tensor: Samples from the posterior distribution. 144 | """ 145 | 146 | samples_standardized = self.flow(x).sample(shape) 147 | return self.standardize.transform(samples_standardized) 148 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/priors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import zuko 3 | from torch.distributions.distribution import Distribution 4 | from torch.utils.data import DataLoader, Dataset, IterableDataset 5 | 6 | 7 | def gen_quat() -> torch.Tensor: 8 | """ 9 | Generate a random quaternion. 10 | 11 | Returns: 12 | quat (np.ndarray): Random quaternion 13 | 14 | """ 15 | count = 0 16 | while count < 1: 17 | quat = 2 * torch.rand(size=(4,)) - 1 18 | norm = torch.sqrt(torch.sum(quat**2)) 19 | if 0.2 <= norm <= 1.0: 20 | quat /= norm 21 | count += 1 22 | 23 | return quat 24 | 25 | 26 | def get_image_priors( 27 | max_index, image_config: dict, device="cuda" 28 | ) -> zuko.distributions.BoxUniform: 29 | """ 30 | Return uniform prior in 1d from 0 to 19 31 | 32 | Args: 33 | max_index (int): max index of the 1d prior 34 | 35 | Returns: 36 | zuko.distributions.BoxUniform: prior 37 | """ 38 | if isinstance(image_config["SIGMA"], list) and len(image_config["SIGMA"]) == 2: 39 | lower = torch.tensor( 40 | [[image_config["SIGMA"][0]]], dtype=torch.float32, device=device 41 | ) 42 | upper = torch.tensor( 43 | [[image_config["SIGMA"][1]]], dtype=torch.float32, device=device 44 | ) 45 | 46 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound." 47 | 48 | sigma_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1) 49 | 50 | shift_prior = zuko.distributions.BoxUniform( 51 | lower=torch.tensor( 52 | [-image_config["SHIFT"], -image_config["SHIFT"]], 53 | dtype=torch.float32, 54 | device=device, 55 | ), 56 | upper=torch.tensor( 57 | [image_config["SHIFT"], image_config["SHIFT"]], 58 | dtype=torch.float32, 59 | device=device, 60 | ), 61 | ndims=1, 62 | ) 63 | 64 | if isinstance(image_config["DEFOCUS"], list) and len(image_config["DEFOCUS"]) == 2: 65 | lower = torch.tensor( 66 | [[image_config["DEFOCUS"][0]]], dtype=torch.float32, device=device 67 | ) 68 | upper = torch.tensor( 69 | [[image_config["DEFOCUS"][1]]], dtype=torch.float32, device=device 70 | ) 71 | 72 | assert lower > 0.0, "The lower bound for DEFOCUS must be positive." 73 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound." 74 | 75 | defocus_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1) 76 | 77 | if ( 78 | isinstance(image_config["B_FACTOR"], list) 79 | and len(image_config["B_FACTOR"]) == 2 80 | ): 81 | lower = torch.tensor( 82 | [[image_config["B_FACTOR"][0]]], dtype=torch.float32, device=device 83 | ) 84 | upper = torch.tensor( 85 | [[image_config["B_FACTOR"][1]]], dtype=torch.float32, device=device 86 | ) 87 | 88 | assert lower > 0.0, "The lower bound for B_FACTOR must be positive." 89 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound." 90 | 91 | b_factor_prior = zuko.distributions.BoxUniform( 92 | lower=lower, upper=upper, ndims=1 93 | ) 94 | 95 | if isinstance(image_config["SNR"], list) and len(image_config["SNR"]) == 2: 96 | lower = torch.tensor( 97 | [[image_config["SNR"][0]]], dtype=torch.float32, device=device 98 | ).log10() 99 | upper = torch.tensor( 100 | [[image_config["SNR"][1]]], dtype=torch.float32, device=device 101 | ).log10() 102 | 103 | assert lower <= upper, "Lower bound must be smaller or equal than upper bound." 104 | 105 | snr_prior = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1) 106 | 107 | amp_prior = zuko.distributions.BoxUniform( 108 | lower=torch.tensor([[image_config["AMP"]]], dtype=torch.float32, device=device), 109 | upper=torch.tensor([[image_config["AMP"]]], dtype=torch.float32, device=device), 110 | ndims=1, 111 | ) 112 | 113 | index_prior = zuko.distributions.BoxUniform( 114 | lower=torch.tensor([0], dtype=torch.float32, device=device), 115 | upper=torch.tensor([max_index], dtype=torch.float32, device=device), 116 | ) 117 | quaternion_prior = QuaternionPrior(device) 118 | if ( 119 | image_config.get("ROTATIONS") 120 | and isinstance(image_config["ROTATIONS"], list) 121 | and len(image_config["ROTATIONS"]) == 4 122 | ): 123 | test_quat = image_config["ROTATIONS"] 124 | quaternion_prior = QuaternionTestPrior(test_quat, device) 125 | 126 | return ImagePrior( 127 | index_prior, 128 | quaternion_prior, 129 | sigma_prior, 130 | shift_prior, 131 | defocus_prior, 132 | b_factor_prior, 133 | amp_prior, 134 | snr_prior, 135 | device=device, 136 | ) 137 | 138 | 139 | class QuaternionPrior: 140 | def __init__(self, device) -> None: 141 | self.device = device 142 | 143 | def sample(self, shape) -> torch.Tensor: 144 | quats = torch.stack( 145 | [gen_quat().to(self.device) for _ in range(shape[0])], dim=0 146 | ) 147 | return quats 148 | 149 | 150 | class QuaternionTestPrior: 151 | def __init__(self, quat, device) -> None: 152 | self.device = device 153 | self.quat = torch.tensor(quat, device=device) 154 | 155 | def sample(self, shape) -> torch.Tensor: 156 | quats = torch.stack([self.quat for _ in range(shape[0])], dim=0) 157 | return quats 158 | 159 | 160 | class ImagePrior: 161 | def __init__( 162 | self, 163 | index_prior, 164 | quaternion_prior, 165 | sigma_prior, 166 | shift_prior, 167 | defocus_prior, 168 | b_factor_prior, 169 | amp_prior, 170 | snr_prior, 171 | device, 172 | ) -> None: 173 | self.priors = [ 174 | index_prior, 175 | quaternion_prior, 176 | sigma_prior, 177 | shift_prior, 178 | defocus_prior, 179 | b_factor_prior, 180 | amp_prior, 181 | snr_prior, 182 | ] 183 | 184 | def sample(self, shape) -> torch.Tensor: 185 | samples = [prior.sample(shape) for prior in self.priors] 186 | return samples 187 | 188 | 189 | class PriorDataset(IterableDataset): 190 | def __init__( 191 | self, 192 | prior: Distribution, 193 | batch_shape: torch.Size = (), 194 | ): 195 | super().__init__() 196 | 197 | self.prior = prior 198 | self.batch_shape = batch_shape 199 | 200 | def __iter__(self): 201 | while True: 202 | theta = self.prior.sample(self.batch_shape) 203 | yield theta 204 | 205 | 206 | class PriorLoader(DataLoader): 207 | def __init__( 208 | self, 209 | prior: Distribution, 210 | batch_size: int = 2**8, # 256 211 | **kwargs, 212 | ): 213 | super().__init__( 214 | PriorDataset(prior, batch_shape=(batch_size,)), 215 | batch_size=None, 216 | **kwargs, 217 | ) 218 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/train_npe_model.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.optim as optim 6 | from torch.utils.data import TensorDataset 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | from lampe.data import JointLoader, H5Dataset 10 | from lampe.inference import NPELoss 11 | from lampe.utils import GDStep 12 | from itertools import islice 13 | 14 | from cryo_sbi.inference.priors import get_image_priors, PriorLoader 15 | from cryo_sbi.inference.models.build_models import build_npe_flow_model 16 | from cryo_sbi.inference.validate_train_config import check_train_params 17 | from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator 18 | from cryo_sbi.wpa_simulator.validate_image_config import check_image_params 19 | from cryo_sbi.inference.validate_train_config import check_train_params 20 | import cryo_sbi.utils.image_utils as img_utils 21 | 22 | 23 | def load_model( 24 | train_config: str, model_state_dict: str, device: str, train_from_checkpoint: bool 25 | ) -> torch.nn.Module: 26 | """ 27 | Load model from checkpoint or from scratch. 28 | 29 | Args: 30 | train_config (str): path to train config file 31 | model_state_dict (str): path to model state dict 32 | device (str): device to load model to 33 | train_from_checkpoint (bool): whether to load model from checkpoint or from scratch 34 | """ 35 | 36 | check_train_params(train_config) 37 | estimator = build_npe_flow_model(train_config) 38 | if train_from_checkpoint: 39 | if not isinstance(model_state_dict, str): 40 | raise Warning("No model state dict specified! --model_state_dict is empty") 41 | print(f"Loading model parameters from {model_state_dict}") 42 | estimator.load_state_dict(torch.load(model_state_dict)) 43 | estimator.to(device=device) 44 | return estimator 45 | 46 | 47 | def npe_train_no_saving( 48 | image_config: str, 49 | train_config: str, 50 | epochs: int, 51 | estimator_file: str, 52 | loss_file: str, 53 | train_from_checkpoint: bool = False, 54 | model_state_dict: Union[str, None] = None, 55 | n_workers: int = 1, 56 | device: str = "cpu", 57 | saving_frequency: int = 20, 58 | simulation_batch_size: int = 1024, 59 | ) -> None: 60 | """ 61 | Train NPE model by simulating training data on the fly. 62 | Saves model and loss to disk. 63 | 64 | Args: 65 | image_config (str): path to image config file 66 | train_config (str): path to train config file 67 | epochs (int): number of epochs 68 | estimator_file (str): path to estimator file 69 | loss_file (str): path to loss file 70 | train_from_checkpoint (bool, optional): train from checkpoint. Defaults to False. 71 | model_state_dict (str, optional): path to pretrained model state dict. Defaults to None. 72 | n_workers (int, optional): number of workers. Defaults to 1. 73 | device (str, optional): training device. Defaults to "cpu". 74 | saving_frequency (int, optional): frequency of saving model. Defaults to 20. 75 | whiten_filter (Union[None, str], optional): path to whiten filter. Defaults to None. 76 | 77 | Raises: 78 | Warning: No model state dict specified! --model_state_dict is empty 79 | 80 | Returns: 81 | None 82 | """ 83 | 84 | train_config = json.load(open(train_config)) 85 | check_train_params(train_config) 86 | image_config = json.load(open(image_config)) 87 | 88 | assert simulation_batch_size >= train_config["BATCH_SIZE"] 89 | assert simulation_batch_size % train_config["BATCH_SIZE"] == 0 90 | 91 | if image_config["MODEL_FILE"].endswith("npy"): 92 | models = ( 93 | torch.from_numpy( 94 | np.load(image_config["MODEL_FILE"]), 95 | ) 96 | .to(device) 97 | .to(torch.float32) 98 | ) 99 | else: 100 | models = torch.load(image_config["MODEL_FILE"]).to(device).to(torch.float32) 101 | 102 | image_prior = get_image_priors(len(models) - 1, image_config, device="cpu") 103 | prior_loader = PriorLoader( 104 | image_prior, batch_size=simulation_batch_size, num_workers=n_workers 105 | ) 106 | 107 | num_pixels = torch.tensor( 108 | image_config["N_PIXELS"], dtype=torch.float32, device=device 109 | ) 110 | pixel_size = torch.tensor( 111 | image_config["PIXEL_SIZE"], dtype=torch.float32, device=device 112 | ) 113 | 114 | estimator = load_model( 115 | train_config, model_state_dict, device, train_from_checkpoint 116 | ) 117 | 118 | loss = NPELoss(estimator) 119 | optimizer = optim.AdamW( 120 | estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001 121 | ) 122 | step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) 123 | mean_loss = [] 124 | 125 | print("Training neural netowrk:") 126 | estimator.train() 127 | with tqdm(range(epochs), unit="epoch") as tq: 128 | for epoch in tq: 129 | losses = [] 130 | for parameters in islice(prior_loader, 100): 131 | ( 132 | indices, 133 | quaternions, 134 | res, 135 | shift, 136 | defocus, 137 | b_factor, 138 | amp, 139 | snr, 140 | ) = parameters 141 | images = cryo_em_simulator( 142 | models, 143 | indices.to(device, non_blocking=True), 144 | quaternions.to(device, non_blocking=True), 145 | res.to(device, non_blocking=True), 146 | shift.to(device, non_blocking=True), 147 | defocus.to(device, non_blocking=True), 148 | b_factor.to(device, non_blocking=True), 149 | amp.to(device, non_blocking=True), 150 | snr.to(device, non_blocking=True), 151 | num_pixels, 152 | pixel_size, 153 | ) 154 | for _indices, _images in zip( 155 | indices.split(train_config["BATCH_SIZE"]), 156 | images.split(train_config["BATCH_SIZE"]), 157 | ): 158 | losses.append( 159 | step( 160 | loss( 161 | _indices.to(device, non_blocking=True), 162 | _images.to(device, non_blocking=True), 163 | ) 164 | ) 165 | ) 166 | losses = torch.stack(losses) 167 | 168 | tq.set_postfix(loss=losses.mean().item()) 169 | mean_loss.append(losses.mean().item()) 170 | if epoch % saving_frequency == 0: 171 | torch.save(estimator.state_dict(), estimator_file + f"_epoch={epoch}") 172 | 173 | torch.save(estimator.state_dict(), estimator_file) 174 | torch.save(torch.tensor(mean_loss), loss_file) 175 | -------------------------------------------------------------------------------- /src/cryo_sbi/inference/validate_train_config.py: -------------------------------------------------------------------------------- 1 | def check_train_params(config: dict) -> None: 2 | """ 3 | Checks if all necessary parameters are provided. 4 | 5 | Args: 6 | config (dict): Dictionary containing training parameters. 7 | 8 | Returns: 9 | None 10 | """ 11 | 12 | needed_keys = [ 13 | "EMBEDDING", 14 | "OUT_DIM", 15 | "NUM_TRANSFORM", 16 | "NUM_HIDDEN_FLOW", 17 | "HIDDEN_DIM_FLOW", 18 | "MODEL", 19 | "LEARNING_RATE", 20 | "CLIP_GRADIENT", 21 | "BATCH_SIZE", 22 | "THETA_SHIFT", 23 | "THETA_SCALE", 24 | ] 25 | 26 | for key in needed_keys: 27 | assert key in config.keys(), f"Please provide a value for {key}" 28 | 29 | return 30 | -------------------------------------------------------------------------------- /src/cryo_sbi/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/src/cryo_sbi/utils/__init__.py -------------------------------------------------------------------------------- /src/cryo_sbi/utils/command_line_tools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cryo_sbi.utils.generate_models import models_to_tensor 3 | 4 | 5 | def cl_models_to_tensor(): 6 | cl_parser = argparse.ArgumentParser( 7 | description="Convert models to tensor for cryoSBI", 8 | epilog="pdb-files: The name for the pdbs must contain a {} to be replaced by the index of the pdb file. The index starts at 0. \ 9 | For example protein_{}.pdb. trr-files: For .trr files you must provide a topology file." 10 | ) 11 | cl_parser.add_argument( 12 | "--model_files", action="store", type=str, required=True 13 | ) 14 | cl_parser.add_argument( 15 | "--output_file", action="store", type=str, required=True 16 | ) 17 | cl_parser.add_argument( 18 | "--n_pdbs", action="store", type=int, required=False, default=None 19 | ) 20 | cl_parser.add_argument( 21 | "--top_file", action="store", type=str, required=False, default=None 22 | ) 23 | args = cl_parser.parse_args() 24 | models_to_tensor( 25 | model_files=args.model_files, 26 | output_file=args.output_file, 27 | n_pdbs=args.n_pdbs, 28 | top_file=args.top_file 29 | ) -------------------------------------------------------------------------------- /src/cryo_sbi/utils/estimator_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from cryo_sbi.inference.models import build_models 4 | 5 | 6 | @torch.no_grad() 7 | def evaluate_log_prob( 8 | estimator: torch.nn.Module, 9 | images: torch.Tensor, 10 | theta: torch.Tensor, 11 | batch_size: int = 0, 12 | device: str = "cpu", 13 | ) -> torch.Tensor: 14 | """ 15 | Evaluates the log probability of a given set of images under a given estimator. 16 | 17 | Args: 18 | estimator (torch.nn.Module): The posterior model to use for evaluation. 19 | images (torch.Tensor): The input images used to condition the posterior. 20 | theta (torch.Tensor): The parameter values at which to evaluate the log probability. 21 | batch_size (int, optional): The batch size for batching the images. Defaults to 0. 22 | device (str, optional): The device to use for computation. Defaults to "cpu". 23 | 24 | Returns: 25 | torch.Tensor: The log probabilities of the images under the estimator. 26 | """ 27 | 28 | # batching images if necessary 29 | if images.shape[0] > batch_size and batch_size > 0: 30 | images = torch.split(images, split_size_or_sections=batch_size, dim=0) 31 | else: 32 | batch_size = images.shape[0] 33 | images = [images] 34 | 35 | # theta dimensions [num_eval, num_images, 1] 36 | if theta.ndim == 3: 37 | num_eval = theta.shape[0] 38 | num_images = images.shape[0] 39 | assert theta.shape == torch.Size([num_eval, num_images, 1]) 40 | 41 | elif theta.ndim == 2: 42 | raise IndexError("theta must have 3 dimensions [num_eval, num_images, 1]") 43 | 44 | elif theta.ndim == 1: 45 | theta = theta.reshape(-1, 1, 1).repeat(1, batch_size, 1) 46 | 47 | log_probs = [] 48 | for image_batch in images: 49 | posterior = estimator.flow(image_batch.to(device)) 50 | log_probs.append(posterior.log_prob(estimator.standardize(theta.to(device)))) 51 | 52 | log_probs = torch.cat(log_probs, dim=1) 53 | return log_probs 54 | 55 | 56 | @torch.no_grad() 57 | def sample_posterior( 58 | estimator: torch.nn.Module, 59 | images: torch.Tensor, 60 | num_samples: int, 61 | batch_size: int = 100, 62 | device: str = "cpu", 63 | ) -> torch.Tensor: 64 | """ 65 | Samples from the posterior distribution 66 | 67 | Args: 68 | estimator (torch.nn.Module): The posterior to use for sampling. 69 | images (torch.Tensor): The images used to condition the posterior. 70 | num_samples (int): The number of samples to draw 71 | batch_size (int, optional): The batch size for sampling. Defaults to 100. 72 | device (str, optional): The device to use. Defaults to "cpu". 73 | 74 | Returns: 75 | torch.Tensor: The posterior samples 76 | """ 77 | 78 | theta_samples = [] 79 | 80 | if images.shape[0] > batch_size and batch_size > 0: 81 | images = torch.split(images, split_size_or_sections=batch_size, dim=0) 82 | else: 83 | batch_size = images.shape[0] 84 | images = [images] 85 | 86 | for image_batch in images: 87 | samples = estimator.sample( 88 | image_batch.to(device, non_blocking=True), shape=(num_samples,) 89 | ).cpu() 90 | theta_samples.append(samples.reshape(-1, image_batch.shape[0])) 91 | 92 | return torch.cat(theta_samples, dim=1) 93 | 94 | 95 | @torch.no_grad() 96 | def compute_latent_repr( 97 | estimator: torch.nn.Module, 98 | images: torch.Tensor, 99 | batch_size: int = 100, 100 | device: str = "cpu", 101 | ) -> torch.Tensor: 102 | """ 103 | Computes the latent representation of images. 104 | 105 | Args: 106 | estimator (torch.nn.Module): Posterior model for which to compute the latent representation. 107 | images (torch.Tensor): The images to compute the latent representation for. 108 | batch_size (int, optional): The batch size to use. Defaults to 100. 109 | device (str, optional): The device to use. Defaults to "cpu". 110 | 111 | Returns: 112 | torch.Tensor: The latent representation of the images. 113 | """ 114 | 115 | latent_space_samples = [] 116 | 117 | if images.shape[0] > batch_size and batch_size > 0: 118 | images = torch.split(images, split_size_or_sections=batch_size, dim=0) 119 | else: 120 | batch_size = images.shape[0] 121 | images = [images] 122 | 123 | for image_batch in images: 124 | samples = estimator.embedding(image_batch.to(device, non_blocking=True)).cpu() 125 | latent_space_samples.append(samples.reshape(image_batch.shape[0], -1)) 126 | 127 | return torch.cat(latent_space_samples, dim=0) 128 | 129 | 130 | def load_estimator( 131 | config_file_path: str, estimator_path: str, device: str = "cpu" 132 | ) -> torch.nn.Module: 133 | """ 134 | Loads a trained estimator. 135 | 136 | Args: 137 | config_file_path (str): Path to the config file used to train the estimator. 138 | estimator_path (str): Path to the estimator. 139 | device (str, optional): The device to use. Defaults to "cpu". 140 | 141 | Returns: 142 | torch.nn.Module: The loaded estimator. 143 | """ 144 | 145 | train_config = json.load(open(config_file_path)) 146 | estimator = build_models.build_npe_flow_model(train_config) 147 | estimator.load_state_dict( 148 | torch.load(estimator_path, map_location=torch.device(device)) 149 | ) 150 | estimator.to(device) 151 | estimator.eval() 152 | 153 | return estimator 154 | -------------------------------------------------------------------------------- /src/cryo_sbi/utils/generate_models.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import MDAnalysis as mda 3 | from MDAnalysis.analysis import align 4 | import torch 5 | 6 | 7 | def pdb_parser_(fname: str, atom_selection: str = "name CA") -> torch.tensor: 8 | """ 9 | Parses a pdb file and returns a coarsed grained atomic model of the protein. 10 | The atomic model is a 5xN array, where N is the number of residues in the protein. 11 | The first three rows are the x, y, z coordinates of the alpha carbons. 12 | 13 | Parameters 14 | ---------- 15 | fname : str 16 | The path to the pdb file. 17 | 18 | Returns 19 | ------- 20 | atomic_model : torch.tensor 21 | The coarse grained atomic model of the protein. 22 | """ 23 | 24 | univ = mda.Universe(fname) 25 | univ.atoms.translate(-univ.atoms.center_of_mass()) 26 | 27 | model = torch.from_numpy(univ.select_atoms(atom_selection).positions.T) 28 | 29 | return model 30 | 31 | 32 | def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1, **kwargs): 33 | """ 34 | Parses multiple pdb files and returns an coarsed grained model of the protein. The atomic model is a 5xN array, where N is the number of atoms or residues in the protein. The first three rows are the x, y, z coordinates of the atoms or residues. The fourth row is the atomic number of the atoms or the density of the residues. The fifth row is the variance of the atoms or residues, which is the resolution of the cryo-EM map divided by pi squared. 35 | 36 | Parameters 37 | ---------- 38 | file_formatter : str 39 | The path to the pdb file. The path must contain the placeholder {} for the pdb index. For example, if the path is "data/pdb/{}.pdb", then the placeholder is {}. 40 | n_pdbs : int 41 | The number of pdb files to parse. 42 | output_file : str 43 | The path to the output file. The output file must be a .pt file. 44 | mode : str 45 | The mode of the atomic model. Either "resid" or "all atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein. 46 | """ 47 | 48 | models = pdb_parser_(file_formatter.format(start_index), **kwargs) 49 | models = torch.zeros((n_pdbs, *models.shape)) 50 | 51 | for i in range(0, n_pdbs): 52 | models[i] = pdb_parser_(file_formatter.format(start_index + i), **kwargs) 53 | 54 | if output_file.endswith("pt"): 55 | torch.save(models, output_file) 56 | 57 | else: 58 | raise ValueError("Model file format not supported. Please use .pt.") 59 | 60 | return 61 | 62 | 63 | def traj_parser_(top_file: str, traj_file: str) -> torch.tensor: 64 | """ 65 | Parses a traj file and returns a coarsed grained atomic model of the protein. 66 | The atomic model is a Mx3xN array, where M is the number of frames in the trajectory, 67 | and N is the number of residues in the protein. The first three rows in axis 1 are the x, y, z coordinates of the alpha carbons. 68 | 69 | Parameters 70 | ---------- 71 | top_file : str 72 | The path to the traj file. 73 | 74 | Returns 75 | ------- 76 | atomic_model : torch.tensor 77 | The coarse grained atomic model of the protein. 78 | """ 79 | 80 | ref = mda.Universe(top_file) 81 | ref.atoms.translate(-ref.atoms.center_of_mass()) 82 | 83 | mobile = mda.Universe(top_file, traj_file) 84 | align.AlignTraj(mobile, ref, select="name CA", in_memory=True).run() 85 | 86 | atomic_models = torch.zeros( 87 | (mobile.trajectory.n_frames, 3, mobile.select_atoms("name CA").n_atoms) 88 | ) 89 | 90 | for i in range(mobile.trajectory.n_frames): 91 | mobile.trajectory[i] 92 | 93 | atomic_models[i, 0:3, :] = torch.from_numpy( 94 | mobile.select_atoms("name CA").positions.T 95 | ) 96 | 97 | return atomic_models 98 | 99 | 100 | def traj_parser(top_file: str, traj_file: str, output_file: str) -> None: 101 | """ 102 | Parses a traj file and returns an atomic model of the protein. The atomic model is a Mx5xN array, where M is the number of frames in the trajectory, and N is the number of atoms in the protein. The first three rows in axis 1 are the x, y, z coordinates of the atoms. The fourth row is the atomic number of the atoms. The fifth row is the variance of the atoms before the resolution is applied. 103 | 104 | Parameters 105 | ---------- 106 | top_file : str 107 | The path to the topology file. 108 | traj_file : str 109 | The path to the trajectory file. 110 | output_file : str 111 | The path to the output file. Must be a .pt file. 112 | mode : str 113 | The mode of the atomic model. Either "resid" or "all-atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein. 114 | 115 | Returns 116 | ------- 117 | None 118 | """ 119 | 120 | atomic_models = traj_parser_(top_file, traj_file) 121 | 122 | if output_file.endswith("pt"): 123 | torch.save(atomic_models, output_file) 124 | 125 | else: 126 | raise ValueError("Model file format not supported. Please use .pt.") 127 | 128 | return 129 | 130 | 131 | def models_to_tensor( 132 | model_files, 133 | output_file, 134 | n_pdbs: Union[int, None] = None, 135 | top_file: Union[str, None] = None, 136 | ): 137 | """ 138 | Converts different model files to a torch tensor. 139 | 140 | Parameters 141 | ---------- 142 | model_files : list 143 | A list of model files to convert to a torch tensor. 144 | 145 | output_file : str 146 | The path to the output file. Must be a .pt file. 147 | 148 | n_models : int 149 | The number of models to convert to a torch tensor. Just needed for models in pdb files. 150 | 151 | top_file : str 152 | The path to the topology file. Just needed for models in trr files. 153 | 154 | Returns 155 | ------- 156 | None 157 | """ 158 | assert output_file.endswith("pt"), "The output file must be a .pt file." 159 | if model_files.endswith("trr"): 160 | assert top_file is not None, "Please provide a topology file." 161 | assert n_pdbs is None, "The number of pdb files is not needed for trr files." 162 | traj_parser(top_file, model_files, output_file) 163 | elif model_files.endswith("pdb"): 164 | assert n_pdbs is not None, "Please provide the number of pdb files." 165 | assert top_file is None, "The topology file is not needed for pdb files." 166 | pdb_parser(model_files, n_pdbs, output_file) 167 | 168 | 169 | -------------------------------------------------------------------------------- /src/cryo_sbi/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torch.distributions as d 7 | import mrcfile 8 | from tqdm import tqdm 9 | 10 | 11 | def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tensor: 12 | """ 13 | Create a circular mask for a given image size and radius. 14 | 15 | Args: 16 | n_pixels (int): Side length of the image in pixels. 17 | radius (int): Radius of the circle. 18 | inside (bool, optional): If True, the mask will be True inside the circle. Defaults to True. 19 | 20 | Returns: 21 | mask (torch.Tensor): Mask of shape (n_pixels, n_pixels). 22 | """ 23 | 24 | grid = torch.linspace(-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels) 25 | r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 26 | 27 | if inside is True: 28 | mask = r_2d < radius**2 29 | else: 30 | mask = r_2d > radius**2 31 | 32 | return mask 33 | 34 | 35 | class Mask: 36 | """ 37 | Mask a circular region in an image. 38 | 39 | Args: 40 | image_size (int): Number of pixels in the image. 41 | radius (int): Radius of the circle. 42 | inside (bool, optional): If True, the mask will be True inside the circle. Defaults to True. 43 | """ 44 | 45 | def __init__(self, image_size: int, radius: int, inside: bool = False) -> None: 46 | self.image_size = image_size 47 | self.n_pixels = radius 48 | self.mask = circular_mask(image_size, radius, inside=inside) 49 | 50 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 51 | """Mask a circular region in an image. 52 | 53 | Args: 54 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 55 | 56 | Returns: 57 | image (torch.Tensor): Image with masked region equal to zero. 58 | """ 59 | 60 | if len(image.shape) == 2: 61 | image[self.mask] = 0 62 | elif len(image.shape) == 3: 63 | image[:, self.mask] = 0 64 | else: 65 | raise NotImplementedError 66 | 67 | return image 68 | 69 | 70 | def fourier_down_sample( 71 | image: torch.Tensor, image_size: int, n_pixels: int 72 | ) -> torch.Tensor: 73 | """ 74 | Downsample an image by removing the outer frequencies. 75 | 76 | Args: 77 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 78 | image_size (int): Side length of the image in pixels. 79 | n_pixels (int): Number of pixels to remove from each side. 80 | 81 | Returns: 82 | reconstructed (torch.Tensor): Downsampled image. 83 | """ 84 | 85 | fft_image = torch.fft.fft2(image) 86 | fft_image = torch.fft.fftshift(fft_image) 87 | 88 | if len(image.shape) == 2: 89 | fft_image = fft_image[ 90 | n_pixels : image_size - n_pixels, 91 | n_pixels : image_size - n_pixels, 92 | ] 93 | elif len(image.shape) == 3: 94 | fft_image = fft_image[ 95 | :, 96 | n_pixels : image_size - n_pixels, 97 | n_pixels : image_size - n_pixels, 98 | ] 99 | else: 100 | raise NotImplementedError 101 | 102 | fft_image = torch.fft.fftshift(fft_image) 103 | reconstructed = torch.fft.ifft2(fft_image).real 104 | return reconstructed 105 | 106 | 107 | class FourierDownSample: 108 | """ 109 | Downsample an image by removing the outer frequencies. 110 | 111 | Args: 112 | image_size (int): Size of image in pixels. 113 | down_sampled_size (int): Size of downsampled image in pixels. 114 | """ 115 | 116 | def __init__(self, image_size: int, down_sampled_size: int) -> None: 117 | self._image_size = image_size 118 | self._n_pixels = (image_size - down_sampled_size) // 2 119 | 120 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 121 | """Downsample an image by removing the outer frequencies. 122 | 123 | Args: 124 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 125 | 126 | Returns: 127 | down_sampled (torch.Tensor): Downsampled image. 128 | """ 129 | 130 | down_sampled = fourier_down_sample( 131 | image, image_size=self._image_size, n_pixels=self._n_pixels 132 | ) 133 | 134 | return down_sampled 135 | 136 | 137 | class LowPassFilter: 138 | """ 139 | Low pass filter an image by removing the outer frequencies. 140 | 141 | Args: 142 | image_size (int): Side length of the image in pixels. 143 | frequency_cutoff (int): Frequency cutoff. 144 | """ 145 | 146 | def __init__(self, image_size: int, frequency_cutoff: int): 147 | self.mask = circular_mask(image_size, frequency_cutoff, inside=False) 148 | 149 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 150 | """ 151 | Low pass filter an image by removing the outer frequencies. 152 | 153 | Args: 154 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 155 | 156 | Returns: 157 | reconstructed (torch.Tensor): Low pass filtered image. 158 | """ 159 | fft_image = torch.fft.fft2(image) 160 | fft_image = torch.fft.fftshift(fft_image) 161 | 162 | if len(image.shape) == 2: 163 | fft_image[self.mask] = 0 + 0j 164 | elif len(image.shape) == 3: 165 | fft_image[:, self.mask] = 0 + 0j 166 | else: 167 | raise NotImplementedError 168 | 169 | fft_image = torch.fft.fftshift(fft_image) 170 | reconstructed = torch.fft.ifft2(fft_image).real 171 | return reconstructed 172 | 173 | 174 | class GaussianLowPassFilter: 175 | """ 176 | Low pass filter by dampening the outer frequencies with a Gaussian. 177 | """ 178 | 179 | def __init__(self, image_size: int, sigma: int): 180 | self._image_size = image_size 181 | self._sigma = sigma 182 | self._grid = torch.linspace( 183 | -0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size 184 | ) 185 | self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2 186 | self._mask = torch.exp(-self._r_2d / (2 * sigma**2)) 187 | 188 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 189 | """ 190 | Low pass filter an image by dampening the outer frequencies with a Gaussian. 191 | 192 | Args: 193 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 194 | 195 | Returns: 196 | reconstructed (torch.Tensor): Low pass filtered image. 197 | """ 198 | 199 | fft_image = torch.fft.fft2(image) 200 | fft_image = torch.fft.fftshift(fft_image) 201 | 202 | if len(image.shape) == 2: 203 | fft_image = fft_image * self._mask 204 | elif len(image.shape) == 3: 205 | fft_image = fft_image * self._mask.unsqueeze(0) 206 | else: 207 | raise NotImplementedError 208 | 209 | fft_image = torch.fft.fftshift(fft_image) 210 | reconstructed = torch.fft.ifft2(fft_image).real 211 | return reconstructed 212 | 213 | 214 | class NormalizeIndividual: 215 | """ 216 | Normalize an image by subtracting the mean and dividing by the standard deviation. 217 | """ 218 | 219 | def __init__(self) -> None: 220 | pass 221 | 222 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 223 | """ 224 | Normalize an image by subtracting the mean and dividing by the standard deviation. 225 | 226 | Args: 227 | images (torch.Tensor): Image of shape (n_channels, n_pixels, n_pixels). 228 | 229 | Returns: 230 | normalized (torch.Tensor): Normalized image. 231 | """ 232 | if len(images.shape) == 2: 233 | mean = images.mean() 234 | std = images.std() 235 | images = images.unsqueeze(0) 236 | elif len(images.shape) == 3: 237 | mean = images.mean(dim=[1, 2]) 238 | std = images.std(dim=[1, 2]) 239 | else: 240 | raise NotImplementedError 241 | 242 | return transforms.functional.normalize(images, mean=mean, std=std) 243 | 244 | 245 | def mrc_to_tensor(image_path: str) -> torch.Tensor: 246 | """ 247 | Convert an MRC file to a tensor. 248 | 249 | Args: 250 | image_path (str): Path to the MRC file. 251 | 252 | Returns: 253 | image (torch.Tensor): Image of shape (n_pixels, n_pixels). 254 | """ 255 | 256 | assert isinstance(image_path, str), "image path needs to be a string" 257 | with mrcfile.open(image_path) as mrc: 258 | image = mrc.data 259 | return torch.from_numpy(image) 260 | 261 | 262 | class MRCtoTensor: 263 | """ 264 | Convert an MRC file to a tensor. 265 | """ 266 | 267 | def __init__(self) -> None: 268 | pass 269 | 270 | def __call__(self, image_path: str) -> torch.Tensor: 271 | """ 272 | Convert an MRC file to a tensor. 273 | 274 | Args: 275 | image_path (str): Path to the MRC file. 276 | 277 | Returns: 278 | image (torch.Tensor): Image of shape (n_pixels, n_pixels). 279 | """ 280 | 281 | return mrc_to_tensor(image_path) 282 | 283 | 284 | def estimate_noise_psd(images: torch.Tensor, image_size: int, mask_radius : Union[int, None] = None) -> torch.Tensor: 285 | """ 286 | Estimates the power spectral density (PSD) of the noise in a set of images. 287 | 288 | Args: 289 | images (torch.Tensor): A tensor containing the input images. The shape of the tensor should be (N, H, W), 290 | where N is the number of images, H is the height, and W is the width. 291 | 292 | Returns: 293 | torch.Tensor: A tensor containing the estimated PSD of the noise. The shape of the tensor is (H, W), where H is the height 294 | and W is the width of the images. 295 | 296 | """ 297 | if mask_radius is None: 298 | mask_radius = image_size // 2 299 | mask = circular_mask(image_size, mask_radius, inside=False) 300 | denominator = mask.sum() * images.shape[0] 301 | images_masked = images * mask 302 | mean_est = images_masked.sum() / denominator 303 | image_masked_fft = torch.fft.fft2(images_masked) 304 | noise_psd_est = torch.sum(torch.abs(image_masked_fft)**2, dim=[0]) / denominator 305 | noise_psd_est[image_size // 2, image_size // 2] -= mean_est 306 | 307 | return noise_psd_est 308 | 309 | 310 | class WhitenImage: 311 | """ 312 | Whiten an image by dividing by the square root of the noise PSD. 313 | 314 | Args: 315 | image_size (int): Size of image in pixels. 316 | mask_radius (int, optional): Radius of the mask. Defaults to None. 317 | 318 | """ 319 | 320 | def __init__(self, image_size: int, mask_radius: Union[int, None] = None) -> None: 321 | self.image_size = image_size 322 | self.mask_radius = mask_radius 323 | 324 | def _estimate_noise_psd(self, images: torch.Tensor) -> torch.Tensor: 325 | """ 326 | Estimates the power spectral density (PSD) of the noise in a set of images. 327 | """ 328 | noise_psd = estimate_noise_psd(images, self.image_size, self.mask_radius) 329 | return noise_psd 330 | 331 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 332 | """ 333 | Whiten an image by dividing by the square root of the noise PSD. 334 | 335 | Args: 336 | image (torch.Tensor): Image of shape (n_pixels, n_pixels). 337 | 338 | Returns: 339 | image (torch.Tensor): Whitened image. 340 | """ 341 | 342 | assert images.ndim == 3, "Image should have shape (num_images , n_pixels, n_pixels)" 343 | noise_psd = self._estimate_noise_psd(images) ** -0.5 344 | images_fft = torch.fft.fft2(images) 345 | images_fft = images_fft * noise_psd 346 | images = torch.fft.ifft2(images_fft).real 347 | return images 348 | 349 | 350 | class MRCdataset: 351 | """ 352 | Creates a dataset of MRC files. 353 | Each MRC file is converted to a tensor and has a unique index. 354 | 355 | Args: 356 | image_paths (list[str]): List of paths to MRC files. 357 | 358 | Methods: 359 | build_index_map: Builds a map of indices to file paths and file indices. 360 | getitem: Returns a at the given global index. 361 | __getitem__: Returns tensor of the MRC file at the given index. 362 | """ 363 | 364 | def __init__(self, image_paths: List[str]): 365 | super().__init__() 366 | self.paths = image_paths 367 | self._num_paths = len(image_paths) 368 | self._index_map = None 369 | 370 | def __len__(self): 371 | return self._num_paths 372 | 373 | def __getitem__(self, idx): 374 | return idx, mrc_to_tensor(self.paths[idx]) 375 | 376 | def _extract_num_particles(self, path): 377 | future_mrc = mrcfile.open_async(path) 378 | mrc = future_mrc.result() 379 | data_shape = mrc.data.shape 380 | # img_stack = mrc.is_image_stack() 381 | num_images = data_shape[0] if len(data_shape) > 2 else 1 382 | return num_images 383 | 384 | def build_index_map(self): 385 | """ 386 | Builds a map of image indices to file paths and file indices. 387 | """ 388 | if self._index_map is not None: 389 | print("Index map already built.") 390 | return 391 | 392 | self._path_index = [] 393 | self._file_index = [] 394 | print("Initalizing indexing...") 395 | for idx, path in tqdm(enumerate(self.paths), total=self._num_paths): 396 | num_images = self._extract_num_particles(path) 397 | self._path_index += [idx] * num_images 398 | self._file_index += list(range(num_images)) 399 | self._index_map = True 400 | 401 | def save_index_map(self, path: str): 402 | """ 403 | Saves the index map to a file. 404 | 405 | Args: 406 | path (str): Path to save the index map. 407 | """ 408 | assert ( 409 | self._index_map is not None 410 | ), "Index map not built. First call build_index_map()" 411 | np.savez( 412 | path, 413 | path_index=self._path_index, 414 | file_index=self._file_index, 415 | paths=self.paths, 416 | ) 417 | 418 | def load_index_map(self, path: str): 419 | """ 420 | Loads the index map from a file. 421 | 422 | Args: 423 | path (str): Path to load the index map. 424 | """ 425 | index_map = np.load(path) 426 | assert len(self.paths) == len(index_map["paths"]), "Number of paths do not match the index map." 427 | for path1, path2 in zip(self.paths, index_map["paths"]): 428 | assert path1 == path2, "Paths do not match the index map." 429 | self._path_index = index_map["path_index"] 430 | self._file_index = index_map["file_index"] 431 | self._index_map = True 432 | 433 | def get_image(self, idx: Union[int, list]): 434 | """ 435 | Returns the image at the given global index. 436 | 437 | Args: 438 | idx (int, List): Global index of the image. 439 | """ 440 | assert ( 441 | self._index_map is not None 442 | ), "Index map not built. First call build_index_map() or load_index_map()" 443 | if isinstance(idx, int): 444 | image = mrc_to_tensor(self.paths[self._path_index[idx]]) 445 | if image.ndim > 2: 446 | return image[self._file_index[idx]] 447 | if isinstance(idx, (list, np.ndarray, torch.Tensor)): 448 | return [ 449 | mrc_to_tensor(self.paths[self._path_index[i]])[self._file_index[i]] 450 | for i in idx 451 | ] 452 | 453 | 454 | class MRCloader(torch.utils.data.DataLoader): 455 | """ 456 | Creates a dataloader of MRC files. 457 | 458 | Args: 459 | image_paths (list[str]): List of paths to MRC files. 460 | **kwargs: Keyword arguments passed to torch.utils.data.DataLoader. 461 | """ 462 | 463 | def __init__(self, image_paths: List[str], **kwargs): 464 | super().__init__(MRCdataset(image_paths), batch_size=None, **kwargs) 465 | -------------------------------------------------------------------------------- /src/cryo_sbi/utils/micrograph_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Union, List 3 | from cryo_sbi.utils.image_utils import mrc_to_tensor 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torchvision.transforms.functional as TF 7 | 8 | 9 | class RandomMicrographPatches: 10 | """ 11 | Iterator that returns random patches from a list of micrographs. 12 | 13 | Args: 14 | micro_graphs (List[Union[str, torch.Tensor]]): List of micrographs. 15 | transform (Union[None, transforms.Compose]): Transform to apply to the patches. 16 | patch_size (int): Size of the patches. 17 | batch_size (int, optional): Batch size. Defaults to 1. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | micro_graphs: List[Union[str, torch.Tensor]], 23 | transform: Union[None, transforms.Compose], 24 | patch_size: int, 25 | max_iter: Optional[int] = 1000, 26 | ) -> None: 27 | if all(map(isinstance, micro_graphs, [str] * len(micro_graphs))): 28 | self._micro_graphs = [mrc_to_tensor(path) for path in micro_graphs] 29 | else: 30 | self._micro_graphs = micro_graphs 31 | 32 | self._transform = transform 33 | self._patch_size = patch_size 34 | self._max_iter = max_iter 35 | self._current_iter = 0 36 | 37 | def __iter__(self) -> "RandomMicrographPatches": 38 | return self 39 | 40 | def __next__(self) -> torch.Tensor: 41 | if self._current_iter == self._max_iter: 42 | self._current_iter = 0 43 | raise StopIteration 44 | random_micrograph = random.choice(self._micro_graphs) 45 | assert random_micrograph.ndim == 2, "Micrograph should be 2D" 46 | x = random.randint(0, random_micrograph.shape[0] - self._patch_size) 47 | y = random.randint(0, random_micrograph.shape[1] - self._patch_size) 48 | patch = TF.crop( 49 | random_micrograph, 50 | top=y, 51 | left=x, 52 | height=self._patch_size, 53 | width=self._patch_size, 54 | ) 55 | if self._transform is not None: 56 | patch = self._transform(patch) 57 | else: 58 | patch = patch.unsqueeze(0) 59 | self._current_iter += 1 60 | return patch 61 | 62 | def __len__(self) -> int: 63 | return self._max_iter 64 | 65 | @property 66 | def shape(self) -> torch.Size: 67 | """ 68 | Shape of the transformed patches. 69 | 70 | Returns: 71 | torch.Size: Shape of the transformed patches. 72 | """ 73 | return self.__next__().shape 74 | 75 | 76 | def compute_average_psd( 77 | images: Union[torch.Tensor, RandomMicrographPatches], 78 | device: str = "cpu", 79 | ) -> torch.Tensor: 80 | """ 81 | Compute the average PSD of a set of images. 82 | 83 | Args: 84 | images (Union[torch.Tensor, RandomMicrographPatches]): Images to compute the average PSD of. 85 | device (str, optional): Device to compute the PSD on. Defaults to "cpu". 86 | 87 | Returns: 88 | avg_psd (torch.Tensor): Average PSD of the images. 89 | """ 90 | 91 | if isinstance(images, RandomMicrographPatches): 92 | avg_psd = torch.zeros(images.shape[1:], device=device) 93 | for image in images: # TODO add progress bar with tqdm 94 | fft_image = torch.fft.fft2(image[0].to(device, non_blocking=True)) 95 | psd = torch.abs(fft_image) ** 2 96 | avg_psd += psd / len(images) 97 | # add convergence check 98 | elif isinstance(images, torch.Tensor): 99 | fft_images = torch.fft.fft2(images.to(device=device), dim=(-2, -1)) 100 | avg_psd = torch.mean(torch.abs(fft_images) ** 2, dim=0) 101 | return avg_psd.cpu() 102 | -------------------------------------------------------------------------------- /src/cryo_sbi/utils/visualize_models.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def _scatter_plot_models( 7 | model: torch.Tensor, view_angles: tuple = (30, 45), **plot_kwargs: dict 8 | ) -> None: 9 | fig = plt.figure() 10 | ax = fig.add_subplot(111, projection="3d") 11 | ax.view_init(*view_angles) 12 | 13 | ax.scatter(*model, **plot_kwargs) 14 | 15 | ax.set_xlabel("X") 16 | ax.set_ylabel("Y") 17 | ax.set_zlabel("Z") 18 | 19 | 20 | def _sphere_plot_models( 21 | model: torch.Tensor, 22 | radius: float = 4, 23 | view_angles: tuple = (30, 45), 24 | **plot_kwargs: dict, 25 | ) -> None: 26 | fig = plt.figure() 27 | ax = fig.add_subplot(111, projection="3d") 28 | ax.view_init(30, 45) 29 | 30 | spheres = [] 31 | for x, y, z in zip(model[0], model[1], model[2]): 32 | spheres.append((x.item(), y.item(), z.item(), radius)) 33 | 34 | for idx, sphere in enumerate(spheres): 35 | x, y, z, r = sphere 36 | 37 | u = np.linspace(0, 2 * np.pi, 100) 38 | v = np.linspace(0, np.pi, 100) 39 | x = r * np.outer(np.cos(u), np.sin(v)) + x 40 | y = r * np.outer(np.sin(u), np.sin(v)) + y 41 | z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + z 42 | 43 | ax.plot_surface(x, y, z, **plot_kwargs) 44 | 45 | ax.set_xlabel("X") 46 | ax.set_ylabel("Y") 47 | ax.set_zlabel("Z") 48 | 49 | 50 | def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None: 51 | """ 52 | Plot a model from the tensor. 53 | 54 | Args: 55 | model (torch.Tensor): Model to plot, should be a 2D tensor with shape (3, num_atoms) 56 | method (str, optional): Method to use for plotting. Defaults to "scatter". Can be "scatter" or "sphere". 57 | "scatter" is fast and simple, "sphere" is a proper 3D representation (Take long to render). 58 | **kwargs: Additional keyword arguments to pass to the plotting function. 59 | 60 | Returns: 61 | None 62 | 63 | Raises: 64 | AssertionError: If the model is not a 2D tensor with shape (3, num_atoms). 65 | ValueError: If the method is not "scatter" or "sphere". 66 | 67 | """ 68 | 69 | assert model.ndim == 2, "Model should be 2D tensor" 70 | assert model.shape[0] == 3, "Model should have 3 rows" 71 | 72 | if method == "scatter": 73 | _scatter_plot_models(model, **kwargs) 74 | 75 | elif method == "sphere": 76 | _sphere_plot_models(model, **kwargs) 77 | 78 | else: 79 | raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.") 80 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/cryo_em_simulator.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable 2 | import json 3 | import numpy as np 4 | import torch 5 | 6 | from cryo_sbi.wpa_simulator.ctf import apply_ctf 7 | from cryo_sbi.wpa_simulator.image_generation import project_density 8 | from cryo_sbi.wpa_simulator.noise import add_noise 9 | from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image 10 | from cryo_sbi.inference.priors import get_image_priors 11 | from cryo_sbi.wpa_simulator.validate_image_config import check_image_params 12 | 13 | 14 | def cryo_em_simulator( 15 | models, 16 | index, 17 | quaternion, 18 | sigma, 19 | shift, 20 | defocus, 21 | b_factor, 22 | amp, 23 | snr, 24 | num_pixels, 25 | pixel_size, 26 | ): 27 | """ 28 | Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models. 29 | 30 | Args: 31 | models (torch.Tensor): A tensor of coars grained models (num_models, 3, num_beads). 32 | index (torch.Tensor): A tensor of indices to select the models to simulate. 33 | quaternion (torch.Tensor): A tensor of quaternions to rotate the models. 34 | sigma (float): The standard deviation of the Gaussian kernel used to project the density. 35 | shift (torch.Tensor): A tensor of shifts to apply to the models. 36 | defocus (float): The defocus value of the contrast transfer function (CTF). 37 | b_factor (float): The B-factor of the CTF. 38 | amp (float): The amplitude contrast of the CTF. 39 | snr (float): The signal-to-noise ratio of the simulated image. 40 | num_pixels (int): The number of pixels in the simulated image. 41 | pixel_size (float): The size of each pixel in the simulated image. 42 | 43 | Returns: 44 | torch.Tensor: A tensor of the simulated cryo-EM image. 45 | """ 46 | models_selected = models[index.round().long().flatten()] 47 | image = project_density( 48 | models_selected, 49 | quaternion, 50 | sigma, 51 | shift, 52 | num_pixels, 53 | pixel_size, 54 | ) 55 | image = apply_ctf(image, defocus, b_factor, amp, pixel_size) 56 | image = add_noise(image, snr) 57 | image = gaussian_normalize_image(image) 58 | return image 59 | 60 | 61 | class CryoEmSimulator: 62 | def __init__(self, config_fname: str, device: str = "cpu"): 63 | self._device = device 64 | self._load_params(config_fname) 65 | self._load_models() 66 | self._priors = get_image_priors(self.max_index, self._config, device=device) 67 | self._num_pixels = torch.tensor( 68 | self._config["N_PIXELS"], dtype=torch.float32, device=device 69 | ) 70 | self._pixel_size = torch.tensor( 71 | self._config["PIXEL_SIZE"], dtype=torch.float32, device=device 72 | ) 73 | 74 | def _load_params(self, config_fname: str) -> None: 75 | """ 76 | Loads the parameters from the config file into a dictionary. 77 | 78 | Args: 79 | config_fname (str): Path to the configuration file. 80 | 81 | Returns: 82 | None 83 | """ 84 | 85 | config = json.load(open(config_fname)) 86 | check_image_params(config) 87 | self._config = config 88 | 89 | def _load_models(self) -> None: 90 | """ 91 | Loads the models from the model file specified in the config file. 92 | 93 | Returns: 94 | None 95 | 96 | """ 97 | if self._config["MODEL_FILE"].endswith("npy"): 98 | models = ( 99 | torch.from_numpy( 100 | np.load(self._config["MODEL_FILE"]), 101 | ) 102 | .to(self._device) 103 | .to(torch.float32) 104 | ) 105 | elif self._config["MODEL_FILE"].endswith("pt"): 106 | models = ( 107 | torch.load(self._config["MODEL_FILE"]) 108 | .to(self._device) 109 | .to(torch.float32) 110 | ) 111 | 112 | else: 113 | raise NotImplementedError( 114 | "Model file format not supported. Please use .npy or .pt." 115 | ) 116 | 117 | self._models = models 118 | 119 | assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)." 120 | assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)." 121 | 122 | @property 123 | def max_index(self) -> int: 124 | """ 125 | Returns the maximum index of the model file. 126 | 127 | Returns: 128 | int: Maximum index of the model file. 129 | """ 130 | return len(self._models) - 1 131 | 132 | def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None): 133 | """ 134 | Simulate cryo-EM images using the specified models and prior distributions. 135 | 136 | Args: 137 | num_sim (int): The number of images to simulate. 138 | indices (torch.Tensor, optional): The indices of the images to simulate. If None, all images are simulated. 139 | return_parameters (bool, optional): Whether to return the sampled parameters used for simulation. 140 | batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch. 141 | 142 | Returns: 143 | torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels), 144 | and optionally the sampled parameters as a tuple of tensors. 145 | """ 146 | 147 | parameters = self._priors.sample((num_sim,)) 148 | indices = parameters[0] if indices is None else indices 149 | if indices is not None: 150 | assert isinstance( 151 | indices, torch.Tensor 152 | ), "Indices are not a torch.tensor, converting to torch.tensor." 153 | assert ( 154 | indices.dtype == torch.float32 155 | ), "Indices are not a torch.float32, converting to torch.float32." 156 | assert ( 157 | indices.ndim == 2 158 | ), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)." 159 | parameters[0] = indices 160 | 161 | images = [] 162 | if batch_size is None: 163 | batch_size = num_sim 164 | for i in range(0, num_sim, batch_size): 165 | batch_indices = indices[i : i + batch_size] 166 | batch_parameters = [param[i : i + batch_size] for param in parameters[1:]] 167 | batch_images = cryo_em_simulator( 168 | self._models, 169 | batch_indices, 170 | *batch_parameters, 171 | self._num_pixels, 172 | self._pixel_size, 173 | ) 174 | images.append(batch_images.cpu()) 175 | 176 | images = torch.cat(images, dim=0) 177 | 178 | if return_parameters: 179 | return images.cpu(), parameters 180 | else: 181 | return images.cpu() 182 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/ctf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.Tensor: 6 | """ 7 | Applies the CTF to the image. 8 | 9 | Args: 10 | image (torch.Tensor): The image to apply the CTF to. 11 | defocus (torch.Tensor): The defocus value. 12 | b_factor (torch.Tensor): The B-factor value. 13 | amp (torch.Tensor): The amplitude value. 14 | pixel_size (torch.Tensor): The pixel size value. 15 | 16 | Returns: 17 | torch.Tensor: The image with the CTF applied. 18 | """ 19 | 20 | num_batch, num_pixels, _ = image.shape 21 | freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device) 22 | x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij") 23 | 24 | freq2_2d = x**2 + y**2 25 | freq2_2d = freq2_2d.expand(num_batch, -1, -1) 26 | imag = torch.zeros_like(freq2_2d, device=image.device) * 1j 27 | 28 | env = torch.exp(-b_factor * freq2_2d * 0.5) 29 | phase = defocus * torch.pi * 2.0 * 10000 * 0.019866 # hardcoded 0.019866 for 300kV 30 | 31 | ctf = ( 32 | -amp * torch.cos(phase * freq2_2d * 0.5) 33 | - torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5) 34 | + imag 35 | ) 36 | ctf = ctf * env / amp 37 | 38 | conv_image_ctf = torch.fft.fft2(image) * ctf 39 | image_ctf = torch.fft.ifft2(conv_image_ctf).real 40 | 41 | return image_ctf 42 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/image_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def gen_quat() -> torch.Tensor: 6 | """ 7 | Generate a random quaternion. 8 | 9 | Returns: 10 | quat (np.ndarray): Random quaternion 11 | 12 | """ 13 | count = 0 14 | while count < 1: 15 | quat = 2 * torch.rand(size=(4,)) - 1 16 | norm = torch.sqrt(torch.sum(quat**2)) 17 | if 0.2 <= norm <= 1.0: 18 | quat /= norm 19 | count += 1 20 | 21 | return quat 22 | 23 | 24 | def gen_rot_matrix(quats: torch.Tensor) -> torch.Tensor: 25 | # TODO add docstring explaining the quaternion convention qr, qx, qy, qz 26 | """ 27 | Generate a rotation matrix from a quaternion. 28 | 29 | Args: 30 | quat (torch.Tensor): Quaternion (n_batch, 4) 31 | 32 | Returns: 33 | rot_matrix (torch.Tensor): Rotation matrix 34 | """ 35 | 36 | rot_matrix = torch.zeros((quats.shape[0], 3, 3), device=quats.device) 37 | 38 | rot_matrix[:, 0, 0] = 1 - 2 * (quats[:, 2] ** 2 + quats[:, 3] ** 2) 39 | rot_matrix[:, 0, 1] = 2 * (quats[:, 1] * quats[:, 2] - quats[:, 3] * quats[:, 0]) 40 | rot_matrix[:, 0, 2] = 2 * (quats[:, 1] * quats[:, 3] + quats[:, 2] * quats[:, 0]) 41 | 42 | rot_matrix[:, 1, 0] = 2 * (quats[:, 1] * quats[:, 2] + quats[:, 3] * quats[:, 0]) 43 | rot_matrix[:, 1, 1] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 3] ** 2) 44 | rot_matrix[:, 1, 2] = 2 * (quats[:, 2] * quats[:, 3] - quats[:, 1] * quats[:, 0]) 45 | 46 | rot_matrix[:, 2, 0] = 2 * (quats[:, 1] * quats[:, 3] - quats[:, 2] * quats[:, 0]) 47 | rot_matrix[:, 2, 1] = 2 * (quats[:, 2] * quats[:, 3] + quats[:, 1] * quats[:, 0]) 48 | rot_matrix[:, 2, 2] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 2] ** 2) 49 | 50 | return rot_matrix 51 | 52 | 53 | def project_density( 54 | coords: torch.Tensor, 55 | quats: torch.Tensor, 56 | sigma: torch.Tensor, 57 | shift: torch.Tensor, 58 | num_pixels: int, 59 | pixel_size: float, 60 | ) -> torch.Tensor: 61 | """ 62 | Generate a 2D projections from a set of coordinates. 63 | 64 | Args: 65 | coords (torch.Tensor): Coordinates of the atoms in the images 66 | sigma (float): Standard deviation of the Gaussian function used to model electron density. 67 | num_pixels (int): Number of pixels along one image size. 68 | pixel_size (float): Pixel size in Angstrom 69 | 70 | Returns: 71 | image (torch.Tensor): Images generated from the coordinates 72 | """ 73 | 74 | num_batch, _, num_atoms = coords.shape 75 | norm = 1 / (2 * torch.pi * sigma**2 * num_atoms) 76 | 77 | grid_min = -pixel_size * num_pixels * 0.5 78 | grid_max = pixel_size * num_pixels * 0.5 79 | 80 | rot_matrix = gen_rot_matrix(quats) 81 | grid = torch.arange(grid_min, grid_max, pixel_size, device=coords.device)[ 82 | 0 : num_pixels.long() 83 | ].repeat( 84 | num_batch, 1 85 | ) # [0: num_pixels.long()] is needed due to single precision error in some cases 86 | 87 | coords_rot = torch.bmm(rot_matrix, coords) 88 | coords_rot[:, :2, :] += shift.unsqueeze(-1) 89 | 90 | gauss_x = torch.exp_( 91 | -0.5 * (((grid.unsqueeze(-1) - coords_rot[:, 0, :].unsqueeze(1)) / sigma) ** 2) 92 | ) 93 | gauss_y = torch.exp_( 94 | -0.5 * (((grid.unsqueeze(-1) - coords_rot[:, 1, :].unsqueeze(1)) / sigma) ** 2) 95 | ).transpose(1, 2) 96 | 97 | image = torch.bmm(gauss_x, gauss_y) * norm.reshape(-1, 1, 1) 98 | 99 | return image 100 | 101 | 102 | '''def project_density( 103 | atomic_model: torch.Tensor, 104 | quats: torch.Tensor, 105 | delta_sigma: torch.Tensor, 106 | shift: torch.Tensor, 107 | num_pixels: int, 108 | pixel_size: float, 109 | ) -> torch.Tensor: 110 | """ 111 | Generate a 2D projections from a set of coordinates. 112 | 113 | Args: 114 | atomic_model (torch.Tensor): Coordinates of the atoms in the images 115 | res (float): resolution of the images in Angstrom 116 | num_pixels (int): Number of pixels along one image size. 117 | pixel_size (float): Pixel size in Angstrom 118 | 119 | Returns: 120 | image (torch.Tensor): Images generated from the coordinates 121 | """ 122 | 123 | num_batch, _, num_atoms = atomic_model.shape 124 | 125 | variances = atomic_model[:, 4, :] * delta_sigma[:, 0] 126 | amplitudes = atomic_model[:, 3, :] / torch.sqrt((2 * torch.pi * variances)) 127 | 128 | grid_min = -pixel_size * num_pixels * 0.5 129 | grid_max = pixel_size * num_pixels * 0.5 130 | 131 | rot_matrix = gen_rot_matrix(quats) 132 | grid = torch.arange(grid_min, grid_max, pixel_size, device=atomic_model.device)[ 133 | 0 : num_pixels.long() 134 | ].repeat( 135 | num_batch, 1 136 | ) # [0: num_pixels.long()] is needed due to single precision error in some cases 137 | 138 | coords_rot = torch.bmm(rot_matrix, atomic_model[:, :3, :]) 139 | coords_rot[:, :2, :] += shift.unsqueeze(-1) 140 | 141 | gauss_x = torch.exp_( 142 | -((grid.unsqueeze(-1) - coords_rot[:, 0, :].unsqueeze(1)) ** 2) 143 | / variances.unsqueeze(1) 144 | ) * amplitudes.unsqueeze(1) 145 | 146 | gauss_y = torch.exp( 147 | -((grid.unsqueeze(-1) - coords_rot[:, 1, :].unsqueeze(1)) ** 2) 148 | / variances.unsqueeze(1) 149 | ) * amplitudes.unsqueeze(1) 150 | 151 | image = torch.bmm(gauss_x, gauss_y.transpose(1, 2)) # * norms 152 | image /= torch.norm(image, dim=[-2, -1]).reshape(-1, 1, 1) # do we need this normalization? 153 | 154 | return image''' 155 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/noise.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tensor: 7 | """ 8 | Creates a circular mask of radius RADIUS_MASK centered in the image 9 | 10 | Args: 11 | n_pixels (int): Number of pixels along image side. 12 | radius (int): Radius of the mask. 13 | 14 | Returns: 15 | mask (torch.Tensor): Mask of shape (n_pixels, n_pixels). 16 | """ 17 | 18 | grid = torch.linspace( 19 | -0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device 20 | ) 21 | r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 22 | mask = r_2d < radius**2 23 | 24 | return mask 25 | 26 | 27 | def get_snr(images, snr): 28 | """ 29 | Computes the SNR of the images 30 | """ 31 | mask = circular_mask( 32 | n_pixels=images.shape[-1], 33 | radius=images.shape[-1] // 2, # TODO: make this a parameter 34 | device=images.device, 35 | ) 36 | signal_power = torch.std( 37 | images[:, mask], dim=[-1] 38 | ) # images are not centered at 0, so std is not the same as power 39 | assert signal_power.shape[0] == images.shape[0] 40 | noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt( 41 | torch.pow(torch.tensor(10), snr) 42 | ) 43 | 44 | return noise_power 45 | 46 | 47 | def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor: 48 | """ 49 | Adds noise to image. 50 | 51 | Args: 52 | image (torch.Tensor): Image of shape (n_pixels, n_pixels). 53 | image_params (dict): Dictionary with image parameters. 54 | seed (int, optional): Seed for random number generator. Defaults to None. 55 | 56 | Returns: 57 | image_noise (torch.Tensor): Image with noise of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 58 | """ 59 | 60 | if seed is not None: 61 | torch.manual_seed(seed) 62 | 63 | noise_power = get_snr(image, snr) 64 | noise = torch.randn_like(image, device=image.device) 65 | 66 | noise = noise * noise_power.reshape(-1, 1, 1) 67 | 68 | image_noise = image + noise 69 | 70 | return image_noise 71 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | 4 | 5 | def gaussian_normalize_image(images: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Normalize an images by subtracting the mean and dividing by the standard deviation. 8 | 9 | Args: 10 | image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). 11 | 12 | Returns: 13 | normalized (torch.Tensor): Normalized image. 14 | """ 15 | 16 | mean = images.mean(dim=[1, 2]) 17 | std = images.std(dim=[1, 2]) 18 | 19 | return transforms.functional.normalize(images, mean=mean, std=std) 20 | -------------------------------------------------------------------------------- /src/cryo_sbi/wpa_simulator/validate_image_config.py: -------------------------------------------------------------------------------- 1 | def check_image_params(config: dict) -> None: 2 | """ 3 | Checks if all necessary parameters are provided. 4 | 5 | Args: 6 | config (dict): Dictionary containing image parameters. 7 | 8 | Returns: 9 | None 10 | """ 11 | 12 | needed_keys = [ 13 | "N_PIXELS", 14 | "PIXEL_SIZE", 15 | "SIGMA", 16 | "SHIFT", 17 | "DEFOCUS", 18 | "SNR", 19 | "MODEL_FILE", 20 | "AMP", 21 | "B_FACTOR", 22 | ] 23 | 24 | for key in needed_keys: 25 | assert key in config.keys(), f"Please provide a value for {key}" 26 | 27 | return None 28 | -------------------------------------------------------------------------------- /tests/config_files/image_params_testing.json: -------------------------------------------------------------------------------- 1 | { 2 | "N_PIXELS": 64, 3 | "PIXEL_SIZE": 2.06, 4 | "SIGMA": [0.5, 5.0], 5 | "MODEL_FILE": "tests/models/hsp90_models.pt", 6 | "SHIFT": 20.0, 7 | "DEFOCUS": [1.5, 3.5], 8 | "SNR": [0.05, 0.05], 9 | "AMP": 0.1, 10 | "B_FACTOR": [1.0, 100.0] 11 | } 12 | -------------------------------------------------------------------------------- /tests/config_files/training_params_npe_testing.json: -------------------------------------------------------------------------------- 1 | { 2 | "EMBEDDING": "RESNET18", 3 | "OUT_DIM": 256, 4 | "NUM_TRANSFORM": 5, 5 | "NUM_HIDDEN_FLOW": 10, 6 | "HIDDEN_DIM_FLOW": 256, 7 | "MODEL": "NSF", 8 | "LEARNING_RATE": 0.0003, 9 | "CLIP_GRADIENT": 5.0, 10 | "THETA_SHIFT": 25, 11 | "THETA_SCALE": 25, 12 | "BATCH_SIZE": 256 13 | } 14 | -------------------------------------------------------------------------------- /tests/data/test.mrc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/tests/data/test.mrc -------------------------------------------------------------------------------- /tests/models/hsp90_models.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flatironinstitute/cryoSBI/d145cf21c1d5991361f028db698b3bc8931e0fdf/tests/models/hsp90_models.pt -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from itertools import product 4 | 5 | from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS 6 | 7 | embedding_networks = list(EMBEDDING_NETS.keys()) 8 | num_images_to_test = [1, 5] 9 | out_dims_to_test = [10, 100] 10 | cases_to_test = list(product(embedding_networks, num_images_to_test, out_dims_to_test)) 11 | 12 | 13 | @pytest.mark.parametrize(("embedding_name", "num_images", "out_dim"), cases_to_test) 14 | def test_embedding(embedding_name, num_images, out_dim): 15 | 16 | if "FFT_FILTER_" in embedding_name: 17 | size = embedding_name.split("FFT_FILTER_")[1] 18 | test_images = torch.randn(num_images, int(size), int(size)) 19 | elif "Tutorial" in embedding_name: 20 | test_images = torch.randn(num_images, 64, 64) 21 | else: 22 | test_images = torch.randn(num_images, 128, 128) 23 | 24 | embedding = EMBEDDING_NETS[embedding_name](out_dim) 25 | out = embedding(test_images).shape 26 | assert out == torch.Size([num_images, out_dim]), embedding_name 27 | -------------------------------------------------------------------------------- /tests/test_estimator_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import torch 4 | import numpy as np 5 | import json 6 | 7 | from cryo_sbi.inference.models import build_models 8 | from cryo_sbi.inference.models.estimator_models import NPEWithEmbedding 9 | from cryo_sbi.inference.validate_train_config import check_train_params 10 | from cryo_sbi.utils.estimator_utils import ( 11 | sample_posterior, 12 | compute_latent_repr, 13 | evaluate_log_prob, 14 | load_estimator, 15 | ) 16 | 17 | 18 | @pytest.fixture 19 | def train_params(): 20 | config = json.load(open("tests/config_files/training_params_npe_testing.json")) 21 | check_train_params(config) 22 | return config 23 | 24 | 25 | @pytest.fixture 26 | def train_config_path(): 27 | return "tests/config_files/training_params_npe_testing.json" 28 | 29 | 30 | @pytest.mark.parametrize( 31 | ("num_images", "num_samples", "batch_size"), 32 | [(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)], 33 | ) 34 | def test_sampling(train_params, num_images, num_samples, batch_size): 35 | estimator = build_models.build_npe_flow_model(train_params) 36 | estimator.eval() 37 | images = torch.randn((num_images, 128, 128)) 38 | samples = sample_posterior( 39 | estimator, images, num_samples=num_samples, batch_size=batch_size 40 | ) 41 | assert samples.shape == torch.Size( 42 | [num_samples, num_images] 43 | ), f"Failed with: num_images: {num_images}, num_samles:{num_samples}, batch_size:{batch_size}" 44 | 45 | 46 | @pytest.mark.parametrize( 47 | ("num_images", "batch_size"), [(1, 1), (2, 2), (1, 5), (100, 10)] 48 | ) 49 | def test_latent_extraction(train_params, num_images, batch_size): 50 | estimator = build_models.build_npe_flow_model(train_params) 51 | estimator.eval() 52 | 53 | latent_dim = train_params["OUT_DIM"] 54 | images = torch.randn((num_images, 128, 128)) 55 | samples = compute_latent_repr(estimator, images, batch_size=batch_size) 56 | assert samples.shape == torch.Size( 57 | [num_images, latent_dim] 58 | ), f"Failed with: num_images: {num_images}, batch_size:{batch_size}" 59 | 60 | 61 | @pytest.mark.parametrize( 62 | ("num_images", "num_eval", "batch_size"), 63 | [(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)], 64 | ) 65 | def test_logprob_eval(train_params, num_images, num_eval, batch_size): 66 | estimator = build_models.build_npe_flow_model(train_params) 67 | estimator.eval() 68 | images = torch.randn((num_images, 128, 128)) 69 | theta = torch.linspace(0, 25, num_eval) 70 | samples = evaluate_log_prob(estimator, images, theta, batch_size=batch_size) 71 | assert samples.shape == torch.Size( 72 | [num_eval, num_images] 73 | ), f"Failed with: num_images: {num_images}, num_eval:{num_eval}, batch_size:{batch_size}" 74 | 75 | 76 | def test_load_estimator(train_params, train_config_path): 77 | estimator = build_models.build_npe_flow_model(train_params) 78 | torch.save(estimator.state_dict(), "tests/config_files/test_estimator.estimator") 79 | estimator = load_estimator( 80 | train_config_path, "tests/config_files/test_estimator.estimator" 81 | ) 82 | assert isinstance(estimator, NPEWithEmbedding) 83 | os.remove("tests/config_files/test_estimator.estimator") 84 | -------------------------------------------------------------------------------- /tests/test_image_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from cryo_sbi.utils import image_utils as iu 4 | 5 | 6 | def test_circular_mask(): 7 | n_pixels = 100 8 | radius = 30 9 | inside_mask = iu.circular_mask(n_pixels, radius, inside=True) 10 | outside_mask = iu.circular_mask(n_pixels, radius, inside=False) 11 | 12 | assert inside_mask.shape == (n_pixels, n_pixels) 13 | assert outside_mask.shape == (n_pixels, n_pixels) 14 | assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10) 15 | assert outside_mask.sum().item() == pytest.approx( 16 | n_pixels**2 - radius**2 * 3.14159, abs=10 17 | ) 18 | 19 | 20 | def test_mask_class(): 21 | image_size = 100 22 | radius = 30 23 | inside = True 24 | mask = iu.Mask(image_size, radius, inside=inside) 25 | image = torch.ones((image_size, image_size)) 26 | 27 | masked_image = mask(image) 28 | assert masked_image.shape == (image_size, image_size) 29 | assert masked_image[inside].sum().item() == pytest.approx( 30 | image_size**2 - radius**2 * 3.14159, abs=10 31 | ) 32 | 33 | 34 | def test_fourier_down_sample(): 35 | image_size = 100 36 | n_pixels = 30 37 | image = torch.ones((image_size, image_size)) 38 | 39 | downsampled_image = iu.fourier_down_sample(image, image_size, n_pixels) 40 | assert downsampled_image.shape == ( 41 | image_size - 2 * n_pixels, 42 | image_size - 2 * n_pixels, 43 | ) 44 | 45 | 46 | def test_fourier_down_sample_class(): 47 | image_size = 100 48 | down_sampled_size = 40 49 | down_sampler = iu.FourierDownSample(image_size, down_sampled_size) 50 | image = torch.ones((image_size, image_size)) 51 | 52 | down_sampled_image = down_sampler(image) 53 | assert down_sampled_image.shape == ( 54 | image_size - 2 * down_sampler._n_pixels, 55 | image_size - 2 * down_sampler._n_pixels, 56 | ) 57 | 58 | 59 | def test_low_pass_filter(): 60 | image_size = 100 61 | frequency_cutoff = 30 62 | low_pass_filter = iu.LowPassFilter(image_size, frequency_cutoff) 63 | image = torch.ones((image_size, image_size)) 64 | 65 | filtered_image = low_pass_filter(image) 66 | assert filtered_image.shape == (image_size, image_size) 67 | 68 | 69 | def test_gaussian_low_pass_filter(): 70 | image_size = 100 71 | frequency_cutoff = 30 72 | low_pass_filter = iu.GaussianLowPassFilter(image_size, frequency_cutoff) 73 | image = torch.ones((image_size, image_size)) 74 | 75 | filtered_image = low_pass_filter(image) 76 | assert filtered_image.shape == (image_size, image_size) 77 | 78 | 79 | def test_normalize_individual(): 80 | normalize_individual = iu.NormalizeIndividual() 81 | image = torch.randn((3, 100, 100)) 82 | 83 | normalized_image = normalize_individual(image) 84 | assert normalized_image.shape == (3, 100, 100) 85 | assert normalized_image.mean().item() == pytest.approx(0.0, abs=1e-1) 86 | assert normalized_image.std().item() == pytest.approx(1.0, abs=1e-1) 87 | 88 | 89 | def test_mrc_to_tensor(): 90 | image_path = "tests/data/test.mrc" 91 | image = iu.mrc_to_tensor(image_path) 92 | 93 | assert isinstance(image, torch.Tensor) 94 | assert image.shape == (5, 5) 95 | 96 | 97 | def test_image_whithening(): 98 | whitening_transform = iu.WhitenImage(100) 99 | images = torch.randn((1, 100, 100)) 100 | images_whitened = whitening_transform(images) 101 | assert images_whitened.shape == (1, 100, 100) 102 | 103 | 104 | def test_image_whithening_batched(): 105 | whitening_transform = iu.WhitenImage(100) 106 | images = torch.randn((10, 100, 100)) 107 | images_whitened = whitening_transform(images) 108 | assert images_whitened.shape == (10, 100, 100) 109 | -------------------------------------------------------------------------------- /tests/test_micrograph_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torchvision.transforms as transforms 4 | from cryo_sbi.utils.image_utils import NormalizeIndividual 5 | from cryo_sbi.utils import micrograph_utils as mu 6 | 7 | 8 | def test_random_micrograph_patches_fail(): 9 | micrograph = torch.randn(2, 128, 128) 10 | random_patches = mu.RandomMicrographPatches( 11 | micro_graphs=[micrograph], patch_size=10, transform=None 12 | ) 13 | with pytest.raises(AssertionError): 14 | patch = next(random_patches) 15 | 16 | 17 | @pytest.mark.parametrize( 18 | ("micrograph_size", "patch_size", "max_iter"), [(128, 12, 100), (128, 90, 1000)] 19 | ) 20 | def test_random_micrograph_patches(micrograph_size, patch_size, max_iter): 21 | micrograph = torch.randn(micrograph_size, micrograph_size) 22 | random_patches = mu.RandomMicrographPatches( 23 | micro_graphs=[micrograph], 24 | patch_size=patch_size, 25 | transform=None, 26 | max_iter=max_iter, 27 | ) 28 | patch = next(random_patches) 29 | assert patch.shape == torch.Size([1, patch_size, patch_size]) 30 | assert len(random_patches) == max_iter 31 | 32 | 33 | def test_compute_average_psd(): 34 | micrograph = torch.randn(128, 128) 35 | transform = transforms.Compose([NormalizeIndividual()]) 36 | random_patches = mu.RandomMicrographPatches( 37 | micro_graphs=[micrograph], patch_size=10, transform=transform, max_iter=100 38 | ) 39 | avg_psd = mu.compute_average_psd(random_patches) 40 | assert avg_psd.shape == torch.Size([10, 10]) 41 | -------------------------------------------------------------------------------- /tests/test_posterior_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | import torch 4 | from cryo_sbi.inference.models import build_models 5 | from cryo_sbi.inference.models import estimator_models 6 | from cryo_sbi.inference.validate_train_config import check_train_params 7 | 8 | 9 | @pytest.fixture 10 | def train_params(): 11 | config = json.load(open("tests/config_files/training_params_npe_testing.json")) 12 | check_train_params(config) 13 | return config 14 | 15 | 16 | def test_build_npe_model(train_params): 17 | posterior_model = build_models.build_npe_flow_model(train_params) 18 | assert isinstance(posterior_model, estimator_models.NPEWithEmbedding) 19 | 20 | 21 | @pytest.mark.parametrize( 22 | ("batch_size", "sample_size"), [(1, 1), (2, 10), (5, 1000), (100, 2)] 23 | ) 24 | def test_sample_npe_model(train_params, batch_size, sample_size): 25 | posterior_model = build_models.build_npe_flow_model(train_params) 26 | test_image = torch.randn((batch_size, 128, 128)) 27 | samples = posterior_model.sample(test_image, shape=(sample_size,)) 28 | assert samples.shape == torch.Size([sample_size, batch_size, 1]) 29 | -------------------------------------------------------------------------------- /tests/test_simulator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | import json 5 | 6 | from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator, CryoEmSimulator 7 | from cryo_sbi.wpa_simulator.ctf import apply_ctf 8 | from cryo_sbi.wpa_simulator.image_generation import ( 9 | project_density, 10 | gen_quat, 11 | gen_rot_matrix, 12 | ) 13 | from cryo_sbi.wpa_simulator.noise import add_noise, circular_mask, get_snr 14 | from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image 15 | from cryo_sbi.inference.priors import get_image_priors 16 | 17 | 18 | def test_apply_ctf(): 19 | # Create a test image 20 | image = torch.randn(1, 64, 64) 21 | 22 | # Set test parameters 23 | defocus = torch.tensor([1.0]) 24 | b_factor = torch.tensor([100.0]) 25 | amp = torch.tensor([0.5]) 26 | pixel_size = torch.tensor(1.0) 27 | 28 | # Apply CTF to the test image 29 | image_ctf = apply_ctf(image, defocus, b_factor, amp, pixel_size) 30 | 31 | assert image_ctf.shape == image.shape 32 | assert isinstance(image_ctf, torch.Tensor) 33 | assert not torch.allclose(image_ctf, image) 34 | 35 | 36 | def test_gen_rot_matrix(): 37 | # Create a test quaternion 38 | quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]]) 39 | 40 | # Generate a rotation matrix from the quaternion 41 | rot_matrix = gen_rot_matrix(quat) 42 | 43 | assert rot_matrix.shape == torch.Size([1, 3, 3]) 44 | assert isinstance(rot_matrix, torch.Tensor) 45 | assert torch.allclose(rot_matrix, torch.eye(3).unsqueeze(0)) 46 | 47 | 48 | def test_gen_rot_matrix_batched(): 49 | # Create a test quaternions with batche size 3 50 | quat = torch.tensor( 51 | [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]] 52 | ) 53 | 54 | # Generate a rotation matrix from the quaternion 55 | rot_matrix = gen_rot_matrix(quat) 56 | 57 | assert rot_matrix.shape == torch.Size([3, 3, 3]) 58 | assert isinstance(rot_matrix, torch.Tensor) 59 | assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1)) 60 | 61 | 62 | @pytest.mark.parametrize( 63 | ("noise_std", "num_images"), 64 | [ 65 | (torch.tensor([1.5, 1]), 2), 66 | (torch.tensor([1.0, 2.0, 3.0]), 3), 67 | (torch.tensor([0.1]), 10), 68 | ], 69 | ) 70 | def test_get_snr(noise_std, num_images): 71 | # Create a test image 72 | images = noise_std.reshape(-1, 1, 1) * torch.randn(num_images, 128, 128) 73 | 74 | # Compute the SNR of the test image 75 | snr = get_snr(images, torch.tensor([0.0])) 76 | 77 | assert snr.shape == torch.Size([images.shape[0], 1, 1]), "SNR has wrong shape" 78 | assert isinstance(snr, torch.Tensor) 79 | assert torch.allclose( 80 | snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01 81 | ), "SNR is not correct" 82 | 83 | 84 | @pytest.mark.parametrize(("num_images"), [1, 5]) 85 | def test_simulator_default_settings(num_images): 86 | sim = CryoEmSimulator("tests/config_files/image_params_testing.json") 87 | images = sim.simulate(num_images) 88 | assert images.shape == torch.Size([num_images, 64, 64]) 89 | 90 | 91 | @pytest.mark.parametrize(("num_images"), [1, 5]) 92 | def test_simulator_custom_indices(num_images): 93 | sim = CryoEmSimulator("tests/config_files/image_params_testing.json") 94 | test_indices = torch.arange(num_images, dtype=torch.float32).reshape(-1, 1) 95 | images, parameters = sim.simulate( 96 | num_images, indices=test_indices, return_parameters=True 97 | ) 98 | 99 | assert (parameters[0] == test_indices).all().item() 100 | assert images.shape == torch.Size([num_images, 64, 64]) 101 | -------------------------------------------------------------------------------- /tests/test_visualize_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from cryo_sbi.utils.visualize_models import plot_model 4 | 5 | 6 | def test_plot_model_scatter(): 7 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 8 | plot_model( 9 | model, method="scatter" 10 | ) # No assertion, just checking if it runs without errors 11 | 12 | 13 | def test_plot_model_sphere(): 14 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 15 | plot_model( 16 | model, method="sphere" 17 | ) # No assertion, just checking if it runs without errors 18 | 19 | 20 | def test_plot_model_invalid_model(): 21 | model = torch.tensor( 22 | [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] 23 | ) # Invalid shape, should have 3 rows 24 | with pytest.raises(AssertionError): 25 | plot_model(model, method="scatter") 26 | 27 | 28 | def test_plot_model_invalid_method(): 29 | model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 30 | with pytest.raises(ValueError): 31 | plot_model(model, method="invalid_method") 32 | -------------------------------------------------------------------------------- /tutorials/.gitignore: -------------------------------------------------------------------------------- 1 | *.h5 2 | *.tp 3 | *.pt -------------------------------------------------------------------------------- /tutorials/simulation_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "N_PIXELS": 64, 3 | "PIXEL_SIZE": 2.5, 4 | "SIGMA": [2.0, 2.0], 5 | "MODEL_FILE": "models.pt", 6 | "SHIFT": 0.0, 7 | "DEFOCUS": [2.0, 2.0], 8 | "SNR": [0.01, 0.5], 9 | "AMP": 0.1, 10 | "B_FACTOR": [1.0, 1.0] 11 | } -------------------------------------------------------------------------------- /tutorials/training_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "EMBEDDING": "ConvEncoder_Tutorial", 3 | "OUT_DIM": 128, 4 | "NUM_TRANSFORM": 3, 5 | "NUM_HIDDEN_FLOW": 3, 6 | "HIDDEN_DIM_FLOW": 128, 7 | "MODEL": "NSF", 8 | "LEARNING_RATE": 0.0003, 9 | "CLIP_GRADIENT": 5.0, 10 | "THETA_SHIFT": 50, 11 | "THETA_SCALE": 50, 12 | "BATCH_SIZE": 32 13 | } 14 | 15 | -------------------------------------------------------------------------------- /tutorials/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from scipy.spatial.transform import Rotation\n", 14 | "\n", 15 | "from cryo_sbi import CryoEmSimulator\n", 16 | "import cryo_sbi.inference.train_npe_model as train_npe_model\n", 17 | "import cryo_sbi.utils.estimator_utils as est_utils\n", 18 | "import cryo_sbi.utils.image_utils as img_utils\n", 19 | "from cryo_sbi.utils.visualize_models import plot_model" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "#### If you want to run the latent space analysis with UMAP, you need to install the following package:\n", 27 | "You can find the installation instructions [here](https://umap-learn.readthedocs.io/en/latest/). \n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# If you installed umap you can import it here \n", 37 | "import umap" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Make the models \n", 45 | "\n", 46 | "Here we are creating a simple molecular model. \n", 47 | "In our model we will have four pseudo atoms arranged in a rectangle. The model differs between the sidelength we are using. The atoms are placed at the corners of the rectangle. The distance between the atoms is the same for all atoms.\n", 48 | "The goal is then to simulate cryo-EM images with these models and infer the distance between the two atoms from the images.\n", 49 | "\n", 50 | "The first step is to create the models. We start by crating an array with the side length `side_length` between the atoms. We will use this array to create the models.\n", 51 | "The models are created by placing pseudo atoms at the corners of the rectangle.\n", 52 | "\n", 53 | "The models are saved into teh file `models.pt`.\n", 54 | "\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "side_lengths = torch.linspace(0, 50, 100)\n", 64 | "\n", 65 | "models = []\n", 66 | "for side_length in side_lengths:\n", 67 | " model = [\n", 68 | " [side_length, -side_length, side_length, -side_length],\n", 69 | " [side_length, side_length, -side_length, -side_length],\n", 70 | " [0.0, 0.0, 0.0, 0.0],\n", 71 | " ]\n", 72 | " models.append(model)\n", 73 | "models = torch.tensor(models)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "##### Visualize the models" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "To visulaize the model in the x-y plane, as we do not have a z-dimension in our model." 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "fig = plt.figure()\n", 97 | "ax = fig.add_subplot()\n", 98 | "for i, c in zip([0, 25, 50, 75, 99], [\"red\", \"orange\", \"green\", \"blue\", \"purple\"]):\n", 99 | " ax.scatter(\n", 100 | " models[i, 0, 0],\n", 101 | " models[i, 1, 0],\n", 102 | " s=60,\n", 103 | " color=c,\n", 104 | " label=f\"Model with side length : {side_lengths[i]:.2f}\",\n", 105 | " )\n", 106 | " ax.scatter(models[i, 0, 1], models[i, 1, 1], s=60, color=c)\n", 107 | " ax.scatter(models[i, 0, 2], models[i, 1, 2], s=60, color=c)\n", 108 | " ax.scatter(models[i, 0, 3], models[i, 1, 3], s=60, color=c)\n", 109 | "\n", 110 | "ax.set_xlabel(\"X\")\n", 111 | "ax.set_ylabel(\"Y\")\n", 112 | "plt.legend()\n", 113 | "\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# substract of center of mass\n", 123 | "for i in range(100):\n", 124 | " models[i] = models[i] - models[i].mean(dim=1, keepdim=True)\n", 125 | "\n", 126 | "torch.save(models, \"models.pt\")" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### Run first simulation\n", 134 | "\n", 135 | "We will now simulate the cryo-EM images with our generated models.\n", 136 | "The simulation is done by the class `CryoEmSimulator`. And the simulation is run by the function `simulate` function.\n", 137 | "The class `CryoEmSimulator` takes as input a config file with the simulation parameters. The config file used here is `simulation_parameters.json`.\n", 138 | "\n", 139 | "The following parameters are used in the simulation:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "```\n", 147 | "simulation_parameters.json\n", 148 | "\n", 149 | "{\n", 150 | " \"N_PIXELS\": 64, --> size of the image\n", 151 | " \"PIXEL_SIZE\": 2.0, --> pixel size in angstroms\n", 152 | " \"SIGMA\": [2.0, 2.0], --> standard deviation of the gaussian\n", 153 | " \"MODEL_FILE\": \"models.pt\", --> file which contains the models\n", 154 | " \"SHIFT\": 0.0, --> shift of model center \n", 155 | " \"DEFOCUS\": [2.0, 2.0], --> defocus range for the simulation\n", 156 | " \"SNR\": [0.01, 0.5], --> signal to noise ratio for the simulation\n", 157 | " \"AMP\": 0.1, --> amplitude for the ctf \n", 158 | " \"B_FACTOR\": [1.0, 1.0] --> b factor for the ctf\n", 159 | "} \n", 160 | "```" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "simulator = CryoEmSimulator(\n", 170 | " \"simulation_parameters.json\"\n", 171 | ") # creating simulator with simulation parameters" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "images, parameters = simulator.simulate(\n", 181 | " num_sim=5000, return_parameters=True\n", 182 | ") # simulating images and save parameters" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "side_length = parameters[0] # extracting side_length from parameters\n", 192 | "snr = parameters[-1] # extracting snr from parameters" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "#### Visualize the simulated images" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "fig, axes = plt.subplots(4, 4, figsize=(6, 6))\n", 209 | "for idx, ax in enumerate(axes.flatten()):\n", 210 | " ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n", 211 | " ax.set_title(\n", 212 | " f\"Side: {side_lengths[side_length[idx].round().long()].item():.2f}\", fontsize=10\n", 213 | " )\n", 214 | " ax.axis(\"off\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "### Train cryoSBI posterior\n", 222 | "\n", 223 | "We will now train the cryoSBI posterior to infer the distance between the atoms from the simulated images.\n", 224 | "The training is done with the function `npe_train_no_saving` which simulates images and simultaneously trains the posterior.\n", 225 | "The function takes as input the config file `training_parameters.json` which contains the training and neural network parameters.\n", 226 | "The function also takes as input the config file `simulation_parameters.json` which contains the simulation parameters used to simulate the images.\n" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "```\n", 234 | "training_parameters.json\n", 235 | "```\n", 236 | "\n", 237 | "```\n", 238 | "{\n", 239 | " \"EMBEDDING\": \"ConvEncoder_Tutorial\", --> embedding network for the images\n", 240 | " \"OUT_DIM\": 128, --> dimension of the embedding\n", 241 | " \"NUM_TRANSFORM\": 5, --> number of transformations\n", 242 | " \"NUM_HIDDEN_FLOW\": 5, --> number of hidden layers in the flow\n", 243 | " \"HIDDEN_DIM_FLOW\": 128, --> dimension of the hidden layers in the flow\n", 244 | " \"MODEL\": \"NSF\", --> type of flow\n", 245 | " \"LEARNING_RATE\": 0.0003, --> learning rate\n", 246 | " \"CLIP_GRADIENT\": 5.0, --> gradient clipping\n", 247 | " \"THETA_SHIFT\": 50, --> shift of the model center\n", 248 | " \"THETA_SCALE\": 50, --> scale of the model\n", 249 | " \"BATCH_SIZE\": 32 --> batch size\n", 250 | "}\n", 251 | "```\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "train_npe_model.npe_train_no_saving(\n", 261 | " \"simulation_parameters.json\",\n", 262 | " \"training_parameters.json\",\n", 263 | " 150,\n", 264 | " \"tutorial_estimator.pt\", # name of the estimator file\n", 265 | " \"tutorial.loss\", # name of the loss file\n", 266 | " n_workers=4, # number of workers for data loading\n", 267 | " device=\"cuda\", # device to use for training and simulation\n", 268 | " saving_frequency=100, # frequency of saving the model\n", 269 | " simulation_batch_size=160, # batch size for simulation\n", 270 | ")" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "##### Visualize the loss after training" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "plt.plot(torch.load(\"tutorial.loss\"))\n", 287 | "plt.xlabel(\"Epoch\")\n", 288 | "plt.ylabel(\"Loss\")" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "### Evaluate the posterior on our simulated images\n", 296 | "\n", 297 | "We will now evaluate the trained posterior on our simulated images.\n", 298 | "For each simulated image we will infer the distance between the atoms and compare it to the true distance, by sampling from the posterior." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "posterior = est_utils.load_estimator(\n", 308 | " \"training_parameters.json\",\n", 309 | " \"tutorial_estimator.pt\",\n", 310 | " device=\"cuda\",\n", 311 | ")" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "samples = est_utils.sample_posterior(\n", 321 | " estimator=posterior,\n", 322 | " images=images,\n", 323 | " num_samples=15000,\n", 324 | " batch_size=1000,\n", 325 | " device=\"cuda\",\n", 326 | ")" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "fig, axes = plt.subplots(4, 4, figsize=(10, 10))\n", 336 | "for idx, ax in enumerate(axes.flatten()):\n", 337 | " ax.hist(samples[:, idx].flatten(), bins=np.linspace(0, simulator.max_index, 60))\n", 338 | " ax.axvline(side_length[idx], ymax=1, ymin=0, color=\"red\")\n", 339 | " ax.set_yticks([])\n", 340 | " ax.set_xticks([])" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "### Plot latent space" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "latent_representations = est_utils.compute_latent_repr(\n", 357 | " estimator=posterior,\n", 358 | " images=images,\n", 359 | " batch_size=1000,\n", 360 | " device=\"cuda\",\n", 361 | ")" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "reducer = umap.UMAP(metric=\"euclidean\", n_components=2, n_neighbors=50)\n", 371 | "latent_vecs_transformed = reducer.fit_transform(latent_representations.numpy())" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "plt.scatter(\n", 381 | " latent_vecs_transformed[:, 0],\n", 382 | " latent_vecs_transformed[:, 1],\n", 383 | " c=side_length,\n", 384 | " cmap=\"viridis\",\n", 385 | " s=10,\n", 386 | ")\n", 387 | "plt.colorbar(label=\"Side length\")\n", 388 | "plt.xlabel(\"UMAP 1\")\n", 389 | "plt.ylabel(\"UMAP 2\")" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "plt.scatter(\n", 399 | " latent_vecs_transformed[:, 0],\n", 400 | " latent_vecs_transformed[:, 1],\n", 401 | " c=snr,\n", 402 | " cmap=\"viridis\",\n", 403 | " s=10,\n", 404 | ")\n", 405 | "plt.colorbar(label=\"SNR\")\n", 406 | "plt.xlabel(\"UMAP 1\")\n", 407 | "plt.ylabel(\"UMAP 2\")" 408 | ] 409 | } 410 | ], 411 | "metadata": { 412 | "kernelspec": { 413 | "display_name": "cryo_sbi", 414 | "language": "python", 415 | "name": "cryo_sbi" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.10.0" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 2 432 | } 433 | --------------------------------------------------------------------------------