├── voxelium ├── vae_volume │ ├── __init__.py │ ├── svr_linear │ │ ├── __init__.py │ │ ├── volume_extraction.h │ │ ├── trilinear_projection.h │ │ ├── volume_extraction_cuda_kernels.h │ │ ├── volume_extraction_cpu_kernels.h │ │ ├── trilinear_projection_cuda_kernels.h │ │ ├── trilinear_projection_cpu_kernels.h │ │ ├── pybind.cpp │ │ ├── base.h │ │ ├── base_cuda.cuh │ │ ├── grids.py │ │ ├── volume_extraction.cpp │ │ ├── volume_extraction_cpu_kernels.cpp │ │ ├── trilinear_projection.cpp │ │ ├── test.py │ │ ├── svr_linear.py │ │ └── volume_extraction_cuda_kernels.cu │ ├── distributed_processing.py │ ├── cache.py │ ├── region_of_interest.py │ ├── vtk_utils.py │ ├── optim.py │ ├── train_arguments.py │ ├── volume_explorer_neurips.py │ ├── volume_renderer.py │ ├── tensorboard_summary.py │ ├── utils.py │ └── volume_explorer.py ├── relion │ ├── __init__.py │ ├── so3.py │ └── dataset.py ├── __init__.py ├── base │ ├── __init__.py │ ├── io_logger.py │ ├── test.py │ ├── star_file.py │ ├── image_transforms.py │ ├── torch_utils.py │ ├── so3.py │ ├── ctf.py │ └── particle_image_preprocessor.py └── __main__.py ├── requirements.txt ├── environment.yml ├── LICENSE ├── README.md ├── setup.py └── .gitignore /voxelium/vae_volume/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | VAE volume reconstruction module 5 | """ 6 | -------------------------------------------------------------------------------- /voxelium/relion/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | RELION module 4 | """ 5 | 6 | from .dataset import * 7 | from .so3 import * 8 | -------------------------------------------------------------------------------- /voxelium/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Voxelium - Cryo-EM data analysis framework 5 | """ 6 | 7 | import os 8 | 9 | __version__ = '0.0.1' -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Sparse volume reconstruction module 5 | """ 6 | 7 | from .svr_linear import * 8 | from .grids import * -------------------------------------------------------------------------------- /voxelium/base/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Base module 5 | """ 6 | 7 | from .so3 import * 8 | from .ctf import * 9 | from .grid import * 10 | from .model_container import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython==1.81 2 | matplotlib==3.7.1 3 | mrcfile==1.4.3 4 | numpy==1.24.3 5 | pandas==2.0.2 6 | scikit-learn==1.2.2 7 | scipy==1.10.1 8 | starfile==0.4.12 9 | tensorboard==2.13.0 10 | torch==2.0.1 11 | torchvision==0.15.2 12 | tqdm==4.65.0 13 | umap-learn==0.5.3 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sbackprop 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pip 7 | - setuptools=68.0.0 8 | - vtk=9.0.3 9 | - loguru=0.7.0 10 | - pip: 11 | - biopython==1.81 12 | - matplotlib==3.7.1 13 | - mrcfile==1.4.3 14 | - numpy==1.24.3 15 | - pandas==2.0.2 16 | - scikit-learn==1.2.2 17 | - scipy==1.10.1 18 | - starfile==0.4.12 19 | - tensorboard==2.13.0 20 | - torch==2.0.1 21 | - torchvision==0.15.2 22 | - tqdm==4.65.0 23 | - umap-learn==0.5.3 -------------------------------------------------------------------------------- /voxelium/base/io_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Simple I/O pipe logging to file 5 | Can replace stdout to log all output 6 | """ 7 | 8 | import sys 9 | 10 | class IOLogger(object): 11 | def __init__(self, filename="Default.log", quiet=False): 12 | self.terminal = sys.stdout 13 | self.log = open(filename, "a") 14 | self.quiet = quiet 15 | 16 | def write(self, message): 17 | if not self.quiet: 18 | self.terminal.write(message) 19 | self.log.write(message) 20 | self.log.flush() 21 | 22 | def flush(self): 23 | if not self.quiet: 24 | self.terminal.flush() 25 | self.log.flush() -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_VOLUME_EXTRACTION_H 3 | #define SVR_LINEAR_VOLUME_EXTRACTION_H 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "vae_volume/svr_linear/volume_extraction_cpu_kernels.h" 13 | #include "vae_volume/svr_linear/volume_extraction_cuda_kernels.h" 14 | 15 | torch::Tensor volume_extraction_forward( 16 | torch::Tensor input, 17 | torch::Tensor weight, 18 | torch::Tensor bias, 19 | torch::Tensor grid3d_index, 20 | const int max_r 21 | ); 22 | 23 | std::vector volume_extraction_backward( 24 | torch::Tensor input_spectral_weight, 25 | torch::Tensor input, 26 | torch::Tensor weight, 27 | torch::Tensor bias, 28 | torch::Tensor grad_output, 29 | torch::Tensor grid3d_index 30 | ); 31 | 32 | #endif // SVR_LINEAR_VOLUME_EXTRACTION_H -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/trilinear_projection.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "vae_volume/svr_linear/trilinear_projection_cpu_kernels.h" 9 | #include "vae_volume/svr_linear/trilinear_projection_cuda_kernels.h" 10 | 11 | torch::Tensor trilinear_projection_forward( 12 | torch::Tensor input, 13 | torch::Tensor weight, 14 | torch::Tensor bias, 15 | torch::Tensor rot_matrix, 16 | torch::Tensor grid2d_coord, 17 | torch::Tensor grid3d_index, 18 | const int max_r 19 | ); 20 | 21 | std::vector trilinear_projection_backward( 22 | torch::Tensor input, 23 | torch::Tensor weight, 24 | torch::Tensor bias, 25 | torch::Tensor rot_matrix, 26 | torch::Tensor input_spectral_weight, 27 | torch::Tensor grad_output, 28 | torch::Tensor grid2d_coord, 29 | torch::Tensor grid3d_index, 30 | bool sparse_grad, 31 | const int max_r 32 | ); -------------------------------------------------------------------------------- /voxelium/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Voxelium - Cryo-EM data analysis framework 5 | """ 6 | 7 | 8 | def main(): 9 | import os 10 | import argparse 11 | import voxelium 12 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument('--version', action='version', version=f'Voxelium {voxelium.__version__}') 14 | 15 | import voxelium.vae_volume.train 16 | 17 | modules = { 18 | "volume_train": voxelium.vae_volume.train, 19 | } 20 | 21 | subparsers = parser.add_subparsers(title='Choose a module') 22 | subparsers.required = 'True' 23 | 24 | for key in modules: 25 | module_parser = subparsers.add_parser(key, description=modules[key].__doc__) 26 | modules[key].append_args(module_parser) 27 | module_parser.set_defaults(func=modules[key].main) 28 | 29 | try: 30 | args = parser.parse_args() 31 | args.func(args) 32 | except TypeError: 33 | parser.print_help() 34 | 35 | 36 | if __name__ == '__main__': 37 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kimanius 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction_cuda_kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_VOLUME_EXTRACTION_CUDA_KERNELS_H 3 | #define SVR_LINEAR_VOLUME_EXTRACTION_CUDA_KERNELS_H 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | void volume_extraction_forward_cuda( 14 | const torch::Tensor grid3d_index, 15 | const torch::Tensor weight, 16 | const torch::Tensor bias, 17 | const torch::Tensor input, 18 | torch::Tensor output, 19 | const int max_r2, 20 | const bool do_bias 21 | ); 22 | 23 | void volume_extraction_backward_cuda( 24 | const torch::Tensor grid3d_index, 25 | const torch::Tensor weight, 26 | const torch::Tensor bias, 27 | const torch::Tensor input_spectral_weight, 28 | const torch::Tensor input, 29 | const torch::Tensor grad_output, 30 | torch::Tensor grad_weight, 31 | torch::Tensor grad_bias, 32 | torch::Tensor grad_input, 33 | const int max_r2, 34 | const bool do_bias, 35 | const bool do_input_grad, 36 | const bool do_spectral_weighting 37 | ); 38 | 39 | #endif // SVR_LINEAR_VOLUME_EXTRACTION_CUDA_KERNELS_H -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction_cpu_kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_VOLUME_EXTRACTION_CPU_KERNELS_H 3 | #define SVR_LINEAR_VOLUME_EXTRACTION_CPU_KERNELS_H 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "vae_volume/svr_linear/base.h" 14 | 15 | void volume_extraction_forward_cpu( 16 | const torch::Tensor grid3d_index, 17 | const torch::Tensor weight, 18 | const torch::Tensor bias, 19 | const torch::Tensor input, 20 | torch::Tensor output, 21 | const int max_r2, 22 | const bool do_bias 23 | ); 24 | 25 | void volume_extraction_backward_cpu( 26 | const torch::Tensor grid3d_index, 27 | const torch::Tensor weight, 28 | const torch::Tensor bias, 29 | const torch::Tensor input_spectral_weight, 30 | const torch::Tensor input, 31 | const torch::Tensor grad_output, 32 | torch::Tensor grad_weight, 33 | torch::Tensor grad_bias, 34 | torch::Tensor grad_input, 35 | const int max_r2, 36 | const bool do_bias, 37 | const bool do_input_grad, 38 | const bool do_spectral_weighting 39 | ); 40 | 41 | #endif // SVR_LINEAR_VOLUME_EXTRACTION_CPU_KERNELS_H -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/trilinear_projection_cuda_kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_TRILINEAR_PROJECTION_CUDA_KERNELS_H 3 | #define SVR_LINEAR_TRILINEAR_PROJECTION_CUDA_KERNELS_H 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | void trilinear_projection_forward_cuda( 14 | const torch::Tensor grid2d_coord, 15 | const torch::Tensor grid3d_index, 16 | const torch::Tensor weight, 17 | const torch::Tensor bias, 18 | const torch::Tensor rot_matrix, 19 | const torch::Tensor input, 20 | torch::Tensor output, 21 | const int max_r2, 22 | const int init_offset, 23 | const bool do_bias 24 | ); 25 | 26 | void trilinear_projection_backward_cuda( 27 | const torch::Tensor grid2d_coord, 28 | const torch::Tensor grid3d_index, 29 | const torch::Tensor weight, 30 | const torch::Tensor bias, 31 | torch::Tensor grad_weight_index, 32 | const torch::Tensor rot_matrix, 33 | const torch::Tensor input_spectral_weight, 34 | const torch::Tensor input, 35 | const torch::Tensor grad_output, 36 | torch::Tensor grad_weight, 37 | torch::Tensor grad_bias, 38 | torch::Tensor grad_input, 39 | torch::Tensor grad_rot_matrix, 40 | const int max_r2, 41 | const int init_offset, 42 | const bool do_bias, 43 | const bool do_grad_rot_matrix, 44 | const bool sparse_grad, 45 | const bool do_spectral_weighting 46 | ); 47 | 48 | #endif // SVR_LINEAR_TRILINEAR_PROJECTION_CUDA_KERNELS_H -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/trilinear_projection_cpu_kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_TRILINEAR_PROJECTION_CPU_KERNELS_H 3 | #define SVR_LINEAR_TRILINEAR_PROJECTION_CPU_KERNELS_H 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "vae_volume/svr_linear/base.h" 14 | 15 | void trilinear_projection_forward_cpu( 16 | const torch::Tensor grid2d_coord, 17 | const torch::Tensor grid3d_index, 18 | const torch::Tensor weight, 19 | const torch::Tensor bias, 20 | const torch::Tensor rot_matrix, 21 | const torch::Tensor input, 22 | torch::Tensor output, 23 | const int max_r2, 24 | const int init_offset, 25 | const bool do_bias 26 | ); 27 | 28 | void trilinear_projection_backward_cpu( 29 | const torch::Tensor grid2d_coord, 30 | const torch::Tensor grid3d_index, 31 | const torch::Tensor weight, 32 | const torch::Tensor bias, 33 | torch::Tensor grad_weight_index, 34 | const torch::Tensor rot_matrix, 35 | const torch::Tensor input_spectral_weight, 36 | const torch::Tensor input, 37 | const torch::Tensor grad_output, 38 | torch::Tensor grad_weight, 39 | torch::Tensor grad_bias, 40 | torch::Tensor grad_input, 41 | torch::Tensor grad_rot_matrix, 42 | const int max_r2, 43 | const int init_offset, 44 | const bool do_bias, 45 | const bool do_grad_rot_matrix, 46 | const bool sparse_grad, 47 | const bool do_spectral_weighting 48 | ); 49 | 50 | #endif // SVR_LINEAR_TRILINEAR_PROJECTION_CPU_KERNELS_H -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Fourier Backpropagation in Cryo-EM Reconstruction 2 | This repository contains code for the paper: [Sparse Fourier Backpropagation in Cryo-EM Reconstruction](https://proceedings.neurips.cc/paper_files/paper/2022/hash/50729453d56ecf6a8b7be78998776472-Abstract-Conference.html). 3 | Part of Advances in Neural Information Processing Systems 35 (NeurIPS 2022) Main Conference Track. 4 | 5 | # Setup a conda environment 6 | You need to setup a Python environment with dependencies. We recommend installing via Miniconda3. 7 | 8 | Once you have conda setup, you can install all the Python dependencies into a new environment by running: 9 | 10 | ```conda env create -f environment.yml``` 11 | 12 | You can then activate the conda environment by running: 13 | 14 | ```conda activate sbackprop``` 15 | 16 | # Compile and Install CUDA code 17 | Once inside the correct environment you can compile and install the CUDA dependencies by running: 18 | 19 | ```python setup.py install``` 20 | 21 | # Running Training 22 | You can then run training by running 23 | 24 | ```python voxelium/vae_volume/train.py --gpu 0``` 25 | 26 | Use ```-h``` for more options. 27 | 28 | # Visualizing Results 29 | You can then visualize the results using 30 | 31 | ```python voxelium/vae_volume/volume_explorer.py ``` 32 | 33 | # Citation 34 | ``` 35 | @article{kimanius2022sparse, 36 | title={Sparse fourier backpropagation in cryo-em reconstruction}, 37 | author={Kimanius, Dari and Jamali, Kiarash and Scheres, Sjors}, 38 | journal={Advances in Neural Information Processing Systems}, 39 | volume={35}, 40 | pages={12395--12408}, 41 | year={2022} 42 | } 43 | ``` -------------------------------------------------------------------------------- /voxelium/base/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for the base module 5 | """ 6 | 7 | import unittest 8 | 9 | from voxelium.base import matrix_to_quaternion, quaternion_to_matrix, is_rotation_matrix, \ 10 | euler_to_matrix, taitbryan_to_matrix 11 | from voxelium.sparse_refine.svr_linear.sparse_linear import * 12 | 13 | ERROR_EPS = 1e-5 14 | 15 | 16 | class TestSparseLinear(unittest.TestCase): 17 | def test_euler_to_matrix(self): 18 | a = self._get_random_angles(1000).double() 19 | R = euler_to_matrix(a) 20 | self.assertTrue(torch.all(is_rotation_matrix(R))) 21 | 22 | def test_taitbyant_to_matrix(self): 23 | a = self._get_random_angles(1000).double() 24 | R = taitbryan_to_matrix(a) 25 | self.assertTrue(torch.all(is_rotation_matrix(R))) 26 | 27 | def test_quaternions(self): 28 | a = self._get_random_angles(1000).double() 29 | R1 = euler_to_matrix(a) 30 | self.assertTrue(torch.all(is_rotation_matrix(R1))) 31 | Q1 = matrix_to_quaternion(R1) 32 | R2 = quaternion_to_matrix(Q1) 33 | self.assertTrue(torch.all(is_rotation_matrix(R2))) 34 | D1 = torch.abs(R1 - R2) 35 | self.assertTrue(torch.all(D1 < ERROR_EPS)) 36 | Q2 = matrix_to_quaternion(R2) 37 | D2 = torch.abs(Q1 - Q2) 38 | self.assertTrue(torch.all(D2 < ERROR_EPS)) 39 | 40 | @staticmethod 41 | def _get_random_angles(count): 42 | a12 = 2 * np.pi * torch.rand(count, 2) 43 | a3 = torch.rand((count, 1)).mul(2).sub(1).acos() 44 | return torch.cat([a12, a3], 1) 45 | 46 | if __name__ == "__main__": 47 | test = TestSparseLinear() 48 | test.test_quaternions() 49 | print("All good!") 50 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include "vae_volume/svr_linear/trilinear_projection.h" 2 | #include "vae_volume/svr_linear/volume_extraction.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 5 | { 6 | m.def( 7 | "trilinear_projection_forward", 8 | &trilinear_projection_forward, 9 | "Trilinear projector forward", 10 | py::arg("input"), 11 | py::arg("weight"), 12 | py::arg("bias"), 13 | py::arg("rot_matrix"), 14 | py::arg("grid2d_coord"), 15 | py::arg("grid3d_index"), 16 | py::arg("max_r") 17 | ); 18 | 19 | m.def( 20 | "trilinear_projection_backward", 21 | &trilinear_projection_backward, 22 | "Trilinear projector backward", 23 | py::arg("input"), 24 | py::arg("weight"), 25 | py::arg("bias"), 26 | py::arg("rot_matrix"), 27 | py::arg("input_spectral_weight"), 28 | py::arg("grid2d_grad"), 29 | py::arg("grid2d_coord"), 30 | py::arg("grid3d_index"), 31 | py::arg("sparse_grad"), 32 | py::arg("max_r") 33 | ); 34 | 35 | m.def( 36 | "volume_extraction_forward", 37 | &volume_extraction_forward, 38 | "Volume extraction forward", 39 | py::arg("input"), 40 | py::arg("weight"), 41 | py::arg("bias"), 42 | py::arg("grid3d_index"), 43 | py::arg("max_r") 44 | ); 45 | 46 | m.def( 47 | "volume_extraction_backward", 48 | &volume_extraction_backward, 49 | "Volume extraction backward", 50 | py::arg("input_spectral_weight"), 51 | py::arg("input"), 52 | py::arg("weight"), 53 | py::arg("bias"), 54 | py::arg("grad_output"), 55 | py::arg("grid3d_index") 56 | ); 57 | } 58 | -------------------------------------------------------------------------------- /voxelium/base/star_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Module for star-file I/O 5 | """ 6 | 7 | from collections import OrderedDict 8 | 9 | 10 | def load_star(filename): 11 | datasets = OrderedDict() 12 | current_data = None 13 | current_colnames = None 14 | 15 | BASE = 0 # Not in a block 16 | COLNAME = 1 # Parsing column name 17 | DATA = 2 # Parsing data 18 | mode = BASE 19 | 20 | for line in open(filename): 21 | line = line.strip() 22 | 23 | # remove comments 24 | comment_pos = line.find('#') 25 | if comment_pos > 0: 26 | line = line[:comment_pos] 27 | 28 | if line == "": 29 | if mode == DATA: 30 | mode = BASE 31 | continue 32 | 33 | if line.startswith("data_"): 34 | mode = BASE 35 | data_name = line[5:] 36 | current_data = OrderedDict() 37 | datasets[data_name] = current_data 38 | 39 | elif line.startswith("loop_"): 40 | current_colnames = [] 41 | mode = COLNAME 42 | 43 | elif line.startswith("_"): 44 | if mode == DATA: 45 | mode = BASE 46 | token = line[1:].split() 47 | if mode == COLNAME: 48 | current_colnames.append(token[0]) 49 | current_data[token[0]] = [] 50 | else: 51 | current_data[token[0]] = token[1] 52 | 53 | elif mode != BASE: 54 | mode = DATA 55 | token = line.split() 56 | if len(token) != len(current_colnames): 57 | raise RuntimeError( 58 | f"Error in STAR file {filename}, number of elements in {token} " 59 | f"does not match number of column names {current_colnames}" 60 | ) 61 | for idx, e in enumerate(token): 62 | current_data[current_colnames[idx]].append(e) 63 | 64 | return datasets 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Setup module for Voxelium 5 | """ 6 | 7 | import os 8 | import sys 9 | import sysconfig 10 | 11 | from setuptools import setup, find_packages 12 | from torch.utils import cpp_extension 13 | 14 | 15 | def print_debug_msg(): 16 | print("-------------------------------------- ") 17 | print("------------- DEBUG MODE ------------- ") 18 | print("-------------------------------------- ") 19 | 20 | 21 | sys.path.insert(0, f'{os.path.dirname(__file__)}/voxelium') 22 | import voxelium 23 | 24 | _DEBUG = False 25 | _DEBUG_LEVEL = 0 26 | 27 | project_root = os.path.join(os.path.realpath(os.path.dirname(__file__)), "voxelium") 28 | 29 | include_dirs = [project_root] 30 | 31 | cxx_extra_compile_args = [] 32 | nvcc_extra_compile_args = [] 33 | if _DEBUG: 34 | print_debug_msg() 35 | cxx_extra_compile_args += ["-g", "-O0", "-DDEBUG=%s" % _DEBUG_LEVEL, "-UNDEBUG"] 36 | nvcc_extra_compile_args += ["-G", "-lineinfo"] 37 | else: 38 | cxx_extra_compile_args += ["-DNDEBUG", "-O3"] 39 | nvcc_extra_compile_args += cxx_extra_compile_args 40 | 41 | ext_modules = [ 42 | cpp_extension.CUDAExtension( 43 | name='voxelium_svr_linear', 44 | sources=[ 45 | 'voxelium/vae_volume/svr_linear/pybind.cpp', 46 | 'voxelium/vae_volume/svr_linear/trilinear_projection.cpp', 47 | 'voxelium/vae_volume/svr_linear/trilinear_projection_cpu_kernels.cpp', 48 | 'voxelium/vae_volume/svr_linear/trilinear_projection_cuda_kernels.cu', 49 | 'voxelium/vae_volume/svr_linear/volume_extraction.cpp', 50 | 'voxelium/vae_volume/svr_linear/volume_extraction_cpu_kernels.cpp', 51 | 'voxelium/vae_volume/svr_linear/volume_extraction_cuda_kernels.cu', 52 | ], 53 | include_dirs=include_dirs, 54 | extra_compile_args={'cxx': cxx_extra_compile_args, 'nvcc': nvcc_extra_compile_args}, 55 | ) 56 | ] 57 | setup( 58 | name='voxelium', 59 | ext_modules=ext_modules, 60 | cmdclass={'build_ext': cpp_extension.BuildExtension}, 61 | packages=find_packages(), 62 | entry_points={ 63 | "console_scripts": [ 64 | "voxelium = voxelium.__main__:main", 65 | ], 66 | }, 67 | version=voxelium.__version__ 68 | ) 69 | 70 | if _DEBUG: 71 | print_debug_msg() 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | utils/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #IDE related 132 | .idea 133 | 134 | # No MRC files 135 | *.mrc 136 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_BASE_H 3 | #define SVR_LINEAR_BASE_H 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define SPECTRAL_WEIGHT_EPS 1e-6 11 | 12 | #define CHECK_SIZE_DIM0(x, SIZE) TORCH_CHECK(x.size(0) == SIZE, \ 13 | #x ".size(0) (", x.size(0), ") != " #SIZE " (", SIZE, "). In ", __FILE__, ":", __LINE__) 14 | #define CHECK_SIZE_DIM1(x, SIZE) TORCH_CHECK(x.size(1) == SIZE, \ 15 | #x ".size(1) (", x.size(1), ") != " #SIZE " (", SIZE, "). In ", __FILE__, ":", __LINE__) 16 | #define CHECK_SIZE_DIM2(x, SIZE) TORCH_CHECK(x.size(2) == SIZE, \ 17 | #x ".size(2) (", x.size(2), ") != " #SIZE " (", SIZE, "). In ", __FILE__, ":", __LINE__) 18 | #define CHECK_SIZE_DIM3(x, SIZE) TORCH_CHECK(x.size(3) == SIZE, \ 19 | #x ".size(3) (", x.size(3), ") != " #SIZE " (", SIZE, "). In ", __FILE__, ":", __LINE__) 20 | 21 | #define CHECK_DTYPE(x, DTYPE) TORCH_CHECK(x.dtype() == DTYPE, \ 22 | #x " has the wrong data type (", x.dtype(), "), expecting " \ 23 | #DTYPE, " (", DTYPE, "). In ", __FILE__, ":", __LINE__) 24 | 25 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), \ 26 | #x " must be a CPU tensor. In ", __FILE__, ":", __LINE__) 27 | 28 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), \ 29 | #x " must be contiguous. In ", __FILE__, ":", __LINE__) 30 | 31 | #define CHECK_DIM(x, DIM) TORCH_CHECK(x.dim() == DIM, \ 32 | #x " has wrong number of dimensions. In ", __FILE__, ":", __LINE__) 33 | 34 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) 35 | 36 | /** Linear interpolation 37 | * 38 | * From low (when a=0) to high (when a=1). The following value is returned 39 | * (equal to (a*h)+((1-a)*l) 40 | */ 41 | #ifndef LIN_INTERP 42 | #define LIN_INTERP(a, l, h) ((l) + ((h) - (l)) * (a)) 43 | #endif 44 | 45 | template 46 | struct dispatch_bools 47 | { 48 | template 49 | void operator()( std::array const& input, F&& continuation, Bools... ) 50 | { 51 | if (input[max-1]) 52 | dispatch_bools{}( input, continuation, std::integral_constant{}, Bools{}... ); 53 | else 54 | dispatch_bools{}( input, continuation, std::integral_constant{}, Bools{}... ); 55 | } 56 | }; 57 | 58 | template<> 59 | struct dispatch_bools<0> 60 | { 61 | template 62 | void operator()( std::array const& input, F&& continuation, Bools... ) 63 | { 64 | continuation( Bools{}... ); 65 | } 66 | }; 67 | 68 | #endif // SVR_LINEAR_BASE_H -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/base_cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef SVR_LINEAR_BASE_CUDA_H 3 | #define SVR_LINEAR_BASE_CUDA_H 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "vae_volume/svr_linear/base.h" 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor. In ", __FILE__, ":", __LINE__) 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | 16 | #define CUDA_ERRCHK(ans) { gpuAssert((ans), __FILE__, __LINE__); } 17 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 18 | { 19 | if (code != cudaSuccess) 20 | { 21 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 22 | if (abort) exit(code); 23 | } 24 | } 25 | 26 | __device__ 27 | inline 28 | bool thread_index_expand(const size_t dim0, const size_t dim1, 29 | size_t &idx0, size_t &idx1) 30 | { 31 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 32 | idx0 = index / dim1; 33 | idx1 = index % dim1; 34 | return idx0 < dim0; 35 | } 36 | 37 | __device__ 38 | inline 39 | bool thread_index_expand(const size_t dim0, const size_t dim1, const size_t dim2, 40 | size_t &idx0, size_t &idx1, size_t &idx2) 41 | { 42 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 43 | idx0 = index / (dim1 * dim2); 44 | idx1 = (index / dim2) % dim1; 45 | idx2 = index % dim2; 46 | return idx0 < dim0; 47 | } 48 | 49 | __device__ 50 | inline 51 | bool thread_index_expand(const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3, 52 | size_t &idx0, size_t &idx1, size_t &idx2, size_t &idx3) 53 | { 54 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 55 | idx0 = index / (dim1 * dim2 * dim3); 56 | idx1 = (index / (dim2 * dim3)) % dim1; 57 | idx2 = (index / dim3) % dim2; 58 | idx3 = index % dim3; 59 | return idx0 < dim0; 60 | } 61 | 62 | template 63 | __device__ 64 | inline 65 | size_t accessor_index_collapse(const accessor_t &a, size_t i0, size_t i1) 66 | { 67 | return i0 * a.stride(0) + i1; 68 | } 69 | 70 | template 71 | __device__ 72 | inline 73 | size_t accessor_index_collapse(const accessor_t &a, size_t i0, size_t i1, size_t i2) 74 | { 75 | return i0 * a.stride(0) + i1 * a.stride(1) + i2; 76 | } 77 | 78 | template 79 | __device__ 80 | inline 81 | size_t accessor_index_collapse(const accessor_t &a, size_t i0, size_t i1, size_t i2, size_t i3) 82 | { 83 | return i0 * a.stride(0) + i1 * a.stride(1) + i2 * a.stride(2) + i3; 84 | } 85 | 86 | #endif // SVR_LINEAR_BASE_CUDA_H -------------------------------------------------------------------------------- /voxelium/relion/so3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Module for calculations related to rotation matrices and euler angles as defined in RELION. 5 | """ 6 | import sys 7 | import numpy as np 8 | import torch 9 | 10 | from typing import Tuple, Union, TypeVar 11 | 12 | Tensor = TypeVar('torch.tensor') 13 | 14 | 15 | def eulerToMatrix( 16 | angles: Union[Tensor, np.ndarray] 17 | ) -> Union[Tensor, np.ndarray]: 18 | """ 19 | Takes a batch of the three Euler angles as defined in RELION and 20 | returns a batch of the corresponding rotation matrices 21 | 22 | Supports both numpy arrays and torch tensor input 23 | 24 | :param angles: an array (B, 3) of the Euler angels, alpha, beta and gamma (rot, tilt, psi) 25 | :return: a 3x3 rotation matrix 26 | """ 27 | if torch.is_tensor(angles): 28 | R = torch.zeros(len(angles), 3, 3, dtype=angles.dtype).to(angles.device) 29 | ca = torch.cos(angles[:, 0]) 30 | cb = torch.cos(angles[:, 1]) 31 | cg = torch.cos(angles[:, 2]) 32 | sa = torch.sin(angles[:, 0]) 33 | sb = torch.sin(angles[:, 1]) 34 | sg = torch.sin(angles[:, 2]) 35 | else: 36 | R = np.zeros((len(angles), 3, 3), dtype=angles.dtype) 37 | ca = np.cos(angles[:, 0]) 38 | cb = np.cos(angles[:, 1]) 39 | cg = np.cos(angles[:, 2]) 40 | sa = np.sin(angles[:, 0]) 41 | sb = np.sin(angles[:, 1]) 42 | sg = np.sin(angles[:, 2]) 43 | 44 | cc = cb * ca 45 | cs = cb * sa 46 | sc = sb * ca 47 | ss = sb * sa 48 | 49 | R[:, 0, 0] = cg * cc - sg * sa 50 | R[:, 0, 1] = cg * cs + sg * ca 51 | R[:, 0, 2] = -cg * sb 52 | R[:, 1, 0] = -sg * cc - cg * sa 53 | R[:, 1, 1] = -sg * cs + cg * ca 54 | R[:, 1, 2] = sg * sb 55 | R[:, 2, 0] = sc 56 | R[:, 2, 1] = ss 57 | R[:, 2, 2] = cb 58 | 59 | return R 60 | 61 | 62 | def matrixToEuler(R: np.ndarray): 63 | """ 64 | Takes a rotation matrix and returns the 65 | three Euler angles as defined in RELION 66 | 67 | TODO: Add support for batches 68 | 69 | :param R: a 3x3 rotation matrix 70 | :return: the three Euler angles, alpha, beta and gamma (rot, tilt, psi) 71 | """ 72 | abs_sb = np.sqrt(R[0, 2] * R[0, 2] + R[1, 2] * R[1, 2]) 73 | if abs_sb > 16*sys.float_info.epsilon: 74 | gamma = np.atan2(R[1, 2], -R[0, 2]) 75 | alpha = np.atan2(R[2, 1], R[2, 0]) 76 | if np.abs(np.sin(gamma)) < sys.float_info.epsilon: 77 | sign_sb = np.sgn(-R(0, 2) / np.cos(gamma)) 78 | else: 79 | sign_sb = np.sgn(R[1, 2]) if np.sin(gamma) > 0 else -np.sgn(R[1, 2]) 80 | beta = np.atan2(sign_sb * abs_sb, R[2, 2]) 81 | 82 | else: 83 | if R[2, 2] > 0: 84 | alpha = 0 85 | beta = 0 86 | gamma = np.atan2(-R[1, 0], R[0, 0]) 87 | else: 88 | alpha = 0 89 | beta = np.pi 90 | gamma = np.atan2(R[1, 0], -R[0, 0]) 91 | 92 | return alpha, beta, gamma 93 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/grids.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Module for the sparse linear layer 5 | """ 6 | import copy 7 | import time 8 | from typing import TypeVar, Union 9 | import numpy as np 10 | import torch 11 | 12 | Tensor = TypeVar('torch.tensor') 13 | 14 | 15 | def make_compact_grid2d(size: int = None, max_r: int = None): 16 | """ 17 | Makes image grid coordinates and indices. 18 | Used by sparse project to output projection images out of 3D grids. 19 | Must provide either max_r or size. If both are given max_r will be ignored. 20 | For even box size: img_size = max_r * 2 - 2 <=> max_r = floor(img_size / 2) + 1 21 | For odd box size: img_size = max_r * 2 - 1 <=> max_r = floor(img_size / 2) + 1 22 | :param size: Size of the grid containing a max_r circle (not including max_r) 23 | :param max_r: Max radius of circle contained by image grid. 24 | """ 25 | if max_r is None: 26 | if size is None: 27 | raise RuntimeError("Either max_r or size must be given.") 28 | if size % 2 == 0: 29 | size += 1 30 | max_r = (size - 1) // 2 31 | if size is None: 32 | if max_r is None: 33 | raise RuntimeError("Either max_r or size must be given.") 34 | size = max_r * 2 + 1 35 | 36 | size_2 = size // 2 37 | 38 | # Make xy-plane grid coordinates 39 | ls = torch.linspace(-size_2, size_2, size) 40 | lsx = torch.linspace(0, size_2, size_2 + 1) 41 | y, x = torch.meshgrid(ls, lsx, indexing='ij') 42 | coord = torch.stack([x, y], 2).view(-1, 2) 43 | 44 | # We need to work with explicit indices, flatten coordinate grid 45 | radius = torch.sqrt(torch.sum(torch.square(coord), -1)) 46 | 47 | # Mask out beyond Nyqvist in 2D grid 48 | mask = radius <= max_r 49 | mask = mask.flatten() 50 | 51 | coord = coord[mask].contiguous() 52 | 53 | # import matplotlib 54 | # import matplotlib.pylab as plt 55 | # matplotlib.use('TkAgg') 56 | # plt.plot(coord.data[:, 0].numpy(), coord.data[:, 1].numpy(), '.', alpha=0.3) 57 | # plt.show() 58 | 59 | coord.require_grad = False 60 | 61 | return coord, mask 62 | 63 | 64 | def make_grid3d(size: int = None, max_r: int = None): 65 | """ 66 | Makes volume grid coordinates and indices. 67 | Used by sparse project to output projection images out of 3D grids. 68 | Must provide either max_r or size. If both are given max_r will be ignored. 69 | Note: img_size = max_r * 2 + 1 <=> max_r = floor(img_size / 2) 70 | :param size: Size of the grid containing a max_r circle 71 | :param max_r: Max radius of circle contained by image grid. 72 | """ 73 | if max_r is None: 74 | if size is None: 75 | raise RuntimeError("Either max_r or size must be given.") 76 | if size % 2 == 0: 77 | size += 1 78 | max_r = (size - 1) // 2 79 | if size is None: 80 | if max_r is None: 81 | raise RuntimeError("Either max_r or size must be given.") 82 | size = max_r * 2 + 1 83 | 84 | size_2 = size // 2 85 | 86 | # Make xy-plane grid coordinates 87 | ls = torch.linspace(-size_2, size_2, size) 88 | lsx = torch.linspace(0, size_2, size_2 + 1) 89 | coord = torch.stack(torch.meshgrid(ls, ls, lsx, indexing='ij'), 3) 90 | 91 | # Mask out beyond Nyqvist in 2D grid 92 | radius = torch.sqrt(torch.sum(torch.square(coord), -1)) 93 | mask = radius <= max_r 94 | 95 | return coord, mask 96 | -------------------------------------------------------------------------------- /voxelium/vae_volume/distributed_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for a training VAE 5 | """ 6 | import os 7 | import sys 8 | from typing import List, TypeVar, Any, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.multiprocessing as mp 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | 19 | Tensor = TypeVar('torch.tensor') 20 | 21 | 22 | class DistributedProcessing: 23 | _doing_ddp = False 24 | _this_rank = None 25 | _this_device = torch.device('cpu') 26 | _this_device_id = None 27 | 28 | @staticmethod 29 | def global_setup(args, main_fn, verbose=True) -> None: 30 | rank_gpu_map = [] 31 | 32 | if args.gpu is not None: 33 | queried_gpu_ids = args.gpu.split(",") 34 | for i in range(len(queried_gpu_ids)): 35 | gpu_id = int(queried_gpu_ids[i].strip()) 36 | try: 37 | gpu_name = torch.cuda.get_device_name(gpu_id) 38 | except AssertionError: 39 | if verbose: 40 | print(f'WARNING: GPU with the device id "{gpu_id}" not found.', file=sys.stderr) 41 | continue 42 | if verbose: 43 | print(f'Found device "{gpu_name}"') 44 | rank_gpu_map.append(gpu_id) 45 | 46 | if len(rank_gpu_map) > 0: 47 | if verbose: 48 | print("Running on GPU with device id(s)", *rank_gpu_map) 49 | else: 50 | if verbose: 51 | print(f'WARNING: no GPUs were found with the specified ids.', file=sys.stderr) 52 | 53 | if len(rank_gpu_map) == 0: 54 | if verbose: 55 | print("Running on CPU") 56 | 57 | world_size = max(len(rank_gpu_map), 1) 58 | ddp_args = {'world_size': world_size, 'rank_gpu_map': rank_gpu_map} 59 | 60 | if world_size > 1: 61 | mp.spawn(main_fn, args=(args, ddp_args), nprocs=world_size, join=True) 62 | else: 63 | main_fn(rank=0, args=args, ddp_args=ddp_args) 64 | 65 | @staticmethod 66 | def process_setup(rank, args) -> None: 67 | 68 | if len(args['rank_gpu_map']) > 0: 69 | device_id = args['rank_gpu_map'][rank] 70 | DistributedProcessing._this_device_id = device_id 71 | DistributedProcessing._this_device = torch.device('cuda:' + str(device_id)) 72 | else: 73 | DistributedProcessing._this_device = torch.device('cpu') 74 | 75 | if args['world_size'] > 1: 76 | os.environ['MASTER_ADDR'] = 'localhost' 77 | os.environ['MASTER_PORT'] = '12355' 78 | 79 | # initialize the process group 80 | dist.init_process_group("gloo", rank=rank, world_size=args['world_size']) 81 | DistributedProcessing._doing_ddp = True 82 | DistributedProcessing._this_rank = rank 83 | 84 | print(f"Process {rank} initialized") 85 | 86 | dist.barrier() 87 | 88 | @staticmethod 89 | def process_cleanup() -> None: 90 | if DistributedProcessing._doing_ddp: 91 | dist.destroy_process_group() 92 | 93 | @staticmethod 94 | def doing_ddp() -> bool: 95 | return DistributedProcessing._doing_ddp 96 | 97 | @staticmethod 98 | def is_rank_zero() -> bool: 99 | return DistributedProcessing._this_rank == 0 100 | 101 | @staticmethod 102 | def get_rank() -> int: 103 | return DistributedProcessing._this_rank 104 | 105 | @staticmethod 106 | def get_device() -> Any: 107 | return DistributedProcessing._this_device 108 | 109 | @staticmethod 110 | def get_device_id() -> int: 111 | return DistributedProcessing._this_device_id 112 | 113 | @staticmethod 114 | def setup_module(module: torch.nn.Module) -> Union[torch.nn.Module, DDP]: 115 | module = module.to(DistributedProcessing.get_device()) 116 | if DistributedProcessing.doing_ddp(): 117 | module = DDP(module, device_ids=[DistributedProcessing.get_device_id()]) 118 | return module 119 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction.cpp: -------------------------------------------------------------------------------- 1 | #include "vae_volume/svr_linear/volume_extraction.h" 2 | 3 | torch::Tensor volume_extraction_forward( 4 | torch::Tensor input, 5 | torch::Tensor weight, 6 | torch::Tensor bias, 7 | torch::Tensor grid3d_index, 8 | const int max_r 9 | ) 10 | { 11 | const int batch_size = input.size(0); 12 | const int img_side = max_r * 2 + 1; 13 | const bool do_bias = bias.size(0) == weight.size(0); 14 | 15 | auto output = torch::zeros( 16 | {batch_size, img_side, img_side, img_side / 2 + 1, 2}, 17 | torch::TensorOptions() 18 | .dtype(input.dtype()) 19 | .layout(torch::kStrided) 20 | .device(input.device()) 21 | .requires_grad(true) 22 | ); 23 | 24 | if (input.device().type() == torch::kCPU) 25 | { 26 | volume_extraction_forward_cpu( 27 | /*grid3d_index*/ grid3d_index, 28 | /*weight*/ weight, 29 | /*bias*/ bias, 30 | /*input*/ input, 31 | /*output*/ output, 32 | /*max_r2*/ (int) max_r * max_r, 33 | /*do_bias*/ do_bias 34 | ); 35 | } 36 | else if (input.device().type() == torch::kCUDA) 37 | { 38 | volume_extraction_forward_cuda( 39 | /*grid3d_index*/ grid3d_index, 40 | /*weight*/ weight, 41 | /*bias*/ bias, 42 | /*input*/ input, 43 | /*output*/ output, 44 | /*max_r2*/ (int) max_r * max_r, 45 | /*do_bias*/ do_bias 46 | ); 47 | } 48 | else 49 | throw std::logic_error("Support for device not implemented"); 50 | 51 | return output; 52 | } 53 | 54 | 55 | std::vector volume_extraction_backward( 56 | torch::Tensor input_spectral_weight, 57 | torch::Tensor input, 58 | torch::Tensor weight, 59 | torch::Tensor bias, 60 | torch::Tensor grad_output, 61 | torch::Tensor grid3d_index 62 | ) 63 | { 64 | const int img_side = grad_output.size(1); 65 | const int max_r = (img_side - 1) / 2; 66 | const bool do_bias = bias.size(0) == weight.size(0); 67 | 68 | CHECK_SIZE_DIM0(grad_output, input.size(0)) 69 | 70 | auto grad_weight = torch::zeros_like(weight); 71 | auto grad_bias = torch::zeros_like(bias); 72 | auto grad_input = torch::zeros_like(input); 73 | 74 | TORCH_CHECK(input_spectral_weight.size(0) == 0 || input_spectral_weight.size(0) > max_r, 75 | "input_spectral_weight.size(0) bad size (", input_spectral_weight.size(0), " <= ", 76 | max_r, "). In ", __FILE__, ":", __LINE__) 77 | 78 | if (input.device().type() == torch::kCPU) 79 | { 80 | volume_extraction_backward_cpu( 81 | /*grid3d_index*/ grid3d_index, 82 | /*weight*/ weight, 83 | /*bias*/ bias, 84 | /*input_spectral_weight*/ input_spectral_weight, 85 | /*input*/ input, 86 | /*grad_output*/ grad_output, 87 | /*grad_weight*/ grad_weight, 88 | /*grad_bias*/ grad_bias, 89 | /*grad_input*/ grad_input, 90 | /*max_r2*/ (int) max_r * max_r, 91 | /*do_bias*/ do_bias, 92 | /*do_input_grad*/ input.requires_grad(), 93 | /*do_spectral_weighting*/ input_spectral_weight.size(0) > 0 94 | ); 95 | } 96 | else if (input.device().type() == torch::kCUDA) 97 | { 98 | volume_extraction_backward_cuda( 99 | /*grid3d_index*/ grid3d_index, 100 | /*weight*/ weight, 101 | /*bias*/ bias, 102 | /*input_spectral_weight*/ input_spectral_weight, 103 | /*input*/ input, 104 | /*grad_output*/ grad_output, 105 | /*grad_weight*/ grad_weight, 106 | /*grad_bias*/ grad_bias, 107 | /*grad_input*/ grad_input, 108 | /*max_r2*/ (int) max_r * max_r, 109 | /*do_bias*/ do_bias, 110 | /*do_input_grad*/ input.requires_grad(), 111 | /*do_spectral_weighting*/ input_spectral_weight.size(0) > 0 112 | ); 113 | } 114 | else 115 | throw std::logic_error("Support for device not implemented"); 116 | 117 | return {grad_input, grad_weight, grad_bias}; 118 | } 119 | -------------------------------------------------------------------------------- /voxelium/vae_volume/cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for a training VAE 5 | """ 6 | import sys 7 | from typing import List, TypeVar, Union, Tuple, Any 8 | 9 | import numpy as np 10 | import torch 11 | 12 | Tensor = TypeVar('torch.tensor') 13 | 14 | from voxelium.base import smooth_square_mask, smooth_circular_mask, get_spectral_indices, dt_symmetrize 15 | 16 | 17 | class Cache: 18 | square_masks = {} 19 | circular_masks = {} 20 | spectral_indices = {} 21 | encoder_input_masks = {} 22 | 23 | @staticmethod 24 | def _get_square_mask(image_size: int, thickness: float) -> Tensor: 25 | return torch.Tensor( 26 | smooth_square_mask( 27 | image_size=image_size, 28 | square_side=image_size - thickness * 2, 29 | thickness=thickness 30 | ) 31 | ) 32 | 33 | @staticmethod 34 | def get_square_mask(image_size: int, thickness: float, device: Any = 'cpu') -> Tensor: 35 | tag = str(image_size) + "_" + str(thickness) + "_" + str(device) 36 | if tag not in Cache.square_masks: 37 | Cache.square_masks[tag] = Cache._get_square_mask(image_size, thickness).to(device) 38 | return Cache.square_masks[tag] 39 | 40 | @staticmethod 41 | def apply_square_mask(input: Tensor, thickness: float) -> Tensor: 42 | return input * Cache.get_square_mask(input.shape[-1], thickness, input.device)[None, ...] 43 | 44 | @staticmethod 45 | def _get_circular_mask(image_size: int, radius: float, thickness: float) -> Tensor: 46 | return torch.Tensor( 47 | smooth_circular_mask( 48 | image_size=image_size, 49 | radius=radius, 50 | thickness=thickness 51 | ) 52 | ) 53 | 54 | @staticmethod 55 | def get_circular_mask(image_size: int, radius: float, thickness: float, device: Any = 'cpu') -> Tensor: 56 | tag = str(image_size) + "_" + str(radius) + "_" + str(thickness) + "_" + str(device) 57 | if tag not in Cache.circular_masks: 58 | Cache.circular_masks[tag] = Cache._get_circular_mask(image_size, radius, thickness).to(device) 59 | return Cache.circular_masks[tag] 60 | 61 | @staticmethod 62 | def apply_circular_mask(input: Tensor, radius: float, thickness: float) -> Tensor: 63 | return input * Cache.get_circular_mask(input.shape[-1], radius, thickness, input.device)[None, ...] 64 | 65 | @staticmethod 66 | def _get_spectral_indices( 67 | shape: Union[Tuple[int, int], Tuple[int, int, int]], numpy: bool = False, max_r: int = None 68 | ) -> Union[Tensor, np.ndarray]: 69 | if shape[0] != shape[1]: 70 | out = get_spectral_indices((shape[0], shape[0])) 71 | out = out[:, shape[0]//2:] 72 | else: 73 | out = get_spectral_indices(shape) 74 | if max_r is not None: 75 | out[out > max_r] = max_r 76 | if not numpy: 77 | out = torch.Tensor(out) 78 | return out 79 | 80 | @staticmethod 81 | def get_spectral_indices( 82 | shape: Union[Tuple[int, int], Tuple[int, int, int]], 83 | numpy: bool = False, 84 | device: Any = 'cpu', 85 | max_r: int = None 86 | ) -> Union[Tensor, np.ndarray]: 87 | tag = str(shape) + "_" + str(max_r) 88 | tag += "_np" if numpy else "_" + str(device) 89 | if tag not in Cache.spectral_indices: 90 | Cache.spectral_indices[tag] = Cache._get_spectral_indices(shape, numpy, max_r) 91 | if not numpy: 92 | Cache.spectral_indices[tag] = Cache.spectral_indices[tag].to(device) 93 | return Cache.spectral_indices[tag] 94 | 95 | @staticmethod 96 | def _get_encoder_input_mask(image_size: int, max_r: int = None) -> Tensor: 97 | spectral_indices = Cache._get_spectral_indices((image_size, image_size)) 98 | if max_r is None: 99 | return spectral_indices < image_size // 2 + 1 100 | else: 101 | return spectral_indices < max_r 102 | 103 | @staticmethod 104 | def get_encoder_input_mask(image_size: int, max_r: int = None, device: Any = 'cpu') -> Tensor: 105 | tag = str(image_size) + "_" + str(max_r) + "_" + str(device) 106 | if tag not in Cache.encoder_input_masks: 107 | Cache.encoder_input_masks[tag] = Cache._get_encoder_input_mask(image_size, max_r).to(device) 108 | return Cache.encoder_input_masks[tag] 109 | 110 | @staticmethod 111 | def apply_encoder_input_mask(input: Tensor, max_r: int = None, device: Any = 'cpu') -> Tensor: 112 | return input[None, Cache.get_encoder_input_mask(input.shape[-1], max_r, device)] 113 | -------------------------------------------------------------------------------- /voxelium/vae_volume/region_of_interest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Container for dataset analysis results 5 | """ 6 | 7 | import os 8 | import shutil 9 | import sys 10 | import numpy as np 11 | from typing import Dict, List, TypeVar, Tuple, Any 12 | 13 | from voxelium.vae_volume.cache import Cache 14 | from voxelium.vae_volume.svr_linear import make_grid3d 15 | 16 | Tensor = TypeVar('torch.tensor') 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | from voxelium.base import euler_to_matrix, matrix_to_quaternion, ContrastTransferFunction, quaternion_to_matrix, \ 22 | get_spectral_avg, spectrum_to_grid, load_mrc, dt_desymmetrize, idft 23 | 24 | 25 | class RegionOfInterest: 26 | def __init__(self, dac, mask_fn, resolution, latent_fraction, device): 27 | self.vae_container = dac.vae_container 28 | self.latent_size = self.vae_container.latent_size 29 | self.pp_latent_size = self.vae_container.decoder.get_postprocess_input_size() 30 | self.reconst_latent_size = self.latent_size - self.pp_latent_size 31 | 32 | image_size = dac.auxiliaries["image_size"] 33 | pixel_size = dac.auxiliaries["pixel_size"] 34 | self.image_max_r = dac.auxiliaries["image_max_r"] 35 | 36 | roi_res_idx = round(image_size * pixel_size / resolution) # Convert resolution to spectral index 37 | self.roi_roni_switch = True # If true, apply ROI, if false apply RONI 38 | 39 | roi_latent_size = round(self.reconst_latent_size * np.clip(latent_fraction, 0., 1.)) 40 | roi_latent_index = np.arange(self.reconst_latent_size) 41 | self.roi_latent_index, self.roni_latent_index = \ 42 | roi_latent_index[:roi_latent_size], roi_latent_index[roi_latent_size:] 43 | 44 | print(f"{roi_latent_size} latent dimensions assigned to ROI") 45 | print("Using ROI resolution cutoff at Fourier shell index", roi_res_idx) 46 | roi_mask_, roi_voxel_size, _ = load_mrc(mask_fn) 47 | if np.abs(pixel_size - roi_voxel_size) > 1e-1: 48 | print(f"WARNING: ROI mask voxel size ({round(roi_voxel_size, 2)}) " 49 | f"is not equal to data ({round(pixel_size, 2)}).", file=sys.stderr) 50 | 51 | if np.min(roi_mask_) < 0. or 1. < np.max(roi_mask_): 52 | print(f"WARNING: ROI mask has values outside range [0,1].", file=sys.stderr) 53 | 54 | if np.all(roi_mask_.shape == image_size): 55 | raise RuntimeError(f"ROI mask size ({roi_mask_.shape}) not equal to data size ({image_size})") 56 | 57 | grid3d_radius = torch.round(torch.sqrt(torch.sum(torch.square(make_grid3d(size=image_size).to(device)), -1))) 58 | self.roi_res_mask = grid3d_radius < roi_res_idx 59 | roi_mask_ = torch.Tensor(roi_mask_.copy()).to(device) 60 | redund = -4. # Mask redundancy coefficient 61 | if redund != 0: 62 | c = np.exp(redund) / (np.exp(redund) - 1) 63 | self.roi_mask = c * (1 - torch.exp(-redund * roi_mask_)) 64 | self.roni_mask = c * (1 - torch.exp(-redund * (1-roi_mask_))) 65 | else: 66 | self.roi_mask = roi_mask_ 67 | self.roni_mask = 1 - roi_mask_ 68 | 69 | self.roi_slice = None 70 | self.roni_slice = None 71 | 72 | dac.auxiliaries["roi_latent_index"] = self.roi_latent_index 73 | dac.auxiliaries["roni_latent_index"] = self.roni_latent_index 74 | 75 | def get_loss(self, z_selected_): 76 | z_selected = torch.zeros_like(z_selected_) 77 | 78 | if self.roi_roni_switch: 79 | z_selected[self.roi_latent_index] = \ 80 | z_selected_[self.roi_latent_index] # Dedicate to inside mask 81 | else: 82 | z_selected[self.roni_latent_index] = \ 83 | z_selected_[self.roni_latent_index] # Dedicate to outside mask 84 | 85 | z_selected[-self.pp_latent_size:] = z_selected_[-self.pp_latent_size:].detach() # Include postprocess 86 | z_selected = z_selected.unsqueeze(0) 87 | v_ft = self.vae_container.decoder(z_selected, self.image_max_r) 88 | v_ft = torch.view_as_complex(v_ft) 89 | v_ft[0, self.roi_res_mask] = v_ft[0, self.roi_res_mask].detach() 90 | v_ft = dt_desymmetrize(v_ft)[0] 91 | vol = idft(v_ft, dim=3, real_in=True) 92 | mask = self.roni_mask if self.roi_roni_switch else self.roi_mask 93 | masked_mean = torch.sum(vol * mask) / torch.sum(mask) 94 | roi_loss = np.prod(vol.shape[-2:]) * torch.mean(mask * torch.square(vol - masked_mean)) 95 | 96 | if self.roi_roni_switch: 97 | self.roi_slice = vol[vol.shape[0] // 2].detach().cpu().numpy() 98 | else: 99 | self.roni_slice = vol[vol.shape[0] // 2].detach().cpu().numpy() 100 | 101 | self.roi_roni_switch = not self.roi_roni_switch # Alternate between inside and outside mask 102 | 103 | return roi_loss -------------------------------------------------------------------------------- /voxelium/base/image_transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Module for pytorch transformations of particle images 5 | """ 6 | 7 | from typing import Tuple, TypeVar, Dict, Any, Union 8 | 9 | import numpy as np 10 | 11 | import torch 12 | 13 | from voxelium.base.grid import dft, idft, dht, idht, spectrum_to_grid, get_spectral_indices 14 | 15 | Tensor = TypeVar('torch.tensor') 16 | 17 | 18 | class SquareMask(object): 19 | """Applies a square mask to the image 20 | """ 21 | 22 | def __init__(self, image_size, square_side, thickness): 23 | square_side_2 = square_side / 2. 24 | y, x = np.meshgrid( 25 | np.linspace(-image_size // 2, image_size // 2 - 1, image_size), 26 | np.linspace(-image_size // 2, image_size // 2 - 1, image_size) 27 | ) 28 | p = np.max([np.abs(x), np.abs(y)], axis=0) 29 | band_mask = (square_side_2 <= p) & (p <= square_side_2 + thickness) 30 | p_band_mask = p[band_mask] 31 | self.mask = np.zeros((image_size, image_size)) 32 | self.mask[p < square_side_2] = 1 33 | self.mask[band_mask] = np.cos(np.pi * (p_band_mask - square_side_2) / thickness) / 2 + .5 34 | self.mask[square_side_2 + thickness < p] = 0 35 | self.mask = torch.Tensor(self.mask) 36 | 37 | def __call__(self, image): 38 | self.mask = self.mask.to(image.device) 39 | return image * self.mask 40 | 41 | def __repr__(self): 42 | return self.__class__.__name__ + '()' 43 | 44 | 45 | class CircularMask(object): 46 | """Applies a circular smooth mask to the image 47 | """ 48 | 49 | def __init__(self, image_size, radius, thickness): 50 | y, x = np.meshgrid( 51 | np.linspace(-image_size // 2, image_size // 2 - 1, image_size), 52 | np.linspace(-image_size // 2, image_size // 2 - 1, image_size) 53 | ) 54 | r = np.sqrt(x ** 2 + y ** 2) 55 | band_mask = (radius <= r) & (r <= radius + thickness) 56 | r_band_mask = r[band_mask] 57 | self.mask = np.zeros((image_size, image_size)) 58 | self.mask[r < radius] = 1 59 | self.mask[band_mask] = np.cos(np.pi * (r_band_mask - radius) / thickness) / 2 + .5 60 | self.mask[radius + thickness < r] = 0 61 | self.mask = torch.Tensor(self.mask) 62 | 63 | def __call__(self, image): 64 | self.mask = self.mask.to(image.device) 65 | return image * self.mask 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + '()' 69 | 70 | 71 | class CenterDFT(object): 72 | """Applies a centered DFT to the image 73 | """ 74 | 75 | def __call__(self, image): 76 | return dft(image) 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '()' 80 | 81 | 82 | class CenterIDFT(object): 83 | """Applies a centered inverse DFT to the image 84 | """ 85 | 86 | def __call__(self, image): 87 | return idft(image) 88 | 89 | def __repr__(self): 90 | return self.__class__.__name__ + '()' 91 | 92 | 93 | class CenterDHT(object): 94 | """Applies a centered DHT to the image 95 | """ 96 | 97 | def __call__(self, image): 98 | return dht(image) 99 | 100 | def __repr__(self): 101 | return self.__class__.__name__ + '()' 102 | 103 | 104 | class CenterIDHT(object): 105 | """Applies a centered inverse DHT to the image 106 | """ 107 | 108 | def __call__(self, image): 109 | return idht(image) 110 | 111 | def __repr__(self): 112 | return self.__class__.__name__ + '()' 113 | 114 | 115 | def _spectral_standardization_stats(image_size, mean, std): 116 | device = mean.device 117 | cutoff_idx = image_size // 2 + 1 118 | spectral_indices = get_spectral_indices([image_size] * 2) 119 | 120 | mean = spectrum_to_grid(mean.cpu().numpy(), spectral_indices) 121 | std = spectrum_to_grid(std.cpu().numpy(), spectral_indices) 122 | mask = (spectral_indices < cutoff_idx) & (std > 1e-5) 123 | 124 | mean = torch.Tensor(mean).to(device) 125 | std = torch.Tensor(std).to(device) 126 | mask = torch.Tensor(mask).to(device) 127 | 128 | return mean, std, mask 129 | 130 | 131 | class SpectralStandardize(object): 132 | """Applies a spectral standardization 133 | """ 134 | 135 | def __init__(self, image_size, mean, std): 136 | self.mean, self.std, self.mask = _spectral_standardization_stats(image_size, mean, std) 137 | 138 | def __call__(self, image): 139 | return ((image - self.mean[None, ...]) / self.std[None, ...]) * self.mask[None, ...] 140 | 141 | def __repr__(self): 142 | return self.__class__.__name__ + '()' 143 | 144 | 145 | class SpectralDestandardize(object): 146 | """Inverses an applied spectral standardization 147 | """ 148 | 149 | def __init__(self, image_size, mean, std): 150 | self.mean, self.std, self.mask = _spectral_standardization_stats(image_size, mean, std) 151 | 152 | def __call__(self, image): 153 | return (image * self.std[None, ...] + self.mean[None, ...]) * self.mask[None, ...] 154 | 155 | def __repr__(self): 156 | return self.__class__.__name__ + '()' 157 | -------------------------------------------------------------------------------- /voxelium/vae_volume/vtk_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for a training VAE 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import vtk 10 | # noinspection PyUnresolvedReferences 11 | import vtkmodules.vtkInteractionStyle 12 | # noinspection PyUnresolvedReferences 13 | import vtkmodules.vtkRenderingOpenGL2 14 | from sklearn.manifold import TSNE 15 | from vtkmodules.vtkCommonColor import vtkNamedColors 16 | # noinspection PyUnresolvedReferences 17 | from vtkmodules.vtkCommonCore import vtkVersion 18 | from vtkmodules.vtkCommonDataModel import vtkImageData 19 | from vtkmodules.vtkFiltersCore import ( 20 | vtkFlyingEdges3D, 21 | vtkMarchingCubes 22 | ) 23 | from vtkmodules.vtkRenderingCore import ( 24 | vtkActor, 25 | vtkPolyDataMapper, 26 | vtkRenderWindow, 27 | vtkRenderWindowInteractor, 28 | vtkRenderer 29 | ) 30 | from vtkmodules.vtkFiltersSources import vtkCylinderSource 31 | 32 | # noinspection PyUnresolvedReferences 33 | from vtk.util import numpy_support 34 | 35 | DEFAULT_VOL_COLOR = (200/255., 200/255., 200/255.) 36 | 37 | def make_cylinder_actor(): 38 | colors = vtkNamedColors() 39 | # Set the background color. 40 | bkg = map(lambda x: x / 255.0, [26, 51, 102, 255]) 41 | colors.SetColor("BkgColor", *bkg) 42 | 43 | # This creates a polygonal cylinder model with eight circumferential 44 | # facets. 45 | cylinder = vtkCylinderSource() 46 | cylinder.SetResolution(8) 47 | 48 | # The mapper is responsible for pushing the geometry into the graphics 49 | # library. It may also do color mapping, if scalars or other 50 | # attributes are defined. 51 | cylinderMapper = vtkPolyDataMapper() 52 | cylinderMapper.SetInputConnection(cylinder.GetOutputPort()) 53 | 54 | # The actor is a grouping mechanism: besides the geometry (mapper), it 55 | # also has a property, transformation matrix, and/or texture map. 56 | # Here we set its color and rotate it -22.5 degrees. 57 | cylinderActor = vtkActor() 58 | cylinderActor.SetMapper(cylinderMapper) 59 | cylinderActor.GetProperty().SetColor(colors.GetColor3d("Tomato")) 60 | cylinderActor.RotateX(30.0) 61 | cylinderActor.RotateY(-45.0) 62 | return cylinderActor 63 | 64 | def numpy_volume_as_vtk_image_data(source_numpy_array): 65 | output_vtk_image = vtkImageData() 66 | output_vtk_image.SetDimensions(source_numpy_array.shape[1], source_numpy_array.shape[0], source_numpy_array.shape[2]) 67 | 68 | vtk_type_by_numpy_type = { 69 | np.uint8: vtk.VTK_UNSIGNED_CHAR, 70 | np.uint16: vtk.VTK_UNSIGNED_SHORT, 71 | np.uint32: vtk.VTK_UNSIGNED_INT, 72 | np.uint64: vtk.VTK_UNSIGNED_LONG if vtk.VTK_SIZEOF_LONG == 64 else vtk.VTK_UNSIGNED_LONG_LONG, 73 | np.int8: vtk.VTK_CHAR, 74 | np.int16: vtk.VTK_SHORT, 75 | np.int32: vtk.VTK_INT, 76 | np.int64: vtk.VTK_LONG if vtk.VTK_SIZEOF_LONG == 64 else vtk.VTK_LONG_LONG, 77 | np.float32: vtk.VTK_FLOAT, 78 | np.float64: vtk.VTK_DOUBLE 79 | } 80 | vtk_datatype = vtk_type_by_numpy_type[source_numpy_array.dtype.type] 81 | depth_array = numpy_support.numpy_to_vtk( 82 | source_numpy_array.ravel(), deep=True, array_type=vtk_datatype) 83 | depth_array.SetNumberOfComponents(1) 84 | output_vtk_image.GetPointData().SetScalars(depth_array) 85 | 86 | output_vtk_image.Modified() 87 | return output_vtk_image 88 | 89 | 90 | def vtk_version_ok(major, minor, build): 91 | """ 92 | Check the VTK version. 93 | 94 | :param major: Major version. 95 | :param minor: Minor version. 96 | :param build: Build version. 97 | :return: True if the requested VTK version is greater or equal to the actual VTK version. 98 | """ 99 | needed_version = 10000000000 * int(major) + 100000000 * int(minor) + int(build) 100 | ver = vtkVersion() 101 | vtk_version_number = 10000000000 * ver.GetVTKMajorVersion() + 100000000 * ver.GetVTKMinorVersion() \ 102 | + ver.GetVTKBuildVersion() 103 | if vtk_version_number >= needed_version: 104 | return True 105 | else: 106 | return False 107 | 108 | 109 | def get_vtk_volume_to_surface(): 110 | use_flying_edges = vtk_version_ok(8, 2, 0) 111 | if use_flying_edges: 112 | try: 113 | return vtkFlyingEdges3D() 114 | except AttributeError: 115 | return vtkMarchingCubes() 116 | else: 117 | return vtkMarchingCubes() 118 | 119 | 120 | def make_volume_actor(volume, iso_value, color=DEFAULT_VOL_COLOR): 121 | surface = get_vtk_volume_to_surface() 122 | surface.SetInputData(volume) 123 | surface.ComputeNormalsOn() 124 | surface.SetValue(0, iso_value) 125 | 126 | mapper = vtkPolyDataMapper() 127 | mapper.SetInputConnection(surface.GetOutputPort()) 128 | mapper.ScalarVisibilityOff() 129 | 130 | actor = vtkActor() 131 | actor.SetMapper(mapper) 132 | actor.GetProperty().SetColor(color) 133 | 134 | return actor 135 | 136 | 137 | def make_all_volume_actors(volumes, iso_value, color=DEFAULT_VOL_COLOR): 138 | actors = [] 139 | for vol in volumes: 140 | actors.append(make_volume_actor(vol, iso_value, color=color)) 141 | return actors 142 | 143 | 144 | def initialize_vtk_resourses(windowName=None, background_color=(1., 1., 1.)): 145 | render_window = vtkRenderWindow() 146 | if windowName is not None: 147 | render_window.SetWindowName(windowName) 148 | 149 | renderer = vtkRenderer() 150 | renderer.SetBackground(background_color) 151 | render_window.AddRenderer(renderer) 152 | 153 | interactor = vtkRenderWindowInteractor() 154 | interactor.SetRenderWindow(render_window) 155 | interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) 156 | 157 | return render_window, renderer, interactor 158 | 159 | def rgb_hex_to_dec(hex): 160 | r = int(f"0x{hex[0] + hex[1]}", 16) / 255. 161 | g = int(f"0x{hex[2] + hex[3]}", 16) / 255. 162 | b = int(f"0x{hex[4] + hex[5]}", 16) / 255. 163 | 164 | return r, g, b -------------------------------------------------------------------------------- /voxelium/vae_volume/optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, TypeVar 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | Tensor = TypeVar('torch.tensor') 9 | 10 | 11 | class ExtendedAdam(Optimizer): 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | if not 0.0 < lr: 14 | raise ValueError("Invalid learning rate: {}".format(lr)) 15 | if not 0.0 < eps: 16 | raise ValueError("Invalid epsilon value: {}".format(eps)) 17 | if not 0.0 <= betas[0] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 19 | if not 0.0 <= betas[1] < 1.0: 20 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 21 | 22 | params = list(params) 23 | 24 | sparse_params = [] 25 | for index, param in enumerate(params): 26 | if isinstance(param, dict): 27 | for d_index, d_param in enumerate(param.get("params", [])): 28 | if d_param.is_sparse: 29 | sparse_params.append([index, d_index]) 30 | elif param.is_sparse: 31 | sparse_params.append(index) 32 | if sparse_params: 33 | raise ValueError( 34 | f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors" 35 | ) 36 | 37 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 38 | super(ExtendedAdam, self).__init__(params, defaults) 39 | 40 | @torch.no_grad() 41 | def step(self, closure=None, regularize=0): 42 | loss = None 43 | if closure is not None: 44 | with torch.enable_grad(): 45 | loss = closure() 46 | 47 | for group in self.param_groups: 48 | beta1, beta2 = group['betas'] 49 | 50 | for p in group['params']: 51 | if p.grad is None: 52 | continue 53 | grad = p.grad 54 | 55 | state = self.state[p] 56 | 57 | # State initialization 58 | if len(state) == 0: 59 | state['step'] = 0 60 | # Exponential moving average of gradient values 61 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 62 | # Exponential moving average of squared gradient values 63 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 64 | 65 | state['step'] += 1 66 | exp_avg = state['exp_avg'] 67 | exp_avg_sq = state['exp_avg_sq'] 68 | step = state['step'] 69 | 70 | bias_correction1 = 1 - beta1 ** step 71 | bias_correction2 = 1 - beta2 ** step 72 | 73 | if grad.is_sparse: 74 | if regularize > 0: 75 | raise NotImplementedError("no_moment not supported for sparse gradient.") 76 | 77 | grad = grad.coalesce() # the update is non-linear so indices must be unique 78 | grad_indices = grad._indices() 79 | grad_values = grad._values() 80 | size = grad.size() 81 | 82 | def make_sparse(values): 83 | constructor = grad.new 84 | if grad_indices.dim() == 0 or values.dim() == 0: 85 | return constructor().resize_as_(grad) 86 | return constructor(grad_indices, values, size) 87 | 88 | # Decay the first and second moment running average coefficient 89 | # old <- b * old + (1 - b) * new 90 | # <==> old += (1 - b) * (new - old) 91 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 92 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 93 | exp_avg.add_(make_sparse(exp_avg_update_values)) 94 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 95 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 96 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 97 | 98 | # Dense addition again is intended, avoiding another sparse_mask 99 | numer = exp_avg_update_values.add_(old_exp_avg_values) 100 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 101 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 102 | del exp_avg_update_values, exp_avg_sq_update_values 103 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 104 | 105 | if group['weight_decay'] != 0: 106 | p.add_( 107 | -step_size * ( 108 | make_sparse(numer.div_(denom)) + 109 | group['weight_decay'] * p.sparse_mask(grad) 110 | ) 111 | ) 112 | else: 113 | p.add_(-step_size * make_sparse(numer.div_(denom))) 114 | else: 115 | if regularize > 0: 116 | p.add_(grad, alpha=-regularize) 117 | else: 118 | if group['weight_decay'] != 0: 119 | grad = grad.add(p, alpha=group['weight_decay']) 120 | 121 | # Decay the first and second moment running average coefficient 122 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 123 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 124 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 125 | 126 | step_size = group['lr'] / bias_correction1 127 | 128 | p.addcdiv_(exp_avg, denom, value=-step_size) 129 | 130 | return loss 131 | -------------------------------------------------------------------------------- /voxelium/vae_volume/train_arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_training_args(): 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | 7 | parser.add_argument('input', help='input job (job directory or optimizer-file)', type=str) 8 | parser.add_argument('log_dir', type=str, metavar='log_dir', help='path to load a model') 9 | parser.add_argument('--particle_diameter', help='size of circular mask (ang)', type=int, default=None) 10 | parser.add_argument('--circular_mask_thickness', help='thickness of mask (ang)', type=int, default=20) 11 | parser.add_argument('--batch_size', type=int, default=256) 12 | parser.add_argument('--overwrite', action='store_true') 13 | parser.add_argument('--gpu', dest='gpu', type=str, default=None, help='gpu to use') 14 | parser.add_argument('--checkpoint_time', help='Minimum time in minutes between checkpoint saves', type=int, 15 | default=10) 16 | parser.add_argument("--image_steps", type=int, default=500, help="Generate a tensorboard image every n steps") 17 | parser.add_argument('--max_steps', dest='max_steps', type=int, default=int(1e9), help='number of steps to train') 18 | parser.add_argument('--max_epochs', dest='max_epochs', type=int, default=int(1e9), help='number of epochs to train') 19 | parser.add_argument('--pytorch_threads', type=int, default=6) 20 | parser.add_argument('--preload', action='store_true') 21 | parser.add_argument('--dataloader_threads', type=int, default=4) 22 | parser.add_argument( 23 | '--do_align', 24 | help='Do optimize pose and translation', 25 | action="store_true" 26 | ) 27 | parser.add_argument( 28 | '--do_ctf_optimization', 29 | help='Do optimize CTF defocuse and angle', 30 | action="store_true" 31 | ) 32 | parser.add_argument( 33 | '--roi_mask', 34 | help='An MRC-file containing values between 0 and 1. Ones (1) for region of interest (ROI).', 35 | type=str, default=None 36 | ) 37 | parser.add_argument( 38 | '--roi_latent_fraction', 39 | help='Fraction of structural latent dimensions assigned to ROI.', 40 | type=float, default=.8 41 | ) 42 | parser.add_argument( 43 | '--roi_res', 44 | help='Lowest resolution at which ROI is considered.', 45 | type=float, default=80. 46 | ) 47 | parser.add_argument( 48 | '--solvent_mask', 49 | help='MRC file with ones in the region that is not solvent (region of interest)', 50 | type=str, default=None 51 | ) 52 | parser.add_argument( 53 | '--solvent_mask_res', 54 | help='Lowest resolution at which solvent mask is considered.', 55 | type=float, default=None 56 | ) 57 | parser.add_argument( 58 | '--spectral_weight_ll_res', 59 | help='Spectral weighting factor resolution for the log-likelihood', 60 | type=float, default=None 61 | ) 62 | parser.add_argument( 63 | '--spectral_weight_grad_res', 64 | help='Spectral weighting factor resolution for the gradient', 65 | type=float, default=None 66 | ) 67 | parser.add_argument('--profile_runtime', action='store_true') 68 | parser.add_argument( 69 | '--latent_regularization', 70 | help='Latent space global KL divergence regularization parameter', 71 | type=float, default=1 72 | ) 73 | parser.add_argument( 74 | '--encoder_mask_resolution', 75 | help='The highest resolution frequency component shown to the encoder. Set to zero to use all image.', 76 | type=float, default=0. 77 | ) 78 | parser.add_argument( 79 | '--encoder_embedding_size', 80 | help='Encoder image group index embedding dimensionality. Set to zero to disable.', 81 | type=int, default=8 82 | ) 83 | 84 | parser.add_argument( 85 | '--sb_latent_size', 86 | help='Structure basis network input size.', 87 | type=int, default=16 88 | ) 89 | parser.add_argument( 90 | '--pp_latent_size', 91 | help='Postprocessing network input size.', 92 | type=int, default=2 93 | ) 94 | parser.add_argument( 95 | '--sb_basis_size', 96 | help='Structure basis network output size.', 97 | type=int, default=16 98 | ) 99 | parser.add_argument( 100 | '--pp_basis_size', 101 | help='Postprocessing network input size.', 102 | type=int, default=2 103 | ) 104 | 105 | parser.add_argument( 106 | '--basis_decoder_depth', 107 | help='Basis decoder network depth (nr layers).', 108 | type=int, default=0 109 | ) 110 | parser.add_argument( 111 | '--basis_decoder_width', 112 | help='Basis decoder network width.', 113 | type=int, default=2 114 | ) 115 | parser.add_argument( 116 | '--encoder_depth', 117 | help='Encoder network depth.', 118 | type=int, default=5 119 | ) 120 | parser.add_argument( 121 | '--structure_decoder_lr', 122 | help='Learning rate of the structure decoder', 123 | type=float, default=0.0001 124 | ) 125 | parser.add_argument( 126 | "--gradient-penalty", 127 | "--gradient_penalty", 128 | help="Use gradient penalty for the basis decoder", 129 | action="store_true" 130 | ) 131 | parser.add_argument('--do_sigma_weighting', action='store_true') 132 | parser.add_argument( 133 | "--dtype", 134 | type=str, 135 | default="float32", 136 | help="Data type used for storing images in data set" 137 | ) 138 | parser.add_argument('--dont_postprocess', action='store_true') 139 | parser.add_argument('--only_update_representation', action='store_true') 140 | parser.add_argument( 141 | "--random-subset", 142 | "--random_subset", 143 | type=int, 144 | default=None, 145 | help="Which Relion random subset to use. Defaults to all." 146 | ) 147 | parser.add_argument( 148 | "--cache", 149 | type=str, 150 | default=None, 151 | help="Cache directory" 152 | ) 153 | 154 | return parser.parse_args() 155 | -------------------------------------------------------------------------------- /voxelium/base/torch_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Simple Pytorch tools 5 | """ 6 | 7 | import argparse 8 | import glob 9 | import importlib.util 10 | import os 11 | import pickle 12 | import sys 13 | import time 14 | 15 | import matplotlib 16 | import numpy as np 17 | import torch 18 | import matplotlib.pyplot as plt 19 | from matplotlib.lines import Line2D 20 | 21 | 22 | def standardize(np_input): 23 | mean = np.mean(np_input, axis=(1, 2, 3, 4)) 24 | mean = np.resize(mean, (np_input.shape[0], 1, 1, 1, 1)) 25 | std = np.std(np_input, axis=(1, 2, 3, 4)) + 1e-12 26 | std = np.resize(std, (np_input.shape[0], 1, 1, 1, 1)) 27 | return mean, std 28 | 29 | 30 | def torch_standardize(torch_input): 31 | mean = torch.mean(torch_input, dim=(1, 2, 3, 4)) 32 | mean = torch.reshape(mean, (torch_input.shape[0], 1, 1, 1, 1)) 33 | std = torch.std(torch_input, dim=(1, 2, 3, 4)) + 1e-12 34 | std = torch.reshape(std, (torch_input.shape[0], 1, 1, 1, 1)) 35 | return mean, std 36 | 37 | 38 | def normalize(np_input): 39 | norm = np.sqrt(np.sum(np.square(np_input), axis=(1, 2, 3, 4))) + 1e-12 40 | norm = np.resize(norm, (np_input.shape[0], 1, 1, 1, 1)) 41 | return norm 42 | 43 | 44 | def torch_normalize(torch_input): 45 | norm = torch.sqrt(torch.sum((torch_input) ** 2, dim=(1, 2, 3, 4))) + 1e-12 46 | norm = torch.reshape(norm, (torch_input.shape[0], 1, 1, 1, 1)) 47 | return norm 48 | 49 | 50 | def make_imshow_fig(data): 51 | if len(data.shape) == 3: 52 | data = data[data.shape[0] // 2] 53 | 54 | if type(data).__module__ == 'torch': 55 | data = data.detach().data.cpu().numpy() 56 | 57 | backend = matplotlib.rcParams['backend'] 58 | matplotlib.use('pdf') # To avoid issues with disconnected X-server over ssh 59 | 60 | fig, ax = plt.subplots(figsize=(7, 7)) 61 | ax.imshow(data) 62 | plt.axis("off") 63 | ax.set_axis_off() 64 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, 65 | hspace=0, wspace=0) 66 | plt.margins(0, 0) 67 | 68 | matplotlib.use(backend) 69 | 70 | return fig 71 | 72 | 73 | def make_scatter_fig(x, y): 74 | if type(x).__module__ == 'torch': 75 | x = x.detach().data.cpu().numpy() 76 | if type(y).__module__ == 'torch': 77 | y = y.detach().data.cpu().numpy() 78 | 79 | backend = matplotlib.rcParams['backend'] 80 | matplotlib.use('pdf') # To avoid issues with disconnected X-server over ssh 81 | 82 | fig, ax = plt.subplots(figsize=(7, 7)) 83 | alpha = min(10./np.sqrt(len(x)), 1.) 84 | ax.scatter(x, y, edgecolors=None, marker='.', c=np.arange(len(x)), cmap="summer", alpha=alpha) 85 | 86 | mx = np.mean(x) 87 | sx = np.std(x)*3 88 | my = np.mean(y) 89 | sy = np.std(y)*3 90 | 91 | ax.set_xlim([mx-sx, mx+sx]) 92 | ax.set_ylim([my-sy, my+sy]) 93 | 94 | plt.subplots_adjust(top=0.99, bottom=0.05, right=0.99, left=0.1, 95 | hspace=0, wspace=0) 96 | plt.margins(0, 0) 97 | # plt.show() 98 | 99 | matplotlib.use(backend) 100 | return fig 101 | 102 | 103 | def make_line_fig(x, y, y_log=False): 104 | if type(x).__module__ == 'torch': 105 | x = x.detach().data.cpu().numpy() 106 | if type(y).__module__ == 'torch': 107 | y = y.detach().data.cpu().numpy() 108 | 109 | backend = matplotlib.rcParams['backend'] 110 | matplotlib.use('pdf') # To avoid issues with disconnected X-server over ssh 111 | 112 | fig, ax = plt.subplots(figsize=(7, 5)) 113 | ax.plot(x, y) 114 | 115 | if y_log: 116 | ax.set_yscale('log') 117 | 118 | plt.subplots_adjust(top=0.99, bottom=0.05, right=0.99, left=0.1, 119 | hspace=0, wspace=0) 120 | plt.margins(0, 0) 121 | # plt.show() 122 | 123 | matplotlib.use(backend) 124 | return fig 125 | 126 | 127 | def make_series_line_fig(data, y_log=False): 128 | backend = matplotlib.rcParams['backend'] 129 | matplotlib.use('pdf') # To avoid issues with disconnected X-server over ssh 130 | 131 | fig, ax = plt.subplots(figsize=(7, 5)) 132 | for d in data: 133 | x = d['x'] 134 | y = d['y'] 135 | if type(x).__module__ == 'torch': 136 | x = x.detach().data.cpu().numpy() 137 | if type(y).__module__ == 'torch': 138 | y = y.detach().data.cpu().numpy() 139 | ax.plot(x, y, 140 | label=d['label'], 141 | color=d['color'] if 'color' in d else None, 142 | linestyle=d['linestyle'] if 'linestyle' in d else None 143 | ) 144 | 145 | ax.legend() 146 | ax.grid() 147 | 148 | if y_log: 149 | ax.set_yscale('log') 150 | 151 | plt.subplots_adjust(top=0.99, bottom=0.05, right=0.99, left=0.1, 152 | hspace=0, wspace=0) 153 | plt.margins(0, 0) 154 | # plt.show() 155 | 156 | matplotlib.use(backend) 157 | return fig 158 | 159 | 160 | def optimizer_to(optim, device): 161 | for param in optim.state.values(): 162 | # Not sure there are any global tensors in the state dict 163 | if isinstance(param, torch.Tensor): 164 | param.data = param.data.to(device) 165 | if param._grad is not None: 166 | param._grad.data = param._grad.data.to(device) 167 | elif isinstance(param, dict): 168 | for subparam in param.values(): 169 | if isinstance(subparam, torch.Tensor): 170 | subparam.data = subparam.data.to(device) 171 | if subparam._grad is not None: 172 | subparam._grad.data = subparam._grad.data.to(device) 173 | 174 | 175 | def optimizer_set_learning_rate(optimizer, learning_rate): 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = learning_rate 178 | 179 | 180 | def plot_grad_flow(named_parameters): 181 | """ 182 | Plots the gradients flowing through different layers in the net during training. 183 | Can be used for checking for possible gradient vanishing / exploding problems. 184 | 185 | Usage: Plug this function in Trainer class after loss.backwards() as 186 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow 187 | """ 188 | ave_grads = [] 189 | max_grads = [] 190 | layers = [] 191 | for n, p in named_parameters: 192 | if (p.requires_grad) and ("bias" not in n): 193 | layers.append(n) 194 | ave_grads.append(p.grad.abs().mean().cpu()) 195 | max_grads.append(p.grad.abs().max().cpu()) 196 | # plt.barh(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 197 | plt.barh(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 198 | plt.yticks(range(0, len(ave_grads), 1), layers) 199 | plt.xlabel("average gradient") 200 | plt.ylabel("Layers") 201 | plt.title("Gradient flow") 202 | plt.grid(True) 203 | plt.legend([Line2D([0], [0], color="c", lw=4), 204 | Line2D([0], [0], color="b", lw=4)], 205 | ['max-gradient', 'mean-gradient']) 206 | plt.show() -------------------------------------------------------------------------------- /voxelium/vae_volume/volume_explorer_neurips.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import time 4 | 5 | import umap 6 | import numpy as np 7 | import argparse 8 | 9 | import scipy.ndimage 10 | import matplotlib.pylab as plt 11 | 12 | import torch 13 | 14 | import multiprocessing as mp 15 | 16 | from voxelium.base.grid import idft, dt_desymmetrize, load_mrc, save_mrc, get_fsc_torch, spectral_correlation_torch, \ 17 | get_spectral_indices, dft 18 | from voxelium.vae_volume.data_analysis_container import DatasetAnalysisContainer 19 | from voxelium.vae_volume.utils import setup_device 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('logdir', help='input checkpoint file', type=str) 24 | parser.add_argument('ordered_ground_truth_files', 25 | help='comma separated paths to ground truth files, in order', type=str) 26 | parser.add_argument('--gpu', type=str, default=None, help='gpu to use') 27 | parser.add_argument('--dont_cache_embed', action="store_true") 28 | parser.add_argument('--ignore_cached_embed', action="store_true") 29 | parser.add_argument('--mask', help='input mask file', type=str, default=None) 30 | args = parser.parse_args() 31 | 32 | torch.no_grad() 33 | device, _ = setup_device(args) 34 | for epoch in range(1, 50): 35 | dac = DatasetAnalysisContainer.load_from_logdir(args.logdir + f"chkpt_{epoch}.pt") 36 | latent = dac.hidden_variable_container.latent_space.numpy().astype(np.float32) 37 | # embed = dac.hidden_variable_container.structure_basis.softmax(dim=-1).numpy().astype(np.float32) 38 | embed = dac.hidden_variable_container.latent_space.numpy().astype(np.float32) 39 | 40 | if latent.shape[-1] > 2: 41 | if dac.auxiliaries is None or "embed" not in dac.auxiliaries or args.ignore_cached_embed: 42 | print("Creating 2D embedding...") 43 | 44 | basis = embed 45 | # basis = embed[:, :16] 46 | 47 | embed = umap.UMAP(local_connectivity=1, repulsion_strength=2).fit_transform( 48 | basis.astype(np.float32) 49 | ) 50 | 51 | # import tsnecuda 52 | # embed = tsnecuda.TSNE(n_components=2, num_neighbors=256, perplexity=100).fit_transform( 53 | # basis.astype(np.float32) 54 | # ) 55 | 56 | # from sklearn.decomposition import PCA 57 | # embed = PCA(n_components=2).fit_transform( 58 | # basis.astype(np.float32) 59 | # ) 60 | 61 | if dac.auxiliaries is None: 62 | dac.auxiliaries = {"embed": embed} 63 | else: 64 | dac.auxiliaries["embed"] = embed 65 | if not args.dont_cache_embed: 66 | print("Saving embedding to checkpoint file...") 67 | dac.save_to_checkpoint(args.logdir) 68 | else: 69 | embed = dac.auxiliaries["embed"] 70 | 71 | # Else, import what the GUI needs 72 | 73 | import vtkmodules.vtkInteractionStyle 74 | # noinspection PyUnresolvedReferences 75 | import vtkmodules.vtkRenderingOpenGL2 76 | 77 | N = 3000 78 | margin = 50 79 | marker_size = 10 80 | 81 | outlier_mask = np.abs(embed) > 5 82 | if np.sum(outlier_mask) < len(embed) * 0.3: 83 | embed[outlier_mask] = np.sign(embed[outlier_mask]) * 5 84 | 85 | x_min = np.min(embed[:, 0]) 86 | x_max = np.max(embed[:, 0]) 87 | y_min = np.min(embed[:, 1]) 88 | y_max = np.max(embed[:, 1]) 89 | 90 | x = (embed[:, 0] - x_min) / (x_max - x_min) 91 | y = (embed[:, 1] - y_min) / (y_max - y_min) 92 | 93 | x = margin + x * (N - 2. * margin) 94 | y = margin + y * (N - 2. * margin) 95 | 96 | heat_map = np.zeros((N, N)) 97 | heat_map[y.astype(int), x.astype(int)] += 1 98 | heat_map_smooth = scipy.ndimage.gaussian_filter(heat_map, 3) 99 | 100 | coord = np.zeros((len(x), 2)) 101 | coord[:, 0] = x 102 | coord[:, 1] = y 103 | 104 | vaec = dac.vae_container 105 | vaec.set_device(device) 106 | vaec.set_eval() 107 | 108 | data_spectra, data_ctf_spectra = dac.hidden_variable_container.get_data_stats(0) 109 | data_ctf_spectra = data_ctf_spectra.to(device) 110 | 111 | nn_time = 0 112 | ft_time = 0 113 | 114 | 115 | def get_volume(z): 116 | global nn_time, ft_time 117 | z = torch.Tensor(z).unsqueeze(0).to(device) 118 | 119 | t = time.time() 120 | sb, pp = vaec.basis_decoder(*vaec.split_latent_space(z)) 121 | v_ft = vaec.structure_decoder(sb, pp, data_spectra=data_ctf_spectra, do_postprocess=True) 122 | nn_time = nn_time * 0.9 + (time.time() - t) * 0.1 123 | 124 | v_ht = torch.view_as_complex(v_ft) 125 | v_ht = dt_desymmetrize(v_ht) 126 | 127 | t = time.time() 128 | vol = idft(v_ht, dim=3, real_in=True) 129 | ft_time = ft_time * 0.9 + (time.time() - t) * 0.1 130 | 131 | vol /= torch.std(vol) 132 | 133 | print("NN time", nn_time, "FT time", ft_time) 134 | 135 | return vol[0] 136 | 137 | mask = None 138 | if args.mask is not None: 139 | mask, _, _ = load_mrc(args.mask) 140 | mask = torch.from_numpy(mask.copy()) 141 | 142 | 143 | ground_truth_path = args.ordered_ground_truth_files.split(",") 144 | ground_truth_count = len(ground_truth_path) 145 | ground_truth_grids = [] 146 | print(f"Number of ground truth files {len(ground_truth_path)}") 147 | for i in range(ground_truth_count): 148 | grid, _, _ = load_mrc(ground_truth_path[i].strip()) 149 | ground_truth_grids.append(torch.from_numpy(grid.copy())) 150 | 151 | print("Calculate FSCs") 152 | count_per_ground_truth = 3 153 | selected_idx = (np.linspace(0, 0.999, ground_truth_count * count_per_ground_truth) * len(latent)).astype(int) 154 | fscs = [] 155 | sidx = torch.Tensor(get_spectral_indices(ground_truth_grids[0].shape)) 156 | for i in range(len(selected_idx)): 157 | j = i // count_per_ground_truth 158 | idx = selected_idx[i] 159 | print(i, j, idx) 160 | vol = get_volume(latent[idx]) 161 | if mask is not None: 162 | vol *= mask 163 | gt = dft(ground_truth_grids[j], dim=3) 164 | save_mrc(vol, 1, [0, 0, 0], f"first_dumps/idx_{i}.mrc") 165 | vol = dft(vol, dim=3) 166 | fsc = spectral_correlation_torch(vol, gt, sidx, normalize=True) 167 | fscs.append(fsc) 168 | 169 | fscs = torch.stack(list(fscs), 0) 170 | means = torch.mean(fscs, 0).detach().cpu().numpy().round(2) 171 | stds = torch.std(fscs, 0).detach().cpu().numpy().round(3) 172 | 173 | np.save(f"sb_fsc_{epoch}", fscs.numpy()) 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /voxelium/vae_volume/volume_renderer.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import time 3 | 4 | import numpy as np 5 | 6 | from voxelium.vae_volume.vtk_utils import initialize_vtk_resourses, make_volume_actor, numpy_volume_as_vtk_image_data, rgb_hex_to_dec 7 | 8 | 9 | class VolumeRenderer: 10 | def __init__(self, queue, windowName=None, timer=100): 11 | self.queue = queue 12 | self.timer = timer 13 | self.nr_vols = 0 14 | self.volumes = None 15 | self.rock_ascend = True 16 | self.iso_min = None 17 | self.iso_max = None 18 | self.iso_steps = None 19 | self.iso_value = None 20 | self.actors = None 21 | self.current_actor = None 22 | self.current_actor_idx = 0 23 | 24 | self.render_window, self.renderer, self.interactor = initialize_vtk_resourses( 25 | windowName=windowName 26 | ) 27 | 28 | def setVolumes(self, volumes): 29 | if volumes is None: 30 | return 31 | self.nr_vols = len(volumes) 32 | if self.nr_vols == 0: 33 | self.volumes = None 34 | return 35 | 36 | self.volumes = [] 37 | for v in volumes: 38 | self.volumes.append(numpy_volume_as_vtk_image_data(v)) 39 | 40 | self.rock_ascend = True 41 | self.iso_min = float(np.min(volumes[0])) 42 | self.iso_max = float(np.max(volumes[0])) 43 | self.iso_steps = float(np.std(volumes[0]) / 2.) 44 | if self.iso_value is None: 45 | self.iso_value = float(np.mean(volumes[0]) * 4.) * 4. 46 | 47 | def updateActors(self): 48 | self.removeCurrentActor() 49 | if self.nr_vols == 0: 50 | self.actors = None 51 | else: 52 | self.actors = [] 53 | for vol in self.volumes: 54 | self.actors.append(make_volume_actor(vol, self.iso_value, color=rgb_hex_to_dec("c596fb"))) 55 | if self.current_actor_idx is not None: 56 | self.setCurrentActor(min(self.current_actor_idx, self.nr_vols-1)) 57 | else: 58 | self.setCurrentActor(0) 59 | 60 | def removeCurrentActor(self): 61 | if self.current_actor is not None: 62 | self.renderer.RemoveActor(self.current_actor) 63 | self.current_actor = None 64 | 65 | def setCurrentActor(self, new_actor_idx=None): 66 | new_actor_idx = 0 if new_actor_idx is None else new_actor_idx 67 | self.current_actor_idx = new_actor_idx 68 | 69 | if self.actors is not None \ 70 | and self.current_actor is not None and \ 71 | self.actors[new_actor_idx] == self.current_actor: 72 | return 73 | 74 | self.removeCurrentActor() 75 | 76 | if self.actors is not None: 77 | self.current_actor = self.actors[new_actor_idx] 78 | self.renderer.AddActor(self.current_actor) 79 | 80 | self.render_window.Render() 81 | 82 | def updateCurrentActorIndex(self): 83 | if self.nr_vols == 1: 84 | self.current_actor_idx = 0 85 | self.rock_ascend = True 86 | return 87 | 88 | self.current_actor_idx += 1 if self.rock_ascend else -1 89 | 90 | if self.current_actor_idx < 0: 91 | self.rock_ascend = True 92 | self.current_actor_idx = 1 93 | 94 | if self.current_actor_idx >= self.nr_vols: 95 | self.rock_ascend = False 96 | self.current_actor_idx = self.nr_vols - 2 97 | 98 | def KeyPressEvent(self, obj, _): 99 | if self.nr_vols == 0: 100 | return 101 | key = obj.GetKeySym() 102 | if key == "Up" or key == "Down": 103 | if key == "Up": 104 | self.iso_value = max(self.iso_min, self.iso_value - self.iso_steps) 105 | elif key == "Down": 106 | self.iso_value = min(self.iso_max, self.iso_value + self.iso_steps) 107 | 108 | self.updateCurrentActorIndex() 109 | self.updateActors() 110 | 111 | if key == "Return": 112 | images = [] 113 | import matplotlib 114 | import vtk 115 | from vtk.util.numpy_support import vtk_to_numpy 116 | import imageio 117 | for i in range(self.nr_vols): 118 | self.setCurrentActor(i) 119 | vtk_win_im = vtk.vtkWindowToImageFilter() 120 | vtk_win_im.SetInput(self.render_window) 121 | vtk_win_im.Update() 122 | 123 | vtk_image = vtk_win_im.GetOutput() 124 | 125 | width, height, _ = vtk_image.GetDimensions() 126 | vtk_array = vtk_image.GetPointData().GetScalars() 127 | components = vtk_array.GetNumberOfComponents() 128 | 129 | arr = vtk_to_numpy(vtk_array).reshape(height, width, components) 130 | 131 | fn = f'dump_{i}.png' 132 | images.append(fn) 133 | arr = np.flip(arr, 0) 134 | matplotlib.image.imsave(fn, arr) 135 | 136 | with imageio.get_writer(f'dump.gif', mode='I') as writer: 137 | for i in range(len(images)): 138 | image = imageio.imread(images[i]) 139 | writer.append_data(image) 140 | for i in range(len(images)-1): 141 | image = imageio.imread(images[len(images)-i-1]) 142 | writer.append_data(image) 143 | 144 | def TimerEvent(self, obj, _): 145 | if not self.queue.empty(): 146 | task = self.queue.get() 147 | if task is None: 148 | self.render_window.Finalize() 149 | self.interactor.TerminateApp() 150 | return 151 | self.setVolumes(task) 152 | self.updateActors() 153 | elif self.nr_vols == 0: 154 | self.removeCurrentActor() 155 | return 156 | 157 | if self.actors is None: 158 | return 159 | 160 | self.setCurrentActor(self.current_actor_idx) 161 | self.updateCurrentActorIndex() 162 | 163 | def start(self): 164 | # Wait for first volume 165 | while True: 166 | if not self.queue.empty(): 167 | self.setVolumes(self.queue.get()) 168 | self.updateActors() 169 | break 170 | else: 171 | time.sleep(0.1) 172 | 173 | self.interactor.Initialize() 174 | self.interactor.AddObserver('TimerEvent', self.TimerEvent) 175 | self.interactor.AddObserver("KeyPressEvent", self.KeyPressEvent) 176 | self.interactor.CreateRepeatingTimer(self.timer) 177 | 178 | self.render_window.Render() 179 | self.interactor.Start() 180 | self.render_window.Finalize() 181 | self.interactor.TerminateApp() 182 | 183 | @staticmethod 184 | def startNewProcess(queue): 185 | vr = VolumeRenderer(queue) 186 | vr.start() 187 | 188 | 189 | def volumeRendererProcessLoop(volume_queue, message_queue): 190 | while True: 191 | if not message_queue.empty(): 192 | return 193 | if not volume_queue.empty(): 194 | p = mp.Process( 195 | target=VolumeRenderer.startNewProcess, 196 | args=(volume_queue,) 197 | ) 198 | p.start() 199 | p.join() 200 | p.terminate() 201 | else: 202 | try: 203 | time.sleep(0.1) 204 | except KeyboardInterrupt: 205 | return 206 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction_cpu_kernels.cpp: -------------------------------------------------------------------------------- 1 | #include "vae_volume/svr_linear/volume_extraction_cpu_kernels.h" 2 | 3 | template 4 | void volume_extraction_forward_cpu_kernel( 5 | const torch::TensorAccessor grid3d_index, 6 | const torch::TensorAccessor weight, 7 | const torch::TensorAccessor bias, 8 | const torch::TensorAccessor input, 9 | torch::TensorAccessor output, 10 | const int offset, 11 | const int max_r2 12 | ) 13 | { 14 | for (long b = 0; b < output.size(0); b ++) 15 | for (long z = 0; z < output.size(1); z ++) 16 | for (long y = 0; y < output.size(2); y ++) 17 | for (long x = 0; x < output.size(3); x ++) 18 | { 19 | const long zp = z - (output.size(1) - 1) / 2; 20 | const long yp = y - (output.size(2) - 1) / 2; 21 | const long xp = x; 22 | if (xp*xp + yp*yp + zp*zp <= max_r2) 23 | { 24 | const long i = grid3d_index[z+offset][y+offset][x]; 25 | 26 | for (int c = 0; c < 2; c++) // Over real and imaginary 27 | { 28 | for (int j = 0; j < input.size(1); j ++) 29 | output[b][z][y][x][c] += weight[i][j][c] * input[b][j]; 30 | if (do_bias) 31 | output[b][z][y][x][c] += bias[i][c]; 32 | } 33 | } 34 | } 35 | } 36 | 37 | void volume_extraction_forward_cpu( 38 | const torch::Tensor grid3d_index, 39 | const torch::Tensor weight, 40 | const torch::Tensor bias, 41 | const torch::Tensor input, 42 | torch::Tensor output, 43 | const int max_r2, 44 | const bool do_bias 45 | ) 46 | { 47 | CHECK_CPU_INPUT(grid3d_index) 48 | CHECK_CPU_INPUT(weight) 49 | CHECK_CPU_INPUT(bias) 50 | CHECK_CPU_INPUT(input) 51 | CHECK_CPU_INPUT(output) 52 | 53 | const int offset = (grid3d_index.size(0) - output.size(1)) / 2; 54 | 55 | std::array bargs={{do_bias}}; 56 | dispatch_bools<1>{}( 57 | bargs, 58 | [&](auto...Bargs) { 59 | AT_DISPATCH_FLOATING_TYPES( 60 | input.scalar_type(), 61 | "volume_extraction_forward_cpu_kernel", 62 | [&] { 63 | volume_extraction_forward_cpu_kernel( 64 | /*grid3d_index*/ grid3d_index.accessor(), 65 | /*weight*/ weight.accessor(), 66 | /*bias*/ bias.accessor(), 67 | /*input*/ input.accessor(), 68 | /*output*/ output.accessor(), 69 | /*offset*/ offset, 70 | /*max_r2*/ max_r2 71 | ); 72 | } 73 | ); 74 | } 75 | ); 76 | } 77 | 78 | template 79 | void volume_extraction_backward_cpu_kernel( 80 | const torch::TensorAccessor grid3d_index, 81 | const torch::TensorAccessor weight, 82 | const torch::TensorAccessor bias, 83 | const torch::TensorAccessor input_spectral_weight, 84 | const torch::TensorAccessor input, 85 | const torch::TensorAccessor grad_output, 86 | torch::TensorAccessor grad_weight, 87 | torch::TensorAccessor grad_bias, 88 | torch::TensorAccessor grad_input, 89 | const int offset, 90 | const int max_r2 91 | ) 92 | { 93 | for (long b = 0; b < grad_output.size(0); b ++) 94 | for (long z = 0; z < grad_output.size(1); z ++) 95 | for (long y = 0; y < grad_output.size(2); y ++) 96 | for (long x = 0; x < grad_output.size(3); x ++) 97 | { 98 | const long zp = z - (grad_output.size(1) - 1) / 2; 99 | const long yp = y - (grad_output.size(2) - 1) / 2; 100 | const long xp = x; 101 | 102 | scalar_t r = xp*xp + yp*yp + zp*zp; 103 | if (r <= max_r2) 104 | { 105 | if (do_spectral_weighting) 106 | r = input_spectral_weight[(int) std::sqrt(r)]; 107 | 108 | const long i = grid3d_index[z+offset][y+offset][x]; 109 | for (int c = 0; c < 2; c++) // Over real and imaginary 110 | { 111 | for (int j = 0; j < input.size(1); j ++) 112 | { 113 | grad_weight[i][j][c] += grad_output[b][z][y][x][c] * input[b][j]; 114 | if (do_input_grad) 115 | grad_input[b][j] += 116 | do_spectral_weighting ? 117 | grad_output[b][z][y][x][c] * weight[i][j][c] * r: 118 | grad_output[b][z][y][x][c] * weight[i][j][c]; 119 | } 120 | 121 | if (do_bias) 122 | grad_bias[i][c] += grad_output[b][z][y][x][c]; 123 | } 124 | } 125 | } 126 | } 127 | 128 | 129 | void volume_extraction_backward_cpu( 130 | const torch::Tensor grid3d_index, 131 | const torch::Tensor weight, 132 | const torch::Tensor bias, 133 | const torch::Tensor input_spectral_weight, 134 | const torch::Tensor input, 135 | const torch::Tensor grad_output, 136 | torch::Tensor grad_weight, 137 | torch::Tensor grad_bias, 138 | torch::Tensor grad_input, 139 | const int max_r2, 140 | const bool do_bias, 141 | const bool do_input_grad, 142 | const bool do_spectral_weighting 143 | ) 144 | { 145 | CHECK_CPU_INPUT(grid3d_index) 146 | CHECK_CPU_INPUT(weight) 147 | CHECK_CPU_INPUT(bias) 148 | CHECK_CPU_INPUT(input_spectral_weight) 149 | CHECK_CPU_INPUT(input) 150 | CHECK_CPU_INPUT(grad_output) 151 | CHECK_CPU_INPUT(grad_weight) 152 | CHECK_CPU_INPUT(grad_bias) 153 | CHECK_CPU_INPUT(grad_input) 154 | 155 | const int offset = (grid3d_index.size(0) - grad_output.size(1)) / 2; 156 | 157 | std::array bargs={{do_bias, do_input_grad, do_spectral_weighting}}; 158 | dispatch_bools<3>{}( 159 | bargs, 160 | [&](auto...Bargs) { 161 | AT_DISPATCH_FLOATING_TYPES( 162 | input.scalar_type(), 163 | "volume_extraction_backward_cpu_kernel", 164 | [&] { 165 | volume_extraction_backward_cpu_kernel( 166 | /*grid3d_index*/ grid3d_index.accessor(), 167 | /*weight*/ weight.accessor(), 168 | /*bias*/ bias.accessor(), 169 | /*input_spectral_weight*/ input_spectral_weight.accessor(), 170 | /*input*/ input.accessor(), 171 | /*grad*/ grad_output.accessor(), 172 | /*grad_weight*/ grad_weight.accessor(), 173 | /*grad_bias*/ grad_bias.accessor(), 174 | /*grad_input*/ grad_input.accessor(), 175 | /*offset*/ offset, 176 | /*max_r2*/ max_r2 177 | ); 178 | } 179 | ); 180 | } 181 | ); 182 | } 183 | -------------------------------------------------------------------------------- /voxelium/base/so3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Module for calculations related to the SO(3) group 5 | """ 6 | import sys 7 | import numpy as np 8 | import torch 9 | 10 | from typing import Tuple, Union, TypeVar 11 | 12 | Tensor = TypeVar('torch.tensor') 13 | 14 | 15 | def taitbryan_to_matrix( 16 | angles: Union[Tensor, np.ndarray] 17 | ) -> Union[Tensor, np.ndarray]: 18 | """ 19 | Takes a batch of the three Tait-Bryan angles (rotation axis: xyz) 20 | in radians and returns a batch of the corresponding rotation matrices 21 | 22 | Supports both numpy arrays and torch tensor input 23 | 24 | :param angles: an array (B, 3) of the angels in radians 25 | :return: a 3x3 rotation matrix 26 | """ 27 | 28 | if torch.is_tensor(angles): 29 | R = torch.zeros(len(angles), 3, 3, dtype=angles.dtype).to(angles.device) 30 | c0 = torch.cos(angles[:, 0]) 31 | s0 = torch.sin(angles[:, 0]) 32 | c1 = torch.cos(angles[:, 1]) 33 | s1 = torch.sin(angles[:, 1]) 34 | c2 = torch.cos(angles[:, 2]) 35 | s2 = torch.sin(angles[:, 2]) 36 | else: 37 | R = np.zeros((len(angles), 3, 3), dtype=angles.dtype) 38 | c0 = np.cos(angles[:, 0]) 39 | s0 = np.sin(angles[:, 0]) 40 | c1 = np.cos(angles[:, 1]) 41 | s1 = np.sin(angles[:, 1]) 42 | c2 = np.cos(angles[:, 2]) 43 | s2 = np.sin(angles[:, 2]) 44 | 45 | """ 46 | Matrix multiplication of Rz * Ry * Rx 47 | |c2, -s2, 0| |c1, 0, s1| |1, 0, 0| 48 | |s2, c2, 0| * |0, 1, 0 | * |0, c0, -s0| 49 | |0, 0, 1| |-s1, 0, c1| |0, s0, c0| 50 | """ 51 | 52 | R[:, 0, 0] = c1 * c2 53 | R[:, 0, 1] = s0 * s1 * c2 - c0 * s2 54 | R[:, 0, 2] = c0 * s1 * c2 + s0 * s2 55 | R[:, 1, 0] = c1 * s2 56 | R[:, 1, 1] = c0 * c2 + s0 * s1 * s2 57 | R[:, 1, 2] = c0 * s1 * s2 - s0 * c2 58 | R[:, 2, 0] = -s1 59 | R[:, 2, 1] = s0 * c1 60 | R[:, 2, 2] = c0 * c1 61 | 62 | return R 63 | 64 | 65 | def euler_to_matrix( 66 | angles: Union[Tensor, np.ndarray] 67 | ) -> Union[Tensor, np.ndarray]: 68 | """ 69 | Takes a batch of the three Euler angles as defined in RELION and 70 | returns a batch of the corresponding rotation matrices 71 | 72 | Supports both numpy arrays and torch tensor input 73 | 74 | :param angles: an array (B, 3) of the Euler angels, alpha, beta and gamma (rot, tilt, psi) 75 | :return: a 3x3 rotation matrix 76 | """ 77 | if torch.is_tensor(angles): 78 | R = torch.zeros(len(angles), 3, 3, dtype=angles.dtype).to(angles.device) 79 | ca = torch.cos(angles[:, 0]) 80 | cb = torch.cos(angles[:, 1]) 81 | cg = torch.cos(angles[:, 2]) 82 | sa = torch.sin(angles[:, 0]) 83 | sb = torch.sin(angles[:, 1]) 84 | sg = torch.sin(angles[:, 2]) 85 | else: 86 | R = np.zeros((len(angles), 3, 3), dtype=angles.dtype) 87 | ca = np.cos(angles[:, 0]) 88 | cb = np.cos(angles[:, 1]) 89 | cg = np.cos(angles[:, 2]) 90 | sa = np.sin(angles[:, 0]) 91 | sb = np.sin(angles[:, 1]) 92 | sg = np.sin(angles[:, 2]) 93 | 94 | cc = cb * ca 95 | cs = cb * sa 96 | sc = sb * ca 97 | ss = sb * sa 98 | 99 | R[:, 0, 0] = cg * cc - sg * sa 100 | R[:, 0, 1] = cg * cs + sg * ca 101 | R[:, 0, 2] = -cg * sb 102 | R[:, 1, 0] = -sg * cc - cg * sa 103 | R[:, 1, 1] = -sg * cs + cg * ca 104 | R[:, 1, 2] = sg * sb 105 | R[:, 2, 0] = sc 106 | R[:, 2, 1] = ss 107 | R[:, 2, 2] = cb 108 | 109 | return R 110 | 111 | 112 | def quaternion_to_matrix(Q: Tensor) -> Tensor: 113 | """ 114 | Covert quaternions into 3D rotation matrices. 115 | :param Q: a Bx4 quaternion (Bx4) 116 | :return: rotation matrix (Bx3x3) 117 | """ 118 | 119 | # Extract the values from Q 120 | q0 = Q[:, 0] 121 | q1 = Q[:, 1] 122 | q2 = Q[:, 2] 123 | q3 = Q[:, 3] 124 | 125 | r = torch.empty(Q.shape[0], 3, 3).to(Q.device) 126 | r[:, 0, 0] = 2 * (q0 * q0 + q1 * q1) - 1 127 | r[:, 0, 1] = 2 * (q1 * q2 - q0 * q3) 128 | r[:, 0, 2] = 2 * (q1 * q3 + q0 * q2) 129 | r[:, 1, 0] = 2 * (q1 * q2 + q0 * q3) 130 | r[:, 1, 1] = 2 * (q0 * q0 + q2 * q2) - 1 131 | r[:, 1, 2] = 2 * (q2 * q3 - q0 * q1) 132 | r[:, 2, 0] = 2 * (q1 * q3 - q0 * q2) 133 | r[:, 2, 1] = 2 * (q2 * q3 + q0 * q1) 134 | r[:, 2, 2] = 2 * (q0 * q0 + q3 * q3) - 1 135 | 136 | return r 137 | 138 | 139 | def matrix_to_quaternion(R: Tensor) -> Tensor: 140 | """ 141 | Covert 3D rotation matrices to quaternions. 142 | :param R: rotation matrices (Bx3x3) 143 | :return: quaternion (Bx4) 144 | """ 145 | # From paper: 146 | # Sarabandi, Soheil, and Federico Thomas. 147 | # "Accurate computation of quaternions from rotation matrices." 148 | # International Symposium on Advances in Robot Kinematics. Springer, Cham, 2018. 149 | Q = torch.empty(R.shape[0], 4).type(R.dtype).to(R.device) 150 | 151 | # Q0 152 | t = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] 153 | m = t > 0 154 | Q[m, 0] = .5 * torch.sqrt(1 + t[m]) 155 | m = ~m 156 | Q[m, 0] = .5 * torch.sqrt( 157 | ( 158 | torch.square(R[m, 2, 1] - R[m, 1, 2]) + 159 | torch.square(R[m, 0, 2] - R[m, 2, 0]) + 160 | torch.square(R[m, 1, 0] - R[m, 0, 1]) 161 | ) / (3 - t[m]) 162 | ) 163 | 164 | # Q1 165 | t = R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2] 166 | m = t > 0 167 | Q[m, 1] = .5 * torch.sqrt(1 + t[m]) 168 | m = ~m 169 | Q[m, 1] = .5 * torch.sqrt( 170 | ( 171 | torch.square(R[m, 2, 1] - R[m, 1, 2]) + 172 | torch.square(R[m, 0, 1] + R[m, 1, 0]) + 173 | torch.square(R[m, 2, 0] + R[m, 0, 2]) 174 | ) / (3 - t[m]) 175 | ) 176 | Q[:, 1] *= torch.sign(R[:, 2, 1] - R[:, 1, 2]) 177 | 178 | # Q2 179 | t = -R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2] 180 | m = t > 0 181 | Q[m, 2] = .5 * torch.sqrt(1 + t[m]) 182 | m = ~m 183 | Q[m, 2] = .5 * torch.sqrt( 184 | ( 185 | torch.square(R[m, 0, 2] - R[m, 2, 0]) + 186 | torch.square(R[m, 0, 1] + R[m, 1, 0]) + 187 | torch.square(R[m, 1, 2] + R[m, 2, 1]) 188 | ) / (3 - t[m]) 189 | ) 190 | Q[:, 2] *= torch.sign(R[:, 0, 2] - R[:, 2, 0]) 191 | 192 | # Q3 193 | t = -R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2] 194 | m = t > 0 195 | Q[m, 3] = .5 * torch.sqrt(1 + t[m]) 196 | m = ~m 197 | Q[m, 3] = .5 * torch.sqrt( 198 | ( 199 | torch.square(R[m, 1, 0] - R[m, 0, 1]) + 200 | torch.square(R[m, 2, 0] + R[m, 0, 2]) + 201 | torch.square(R[m, 2, 1] + R[m, 1, 2]) 202 | ) / (3 - t[m]) 203 | ) 204 | Q[:, 3] *= torch.sign(R[:, 1, 0] - R[:, 0, 1]) 205 | 206 | return Q 207 | 208 | 209 | def normalize_quaternions(Q: Tensor) -> Tensor: 210 | norm = torch.sqrt(torch.sum(torch.square(Q), dim=1)) 211 | assert torch.all(norm > 0) 212 | return Q / norm[:, None] 213 | 214 | 215 | def random_rotation_matrix(count): 216 | """ 217 | Returns random rotation matrices. 218 | :param count: number of rotation matrices to return (B) 219 | :return: rotation matrices(Bx3x3) 220 | """ 221 | a12 = 2 * np.pi * torch.rand(count, 2) 222 | a3 = torch.rand((count, 1)).mul(2).sub(1).acos() 223 | return euler_to_matrix(torch.cat([a12, a3], 1)) 224 | 225 | 226 | def is_rotation_matrix(R, eps: float = 1e-6): 227 | """ 228 | Test if R is a rotation matrix. 229 | :param R: rotational matrix to test (Bx3x3) 230 | :param eps: numerical error margin (default: 1e-6) 231 | """ 232 | eye = torch.eye(R.shape[-1]).type(R.dtype).to(R.device) 233 | RRt = torch.matmul(R, torch.transpose(R, 1, 2)) 234 | eye_ae = torch.abs(RRt - eye[None, ...]) 235 | det_ae = torch.abs(torch.linalg.det(R) - 1) 236 | return torch.all(torch.all((eye_ae < eps), dim=-1), dim=-1) & (det_ae < eps) 237 | 238 | 239 | if __name__ == "__main__": 240 | angles = torch.Tensor([[0.1, 0.2, 0.1], [1.1, 0.2, 1.2]]) 241 | 242 | R1 = euler_to_matrix(angles) 243 | Q1 = matrix_to_quaternion(R1) 244 | R2 = quaternion_to_matrix(Q1) 245 | -------------------------------------------------------------------------------- /voxelium/vae_volume/tensorboard_summary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for a training VAE 5 | """ 6 | import os 7 | import sys 8 | from typing import List, TypeVar, Dict 9 | 10 | import numpy as np 11 | import matplotlib 12 | import matplotlib.pylab as plt 13 | import torch 14 | import torch.nn.functional as F 15 | from sklearn.decomposition import PCA 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from voxelium.base import dt_desymmetrize, idft, idht, dt_symmetrize, grid_spectral_average_torch 19 | from voxelium.base.torch_utils import make_imshow_fig, make_scatter_fig, make_line_fig, make_series_line_fig 20 | from voxelium.vae_volume.cache import Cache 21 | from voxelium.vae_volume.data_analysis_container import DatasetAnalysisContainer 22 | from voxelium.vae_volume.hidden_variable_container import HiddenVariableContainer 23 | 24 | Tensor = TypeVar('torch.tensor') 25 | 26 | 27 | def tesnor_to_np(tensor): 28 | return tensor.detach().cpu().numpy() 29 | 30 | 31 | class TensorboardSummary: 32 | def __init__(self, logdir, pixel_size, spectral_size, step=0): 33 | self.summary_fn = os.path.join(logdir, "summary") 34 | self.summary = SummaryWriter(self.summary_fn) 35 | self.spectral_size = spectral_size 36 | self.step = step 37 | self.pixel_size = pixel_size 38 | 39 | def set_step(self, step: int = None): 40 | self.step = self.step + 1 if step is None else step 41 | 42 | def write_losses( 43 | self, 44 | train_mse, 45 | kl_weight, kld_loss, data_loss, 46 | train_loss, roi_weight, roi_loss 47 | ): 48 | kld_loss = torch.mean(kld_loss) 49 | data_loss = torch.mean(data_loss) 50 | train_loss = torch.mean(train_loss) 51 | self.summary.add_scalar(f"Loss/KLD weight", kl_weight, self.step) 52 | self.summary.add_scalar(f"Loss/KLD", kld_loss.detach().cpu().numpy(), self.step) 53 | self.summary.add_scalar(f"Loss/Data", data_loss.detach().cpu().numpy(), self.step) 54 | self.summary.add_scalar(f"Loss/Train", train_loss.detach().cpu().numpy(), self.step) 55 | self.summary.add_scalar(f"Loss/Train MSE", train_mse.detach().cpu().numpy(), self.step) 56 | if roi_weight > 0: 57 | self.summary.add_scalar(f"Loss/ROI", roi_loss.detach().cpu().numpy(), self.step) 58 | self.summary.add_scalar(f"Loss/ROI wight", roi_weight, self.step) 59 | 60 | def write_stats(self, x_ft, y_ft, data_amp, data_ctf_amp): 61 | if x_ft.shape[-1] != 2: 62 | x_ft = torch.view_as_real(x_ft) 63 | if y_ft.shape[-1] != 2: 64 | y_ft = torch.view_as_real(y_ft) 65 | self.summary.add_scalar(f"Stats/X mean", torch.mean(x_ft).detach().cpu().numpy(), self.step) 66 | self.summary.add_scalar(f"Stats/X std", torch.std(x_ft).detach().cpu().numpy(), self.step) 67 | self.summary.add_scalar(f"Stats/Y mean", torch.mean(y_ft).detach().cpu().numpy(), self.step) 68 | self.summary.add_scalar(f"Stats/Y std", torch.std(y_ft).detach().cpu().numpy(), self.step) 69 | self.summary.add_scalar(f"Stats/data amp mean", torch.mean(data_amp).detach().cpu().numpy(), self.step) 70 | self.summary.add_scalar(f"Stats/data amp std", torch.std(data_amp).detach().cpu().numpy(), self.step) 71 | self.summary.add_scalar(f"Stats/data ctf amp mean", torch.mean(data_ctf_amp).detach().cpu().numpy(), self.step) 72 | self.summary.add_scalar(f"Stats/data ctf amp std", torch.std(data_ctf_amp).detach().cpu().numpy(), self.step) 73 | 74 | def write_images(self, x_ft, y_ft, ctf, roi=None): 75 | x_ft = x_ft.detach() 76 | y_ft = y_ft.detach() 77 | ctf = ctf.detach() 78 | 79 | y_ft_ = torch.abs(torch.view_as_complex(y_ft[0])).detach().cpu().numpy() 80 | c_ = torch.abs(ctf[0].detach()).data.cpu().numpy() 81 | c_std = np.std(c_) 82 | y_ft_std = np.std(y_ft_) 83 | if c_std != 0 and y_ft_std != 0: 84 | c_ /= c_std 85 | y_ft_ /= y_ft_std 86 | y_ft_[:c_.shape[0] // 2] = c_[:c_.shape[0] // 2] 87 | self.summary.add_figure(f"Data/HT", make_imshow_fig(y_ft_), self.step) 88 | 89 | x_ft_ = dt_desymmetrize(torch.view_as_complex(x_ft[0]), dim=2).detach().cpu().numpy() 90 | self.summary.add_figure(f"Output/FT", make_imshow_fig(np.abs(x_ft_)), self.step) 91 | 92 | x_ = idft(x_ft_, dim=2, real_in=True) 93 | y_ = idft(y_ft[0].detach(), dim=2).real.cpu().numpy() 94 | self.summary.add_figure(f"Output/Image", make_imshow_fig(x_), self.step) 95 | self.summary.add_figure(f"Data/Image", make_imshow_fig(y_), self.step) 96 | 97 | if roi is not None: 98 | if roi.roi_slice is not None: 99 | self.summary.add_figure(f"ROI", make_imshow_fig(roi.roi_slice), self.step) 100 | if roi.roni_slice is not None: 101 | self.summary.add_figure(f"RONI", make_imshow_fig(roi.roni_slice), self.step) 102 | 103 | def write_hidden_variable(self, hvc: HiddenVariableContainer): 104 | vars = hvc.vars 105 | 106 | v = torch.stack([vars['pose_alpha'].vars, vars['pose_beta'].vars, vars['pose_gamma'].vars], 1) 107 | o = torch.stack([vars['pose_alpha'].orig, vars['pose_beta'].orig, vars['pose_gamma'].orig], 1) 108 | e = torch.sqrt(F.mse_loss(v, o.to(v.device))).cpu().detach().item() 109 | self.summary.add_scalar(f"Hidden variables/pose error", e, self.step) 110 | 111 | v = torch.stack([vars['shift_x'].vars, vars['shift_y'].vars], 1) 112 | o = torch.stack([vars['shift_x'].orig, vars['shift_y'].orig], 1) 113 | e = torch.sqrt(F.mse_loss(v, o.to(v.device))).cpu().detach().item() 114 | self.summary.add_scalar(f"Hidden variables/shift error", e, self.step) 115 | 116 | if hvc.do_ctf: 117 | v = torch.stack([vars['ctf_defocus_u'].vars, vars['ctf_defocus_v'].vars], 1) 118 | o = torch.stack([vars['ctf_defocus_u'].orig, vars['ctf_defocus_v'].orig], 1) 119 | e = torch.sqrt(F.mse_loss(v, o.to(v.device))).cpu().detach().item() 120 | self.summary.add_scalar(f"Hidden variables/ctf defocus error", e, self.step) 121 | 122 | v = vars['ctf_angle'].vars 123 | o = vars['ctf_angle'].orig 124 | e = torch.sqrt(F.mse_loss(v, o.to(v.device))).cpu().detach().item() 125 | self.summary.add_scalar(f"Hidden variables/ctf angle error", e, self.step) 126 | 127 | if hvc.latent_space is not None and torch.any(hvc.latent_space != 0): 128 | latent_space = hvc.latent_space.detach().cpu().numpy() 129 | if latent_space.shape[-1] == 2: 130 | embed = latent_space 131 | else: 132 | embed = PCA(n_components=2).fit_transform( 133 | latent_space.astype(np.float32) 134 | ) 135 | 136 | # import tsnecuda 137 | # device_id = hvc.latent_space.device.index 138 | # embed = tsnecuda.TSNE(n_components=2, device=device_id).fit_transform( 139 | # latent_space.astype(np.float32) 140 | # ) 141 | fig = make_scatter_fig(embed[:, 0], embed[:, 1]) 142 | self.summary.add_figure( 143 | f"Latent/Train", 144 | fig, 145 | self.step 146 | ) 147 | #fig.savefig(f"{self.summary_fn}/latent_fig_{self.step}.svg") 148 | 149 | 150 | # if hvc.structure_basis is not None and torch.any(hvc.structure_basis != 0): 151 | # structure_basis = hvc.structure_basis.detach().cpu().numpy() 152 | # if structure_basis.shape[-1] == 2: 153 | # embed = structure_basis 154 | # else: 155 | # embed = PCA(n_components=2).fit_transform( 156 | # structure_basis.astype(np.float32) 157 | # ) 158 | # 159 | # self.summary.add_figure( 160 | # f"Structure basis/Train", 161 | # make_scatter_fig( 162 | # embed[:, 0], 163 | # embed[:, 1] 164 | # ), 165 | # self.step 166 | # ) -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/trilinear_projection.cpp: -------------------------------------------------------------------------------- 1 | #include "vae_volume/svr_linear/trilinear_projection.h" 2 | 3 | torch::Tensor trilinear_projection_forward( 4 | torch::Tensor input, 5 | torch::Tensor weight, 6 | torch::Tensor bias, 7 | torch::Tensor rot_matrix, 8 | torch::Tensor grid2d_coord, 9 | torch::Tensor grid3d_index, 10 | const int max_r 11 | ) 12 | { 13 | const int batch_size = input.size(0); 14 | const bool do_bias = bias.size(0) == weight.size(0); 15 | 16 | CHECK_SIZE_DIM1(grid2d_coord, 2) 17 | CHECK_SIZE_DIM0(rot_matrix, batch_size) 18 | CHECK_SIZE_DIM1(rot_matrix, 3) 19 | CHECK_SIZE_DIM2(rot_matrix, 3) 20 | 21 | auto output = torch::zeros( 22 | {batch_size, grid2d_coord.size(0), 2}, 23 | torch::TensorOptions() 24 | .dtype(input.dtype()) 25 | .layout(torch::kStrided) 26 | .device(input.device()) 27 | .requires_grad(true) 28 | ); 29 | 30 | if (input.device().type() == torch::kCPU) 31 | { 32 | trilinear_projection_forward_cpu( 33 | /*grid2d_coord*/ grid2d_coord, 34 | /*grid3d_index*/ grid3d_index, 35 | /*weight*/ weight, 36 | /*bias*/ bias, 37 | /*rot_matrix*/ rot_matrix, 38 | /*input*/ input, 39 | /*output*/ output, 40 | /*max_r2*/ (int) max_r * max_r, 41 | /*init_offset*/ (int) grid3d_index.size(0)/2, 42 | /*do_bias*/ do_bias 43 | ); 44 | } 45 | else if (input.device().type() == torch::kCUDA) 46 | { 47 | trilinear_projection_forward_cuda( 48 | /*grid2d_coord*/ grid2d_coord, 49 | /*grid3d_index*/ grid3d_index, 50 | /*weight*/ weight, 51 | /*bias*/ bias, 52 | /*rot_matrix*/ rot_matrix, 53 | /*input*/ input, 54 | /*output*/ output, 55 | /*max_r2*/ (int) max_r * max_r, 56 | /*init_offset*/ (int) grid3d_index.size(0)/2, 57 | /*do_bias*/ do_bias 58 | ); 59 | } 60 | else 61 | throw std::logic_error("Support for device not implemented"); 62 | 63 | return output; 64 | } 65 | 66 | 67 | std::vector trilinear_projection_backward( 68 | torch::Tensor input, 69 | torch::Tensor weight, 70 | torch::Tensor bias, 71 | torch::Tensor rot_matrix, 72 | torch::Tensor input_spectral_weight, 73 | torch::Tensor grad_output, 74 | torch::Tensor grid2d_coord, 75 | torch::Tensor grid3d_index, 76 | bool sparse_grad, 77 | const int max_r 78 | ) 79 | { 80 | const int batch_size = input.size(0); 81 | const int input_size = input.size(1); 82 | const int points_count = grid2d_coord.size(0); 83 | const bool do_bias = bias.size(0) == weight.size(0); 84 | 85 | CHECK_SIZE_DIM0(grid2d_coord, points_count) 86 | CHECK_SIZE_DIM1(grid2d_coord, 2) 87 | CHECK_SIZE_DIM0(rot_matrix, batch_size) 88 | CHECK_SIZE_DIM1(rot_matrix, 3) 89 | CHECK_SIZE_DIM2(rot_matrix, 3) 90 | 91 | CHECK_DTYPE(weight, input.dtype()) 92 | CHECK_DTYPE(bias, input.dtype()) 93 | CHECK_DTYPE(rot_matrix, input.dtype()) 94 | CHECK_DTYPE(grad_output, input.dtype()) 95 | CHECK_DTYPE(grid2d_coord, input.dtype()) 96 | 97 | TORCH_CHECK(input_spectral_weight.size(0) == 0 || input_spectral_weight.size(0) > max_r, 98 | "input_spectral_weight.size(0) bad size (", input_spectral_weight.size(0), " <= ", 99 | max_r, "). In ", __FILE__, ":", __LINE__) 100 | 101 | torch::Tensor grad_weight, grad_bias, grad_weight_index; 102 | 103 | if (sparse_grad) 104 | { 105 | grad_weight = torch::zeros( 106 | {batch_size, points_count * 8, weight.size(1), 2}, 107 | torch::TensorOptions() 108 | .dtype(input.dtype()) 109 | .layout(torch::kStrided) 110 | .device(input.device()) 111 | .requires_grad(false) 112 | ); 113 | 114 | if (do_bias) 115 | grad_bias = torch::zeros( 116 | {batch_size, points_count * 8, 2}, 117 | torch::TensorOptions() 118 | .dtype(input.dtype()) 119 | .layout(torch::kStrided) 120 | .device(input.device()) 121 | .requires_grad(false) 122 | ); 123 | else 124 | grad_bias = torch::empty( 125 | {0, 0, 0}, 126 | torch::TensorOptions() 127 | .dtype(input.dtype()) 128 | .layout(torch::kStrided) 129 | .device(input.device()) 130 | .requires_grad(false) 131 | ); 132 | 133 | grad_weight_index = torch::zeros( 134 | {batch_size, points_count * 8}, 135 | torch::TensorOptions() 136 | .dtype(torch::kInt64) 137 | .layout(torch::kStrided) 138 | .device(input.device()) 139 | .requires_grad(false) 140 | ); 141 | } 142 | else 143 | { 144 | grad_weight = torch::zeros_like(weight); 145 | grad_bias = torch::zeros_like(bias); 146 | grad_weight_index = torch::empty(0, 147 | torch::TensorOptions() 148 | .dtype(input.dtype()) 149 | .layout(torch::kStrided) 150 | .device(input.device()) 151 | .requires_grad(false) 152 | ); 153 | } 154 | 155 | auto grad_input = torch::zeros( 156 | {batch_size, input_size}, 157 | torch::TensorOptions() 158 | .dtype(input.dtype()) 159 | .layout(torch::kStrided) 160 | .device(input.device()) 161 | .requires_grad(false) 162 | ); 163 | 164 | auto grad_rot_matrix = torch::zeros( 165 | {batch_size, 3, 3}, 166 | torch::TensorOptions() 167 | .dtype(input.dtype()) 168 | .layout(torch::kStrided) 169 | .device(input.device()) 170 | .requires_grad(false) 171 | ); 172 | 173 | if (input.device().type() == torch::kCPU) 174 | { 175 | trilinear_projection_backward_cpu( 176 | /*grid2d_coord*/grid2d_coord, 177 | /*grid3d_index*/ grid3d_index, 178 | /*weight*/ weight, 179 | /*bias*/ bias, 180 | /*grad_weight_index*/ grad_weight_index, 181 | /*rot_matrix*/ rot_matrix, 182 | /*input_spectral_weight*/ input_spectral_weight, 183 | /*input*/ input, 184 | /*grad_output*/ grad_output, 185 | /*grad_weight*/ grad_weight, 186 | /*grad_bias*/ grad_bias, 187 | /*grad_input*/ grad_input, 188 | /*grad_rot_matrix*/ grad_rot_matrix, 189 | /*max_r2*/ (int) max_r * max_r, 190 | /*init_offset*/ (int) grid3d_index.size(0)/2, 191 | /*do_bias*/ do_bias, 192 | /*do_rot_matrix_grad*/ rot_matrix.requires_grad(), 193 | /*sparse_grad*/ sparse_grad, 194 | /*do_spectral_weighting*/ input_spectral_weight.size(0) > 0 195 | ); 196 | } 197 | else if (input.device().type() == torch::kCUDA) 198 | { 199 | trilinear_projection_backward_cuda( 200 | /*grid2d_coord*/grid2d_coord, 201 | /*grid3d_index*/ grid3d_index, 202 | /*weight*/ weight, 203 | /*bias*/ bias, 204 | /*grad_weight_index*/ grad_weight_index, 205 | /*rot_matrix*/ rot_matrix, 206 | /*input_spectral_weight*/ input_spectral_weight, 207 | /*input*/ input, 208 | /*grad_output*/ grad_output, 209 | /*grad_weight*/ grad_weight, 210 | /*grad_bias*/ grad_bias, 211 | /*grad_input*/ grad_input, 212 | /*grad_rot_matrix*/ grad_rot_matrix, 213 | /*max_r2*/ (int) max_r * max_r, 214 | /*init_offset*/ (int) grid3d_index.size(0)/2, 215 | /*do_bias*/ do_bias, 216 | /*do_rot_matrix_grad*/ rot_matrix.requires_grad(), 217 | /*sparse_grad*/ sparse_grad, 218 | /*do_spectral_weighting*/ input_spectral_weight.size(0) > 0 219 | ); 220 | } 221 | else 222 | throw std::logic_error("Support for device not implemented"); 223 | 224 | return {grad_input, grad_weight_index, grad_weight, grad_bias, grad_rot_matrix}; 225 | } 226 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for the sparse linear layer 5 | """ 6 | 7 | import unittest 8 | 9 | from voxelium.vae_volume.svr_linear import make_compact_grid2d 10 | from voxelium.vae_volume.svr_linear.svr_linear import * 11 | from voxelium.relion import eulerToMatrix 12 | 13 | 14 | ATOL = 1e-6 15 | PERTURB_EPS = 1e-6 16 | 17 | 18 | class TestSparseLinear(unittest.TestCase): 19 | def test_gradcheck_projection_cpu_sparse(self): 20 | self.assertTrue(self._gradcheck_projection("cpu", sparse=True)) 21 | 22 | def test_gradcheck_projection_cuda_sparse(self): 23 | self.assertTrue(self._gradcheck_projection("cuda:0", sparse=True)) 24 | 25 | def test_gradcheck_projection_cpu_dense(self): 26 | self.assertTrue(self._gradcheck_projection("cpu", sparse=False)) 27 | 28 | def test_gradcheck_projection_cuda_dense(self): 29 | self.assertTrue(self._gradcheck_projection("cuda:0", sparse=False)) 30 | 31 | @staticmethod 32 | def _gradcheck_projection(device, sparse): 33 | input_size = 5 34 | img_size = 8 35 | bsize = 2 36 | 37 | max_r = img_size // 2 38 | coord, mask = make_compact_grid2d(size=img_size) 39 | coord = coord.to(device).double() 40 | 41 | p = SparseVolumeReconstructionLinear(img_size, input_size, dtype=torch.double, bias=True).to(device) 42 | input = torch.randn(bsize, p.input_size, dtype=torch.float64, requires_grad=True).to(p.weight.device) 43 | angles = torch.randn(bsize, 3, dtype=torch.float64, requires_grad=True).to(p.weight.device) 44 | rot_matrices = eulerToMatrix(angles) 45 | 46 | return torch.autograd.gradcheck( 47 | TrilinearProjection.apply, 48 | ( 49 | input, # input 50 | p.weight.double(), # weight 51 | p.bias.double(), # bias 52 | p.grid3d_index, # grid3d_index 53 | rot_matrices.detach(), # rot_matrices 54 | coord, # grid2d_coord 55 | max_r, # max_r 56 | None, # input_spectral_weight 57 | sparse, # sparse_grad 58 | True # testing 59 | ), 60 | eps=PERTURB_EPS, 61 | atol=ATOL 62 | ) 63 | 64 | def test_forward_projection_cpu(self): 65 | self.assertTrue(self._forward_projection("cpu")) 66 | 67 | def test_forward_projection_cuda(self): 68 | self.assertTrue(self._forward_projection("cuda")) 69 | 70 | @staticmethod 71 | def _forward_projection(device): 72 | input_size = 5 73 | img_size = 16 74 | 75 | max_r = img_size // 2 76 | coord, mask = make_compact_grid2d(size=img_size) 77 | coord = coord.to(device) 78 | 79 | p = SparseVolumeReconstructionLinear(img_size, input_size) 80 | p.to(device) 81 | 82 | ref = TestSparseLinear._make_random_ref(img_size, input_size) 83 | p.weight.data *= 0 84 | p.bias.data[...] = 0 85 | p.set_reference(ref.to(device)) 86 | 87 | input = torch.ones([1, input_size]).to(device) * 1. / input_size 88 | rot_matrices = torch.eye(3).unsqueeze(0).to(device) 89 | projection_ = p(input, max_r=max_r, grid2d_coord=coord, rot_matrices=rot_matrices) 90 | projection = torch.zeros((img_size + 1) * (img_size//2 + 1), 2).to(device) 91 | projection[mask, :] = projection_[0] 92 | projection = projection.view(img_size + 1, img_size//2 + 1, 2) 93 | projection = projection.cpu().detach().numpy() 94 | ref_projection = ref[img_size//2, :, :, 0, :] 95 | ref_projection = ref_projection.cpu().detach().numpy() 96 | 97 | return np.all(np.abs(projection - ref_projection) < ATOL) 98 | 99 | def test_forward_volume_extraction_cpu(self): 100 | self.assertTrue(self._forward_volume_extraction("cpu")) 101 | 102 | def test_forward_volume_extraction_cuda(self): 103 | self.assertTrue(self._forward_volume_extraction("cuda")) 104 | 105 | @staticmethod 106 | def _forward_volume_extraction(device): 107 | input_size = 5 108 | img_size = 16 109 | 110 | p = SparseVolumeReconstructionLinear(img_size, input_size) 111 | p.to(device) 112 | 113 | ref = TestSparseLinear._make_random_ref(img_size, input_size) 114 | 115 | p.weight.data *= 0 116 | p.bias.data *= 0 117 | p.set_reference(ref) 118 | 119 | input = torch.ones([1, input_size]).to(device) * 1. / input_size 120 | vol = p(input)[0].cpu().detach().numpy() 121 | ref = ref[..., 0, :].cpu().detach().numpy() 122 | 123 | return np.all(np.abs(vol - ref) < ATOL) 124 | 125 | def test_gradcheck_volume_extraction_cpu(self): 126 | self.assertTrue(self._gradcheck_volume_extraction("cpu")) 127 | 128 | def test_gradcheck_volume_extraction_cuda(self): 129 | self.assertTrue(self._gradcheck_volume_extraction("cuda:0")) 130 | 131 | @staticmethod 132 | def _gradcheck_volume_extraction(device): 133 | input_size = 3 134 | img_size = 8 135 | bsize = 2 136 | p = SparseVolumeReconstructionLinear(img_size, input_size, dtype=torch.double).to(device) 137 | input = torch.randn(bsize, p.input_size, dtype=torch.float64, requires_grad=True).to(p.weight.device) 138 | 139 | return torch.autograd.gradcheck( 140 | VolumeExtraction.apply, 141 | ( 142 | input, # input 143 | p.weight.double(), # weight 144 | p.bias.double(), # bias 145 | p.grid3d_index, # grid3d_index 146 | None, # input_spectral_weight 147 | ), 148 | eps=PERTURB_EPS, 149 | atol=ATOL 150 | ) 151 | 152 | @staticmethod 153 | def _make_random_ref(img_size, input_size): 154 | ls = torch.linspace(-img_size // 2, img_size // 2, img_size + 1) 155 | lsx = torch.linspace(0, img_size // 2, img_size // 2 + 1) 156 | z, y, x = np.meshgrid(ls, ls, lsx, indexing='ij') 157 | r = np.sqrt(x ** 2 + y ** 2 + z ** 2) 158 | mask = r < img_size // 2 159 | n = np.sum(mask) 160 | mask = torch.Tensor(mask).bool() 161 | ref = torch.zeros([img_size + 1, img_size + 1, img_size // 2 + 1, 2]) 162 | ref[mask] = torch.empty(n, 2).normal_() 163 | return ref.unsqueeze(-2).expand(img_size + 1, img_size + 1, img_size // 2 + 1, input_size, 2) 164 | 165 | 166 | if __name__ == "__main__": 167 | torch.autograd.set_detect_anomaly(True) 168 | test = TestSparseLinear() 169 | # test.test_gradcheck_projection_cpu_dense() 170 | test.test_gradcheck_volume_extraction_cuda() 171 | print("All good!") 172 | 173 | # device = "cuda:0" 174 | # device = "cpu" 175 | # 176 | # def _make_random_ref(img_size, input_size): 177 | # ls = torch.linspace(-img_size // 2, img_size // 2, img_size + 1) 178 | # lsx = torch.linspace(0, img_size // 2, img_size // 2 + 1) 179 | # z, y, x = np.meshgrid(ls, ls, lsx, indexing='ij') 180 | # r = np.sqrt(x ** 2 + y ** 2 + z ** 2) 181 | # mask = (r <= img_size // 2) & (z == 0) 182 | # mask = torch.Tensor(mask).bool() 183 | # ref = torch.zeros([img_size + 1, img_size + 1, img_size // 2 + 1, 2]) 184 | # ref[mask] = 0.5 185 | # return ref.unsqueeze(-2).expand(img_size + 1, img_size + 1, img_size // 2 + 1, input_size, 2) 186 | # 187 | # 188 | # input_size = 5 189 | # img_size = 16 190 | # 191 | # p = SparseVolumeReconstructionLinear(img_size, input_size) 192 | # p.to(device) 193 | # max_r = img_size // 2 194 | # coord, mask = make_compact_grid2d(size=img_size) 195 | # coord = coord.to(device) 196 | # 197 | # ref = _make_random_ref(img_size, input_size) 198 | # p.weight.data *= 0 199 | # p.bias.data *= 0 200 | # p.set_reference(ref.to(device)) 201 | # 202 | # input = torch.ones([1, input_size]).to(device) * 1. / input_size 203 | # angles = np.zeros([1, 3]) 204 | # angles[0, 1] = 1 205 | # angles = torch.Tensor(angles).to(device) 206 | # angles.requires_grad = True 207 | # angles.retain_grad() 208 | # angles.register_hook(lambda grad: print(grad)) 209 | # optimizer = torch.optim.SGD([angles], lr=0.01, momentum=0.9) 210 | # rot_matrices = eulerToMatrix(angles) 211 | # # rot_matrices.retain_grad() 212 | # projection_ = p(input, max_r=max_r, grid2d_coord=coord, rot_matrices=rot_matrices) 213 | # projection = torch.zeros((img_size + 1) * (img_size // 2 + 1), 2).to(device) 214 | # projection[mask, :] = projection_[0] 215 | # projection = projection.view(img_size + 1, img_size // 2 + 1, 2) 216 | # ref_projection = ref[img_size // 2, :, :, 0, :] 217 | # 218 | # loss = torch.mean(torch.square(projection.cpu() - ref_projection)) 219 | # loss.backward() 220 | # 221 | # optimizer.step() 222 | # print(angles.grad) 223 | -------------------------------------------------------------------------------- /voxelium/vae_volume/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test module for a training VAE 5 | """ 6 | from glob import glob 7 | import os 8 | import shutil 9 | import sys 10 | from typing import List, TypeVar 11 | 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import torch.multiprocessing as mp 18 | from voxelium.base.star_file import load_star 19 | from voxelium.vae_volume.cache import Cache 20 | 21 | import matplotlib.pyplot as plt 22 | 23 | Tensor = TypeVar('torch.tensor') 24 | 25 | from voxelium.base import smooth_square_mask, smooth_circular_mask, ModelContainer, grid_spectral_average_torch, \ 26 | dt_symmetrize, spectra_to_grid_torch 27 | 28 | 29 | def cos_step_ascend(begin_ascend, end_ascend, x): 30 | if x < begin_ascend: 31 | return 0. 32 | if x > end_ascend: 33 | return 1. 34 | a = begin_ascend 35 | b = end_ascend - begin_ascend 36 | return .5 + np.cos(np.pi * (x - a) / b + np.pi) / 2. 37 | 38 | 39 | def cos_step_descend(begin_descend, end_descend, x): 40 | if x < begin_descend: 41 | return 1. 42 | if x > end_descend: 43 | return 0. 44 | a = begin_descend 45 | b = end_descend - begin_descend 46 | return .5 + np.cos(np.pi * (x - a) / b) / 2. 47 | 48 | 49 | def get_kld_loss(mu, logvar): 50 | # see Appendix B from VAE paper: 51 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 52 | # https://arxiv.org/abs/1312.6114 53 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) 54 | 55 | 56 | def setup_device(args, verbose=False): 57 | device = None 58 | gpu_ids = [] 59 | if args.gpu is not None: 60 | queried_gpu_ids = args.gpu.split(",") 61 | for i in range(len(queried_gpu_ids)): 62 | gpu_id = int(queried_gpu_ids[i].strip()) 63 | try: 64 | gpu_name = torch.cuda.get_device_name(gpu_id) 65 | except AssertionError: 66 | if verbose: 67 | print(f'WARNING: GPU with the device id "{gpu_id}" not found.', file=sys.stderr) 68 | continue 69 | if verbose: 70 | print(f'Found device "{gpu_name}"') 71 | gpu_ids.append(gpu_id) 72 | 73 | if len(gpu_ids) > 0: 74 | device = "cuda:" + str(gpu_ids[0]) 75 | if verbose: 76 | print("Running on GPU with device id(s)", *gpu_ids) 77 | else: 78 | if verbose: 79 | print(f'WARNING: no GPUs were found with the specified ids.', file=sys.stderr) 80 | 81 | if len(gpu_ids) == 0: 82 | gpu_ids = None 83 | if verbose: 84 | print("Running on CPU") 85 | device = torch.device("cpu") 86 | 87 | return device, gpu_ids 88 | 89 | def get_gradient_penalty(module: nn.Module, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 90 | alpha = torch.rand(*x.shape[:-1], 1, device=x.device) 91 | interpolates = (alpha * x + (1 - alpha) * y).requires_grad_() 92 | m_interpolates = module(interpolates) 93 | grad_output = torch.ones_like(m_interpolates, requires_grad=False) 94 | gradients, = torch.autograd.grad( 95 | outputs = m_interpolates, 96 | inputs = interpolates, 97 | grad_outputs = grad_output, 98 | create_graph = True, 99 | retain_graph = True, 100 | only_inputs = True 101 | ) 102 | gradient_norm = gradients.view(x.shape[0], -1).norm(p=2, dim=-1) 103 | gradient_penalty = (gradient_norm - 1).square().mean() 104 | return gradient_penalty 105 | 106 | np_dtype_dict = {"float32": np.float32, "float16": np.float16, "float64": np.float64} 107 | 108 | def get_np_dtype(dtype_str: str) -> np.dtype: 109 | return np_dtype_dict[dtype_str] 110 | 111 | def find_star_file_in_path(path: str, type: str = "optimiser") -> str: 112 | if os.path.isfile(os.path.join(path, f"run_{type}.star")): 113 | return os.path.join(path, f"run_{type}.star") 114 | files = glob(os.path.join(path, f"*{type}.star")) 115 | if len(files) > 0: 116 | files = list.sort(files) 117 | return files[-1] 118 | 119 | raise FileNotFoundError(f"Could not find '{type}' star-file in path: {path}") 120 | 121 | def find_project_root(from_path: str, file_relative_path: str) -> str: 122 | """ 123 | Searches for the Relion project root starting at from_path and iterate through parent directories 124 | till file_relative_path is found as a relative sub path or till filesystem root is found, at which 125 | point a RuntimeException is raise. 126 | 127 | :param from_path: starting search from this path 128 | :param file_relative_path: searching to find this relative path as a file 129 | """ 130 | current_path = os.path.abspath(from_path) 131 | while True: 132 | trial_path = os.path.join(current_path, file_relative_path) 133 | if os.path.isfile(trial_path): 134 | return current_path 135 | if current_path == os.path.dirname(current_path): # At filesystem root 136 | raise RuntimeError(f"Relion project directory could not be found from the subdirectory: {from_path}") 137 | current_path = os.path.dirname(current_path) 138 | 139 | def dump_particles_to_dir(input_path, output_path): 140 | """ 141 | Load data from path 142 | :param path: relion job directory or data file 143 | """ 144 | if os.path.isfile(input_path): 145 | data_star_path = input_path 146 | root_search_path = os.path.dirname(os.path.abspath(input_path)) 147 | else: 148 | data_star_path = os.path.abspath(find_star_file_in_path(input_path, "data")) 149 | root_search_path = os.path.abspath(input_path) 150 | 151 | data_star_path = os.path.abspath(data_star_path) 152 | data = load_star(data_star_path) 153 | 154 | if 'optics' not in data: 155 | raise RuntimeError("Optics groups table not found in data star file") 156 | if 'particles' not in data: 157 | raise RuntimeError("Particles table not found in data star file") 158 | 159 | particles = data['particles'] 160 | nr_particles = len(particles['rlnImageName']) 161 | image_file_paths = set() 162 | 163 | for i in range(nr_particles): 164 | img_name = particles['rlnImageName'][i] 165 | img_tokens = img_name.split("@") 166 | if len(img_tokens) == 2: 167 | img_path = img_tokens[1] 168 | elif len(img_tokens) == 1: 169 | img_path = img_tokens[1] 170 | else: 171 | raise RuntimeError(f"Invalid image file name (rlnImageName): {img_name}") 172 | image_file_paths.add(img_path) 173 | 174 | image_file_paths = list(image_file_paths) 175 | 176 | project_root = find_project_root(root_search_path, image_file_paths[0]) 177 | 178 | # Convert image paths to absolute paths 179 | for i in range(len(image_file_paths)): 180 | image_file_paths[i] = os.path.abspath(os.path.join(project_root, image_file_paths[i])) 181 | 182 | new_project_path = os.path.abspath(output_path) 183 | destination_image_file_paths = [p.replace(project_root, new_project_path) for p in image_file_paths] 184 | 185 | [os.makedirs(os.path.dirname(p), exist_ok=True) for p in destination_image_file_paths] 186 | 187 | for src, dst in zip(image_file_paths, destination_image_file_paths): 188 | shutil.copy(src, dst) 189 | 190 | new_star_path = data_star_path.replace(os.path.dirname(data_star_path), new_project_path) 191 | shutil.copy(data_star_path, new_star_path) 192 | 193 | def plot_fscs(output_file, **fscs): 194 | fig, main_ax = plt.subplots() 195 | for plot_name in fscs: 196 | fsc_file = load_star(fscs[plot_name]) 197 | fsc_values = [float(x) for x in fsc_file["fsc"]["rlnFourierShellCorrelation"]] 198 | main_ax.plot(fsc_values, label=plot_name) 199 | main_ax.legend() 200 | main_ax.set_xlabel("1/Angstroms (1/Å)") 201 | main_ax.set_ylabel("FSC") 202 | plt.plot() 203 | plt.savefig(output_file) 204 | 205 | 206 | # class SpectralStandardMapping(torch.nn.Module): 207 | # def __init__(self, image_size): 208 | # super().__init__() 209 | # self.image_size = image_size 210 | # self.std_spectra = torch.nn.Parameter(torch.zeros(image_size // 2 + 1)) 211 | # self.std_grid = torch.nn.Parameter(torch.zeros(image_size, image_size)) 212 | # 213 | # def forward(self, input): 214 | # 215 | # 216 | # def track(self, input): 217 | # with torch.no_grad(): 218 | # s_idx = Cache.get_spectral_indices( 219 | # (self.image_size + 1, self.image_size + 1), 220 | # max_r=self.image_size // 2 + 1, 221 | # device=input.device 222 | # )[:, self.image_size // 2:] 223 | # 224 | # if torch.is_complex(input): 225 | # input = torch.view_as_real(input) 226 | # 227 | # power = torch.mean(torch.sum(torch.square(input), -1), 0) 228 | # power = grid_spectral_average_torch(power, s_idx) 229 | # power = power[:, :-1] 230 | # power[:, -1] = power[:, -2] 231 | # 232 | # stds = torch.sqrt(power) 233 | # 234 | # g_idx = Cache.get_spectral_indices( 235 | # (self.image_size, self.image_size), 236 | # max_r=self.image_size // 2, 237 | # device=input.device 238 | # ) 239 | # g_idx = dt_symmetrize(g_idx, dim=2)[..., self.image_size // 2:] 240 | # self.std_grid.data = spectra_to_grid_torch(spectra=stds, indices=g_idx) 241 | # self.std_spectra = stds 242 | -------------------------------------------------------------------------------- /voxelium/base/ctf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Module for calculations related to the contrast transfer function (CTF) 5 | in cryo-EM single particle analysis. 6 | """ 7 | 8 | from typing import Tuple, Union, TypeVar, Dict 9 | 10 | import numpy as np 11 | import torch 12 | 13 | Tensor = TypeVar('torch.tensor') 14 | 15 | 16 | class ContrastTransferFunction: 17 | def __init__( 18 | self, 19 | voltage: float, 20 | spherical_aberration: float = 0., 21 | amplitude_contrast: float = 0., 22 | phase_shift: float = 0., 23 | b_factor: float = 0., 24 | ) -> None: 25 | """ 26 | Initialization of the CTF parameter for an optics group. 27 | :param voltage: Voltage 28 | :param spherical_aberration: Spherical aberration 29 | :param amplitude_contrast: Amplitude contrast 30 | :param phase_shift: Phase shift 31 | :param b_factor: B-factor 32 | """ 33 | 34 | if voltage <= 0: 35 | raise RuntimeError( 36 | f"Invalid value ({voltage}) for voltage of optics group {id}." 37 | ) 38 | 39 | self.voltage = voltage 40 | self.spherical_aberration = spherical_aberration 41 | self.amplitude_contrast = amplitude_contrast 42 | self.phase_shift = phase_shift 43 | self.b_factor = b_factor 44 | 45 | # Adjust units 46 | spherical_aberration = spherical_aberration * 1e7 47 | voltage = voltage * 1e3 48 | 49 | # Relativistic wave length 50 | # See http://en.wikipedia.org/wiki/Electron_diffraction 51 | # lambda = h/sqrt(2*m*e) * 1/sqrt(V*(1+V*e/(2*m*c^2))) 52 | # h/sqrt(2*m*e) = 12.2642598 * 10^-10 meters -> 12.2642598 Angstrom 53 | # e/(2*m*c^2) = 9.78475598 * 10^-7 coulombs/joules 54 | lam = 12.2642598 / np.sqrt(voltage * (1. + voltage * 9.78475598e-7)) 55 | 56 | # Some constants 57 | self.c1 = -np.pi * lam 58 | self.c2 = np.pi / 2. * spherical_aberration * lam ** 3 59 | self.c3 = phase_shift * np.pi / 180. 60 | self.c4 = -b_factor/4. 61 | self.c5 = \ 62 | np.arctan( 63 | amplitude_contrast / np.sqrt(1-amplitude_contrast**2) 64 | ) 65 | 66 | self.xx = {} 67 | self.yy = {} 68 | self.xy = {} 69 | self.n4 = {} 70 | 71 | self.device = torch.device("cpu") 72 | 73 | def __call__( 74 | self, 75 | grid_size: int, 76 | pixel_size: float, 77 | u: Tensor, 78 | v: Tensor, 79 | angle: Tensor, 80 | h_sym: bool = False, 81 | antialiasing: int = 0 82 | ) -> Tensor: 83 | """ 84 | Get the CTF in an numpy array, the size of freq_x or freq_y. 85 | Generates a Numpy array or a Torch tensor depending on the object type 86 | on freq_x and freq_y passed to the constructor. 87 | :param u: the U defocus 88 | :param v: the V defocus 89 | :param angle: the azimuthal angle defocus (degrees) 90 | :param antialiasing: Antialiasing oversampling factor (0 = no antialiasing) 91 | :param grid_size: the side of the box 92 | :param pixel_size: pixel size 93 | :param h_sym: Only consider the hermitian half 94 | :return: Numpy array or Torch tensor containing the CTF 95 | """ 96 | 97 | # Use cache 98 | tag = f"{grid_size}_{round(pixel_size, 3)}_{h_sym}_{antialiasing}" 99 | if tag not in self.xx: 100 | freq_x, freq_y = self._get_freq(grid_size, pixel_size, h_sym, antialiasing) 101 | xx = freq_x**2 102 | yy = freq_y**2 103 | xy = freq_x * freq_y 104 | n4 = (xx + yy)**2 # Norms squared^2 105 | self.xx[tag] = xx.to(self.device) 106 | self.yy[tag] = yy.to(self.device) 107 | self.xy[tag] = xy.to(self.device) 108 | self.n4[tag] = n4.to(self.device) 109 | 110 | xx = self.xx[tag] 111 | yy = self.yy[tag] 112 | xy = self.xy[tag] 113 | n4 = self.n4[tag] 114 | 115 | angle = angle * np.pi / 180 116 | acos = torch.cos(angle) 117 | asin = torch.sin(angle) 118 | acos2 = torch.square(acos) 119 | asin2 = torch.square(asin) 120 | 121 | """ 122 | Out line of math for following three lines of code 123 | Q = [[sin cos] [-sin cos]] sin/cos of the angle 124 | D = [[u 0] [0 v]] 125 | A = Q^T.D.Q = [[Axx Axy] [Ayx Ayy]] 126 | Axx = cos^2 * u + sin^2 * v 127 | Ayy = sin^2 * u + cos^2 * v 128 | Axy = Ayx = cos * sin * (u - v) 129 | defocus = A.k.k^2 = Axx*x^2 + 2*Axy*x*y + Ayy*y^2 130 | """ 131 | 132 | xx_ = (acos2 * u + asin2 * v)[:, None, None] * xx[None, :, :] 133 | yy_ = (asin2 * u + acos2 * v)[:, None, None] * yy[None, :, :] 134 | xy_ = (acos * asin * (u - v))[:, None, None] * xy[None, :, :] 135 | 136 | gamma = self.c1 * (xx_ + 2. * xy_ + yy_) + self.c2 * n4[None, :, :] - self.c3 - self.c5 137 | ctf = -torch.sin(gamma) 138 | if self.c4 > 0: 139 | ctf *= torch.exp(self.c4 * n4) 140 | 141 | if antialiasing > 0: 142 | o = 2**antialiasing 143 | ctf = ctf.unsqueeze(1) # Add singleton channel 144 | ctf = torch.nn.functional.avg_pool2d(ctf, kernel_size=o+o//2, stride=o) 145 | ctf = ctf.squeeze(1) # Remove singleton channel 146 | 147 | return ctf 148 | 149 | def to(self, device): 150 | if self.device == device: 151 | return 152 | self.device = device 153 | for tag in self.xx: 154 | self.xx[tag] = self.xx[tag].to(device) 155 | self.yy[tag] = self.yy[tag].to(device) 156 | self.xy[tag] = self.xy[tag].to(device) 157 | self.n4[tag] = self.n4[tag].to(device) 158 | 159 | @staticmethod 160 | def _get_freq( 161 | grid_size: int, 162 | pixel_size: float, 163 | h_sym: bool = False, 164 | antialiasing: int = 0 165 | ) -> Union[ 166 | Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]], 167 | Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]] 168 | ]: 169 | """ 170 | Get the inverted frequencies of the Fourier transform of a square or cuboid grid. 171 | Can generate both Torch tensors and Numpy arrays. 172 | TODO Add 3D 173 | :param antialiasing: Antialiasing oversampling factor (0 = no antialiasing) 174 | :param grid_size: the side of the box 175 | :param pixel_size: pixel size 176 | :param h_sym: Only consider the hermitian half 177 | :return: two or three numpy arrays or tensors, 178 | containing frequencies along the different axes 179 | """ 180 | if antialiasing > 0: 181 | o = 2**antialiasing 182 | grid_size *= o 183 | y_ls = np.linspace( 184 | -(grid_size + o) // 2, 185 | (grid_size - o) // 2, 186 | grid_size + o//2 187 | ) 188 | x_ls = y_ls if not h_sym else torch.linspace(0, grid_size // 2, grid_size // 2 + o + 1) 189 | else: 190 | y_ls = np.linspace(-grid_size // 2, grid_size // 2 - 1, grid_size) 191 | x_ls = y_ls if not h_sym else torch.linspace(0, grid_size // 2, grid_size // 2 + 1) 192 | 193 | y, x = torch.meshgrid(torch.Tensor(y_ls), torch.Tensor(x_ls), indexing='ij') 194 | freq_x = x / (grid_size * pixel_size) 195 | freq_y = y / (grid_size * pixel_size) 196 | 197 | return freq_x, freq_y 198 | 199 | def get_state_dict(self) -> Dict: 200 | return { 201 | "type": "ContrastTransferFunction", 202 | "version": "0.0.1", 203 | "voltage": self.voltage, 204 | "spherical_aberration": self.spherical_aberration, 205 | "amplitude_contrast": self.amplitude_contrast, 206 | "phase_shift": self.phase_shift, 207 | "b_factor": self.b_factor 208 | } 209 | 210 | @staticmethod 211 | def load_from_state_dict(state_dict): 212 | if "type" not in state_dict or state_dict["type"] != "ContrastTransferFunction": 213 | raise TypeError("Input is not an 'ContrastTransferFunction' instance.") 214 | 215 | if "version" not in state_dict: 216 | raise RuntimeError("ContrastTransferFunction instance lacks version information.") 217 | 218 | if state_dict["version"] == "0.0.1": 219 | return ContrastTransferFunction( 220 | voltage=state_dict['voltage'], 221 | spherical_aberration=state_dict['spherical_aberration'], 222 | amplitude_contrast=state_dict['amplitude_contrast'], 223 | phase_shift=state_dict['phase_shift'], 224 | b_factor=state_dict['b_factor'], 225 | ) 226 | else: 227 | raise RuntimeError(f"Version '{state_dict['version']}' not supported.") 228 | 229 | 230 | if __name__ == "__main__": 231 | os1 = 0 232 | os2 = 2 233 | box = 200 234 | df = torch.Tensor([[20000]]) 235 | pixA = 1. 236 | 237 | freq1_x, freq1_y = ContrastTransferFunction._get_freq(box, pixA, antialiasing=os1) 238 | ctf1 = ContrastTransferFunction(300, 2.7, 0.1, 0, 0, os1) 239 | 240 | freq2_x, freq2_y = ContrastTransferFunction._get_freq(box, pixA, antialiasing=os2) 241 | ctf2 = ContrastTransferFunction(300, 2.7, 0.1, 0, 0, os2) 242 | 243 | test1 = ctf1(df, df, torch.zeros([1, 1])).cpu().numpy() 244 | test2 = ctf2(df, df, torch.zeros([1, 1])).cpu().numpy() 245 | diff = test1 - test2 246 | print("%.10f" % np.mean(np.abs(diff))) 247 | 248 | import matplotlib.pylab as plt 249 | _, [ax1, ax2, ax3] = plt.subplots(1, 3) 250 | ax1.imshow(test1) 251 | ax2.imshow(test2) 252 | ax3.imshow(diff) 253 | plt.show() 254 | -------------------------------------------------------------------------------- /voxelium/base/particle_image_preprocessor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Module for particle image preprocessing 5 | """ 6 | 7 | from typing import Tuple, TypeVar, Dict, Any, Union 8 | 9 | import numpy as np 10 | 11 | import torch 12 | 13 | from voxelium.base.grid import dht, smooth_circular_mask, smooth_square_mask, \ 14 | get_spectral_indices, get_spectral_avg, spectrum_to_grid 15 | 16 | Tensor = TypeVar('torch.tensor') 17 | 18 | 19 | def get_spectral_stats(img_stack: Union[Tensor, np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 20 | if len(img_stack.shape) != 3: 21 | raise RuntimeError(f"Input is not a stack") 22 | if len(img_stack) < 2: 23 | raise RuntimeError(f"Image stack too small") 24 | 25 | spectral_indices = get_spectral_indices(img_stack.shape[1:]) 26 | 27 | if torch.is_tensor(img_stack): 28 | stack_mean = torch.mean(img_stack, 0).cpu().numpy() 29 | stack_square_mean = torch.mean(torch.square(img_stack), 0).cpu().numpy() 30 | stack_std = torch.std(img_stack, 0).cpu().numpy() 31 | else: 32 | stack_mean = np.mean(img_stack, axis=0) 33 | stack_square = np.mean(np.sqaure(img_stack), axis=0) 34 | stack_std = np.std(img_stack, axis=0) 35 | 36 | mean_spectrum = get_spectral_avg(stack_mean, spectral_indices=spectral_indices) 37 | square_mean_spectrum = get_spectral_avg(stack_square_mean, spectral_indices=spectral_indices) 38 | std_spectrum = get_spectral_avg(stack_std, spectral_indices=spectral_indices) 39 | 40 | cutoff_idx = img_stack.shape[-1] // 2 + 1 41 | 42 | # No standardization beyond Nyqvist 43 | mean_spectrum[cutoff_idx:] = 0 44 | std_spectrum[cutoff_idx:] = 1 45 | 46 | return mean_spectrum, square_mean_spectrum, std_spectrum 47 | 48 | 49 | class ParticleImagePreprocessor: 50 | def __init__(self) -> None: 51 | self.image_size = None 52 | self.make_stats = None 53 | self.spectral_mean = None 54 | self.spectral_std = None 55 | self.circular_mask_radius = None 56 | self.circular_mask_thickness = None 57 | 58 | self.spectral_square_mean = None 59 | self.spectral_sigma2 = None 60 | 61 | self.precompute_spectral_mean = None 62 | self.precompute_spectral_std = None 63 | self.precompute_spectral_sigma2 = None 64 | self.precompute_spectral_mask = None 65 | self.precompute_circular_mask = None 66 | self.precompute_square_mask = None 67 | 68 | def initialize( 69 | self, 70 | image_size: int, 71 | circular_mask_radius: float, 72 | circular_mask_thickness: float, 73 | spectral_mean=None, 74 | spectral_std=None, 75 | spectral_square_mean=None, 76 | ) -> None: 77 | self.image_size = image_size 78 | self.circular_mask_radius = circular_mask_radius 79 | self.circular_mask_thickness = circular_mask_thickness 80 | self.spectral_mean = spectral_mean 81 | self.spectral_std = spectral_std 82 | self.spectral_square_mean = spectral_square_mean 83 | 84 | self._precompute() 85 | 86 | def initialize_from_stack( 87 | self, 88 | stack: np.ndarray, 89 | circular_mask_radius: float, 90 | circular_mask_thickness: float, 91 | ) -> None: 92 | self.image_size = stack.shape[-1] 93 | self.circular_mask_radius = circular_mask_radius 94 | self.circular_mask_thickness = circular_mask_thickness 95 | 96 | stack_ht = dht(stack, dim=2) 97 | self.spectral_mean, self.spectral_square_mean, self.spectral_std = get_spectral_stats(stack_ht) 98 | 99 | self.spectral_sigma2 = self.spectral_square_mean - np.square(self.spectral_mean) 100 | 101 | # import matplotlib.pylab as plt 102 | # x = np.arange(len(self.spectral_sigma2)) 103 | # plt.plot(x, self.spectral_sigma2, "r") 104 | # plt.plot(x, self.spectral_square_mean, "b") 105 | # plt.plot(x, self.spectral_sigma2, "k") 106 | # plt.show() 107 | 108 | self._precompute() 109 | 110 | def _precompute(self) -> None: 111 | # Square mask 112 | self.precompute_square_mask = torch.Tensor( 113 | smooth_square_mask( 114 | image_size=self.image_size, 115 | square_side=self.image_size - self.circular_mask_thickness * 2, 116 | thickness=self.circular_mask_thickness 117 | ) 118 | ) 119 | 120 | # Circular mask 121 | self.precompute_circular_mask = torch.Tensor( 122 | smooth_circular_mask( 123 | image_size=self.image_size, 124 | radius=self.circular_mask_radius, 125 | thickness=self.circular_mask_thickness 126 | ) 127 | ) 128 | 129 | # Calculate standardization coefficients 130 | cutoff_idx = self.image_size // 2 + 1 131 | spectral_indices = get_spectral_indices([self.image_size] * 2) 132 | 133 | if self.spectral_mean is not None and \ 134 | self.spectral_std is not None and \ 135 | self.spectral_square_mean is not None: 136 | self.spectral_sigma2 = self.spectral_square_mean - np.square(self.spectral_mean) 137 | # If we have statistics 138 | self.precompute_spectral_mean = spectrum_to_grid(self.spectral_mean, spectral_indices) 139 | self.precompute_spectral_std = spectrum_to_grid(self.spectral_std, spectral_indices) 140 | self.precompute_spectral_sigma2 = spectrum_to_grid(self.spectral_sigma2, spectral_indices) 141 | self.precompute_spectral_mask = \ 142 | ((spectral_indices < cutoff_idx) & (self.precompute_spectral_std > 1e-5)) 143 | 144 | self.precompute_spectral_mean = torch.Tensor(self.precompute_spectral_mean) 145 | self.precompute_spectral_std = torch.Tensor(self.precompute_spectral_std) 146 | self.precompute_spectral_sigma2 = torch.Tensor(self.precompute_spectral_sigma2) 147 | self.precompute_spectral_mask = torch.Tensor(self.precompute_spectral_mask) 148 | else: 149 | # If we don't have statistics, make grids of ones rather than raising an exception 150 | self.precompute_spectral_mean = torch.ones_like(self.precompute_square_mask) 151 | self.precompute_spectral_std = torch.ones_like(self.precompute_square_mask) 152 | self.precompute_spectral_sigma2 = torch.ones_like(self.precompute_spectral_sigma2) 153 | self.precompute_spectral_mask = torch.ones_like(self.precompute_square_mask) 154 | 155 | def set_device(self, device: Any) -> None: 156 | self.precompute_square_mask = self.precompute_square_mask.to(device) 157 | self.precompute_circular_mask = self.precompute_circular_mask.to(device) 158 | self.precompute_spectral_mean = self.precompute_spectral_mean.to(device) 159 | self.precompute_spectral_std = self.precompute_spectral_std.to(device) 160 | self.precompute_spectral_sigma2 = self.precompute_spectral_sigma2.to(device) 161 | self.precompute_spectral_mask = self.precompute_spectral_mask.to(device) 162 | 163 | def apply_square_mask(self, img_stack: Tensor) -> Tensor: 164 | return img_stack * self.precompute_square_mask[None, ...] 165 | 166 | def apply_circular_mask(self, img_stack: Tensor) -> Tensor: 167 | return img_stack * self.precompute_circular_mask[None, ...] 168 | 169 | def apply_translation( 170 | self, grids: Tensor, 171 | shift: Union[np.ndarray, Tensor], 172 | shift_y: Union[np.ndarray, Tensor] = None 173 | ) -> Tensor: 174 | # Generate the 3x4 matrix for affine grid RR = (E|S) = (eye|shift) 175 | if torch.is_tensor(shift): 176 | if shift_y is None: 177 | S = shift 178 | else: 179 | S = torch.stack([shift, shift_y], 1) 180 | else: 181 | if shift_y is None: 182 | S = torch.tensor(shift) 183 | else: 184 | S = torch.tensor(np.stack([shift, shift_y], 1)) 185 | 186 | S = -S.unsqueeze(2) * 2 / self.image_size 187 | B = S.shape[0] 188 | I = torch.eye(2).unsqueeze(0).to(S.device) 189 | 190 | RR = torch.cat([torch.tile(I, (B, 1, 1)), S], 2) 191 | 192 | # generate affine Grid 193 | grid = torch.nn.functional.affine_grid(RR, (B, 1, self.image_size, self.image_size), align_corners=False) 194 | grid = grid.to(grids.device) 195 | 196 | # apply shift 197 | img_stack_out = torch.nn.functional.grid_sample( 198 | input=grids.reshape(B, 1, self.image_size, self.image_size).float(), 199 | grid=grid.float(), 200 | mode='bilinear', 201 | align_corners=False 202 | ) 203 | img_stack_out = torch.squeeze(img_stack_out, 1) 204 | return img_stack_out 205 | 206 | def get_state_dict(self) -> Dict: 207 | return { 208 | "type": "ParticleImagePreprocessor", 209 | "version": "0.0.1", 210 | "image_size": self.image_size, 211 | "spectral_mean": self.spectral_mean, 212 | "spectral_std": self.spectral_std, 213 | "spectral_square_mean": self.spectral_square_mean, 214 | "circular_mask_radius": self.circular_mask_radius, 215 | "circular_mask_thickness": self.circular_mask_thickness 216 | } 217 | 218 | def set_state_dict(self, state_dict) -> None: 219 | if "type" not in state_dict or state_dict["type"] != "ParticleImagePreprocessor": 220 | raise TypeError("Input is not an 'ParticleImagePreprocessor' instance.") 221 | 222 | if "version" not in state_dict: 223 | raise RuntimeError("ParticleImagePreprocessor instance lacks version information.") 224 | 225 | if state_dict["version"] == "0.0.1": 226 | self.initialize( 227 | image_size=state_dict["image_size"], 228 | spectral_mean=state_dict["spectral_mean"], 229 | spectral_std=state_dict["spectral_std"], 230 | spectral_square_mean=state_dict["spectral_square_mean"], 231 | circular_mask_radius=state_dict["circular_mask_radius"], 232 | circular_mask_thickness=state_dict["circular_mask_thickness"], 233 | ) 234 | else: 235 | raise RuntimeError(f"Version '{state_dict['version']}' not supported.") 236 | -------------------------------------------------------------------------------- /voxelium/vae_volume/volume_explorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import time 4 | 5 | import umap 6 | import numpy as np 7 | import argparse 8 | 9 | import scipy.ndimage 10 | import matplotlib.pylab as plt 11 | 12 | import torch 13 | 14 | import multiprocessing as mp 15 | 16 | from voxelium.base.grid import idft, dt_desymmetrize, load_mrc, save_mrc 17 | from voxelium.vae_volume.data_analysis_container import DatasetAnalysisContainer 18 | from voxelium.vae_volume.utils import setup_device 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('logdir', help='input checkpoint file', type=str) 24 | parser.add_argument('--gpu', type=str, default=None, help='gpu to use') 25 | parser.add_argument('--roi', type=str, default=None, help='Region of interest') 26 | parser.add_argument('--dont_cache_embed', action="store_true") 27 | parser.add_argument('--ignore_cached_embed', action="store_true") 28 | parser.add_argument('--nogui', action="store_true") 29 | args = parser.parse_args() 30 | 31 | torch.no_grad() 32 | device, _ = setup_device(args) 33 | dac = DatasetAnalysisContainer.load_from_logdir(args.logdir) 34 | latent = dac.hidden_variable_container.latent_space.numpy().astype(np.float32) 35 | # embed = dac.hidden_variable_container.structure_basis.softmax(dim=-1).numpy().astype(np.float32) 36 | embed = dac.hidden_variable_container.latent_space.numpy().astype(np.float32) 37 | 38 | 39 | selection = np.arange(len(latent)) 40 | # np.random.shuffle(selection) 41 | # selection = selection[:min(len(latent), 10000)] 42 | 43 | latent = latent[selection] 44 | embed = embed[selection] 45 | if latent.shape[-1] > 2: 46 | if dac.auxiliaries is None or "embed" not in dac.auxiliaries or args.ignore_cached_embed: 47 | print("Creating 2D embedding...") 48 | # if "roi_basis_index" in dac.auxiliaries: 49 | # basis = embed[:, dac.auxiliaries["roi_basis_index"]] 50 | # else: 51 | # basis = embed 52 | 53 | basis = embed 54 | # basis = embed[:, :16] 55 | 56 | embed = umap.UMAP(local_connectivity=1, repulsion_strength=2).fit_transform( 57 | basis.astype(np.float32) 58 | ) 59 | 60 | # import tsnecuda 61 | # embed = tsnecuda.TSNE(n_components=2, num_neighbors=256, perplexity=100).fit_transform( 62 | # basis.astype(np.float32) 63 | # ) 64 | 65 | # from sklearn.decomposition import PCA 66 | # embed = PCA(n_components=2).fit_transform( 67 | # basis.astype(np.float32) 68 | # ) 69 | 70 | if dac.auxiliaries is None: 71 | dac.auxiliaries = {"embed": embed} 72 | else: 73 | dac.auxiliaries["embed"] = embed 74 | if not args.dont_cache_embed: 75 | print("Saving embedding to checkpoint file...") 76 | dac.save_to_checkpoint(args.logdir) 77 | else: 78 | embed = dac.auxiliaries["embed"] 79 | 80 | if args.nogui: 81 | print("No GUI... Exiting!") 82 | exit(0) 83 | 84 | # Else, import what the GUI needs 85 | 86 | import vtkmodules.vtkInteractionStyle 87 | # noinspection PyUnresolvedReferences 88 | import vtkmodules.vtkRenderingOpenGL2 89 | 90 | from voxelium.vae_volume.volume_renderer import volumeRendererProcessLoop 91 | 92 | roi = None 93 | if args.roi is not None: 94 | roi, _, _ = load_mrc(args.roi) 95 | roi = torch.Tensor(roi.copy()).to(device) 96 | 97 | N = 3000 98 | margin = 50 99 | marker_size = 10 100 | 101 | outlier_mask = np.abs(embed) > 5 102 | if np.sum(outlier_mask) < len(embed) * 0.3: 103 | embed[outlier_mask] = np.sign(embed[outlier_mask]) * 5 104 | 105 | x_min = np.min(embed[:, 0]) 106 | x_max = np.max(embed[:, 0]) 107 | y_min = np.min(embed[:, 1]) 108 | y_max = np.max(embed[:, 1]) 109 | 110 | x = (embed[:, 0] - x_min) / (x_max - x_min) 111 | y = (embed[:, 1] - y_min) / (y_max - y_min) 112 | 113 | x = margin + x * (N - 2. * margin) 114 | y = margin + y * (N - 2. * margin) 115 | 116 | heat_map = np.zeros((N, N)) 117 | heat_map[y.astype(int), x.astype(int)] += 1 118 | heat_map_smooth = scipy.ndimage.gaussian_filter(heat_map, 3) 119 | 120 | coord = np.zeros((len(x), 2)) 121 | coord[:, 0] = x 122 | coord[:, 1] = y 123 | 124 | vaec = dac.vae_container 125 | vaec.set_device(device) 126 | vaec.set_eval() 127 | 128 | data_spectra, data_ctf_spectra = dac.hidden_variable_container.get_data_stats(0) 129 | data_ctf_spectra = data_ctf_spectra.to(device) 130 | 131 | nn_time = 0 132 | ft_time = 0 133 | 134 | def get_volume(z): 135 | global nn_time, ft_time 136 | z = torch.Tensor(z).unsqueeze(0).to(device) 137 | 138 | t = time.time() 139 | sb, pp = vaec.basis_decoder(*vaec.split_latent_space(z)) 140 | v_ft = vaec.structure_decoder(sb, pp, data_spectra=data_ctf_spectra, do_postprocess=True) 141 | nn_time = nn_time * 0.9 + (time.time() - t) * 0.1 142 | 143 | v_ht = torch.view_as_complex(v_ft) 144 | v_ht = dt_desymmetrize(v_ht) 145 | 146 | t = time.time() 147 | vol = idft(v_ht, dim=3, real_in=True) 148 | ft_time = ft_time * 0.9 + (time.time() - t) * 0.1 149 | 150 | if roi is not None: 151 | vol *= roi 152 | vol /= torch.std(vol) 153 | 154 | print("NN time", nn_time, "FT time", ft_time) 155 | 156 | return vol[0].detach().cpu().numpy() 157 | 158 | # print("dumping first 800...") 159 | # count = 50 160 | # selected_idx = np.arange(800) 161 | # np.random.shuffle(selected_idx) 162 | # selected_idx = selected_idx[:count] 163 | # avg = None 164 | # for i in range(count): 165 | # vol = get_volume(latent[selected_idx[i]]) 166 | # save_mrc(vol, 1, [0, 0, 0], f"first_dumps/idx_{selected_idx[i]}.mrc") 167 | # if avg is None: 168 | # avg = vol 169 | # else: 170 | # avg += vol 171 | # save_mrc(avg / count, 1, [0, 0, 0], f"first_dumps/avg.mrc") 172 | # exit(0) 173 | 174 | fig_hm, ax_hm = plt.subplots(figsize=(7, 7)) # Heat map 175 | 176 | ax_hm.axis('off') 177 | plt.tight_layout() 178 | 179 | # MAKE HEAT MAP ------------------------------------------------------------------------------------- 180 | alpha = min(10./np.sqrt(len(x)), 1.) 181 | # ax_hm.imshow(heat_map_smooth, cmap='RdPu', zorder=0) 182 | ax_hm.plot(x, y, 'k.', markersize=1, alpha=0.1, zorder=1) 183 | # ax_hm.scatter(x, y, edgecolors=None, marker='.', c=np.arange(len(x)), cmap="summer", alpha=alpha) 184 | 185 | mx = np.mean(x) 186 | sx = np.std(x)*3 187 | my = np.mean(y) 188 | sy = np.std(y)*3 189 | 190 | ax_hm.set_xlim([mx - sx, mx + sx]) 191 | ax_hm.set_ylim([my - sy, my + sy]) 192 | 193 | # VOLUME RENDERER PROCESS --------------------------------------------------------------------------- 194 | process_loop_queue = mp.Queue() 195 | volume_render_dispatch_queue = mp.Queue() 196 | volume_render_process = mp.Process( 197 | target=volumeRendererProcessLoop, 198 | args=( 199 | volume_render_dispatch_queue, 200 | process_loop_queue 201 | ) 202 | ) 203 | volume_render_process.start() 204 | 205 | # SELECTION UPDATES -------------------------------------------------------------------------------- 206 | circles = [] 207 | circles_coord = [] 208 | selected_ids = [] 209 | volumes = [] 210 | selected_iso_value = None 211 | 212 | get_volume(np.mean(latent, axis=0)) # Warm up 213 | 214 | def onClickHm(event): 215 | if event.xdata is None or event.ydata is None: 216 | return 217 | 218 | xy = np.array([event.xdata, event.ydata]) 219 | 220 | state_change = False 221 | 222 | if event.button == 1: 223 | c = np.sum((coord - xy) ** 2, axis=1) 224 | idx = np.argmin(c) 225 | dis2 = c[idx] 226 | xy = [coord[idx, 0], coord[idx, 1]] 227 | 228 | if dis2 < (N/50)**2 and idx not in selected_ids: 229 | circle = plt.Circle(xy, marker_size, color='black', alpha=1, zorder=2) 230 | vol = get_volume(latent[idx]) 231 | circles.append(circle) 232 | circles_coord.append(xy) 233 | selected_ids.append(idx) 234 | volumes.append(vol) 235 | print("Selected point index", idx) 236 | ax_hm.add_patch(circle) 237 | state_change = True 238 | 239 | elif event.button == 3: 240 | if len(circles) > 0: 241 | c = np.sum((np.array(circles_coord) - xy) ** 2, axis=1) 242 | selected_idx = np.argmin(c) 243 | dis2 = c[selected_idx] 244 | 245 | if dis2 < (N/50)**2: 246 | circles[selected_idx].remove() 247 | del (circles[selected_idx]) 248 | del (circles_coord[selected_idx]) 249 | del (selected_ids[selected_idx]) 250 | del (volumes[selected_idx]) 251 | state_change = True 252 | 253 | if state_change: 254 | volume_render_dispatch_queue.put(volumes) 255 | fig_hm.canvas.draw() 256 | 257 | 258 | def onKeyHm(event): 259 | sys.stdout.flush() 260 | if event.key == 'escape': 261 | for i in range(len(circles)): 262 | circles[i].remove() 263 | 264 | circles.clear() 265 | circles_coord.clear() 266 | selected_ids.clear() 267 | volumes.clear() 268 | 269 | volume_render_dispatch_queue.put(volumes) 270 | fig_hm.canvas.draw() 271 | elif event.key == 'enter': 272 | print('Saving selected structures to MRC-files') 273 | for i, v in enumerate(volumes): 274 | save_mrc(v, 1, [0, 0, 0], "particle_id_" + str(selected_ids[i]) + ".mrc") 275 | 276 | # -------------------------------------------------------------------------------- 277 | 278 | click_connect = fig_hm.canvas.mpl_connect('button_press_event', onClickHm) 279 | key_connect = fig_hm.canvas.mpl_connect('key_press_event', onKeyHm) 280 | 281 | try: 282 | plt.show() 283 | except KeyboardInterrupt: 284 | print("Exiting!") 285 | 286 | volume_render_dispatch_queue.put(None) 287 | process_loop_queue.put(None) 288 | volume_render_process.join() 289 | volume_render_process.terminate() 290 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/svr_linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Python API for the sparse volume reconstruction linear layer 5 | """ 6 | import time 7 | from typing import TypeVar, Union 8 | 9 | import numpy as np 10 | import torch 11 | import voxelium_svr_linear 12 | 13 | from voxelium.base import grid_iterator, sample_gaussian_function 14 | from voxelium.relion import eulerToMatrix 15 | 16 | Tensor = TypeVar('torch.tensor') 17 | 18 | VOXEL_SPREAD_MARGIN = 3 19 | 20 | 21 | class SparseVolumeReconstructionLinear(torch.nn.Module): 22 | def __init__( 23 | self, 24 | size, 25 | input_size, 26 | dtype=torch.float32, 27 | bias=True, 28 | input_spectral_weight=None 29 | ): 30 | super(SparseVolumeReconstructionLinear, self).__init__() 31 | 32 | if size % 2 == 0: 33 | size += 1 34 | 35 | self.size = size 36 | self.size_x = size // 2 + 1 37 | self.input_size = input_size 38 | 39 | bz = size 40 | bz_2 = bz // 2 41 | grid_indices = np.zeros((bz, bz, bz // 2 + 1), dtype=int) - 1 42 | max_r2 = (size // 2) ** 2 43 | i = 0 44 | # for z, y, x in grid_iterator(bz-1, bz-1, bz_2+1): 45 | for z, y, x in grid_iterator(bz, bz, bz_2 + 1): 46 | if (z - bz_2) ** 2 + (y - bz_2) ** 2 + x ** 2 <= max_r2: 47 | grid_indices[z, y, x] = i 48 | i += 1 49 | 50 | self.weight_count = i 51 | 52 | # Add margin for pixel spread into voxels 53 | m = VOXEL_SPREAD_MARGIN 54 | bz += m * 2 55 | bz_2 += m 56 | grid_indices_margin = np.zeros((bz, bz, bz_2 + 1), dtype=int) - 1 57 | grid_indices_margin[m:-m, m:-m, :-m] = grid_indices 58 | grid_indices = grid_indices_margin 59 | 60 | for i in range(5): 61 | self.radial_expansion(grid_indices) 62 | 63 | self.grid3d_index = torch.nn.Parameter( 64 | torch.tensor(grid_indices, dtype=torch.long), requires_grad=False) 65 | 66 | data_tensor = torch.empty((self.weight_count, input_size, 2), dtype=dtype).normal_() 67 | self.weight = torch.nn.Parameter(data=data_tensor, requires_grad=True) 68 | 69 | if bias: 70 | data_tensor = torch.empty((self.weight_count, 2), dtype=dtype).normal_() 71 | self.bias = torch.nn.Parameter(data=data_tensor, requires_grad=True) 72 | else: 73 | self.bias = None 74 | 75 | self.weight.data *= 1e-6 76 | if self.bias is not None: 77 | self.bias.data *= 0 78 | 79 | self.input_spectral_weight = torch.nn.Parameter( 80 | torch.empty(0, dtype=self.weight.dtype) if input_spectral_weight is None else input_spectral_weight 81 | ) 82 | self.input_spectral_weight.requires_grad = False 83 | 84 | def forward(self, input, max_r=None, grid2d_coord=None, rot_matrices=None, sparse_grad=True): 85 | if rot_matrices is not None and grid2d_coord is not None: 86 | return TrilinearProjection.apply( 87 | input, # input 88 | self.weight, # weight 89 | self.bias, # bias 90 | self.grid3d_index, # grid3d_index 91 | rot_matrices, # rot_matrices 92 | grid2d_coord, # grid2d_coord 93 | max_r, # max_r 94 | self.input_spectral_weight, # input_spectral_weight 95 | sparse_grad, # sparse_grad 96 | False # testing 97 | ) 98 | else: 99 | return VolumeExtraction.apply( 100 | input, # input 101 | self.weight, # weight 102 | self.bias, # bias 103 | self.grid3d_index, # grid3d_index 104 | self.input_spectral_weight, # input_spectral_weight 105 | max_r # max_r 106 | ) 107 | 108 | def set_reference(self, grid_ht: Union[Tensor, np.ndarray]): 109 | self.weight.data = torch.zeros_like(self.weight.data) 110 | m = VOXEL_SPREAD_MARGIN 111 | max_r2 = (self.size // 2) ** 2 112 | s = (self.size - 1) // 2 113 | for z, y, x in grid_iterator(self.size, self.size, self.size_x): 114 | i = self.grid3d_index[z + m, y + m, x] 115 | if i >= 0 and (z - s) ** 2 + (y - s) ** 2 + x ** 2 <= max_r2: 116 | self.weight.data[i] = grid_ht[z, y, x] 117 | 118 | @staticmethod 119 | def radial_expansion(grid): 120 | assert grid.shape[0] == grid.shape[1] == grid.shape[2] * 2 - 1 121 | bz = grid.shape[0] 122 | bz2 = bz // 2 123 | mask1 = grid == -1 124 | 125 | ls = np.linspace(-bz2, bz2, bz) 126 | lsx = np.linspace(0, bz2, bz // 2 + 1) 127 | z, y, x = np.meshgrid(ls, ls, lsx) 128 | c = np.zeros((int(np.sum(mask1)), 3)) 129 | c[:, 0] = x[mask1] 130 | c[:, 1] = y[mask1] 131 | c[:, 2] = z[mask1] 132 | 133 | norm = np.sqrt(np.sum(np.square(c), axis=1)) 134 | c_ = np.round(c / norm[:, None]).astype(int) 135 | 136 | c = c.astype(int) 137 | c[:, 1:] += bz2 138 | c_ = c - c_ 139 | 140 | g = grid[c_[:, 2], c_[:, 1], c_[:, 0]] 141 | mask2 = g >= 0 142 | 143 | grid[c[mask2, 2], c[mask2, 1], c[mask2, 0]] = g[mask2] 144 | 145 | 146 | class TrilinearProjection(torch.autograd.Function): 147 | @staticmethod 148 | def forward( 149 | ctx, input, weight, bias, grid3d_index, 150 | rot_matrices, grid2d_coord, max_r, 151 | input_spectral_weight=None, 152 | sparse_grad=True, testing=False 153 | ): 154 | assert grid3d_index.shape[0] == grid3d_index.shape[1] == grid3d_index.shape[2] * 2 - 1 155 | if max_r is None: 156 | max_r = (grid3d_index.shape[0] - 2 * VOXEL_SPREAD_MARGIN) // 2 157 | else: 158 | max_r = min(max_r, (grid3d_index.shape[0] - 2 * VOXEL_SPREAD_MARGIN) // 2) 159 | 160 | input_spectral_weight = torch.empty(0, dtype=input.dtype).to(input.device) \ 161 | if input_spectral_weight is None else input_spectral_weight 162 | 163 | output = voxelium_svr_linear.trilinear_projection_forward( 164 | input=input, 165 | weight=weight, 166 | bias=torch.empty([0, 0], dtype=weight.dtype).to(weight.device) if bias is None else bias, 167 | rot_matrix=rot_matrices, 168 | grid2d_coord=grid2d_coord, 169 | grid3d_index=grid3d_index, 170 | max_r=max_r 171 | ) 172 | 173 | ctx.save_for_backward( 174 | input, 175 | weight, 176 | bias, 177 | grid3d_index, 178 | rot_matrices, 179 | grid2d_coord, 180 | input_spectral_weight, 181 | torch.Tensor([max_r]), 182 | torch.Tensor([sparse_grad]), 183 | torch.Tensor([testing]) 184 | ) 185 | 186 | return output 187 | 188 | @staticmethod 189 | def backward(ctx, grad_output): 190 | input, weight, bias, grid3d_index, rot_matrices, grid2d_coord, \ 191 | input_spectral_weight, max_r, sparse_grad, testing \ 192 | = ctx.saved_tensors 193 | sparse_grad = bool(sparse_grad[0]) 194 | 195 | grad_input, grad_weight_index, grad_weight_values, grad_bias_values, grad_rot_matrix = \ 196 | voxelium_svr_linear.trilinear_projection_backward( 197 | input=input, 198 | grid2d_grad=grad_output.contiguous(), 199 | weight=weight, 200 | bias=torch.empty([0, 0], dtype=weight.dtype).to(weight.device) if bias is None else bias, 201 | grid3d_index=grid3d_index, 202 | rot_matrix=rot_matrices, 203 | grid2d_coord=grid2d_coord, 204 | input_spectral_weight=input_spectral_weight, 205 | max_r=max_r[0], 206 | sparse_grad=sparse_grad 207 | ) 208 | 209 | if sparse_grad: 210 | grad_weight = torch.sparse_coo_tensor( 211 | grad_weight_index.contiguous().view(1, -1), 212 | grad_weight_values.contiguous().view(-1, weight.shape[-2], 2), 213 | weight.shape 214 | ) 215 | else: 216 | grad_weight = grad_weight_values 217 | 218 | if bias is None: 219 | grad_bias = None 220 | else: 221 | if sparse_grad: 222 | grad_bias = torch.sparse_coo_tensor( 223 | grad_weight_index.contiguous().view(1, -1), 224 | grad_bias_values.contiguous().view(-1, 2), 225 | bias.shape 226 | ) 227 | else: 228 | grad_bias = grad_bias_values 229 | 230 | if testing[0] and sparse_grad: 231 | grad_weight = grad_weight.to_dense() 232 | if bias is not None: 233 | grad_bias = grad_bias.to_dense() 234 | 235 | return grad_input, grad_weight, grad_bias, None, grad_rot_matrix, None, None, None, None, None 236 | 237 | 238 | class VolumeExtraction(torch.autograd.Function): 239 | @staticmethod 240 | def forward(ctx, input, weight, bias, grid3d_index, input_spectral_weight=None, max_r=None): 241 | 242 | assert grid3d_index.shape[0] == grid3d_index.shape[1] == grid3d_index.shape[2] * 2 - 1 243 | if max_r is None: 244 | max_r = (grid3d_index.shape[0] - 2 * VOXEL_SPREAD_MARGIN) // 2 245 | else: 246 | max_r = min(max_r, (grid3d_index.shape[0] - 2 * VOXEL_SPREAD_MARGIN) // 2) 247 | 248 | input_spectral_weight = torch.empty(0, dtype=input.dtype).to(input.device) \ 249 | if input_spectral_weight is None else input_spectral_weight 250 | 251 | output = voxelium_svr_linear.volume_extraction_forward( 252 | input=input, 253 | weight=weight, 254 | bias=torch.empty(0, dtype=weight.dtype).to(weight.device) if bias is None else bias, 255 | grid3d_index=grid3d_index, 256 | max_r=max_r 257 | ) 258 | ctx.save_for_backward(input, weight, bias, input_spectral_weight, grid3d_index) 259 | 260 | return output 261 | 262 | @staticmethod 263 | def backward(ctx, grad_output): 264 | input, weight, bias, input_spectral_weight, grid3d_index = ctx.saved_tensors 265 | 266 | input_spectral_weight = torch.empty(0, dtype=weight.dtype).to(weight.device) \ 267 | if input_spectral_weight is None else input_spectral_weight 268 | 269 | grad_input, grad_weight, grad_bias = \ 270 | voxelium_svr_linear.volume_extraction_backward( 271 | input=input, 272 | weight=weight, 273 | bias=torch.empty(0, dtype=weight.dtype).to(weight.device) if bias is None else bias, 274 | grad_output=grad_output, 275 | grid3d_index=grid3d_index, 276 | input_spectral_weight=input_spectral_weight 277 | ) 278 | 279 | return grad_input, grad_weight, grad_bias, None, None, None 280 | -------------------------------------------------------------------------------- /voxelium/vae_volume/svr_linear/volume_extraction_cuda_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "vae_volume/svr_linear/base_cuda.cuh" 12 | #include "vae_volume/svr_linear/volume_extraction_cuda_kernels.h" 13 | 14 | 15 | template 16 | __global__ void volume_extraction_forward_cuda_kernel( 17 | const torch::PackedTensorAccessor64 grid3d_index, 18 | const torch::PackedTensorAccessor64 weight, 19 | const torch::PackedTensorAccessor64 bias, 20 | const torch::PackedTensorAccessor32 input, 21 | torch::PackedTensorAccessor64 output, 22 | const int offset, 23 | const int max_r2 24 | ) 25 | { 26 | size_t b, z, y, x; 27 | if (thread_index_expand(output.size(0), output.size(1), output.size(2), output.size(3), b, z, y, x)) 28 | { 29 | const long zp = z - (output.size(1) - 1) / 2; 30 | const long yp = y - (output.size(2) - 1) / 2; 31 | const long xp = x; 32 | 33 | if (xp*xp + yp*yp + zp*zp <= max_r2) 34 | { 35 | const long i = grid3d_index[z+offset][y+offset][x]; 36 | 37 | for (int c = 0; c < 2; c++) // Over real and imaginary 38 | { 39 | accscaler_t v(0); 40 | for (int j = 0; j < input.size(1); j ++) 41 | v += weight[i][j][c] * input[b][j]; 42 | if (do_bias) 43 | v += bias[i][c]; 44 | output[b][z][y][x][c] = (scalar_t) v; 45 | } 46 | } 47 | } 48 | } 49 | 50 | void volume_extraction_forward_cuda( 51 | const torch::Tensor grid3d_index, 52 | const torch::Tensor weight, 53 | const torch::Tensor bias, 54 | const torch::Tensor input, 55 | torch::Tensor output, 56 | const int max_r2, 57 | const bool do_bias 58 | ) 59 | { 60 | CHECK_CUDA_INPUT(grid3d_index) 61 | CHECK_CUDA_INPUT(weight) 62 | CHECK_CUDA_INPUT(bias) 63 | CHECK_CUDA_INPUT(input) 64 | CHECK_CUDA_INPUT(output) 65 | 66 | const int offset = (grid3d_index.size(0) - output.size(1)) / 2; 67 | 68 | const int deviceId = input.device().index(); 69 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(deviceId); 70 | CUDA_ERRCHK(cudaSetDevice(deviceId)); 71 | 72 | const dim3 threads(512); 73 | const dim3 blocks( 74 | ( 75 | output.size(0) * 76 | output.size(1) * 77 | output.size(2) * 78 | output.size(3) + 79 | threads.x - 1 80 | ) / threads.x 81 | ); 82 | 83 | std::array bargs={{do_bias}}; 84 | dispatch_bools<1>{}( 85 | bargs, 86 | [&](auto...Bargs) { 87 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 88 | input.scalar_type(), 89 | "volume_extraction_forward_cuda_kernel", 90 | [&] { 91 | using accscalar_t = at::acc_type; 92 | volume_extraction_forward_cuda_kernel 93 | 94 | <<>>( 95 | /*grid3d_index*/ grid3d_index.packed_accessor64(), 96 | /*weight*/ weight.packed_accessor64(), 97 | /*bias*/ bias.packed_accessor64(), 98 | /*input*/ input.packed_accessor32(), 99 | /*output*/ output.packed_accessor64(), 100 | /*offset*/ offset, 101 | /*max_r2*/ max_r2 102 | ); 103 | } 104 | ); 105 | } 106 | ); 107 | 108 | #ifdef DEBUG 109 | CUDA_ERRCHK(cudaPeekAtLastError()); 110 | CUDA_ERRCHK(cudaDeviceSynchronize()); 111 | #endif 112 | } 113 | 114 | template 115 | __global__ void volume_extraction_backward_cuda_kernel( 116 | const torch::PackedTensorAccessor64 grid3d_index, 117 | const torch::PackedTensorAccessor64 weight, 118 | const torch::PackedTensorAccessor64 bias, 119 | const torch::PackedTensorAccessor32 input_spectral_weight, 120 | const torch::PackedTensorAccessor32 input, 121 | const torch::PackedTensorAccessor64 grad_output, 122 | torch::PackedTensorAccessor64 grad_weight, 123 | torch::PackedTensorAccessor64 grad_bias, 124 | torch::PackedTensorAccessor32 grad_input, 125 | const int offset, 126 | const int max_r2, 127 | const size_t grad_weight_numel, 128 | const size_t grad_bias_numel, 129 | const size_t grad_input_numel 130 | ) 131 | { 132 | size_t b, z, y, x; 133 | if ( 134 | thread_index_expand( 135 | grad_output.size(0), 136 | grad_output.size(1), 137 | grad_output.size(2), 138 | grad_output.size(3), 139 | b, z, y, x 140 | ) 141 | ) 142 | { 143 | const long zp = z - (grad_output.size(1) - 1) / 2; 144 | const long yp = y - (grad_output.size(2) - 1) / 2; 145 | const long xp = x; 146 | 147 | scalar_t r = xp*xp + yp*yp + zp*zp; 148 | if (r <= max_r2) 149 | { 150 | if (do_spectral_weighting) 151 | r = input_spectral_weight[(int) sqrt((float) r)]; 152 | 153 | const long i = grid3d_index[z+offset][y+offset][x]; 154 | for (int c = 0; c < 2; c++) // Over real and imaginary 155 | { 156 | for (int j = 0; j < input.size(1); j ++) 157 | { 158 | if (do_input_grad) 159 | at::native::fastAtomicAdd( 160 | grad_input.data(), 161 | accessor_index_collapse(grad_input, b, j), 162 | grad_input_numel, 163 | do_spectral_weighting ? 164 | grad_output[b][z][y][x][c] * weight[i][j][c] * r : 165 | grad_output[b][z][y][x][c] * weight[i][j][c], 166 | true 167 | ); 168 | 169 | at::native::fastAtomicAdd( 170 | grad_weight.data(), 171 | accessor_index_collapse(grad_weight, i, j, c), 172 | grad_weight_numel, 173 | grad_output[b][z][y][x][c] * input[b][j], 174 | true 175 | ); 176 | } 177 | 178 | if (do_bias) 179 | at::native::fastAtomicAdd( 180 | grad_bias.data(), 181 | accessor_index_collapse(grad_bias, i, c), 182 | grad_bias_numel, 183 | grad_output[b][z][y][x][c], 184 | true 185 | ); 186 | } 187 | } 188 | } 189 | } 190 | 191 | 192 | void volume_extraction_backward_cuda( 193 | const torch::Tensor grid3d_index, 194 | const torch::Tensor weight, 195 | const torch::Tensor bias, 196 | const torch::Tensor input_spectral_weight, 197 | const torch::Tensor input, 198 | const torch::Tensor grad_output, 199 | torch::Tensor grad_weight, 200 | torch::Tensor grad_bias, 201 | torch::Tensor grad_input, 202 | const int max_r2, 203 | const bool do_bias, 204 | const bool do_input_grad, 205 | const bool do_spectral_weighting 206 | ) 207 | { 208 | CHECK_CUDA_INPUT(grid3d_index) 209 | CHECK_CUDA_INPUT(weight) 210 | CHECK_CUDA_INPUT(bias) 211 | CHECK_CUDA_INPUT(input_spectral_weight) 212 | CHECK_CUDA_INPUT(input) 213 | CHECK_CUDA_INPUT(grad_output) 214 | CHECK_CUDA_INPUT(grad_weight) 215 | CHECK_CUDA_INPUT(grad_bias) 216 | CHECK_CUDA_INPUT(grad_input) 217 | 218 | const int offset = (grid3d_index.size(0) - grad_output.size(1)) / 2; 219 | 220 | const int deviceId = input.device().index(); 221 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(deviceId); 222 | CUDA_ERRCHK(cudaSetDevice(deviceId)); 223 | 224 | const dim3 threads(512); 225 | const dim3 blocks( 226 | ( 227 | grad_output.size(0) * 228 | grad_output.size(1) * 229 | grad_output.size(2) * 230 | grad_output.size(3) + 231 | threads.x - 1 232 | ) / threads.x 233 | ); 234 | 235 | std::array bargs={{do_bias, do_input_grad, do_spectral_weighting}}; 236 | dispatch_bools<3>{}( 237 | bargs, 238 | [&](auto...Bargs) { 239 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 240 | input.scalar_type(), 241 | "volume_extraction_backward_cuda_kernel", 242 | [&] { 243 | volume_extraction_backward_cuda_kernel 244 | <<>>( 245 | /*grid3d_index*/ grid3d_index.packed_accessor64(), 246 | /*weight*/ weight.packed_accessor64(), 247 | /*bias*/ bias.packed_accessor64(), 248 | /*input_spectral_weight*/ input_spectral_weight.packed_accessor32 249 | (), 250 | /*input*/ input.packed_accessor32(), 251 | /*grad*/ grad_output.packed_accessor64(), 252 | /*grad_weight*/ grad_weight.packed_accessor64(), 253 | /*grad_bias*/ grad_bias.packed_accessor64(), 254 | /*grad_input*/ grad_input.packed_accessor32(), 255 | /*offset*/ offset, 256 | /*max_r2*/ max_r2, 257 | /*grad_weight_numel*/ grad_weight.numel(), 258 | /*grad_bias_numel*/ grad_bias.numel(), 259 | /*grad_input_numel*/ grad_input.numel() 260 | ); 261 | } 262 | ); 263 | } 264 | ); 265 | 266 | #ifdef DEBUG 267 | CUDA_ERRCHK(cudaPeekAtLastError()); 268 | CUDA_ERRCHK(cudaDeviceSynchronize()); 269 | #endif 270 | } 271 | -------------------------------------------------------------------------------- /voxelium/relion/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Module for loading RELION particle datasets 5 | """ 6 | 7 | import os 8 | import warnings 9 | from glob import glob 10 | import numpy as np 11 | import mrcfile 12 | 13 | from typing import List 14 | 15 | from voxelium.base.ctf import ContrastTransferFunction 16 | from voxelium.base.particle_dataset import ParticleDataset 17 | from voxelium.base.star_file import load_star 18 | 19 | 20 | class RelionDataset: 21 | def __init__(self, path: str = None, dtype: np.dtype = np.float32, random_subset: str = None): 22 | self.dtype = dtype 23 | self.random_subset = random_subset 24 | self.project_root = None 25 | self.data_star_path = None 26 | self.preload = None 27 | self.image_file_paths = [] 28 | 29 | # In data star file 30 | self.part_random_subset = [] 31 | self.part_rotation = [] 32 | self.part_translation = [] 33 | self.part_defocus = [] 34 | self.part_og_idx = [] 35 | self.part_stack_idx = [] 36 | self.part_image_file_path_idx = [] 37 | self.part_norm_correction = [] 38 | self.part_noise_group_id = [] 39 | self.nr_particles = None 40 | 41 | # In data star file 42 | self.optics_groups = [] 43 | self.optics_groups_ids = [] 44 | 45 | if path is not None: 46 | self.load(path) 47 | 48 | def load(self, path: str) -> None: 49 | """ 50 | Load data from path 51 | :param path: relion job directory or data file 52 | """ 53 | if os.path.isfile(path): 54 | data_star_path = path 55 | root_search_path = os.path.dirname(os.path.abspath(path)) 56 | else: 57 | data_star_path = os.path.abspath(self._find_star_file_in_path(path, "data")) 58 | root_search_path = os.path.abspath(path) 59 | 60 | self.data_star_path = os.path.abspath(data_star_path) 61 | data = load_star(self.data_star_path) 62 | 63 | if 'optics' not in data: 64 | raise RuntimeError("Optics groups table not found in data star file") 65 | if 'particles' not in data: 66 | raise RuntimeError("Particles table not found in data star file") 67 | 68 | self._load_optics_group(data['optics']) 69 | self._load_particles(data['particles']) 70 | 71 | self.project_root = self._find_project_root(root_search_path, self.image_file_paths[0]) 72 | 73 | # Convert image paths to absolute paths 74 | for i in range(len(self.image_file_paths)): 75 | self.image_file_paths[i] = os.path.abspath(os.path.join(self.project_root, self.image_file_paths[i])) 76 | 77 | # TODO check cross reference integrity, e.g. all part_noise_group_id exist in noise_group_id 78 | 79 | def make_particle_dataset(self): 80 | dataset = ParticleDataset() 81 | dataset.initialize( 82 | image_file_paths = self.image_file_paths, 83 | part_random_subset = self.part_random_subset, 84 | part_rotation = self.part_rotation, 85 | part_translation = self.part_translation, 86 | part_defocus = self.part_defocus, 87 | part_og_idx = self.part_og_idx, 88 | part_stack_idx = self.part_stack_idx, 89 | part_image_file_path_idx = self.part_image_file_path_idx, 90 | part_norm_correction = self.part_norm_correction, 91 | part_noise_group_id = self.part_noise_group_id, 92 | optics_group_stats = self.optics_groups, 93 | dtype = self.dtype, 94 | ) 95 | return dataset 96 | 97 | def _load_optics_group(self, optics: dict) -> None: 98 | if 'rlnOpticsGroup' not in optics: 99 | raise RuntimeError( 100 | "Optics group id (rlnOpticsGroup) is required, " 101 | "but was not found in optics group table." 102 | ) 103 | 104 | if 'rlnImageSize' not in optics: 105 | raise RuntimeError( 106 | "Image size (rlnImageSize) is required, " 107 | "but was not found in optics group table." 108 | ) 109 | 110 | if 'rlnImagePixelSize' not in optics: 111 | raise RuntimeError( 112 | "Image pixel size (rlnImagePixelSize) is required, " 113 | "but was not found in optics group table." 114 | ) 115 | 116 | nr_optics = len(optics['rlnOpticsGroup']) 117 | 118 | for i in range(nr_optics): 119 | id = int(optics['rlnOpticsGroup'][i]) 120 | image_size = int(optics['rlnImageSize'][i]) 121 | pixel_size = float(optics['rlnImagePixelSize'][i]) 122 | 123 | if image_size <= 0 or image_size % 2 != 0: 124 | raise RuntimeError( 125 | f"Invalid value ({image_size}) for image size of optics group {id}.\n" 126 | f"Image size must be even and larger than 0." 127 | ) 128 | if pixel_size <= 0: 129 | raise RuntimeError( 130 | f"Invalid value ({pixel_size}) for pixel size of optics group {id}." 131 | ) 132 | 133 | voltage = float(optics['rlnVoltage'][i]) \ 134 | if 'rlnVoltage' in optics else None 135 | spherical_aberration = float(optics['rlnSphericalAberration'][i]) \ 136 | if 'rlnSphericalAberration' in optics else None 137 | amplitude_contrast = float(optics['rlnAmplitudeContrast'][i]) \ 138 | if 'rlnAmplitudeContrast' in optics else None 139 | 140 | self.optics_groups_ids.append(id) 141 | self.optics_groups.append({ 142 | "id": id, 143 | "image_size": image_size, 144 | "pixel_size": pixel_size, 145 | "voltage": voltage, 146 | "spherical_aberration": spherical_aberration, 147 | "amplitude_contrast": amplitude_contrast 148 | }) 149 | 150 | def _load_particles(self, particles: dict) -> None: 151 | if 'rlnImageName' not in particles: 152 | raise RuntimeError( 153 | "Image name (rlnImageName) is required, " 154 | "but was not found in particles table." 155 | ) 156 | 157 | if 'rlnOpticsGroup' not in particles: 158 | raise RuntimeError( 159 | "Optics group id (rlnOpticsGroup) is required, " 160 | "but was not found in particles table." 161 | ) 162 | 163 | nr_particles = len(particles['rlnImageName']) 164 | 165 | for i in range(nr_particles): 166 | if self.random_subset is not None: 167 | if particles["rlnRandomSubset"][i] != self.random_subset: 168 | continue 169 | 170 | # Optics group --------------------------------------- 171 | og_id = int(particles['rlnOpticsGroup'][i]) 172 | og_idx = self.optics_groups_ids.index(og_id) 173 | self.part_og_idx.append(og_idx) 174 | og = self.optics_groups[og_idx] 175 | 176 | # Norm correction ------------------------------------- 177 | if 'rlnNormCorrection' in particles: 178 | nc = float(particles['rlnNormCorrection'][i]) 179 | self.part_norm_correction.append(nc) 180 | else: 181 | self.part_norm_correction.append(1.) 182 | 183 | # Noise group ----------------------------------------- 184 | if 'rlnGroupNumber' in particles: 185 | ng = int(particles['rlnGroupNumber'][i]) 186 | self.part_noise_group_id.append(ng) 187 | else: 188 | self.part_noise_group_id.append(None) 189 | 190 | # CTF parameters ------------------------------------- 191 | if 'rlnDefocusU' in particles and \ 192 | 'rlnDefocusV' in particles and \ 193 | 'rlnDefocusAngle' in particles: 194 | ctf_u = float(particles['rlnDefocusU'][i]) 195 | ctf_v = float(particles['rlnDefocusV'][i]) 196 | ctf_a = float(particles['rlnDefocusAngle'][i]) 197 | self.part_defocus.append([ctf_u, ctf_v, ctf_a]) 198 | else: 199 | self.part_defocus.append(None) 200 | 201 | # Rotation parameters -------------------------------- 202 | if 'rlnAngleRot' in particles and \ 203 | 'rlnAngleTilt' in particles and \ 204 | 'rlnAnglePsi' in particles: 205 | a = np.array([ 206 | float(particles['rlnAngleRot'][i]), 207 | float(particles['rlnAngleTilt'][i]), 208 | float(particles['rlnAnglePsi'][i]) 209 | ]) 210 | a *= np.pi / 180. 211 | self.part_rotation.append(a) 212 | elif 'rlnAnglePsi' in particles: 213 | a = np.array([0., 0., float(particles['rlnAnglePsi'][i])]) 214 | a *= np.pi / 180. 215 | self.part_rotation.append(a) 216 | else: 217 | self.part_rotation.append(np.zeros([3])) 218 | 219 | # Translation parameters ------------------------------ 220 | if 'rlnOriginXAngst' in particles and 'rlnOriginYAngst' in particles: 221 | trans_x = float(particles['rlnOriginXAngst'][i]) / og['pixel_size'] 222 | trans_y = float(particles['rlnOriginYAngst'][i]) / og['pixel_size'] 223 | else: 224 | trans_x = 0. 225 | trans_y = 0. 226 | self.part_translation.append([trans_x, trans_y]) 227 | 228 | # Image data ------------------------------------------ 229 | img_name = particles['rlnImageName'][i] 230 | img_tokens = img_name.split("@") 231 | if len(img_tokens) == 2: 232 | image_stack_id = int(img_tokens[0]) - 1 233 | img_path = img_tokens[1] 234 | elif len(img_tokens) == 1: 235 | image_stack_id = 0 236 | img_path = img_tokens[1] 237 | else: 238 | raise RuntimeError(f"Invalid image file name (rlnImageName): {img_name}") 239 | 240 | self.part_random_subset.append(particles["rlnRandomSubset"][i]) 241 | self.part_stack_idx.append(image_stack_id) 242 | 243 | try: # Assume image file path has been added to list 244 | img_path_idx = self.image_file_paths.index(img_path) 245 | self.part_image_file_path_idx.append(img_path_idx) 246 | except ValueError: # If image file path not found in existing list 247 | img_path_idx = len(self.image_file_paths) 248 | self.part_image_file_path_idx.append(img_path_idx) 249 | self.image_file_paths.append(img_path) 250 | 251 | self.part_og_idx = np.array(self.part_og_idx) 252 | self.part_defocus = np.array(self.part_defocus, dtype=np.float32) 253 | self.part_rotation = np.array(self.part_rotation, dtype=np.float32) 254 | self.part_translation = np.array(self.part_translation, dtype=np.float32) 255 | self.part_noise_group_id = np.array(self.part_noise_group_id) 256 | self.part_stack_idx = np.array(self.part_stack_idx) 257 | self.part_image_file_path_idx = np.array(self.part_image_file_path_idx) 258 | self.nr_particles = len(self.part_image_file_path_idx) 259 | 260 | @staticmethod 261 | def _find_star_file_in_path(path: str, type: str = "optimiser") -> str: 262 | if os.path.isfile(os.path.join(path, f"run_{type}.star")): 263 | return os.path.join(path, f"run_{type}.star") 264 | files = glob(os.path.join(path, f"*{type}.star")) 265 | if len(files) > 0: 266 | files = list.sort(files) 267 | return files[-1] 268 | 269 | raise FileNotFoundError(f"Could not find '{type}' star-file in path: {path}") 270 | 271 | @staticmethod 272 | def _find_project_root(from_path: str, file_relative_path: str) -> str: 273 | """ 274 | Searches for the Relion project root starting at from_path and iterate through parent directories 275 | till file_relative_path is found as a relative sub path or till filesystem root is found, at which 276 | point a RuntimeException is raise. 277 | 278 | :param from_path: starting search from this path 279 | :param file_relative_path: searching to find this relative path as a file 280 | """ 281 | current_path = os.path.abspath(from_path) 282 | while True: 283 | trial_path = os.path.join(current_path, file_relative_path) 284 | if os.path.isfile(trial_path): 285 | return current_path 286 | if current_path == os.path.dirname(current_path): # At filesystem root 287 | raise RuntimeError(f"Relion project directory could not be found from the subdirectory: {from_path}") 288 | current_path = os.path.dirname(current_path) 289 | --------------------------------------------------------------------------------