├── utils ├── __init__.py ├── wavelet_utils.py └── utils.py ├── .gitignore ├── requirements.txt ├── download.sh ├── download_old.sh ├── download_ribosome.sh ├── ribo80.yaml ├── ribo80_wavelet.yaml ├── filter_models ├── fir.py ├── __init__.py ├── polynomial.py ├── cnn.py └── vector.py ├── ribo80_list.yaml ├── models ├── __init__.py ├── model_wrapper.py ├── standardmlp.py ├── adjoint.py ├── sliceSet.py └── slicemlp.py ├── LICENSE ├── super.py ├── super-list.py ├── README.md └── evaluator.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /results/ 3 | __pycache__/ 4 | *.pyc 5 | *.pyo 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | pyyaml==6.0.1 4 | ml-collections==0.1.1 5 | tqdm==4.66.3 6 | mrcfile==1.4.3 7 | scikit-image==0.21.0 8 | ptwt==0.1.9 9 | numpy==1.26.4 -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | wget https://drive.switch.ch/index.php/s/rVMAxUubCj3C6C7/download -O trained_models.zip 2 | unzip trained_models.zip -d trained_models 3 | rm trained_models.zip 4 | -------------------------------------------------------------------------------- /download_old.sh: -------------------------------------------------------------------------------- 1 | wget https://drive.switch.ch/index.php/s/bLW0W3gqnwbV5qJ/download -O trained_models.zip 2 | unzip trained_models.zip -d trained_models 3 | rm trained_models.zip 4 | -------------------------------------------------------------------------------- /download_ribosome.sh: -------------------------------------------------------------------------------- 1 | 2 | # Make the data directory if it doesn't exist 3 | mkdir -p ./data/ 4 | # Download the ribosome data 5 | wget https://drive.switch.ch/index.php/s/fWr8ZcFQiiU0OJv/download -O ribo_data.zip 6 | unzip ribo_data.zip -d ./data/ 7 | # Remove the zip file 8 | rm ribo_data.zip -------------------------------------------------------------------------------- /ribo80.yaml: -------------------------------------------------------------------------------- 1 | # Directory containing the traned model 2 | model_dir: "./trained_models/cryolithe_pixel/" 3 | 4 | # Path to the projection file 5 | proj_file: "./data/ribosome/projections.mrc" 6 | # Path to the angle file in degrees 7 | angle_file: "./data/ribosome/angles.tlt" 8 | 9 | # Save location 10 | save_dir: "./results/ribo_slice_set/" 11 | save_name: "./vol_ribo_single_gpu.mrc" 12 | 13 | # GPU device to be used 14 | device: 0 # Use [0,1] to use gpus with id 0 and 1. 15 | multi_gpu: False 16 | # Pre prcessing Projection to desired resolution 17 | downsample_projections: False 18 | downsample_factor: 0.25 19 | anti_alias: True 20 | # Volume size along the z axis (constructs center-N3//2 to center+N3//2) 21 | N3: 200 22 | 23 | #Model Memory Parameters 24 | batch_size: 100_000 25 | 26 | # CPU threads 27 | num_workers: 4 28 | -------------------------------------------------------------------------------- /ribo80_wavelet.yaml: -------------------------------------------------------------------------------- 1 | # Directory containing the traned model 2 | model_dir: "./trained_models/cryolithe_21/" 3 | 4 | # Path to the projection file 5 | proj_file: "./data/ribosome/projections.mrc" 6 | # Path to the angle file in degrees 7 | angle_file: "./data/ribosome/angles.tlt" 8 | 9 | # Save location 10 | save_dir: "./results/ribo_slice_set/" 11 | save_name: "./vol_ribo_single_gpu_wavelet.mrc" 12 | 13 | # GPU device to be used 14 | device: 0 # Use [0,1] for multiple gpus 15 | multi_gpu: False 16 | # Pre prcessing Projection to desired resolution 17 | downsample_projections: False 18 | downsample_factor: 0.25 19 | anti_alias: True 20 | # Volume size along the z axis (constructs center-N3//2 to center+N3//2) 21 | N3: 200 22 | 23 | #Model Memory Parameters 24 | batch_size: 100_000 25 | 26 | # CPU threads 27 | num_workers: 4 28 | -------------------------------------------------------------------------------- /filter_models/fir.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains the FIR filter to be applied on the projections 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from skimage.transform.radon_transform import _get_fourier_filter 8 | 9 | class FIRModel(nn.Module): 10 | def __init__(self, init: str, size: int): 11 | super(FIRModel, self).__init__() 12 | self.size = size 13 | self.init = init 14 | 15 | if self.init == 'ones': 16 | self.fir = nn.Parameter(torch.ones(size,size)) 17 | if self.init == 'impulse': 18 | fir = torch.zeros(size,size,dtype=torch.float32) 19 | fir[size//2,size//2] = 1 20 | self.fir = nn.Parameter(fir) 21 | else: 22 | self.fir = nn.Parameter(torch.randn(size,size)) 23 | 24 | def forward(self, x: int): 25 | return self.fir 26 | -------------------------------------------------------------------------------- /filter_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vector import VectorModel, VectorModel_symmetric,VectorModel_real 2 | from .polynomial import PolynomialModel 3 | from .fir import FIRModel 4 | from .cnn import CNN 5 | 6 | 7 | 8 | def get_filter_model(type: str, **kwargs): 9 | """ 10 | Returns the model with the given name 11 | Note: For unet the n_projections is not used 12 | """ 13 | print(type) 14 | if type == 'vector': 15 | return VectorModel(**kwargs) 16 | if type == 'vector_symmetric': 17 | return VectorModel_symmetric(**kwargs) 18 | if type == 'VectorModel_real': 19 | return VectorModel_real(**kwargs) 20 | if type == 'polynomial': 21 | return PolynomialModel(**kwargs) 22 | if type == 'fir': 23 | return FIRModel(**kwargs) 24 | if type == 'CNN': 25 | return CNN(**kwargs) 26 | else: 27 | raise NotImplementedError(f"Model {type} not implemented") -------------------------------------------------------------------------------- /filter_models/polynomial.py: -------------------------------------------------------------------------------- 1 | """ 2 | Polynomial ramp model for the filters 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class PolynomialModel(nn.Module): 10 | def __init__(self, degree: int): 11 | super(PolynomialModel, self).__init__() 12 | self.degree = degree 13 | 14 | self.poly_parameters = torch.nn.Parameter(torch.rand(degree + 1)) 15 | 16 | def forward(self, res: int): 17 | """ 18 | Resolution of the polynomial model 19 | """ 20 | x = torch.linspace(0, 1, res).to(self.poly_parameters.device) 21 | # Flip the x value to get a symmetric polynomial 22 | 23 | ramp = x*0 + self.poly_parameters[0] 24 | 25 | for i in range(1, self.degree + 1): 26 | ramp += self.poly_parameters[i] * x**i 27 | 28 | # Flip and concatenate the ramp 29 | ramp = torch.cat([ramp, ramp.flip(0)]) 30 | 31 | ramp = ramp / ramp.max() 32 | 33 | return ramp -------------------------------------------------------------------------------- /ribo80_list.yaml: -------------------------------------------------------------------------------- 1 | # Directory containing the traned model 2 | model_dir: "./trained_models/cryolithe_21/" 3 | 4 | # Path to the list of projection files we use the same dataset here for example 5 | proj_file: 6 | - "./data/ribosome/projections.mrc" 7 | - "./data/ribosome/projections.mrc" 8 | 9 | # Path to the angle file in degrees 10 | angle_file: 11 | - "./data/ribosome/angles.tlt" 12 | - "./data/ribosome/angles.tlt" 13 | 14 | # Save location 15 | save_dir: "./results/ribo_slice_set_wavelet_list/" 16 | # List of save names for each projection 17 | save_name: 18 | - "./vol_ribo_single_gpu.mrc" 19 | - "./vol_ribo2_single_gpu.mrc" 20 | 21 | # GPU device to be used 22 | device: 0 # Use [0,1] to use gpus with id 0 and 1. 23 | multi_gpu: False 24 | # Pre prcessing Projection to desired resolution 25 | downsample_projections: False 26 | downsample_factor: 0.25 27 | anti_alias: True 28 | # Volume size along the z axis (constructs center-N3//2 to center+N3//2) 29 | N3: 30 | - 200 31 | - 200 32 | 33 | #Model Memory Parameters 34 | batch_size: 100_000 35 | 36 | # CPU threads 37 | num_workers: 4 38 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .slicemlp import sliceMLP, sliceMlp_v2,sliceMlp_mulitnet 3 | from .standardmlp import standardMLP 4 | from .adjoint import adjoint,adjoint_patch 5 | from .model_wrapper import model_wrapper 6 | from .sliceSet import SliceSet 7 | 8 | def get_model(n_projections: int, type: str, **kwargs): 9 | """ 10 | Returns the model with the given name 11 | Note: For unet the n_projections is not used 12 | """ 13 | print(type) 14 | if type == "slicemlp": 15 | return sliceMLP(n_projections = n_projections,**kwargs) 16 | elif type == "slicemlp_v2": 17 | return sliceMlp_v2(n_projections = n_projections,**kwargs) 18 | elif type == "sliceMlp_mulitnet": 19 | return sliceMlp_mulitnet(n_projections = n_projections,**kwargs) 20 | elif type == "adjoint": 21 | return adjoint() 22 | elif type == "SliceSet": 23 | return SliceSet(n_projections = n_projections,**kwargs) 24 | elif type == "adjoint_patch": 25 | return adjoint_patch(**kwargs) 26 | else: 27 | raise NotImplementedError(f"Model {type} not implemented") 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 swing-research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/wavelet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ptwt 3 | 4 | 5 | def wavelet_decomposition(volume, wavelet): 6 | # initial decomposition 7 | 8 | if len(volume.shape) == 3: 9 | volume = volume.unsqueeze(0) 10 | 11 | volume_wt = ptwt.wavedec3(volume, wavelet=wavelet, level=1) 12 | 13 | # Convert to batch mode 14 | 15 | volume_wt_set = [volume_wt[0]] 16 | 17 | 18 | 19 | keys = list(volume_wt[1].keys()) 20 | 21 | for key in keys: 22 | volume_wt_set.append(volume_wt[1][key]) 23 | 24 | #convert to batch mode 25 | volume_wt_st = torch.concatenate(volume_wt_set, dim=0) 26 | return volume_wt_st 27 | 28 | 29 | def wavelet_multilevel_decomposition(volume, wavelet ,levels): 30 | 31 | for i in range(levels): 32 | volume = wavelet_decomposition(volume, wavelet) 33 | return volume 34 | 35 | 36 | def wavelet_reconstruction(volume_wt, wavelet): 37 | # Reconstruct a batch of wavelet coefficients 38 | # Bx8xn1xn2xn3 39 | 40 | keys = ['aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'] 41 | 42 | vol_wt_set = [volume_wt[:,0]] 43 | 44 | 45 | vol_wt_dict = {} 46 | for i, key in enumerate(keys): 47 | vol_wt_dict[key] = volume_wt[:,i+1] 48 | 49 | vol_wt_set.append(vol_wt_dict) 50 | 51 | vol_reconstructed = ptwt.waverec3(vol_wt_set, wavelet=wavelet) 52 | 53 | return vol_reconstructed 54 | 55 | 56 | def wavelet_multilevel_reconstruction(volume_wt, wavelet): 57 | 58 | # Reconstruct the volume from the wavelet coefficients 59 | #BXN1XN2XN3 60 | 61 | B,N1,N2,N3 = volume_wt.shape 62 | while B>1: 63 | volume_wt = volume_wt.view(B//8,8,N1,N2,N3) 64 | if volume_wt.shape[0] >1: 65 | volume_wt = volume_wt.permute(1,0,2,3,4) 66 | volume_wt = wavelet_reconstruction(volume_wt, wavelet) 67 | B,N1,N2,N3 = volume_wt.shape 68 | return volume_wt.squeeze(0) -------------------------------------------------------------------------------- /models/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.utils import generate_patches_from_volume_location 5 | 6 | 7 | 8 | 9 | class model_wrapper(nn.Module): 10 | def __init__(self, model, 11 | projections, 12 | angles, 13 | volume_dummy, 14 | patch_scale, 15 | scale, 16 | configs): 17 | super(model_wrapper, self).__init__() 18 | 19 | self.model = model 20 | 21 | self.scale = scale 22 | self.configs = configs 23 | 24 | self.register_buffer("projections", projections) 25 | self.register_buffer("angles", angles) 26 | self.register_buffer("volume_dummy", volume_dummy) 27 | self.register_buffer("patch_scale", patch_scale) 28 | 29 | 30 | def forward(self, points): 31 | """ 32 | Wrapper for the model so that it can be used with the DataParallel 33 | points: Bx3 34 | """ 35 | 36 | 37 | _, projection_patches = generate_patches_from_volume_location(points, self.volume_dummy , 38 | self.projections, 39 | self.angles, 40 | patch_size = self.configs.model.patch_size, 41 | scale=self.scale, 42 | patch_scale= self.patch_scale) 43 | 44 | if self.configs.training.use_angle: 45 | angle_info = self.angles.unsqueeze(0).repeat(projection_patches.shape[0],1) 46 | vol_est = self.model(projection_patches.half(),angle_info.half()) 47 | else: 48 | vol_est = self.model(projection_patches.half()) 49 | return vol_est -------------------------------------------------------------------------------- /models/standardmlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard mlp network 3 | """ 4 | 5 | from typing import Union, List 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | 11 | class standardMLP(nn.Module): 12 | """ 13 | Standard MLP network Note: learn_residual is a dummy variable 14 | TODO: Correct the learn_residual variable 15 | """ 16 | def __init__(self, input_size : int, 17 | output_size : int, 18 | mlp_hidden : Union[int , List[int]], 19 | mlp_layers : int, 20 | batch_norm : bool = False, 21 | dropout=0.0, 22 | learn_residual=False, 23 | skip_connection=False, 24 | bias = True): 25 | super(standardMLP, self).__init__() 26 | 27 | self.input_size = input_size 28 | self.mlp_hidden = mlp_hidden 29 | self.mlp_layers = mlp_layers 30 | self.batch_norm = batch_norm 31 | self.dropout = dropout 32 | self.learn_residual = learn_residual 33 | self.skip_connection = skip_connection 34 | self.non_linearity = nn.ReLU() 35 | 36 | self.layers = nn.ModuleList() 37 | if self.batch_norm: 38 | self.batch_norms = nn.ModuleList() 39 | if self.dropout > 0: 40 | self.dropouts = nn.ModuleList() 41 | if isinstance(mlp_hidden, int): 42 | for i in range(self.mlp_layers): 43 | if i == 0: 44 | self.layers.append(nn.Linear(self.input_size, self.mlp_hidden, bias=bias)) 45 | else: 46 | self.layers.append(nn.Linear(self.mlp_hidden, self.mlp_hidden, bias=bias)) 47 | if self.batch_norm: 48 | self.batch_norms.append(nn.BatchNorm1d(self.mlp_hidden)) 49 | if self.dropout > 0: 50 | self.dropouts.append(nn.Dropout(self.dropout)) 51 | self.last_layer = nn.Linear(self.mlp_hidden, output_size, bias=bias) 52 | else: 53 | for i in range(mlp_layers): 54 | if i == 0: 55 | self.layers.append(nn.Linear(self.input_size, mlp_hidden[0], bias=bias)) 56 | else: 57 | self.layers.append(nn.Linear(mlp_hidden[i-1], mlp_hidden[i], bias=bias)) 58 | 59 | if self.batch_norm: 60 | self.batch_norms.append(nn.BatchNorm1d(mlp_hidden[i])) 61 | if self.dropout > 0: 62 | self.dropouts.append(nn.Dropout(self.dropout)) 63 | self.last_layer = nn.Linear(mlp_hidden[-1], output_size, bias=bias) 64 | 65 | 66 | 67 | 68 | 69 | def forward(self, x): 70 | """ 71 | Forward pass of the network 72 | 73 | """ 74 | #TODO: check skip connection 75 | for i in range(self.mlp_layers): 76 | if self.skip_connection: 77 | skip_input = x.clone() 78 | x = self.layers[i](x) 79 | if self.batch_norm: 80 | x = self.batch_norms[i](x) 81 | x = self.non_linearity(x) 82 | if self.dropout > 0: 83 | x = self.dropouts[i](x) 84 | 85 | if self.skip_connection and i>0: 86 | x = x + skip_input 87 | x = self.last_layer(x) 88 | return x -------------------------------------------------------------------------------- /models/adjoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains a model class which just average the center pixel of the projection patches 3 | Used to model the adjoint operator. 4 | """ 5 | import torch.nn as nn 6 | import torch 7 | from utils.utils import generate_projections_location 8 | 9 | class adjoint(nn.Module): 10 | """ 11 | Using to obtain pixel wise estimate of the adjoint operator 12 | """ 13 | def __init__(self,avg_type = 'mean'): 14 | super(adjoint, self).__init__() 15 | self.avg_type = avg_type 16 | 17 | 18 | def forward(self, x): 19 | """ 20 | Forward pass of the network 21 | """ 22 | patch_size = x.shape[-1] 23 | 24 | mid_pix = x[:,:,patch_size//2,patch_size//2] 25 | if self.avg_type == 'mean': 26 | mid_pix_sum = torch.mean(mid_pix,1) 27 | elif self.avg_type == 'sum': 28 | mid_pix_sum = torch.sum(mid_pix,1) 29 | else: 30 | raise ValueError('Invalid avg_type') 31 | return mid_pix_sum.unsqueeze(1) 32 | 33 | 34 | 35 | class adjoint_patch(nn.Module): 36 | """ 37 | MLP along the slices of the projection and an mlp to combine the slices with a different architecture for 38 | encoding and combining 39 | """ 40 | def __init__(self,output_patch_size = None): 41 | super(adjoint_patch, self).__init__() 42 | self.output_patch_size = output_patch_size 43 | if self.output_patch_size is not None: 44 | assert self.output_patch_size%2 == 1, "patch size should be odd" 45 | 46 | 47 | def forward(self, x,angles): 48 | """ 49 | Forward pass of the network 50 | """ 51 | patch_size = x.shape[-1] 52 | batch_size = x.shape[0] 53 | device = x.device 54 | 55 | x_patch = torch.linspace(-1,1,patch_size,device=device) 56 | y_patch = torch.linspace(-1,1,patch_size,device=device) 57 | z_patch = torch.linspace(-1,1,patch_size,device=device) 58 | xx_pathc, yy_pathc,zz_pathc = torch.meshgrid(y_patch, x_patch, z_patch, indexing='ij') 59 | points_patch = torch.cat([zz_pathc.unsqueeze(-1),yy_pathc.unsqueeze(-1),xx_pathc.unsqueeze(-1)],dim= -1) 60 | points_patch = points_patch.reshape(-1,3) 61 | #print(angles[0]) 62 | points_patch_proj = generate_projections_location(points_patch,angles[0]) 63 | x_fbp = torch.zeros(batch_size,patch_size,patch_size,patch_size,device =device) 64 | # converting batch to channels 65 | #print(x.shape) 66 | x = x.permute(1,0,2,3) 67 | 68 | for i,x_proj in enumerate(x): 69 | x_fbp = x_fbp + torch.nn.functional.grid_sample(x_proj.unsqueeze(0), 70 | points_patch_proj[i].unsqueeze(0).unsqueeze(0), 71 | mode='bilinear', 72 | padding_mode='zeros', 73 | align_corners=True).squeeze().reshape(batch_size,patch_size,patch_size,patch_size) 74 | 75 | 76 | x_fbp = x_fbp/len(angles[0]) 77 | 78 | if self.output_patch_size is not None: 79 | x_fbp = x_fbp[:,patch_size//2-self.output_patch_size//2:patch_size//2+self.output_patch_size//2 + 1, 80 | patch_size//2-self.output_patch_size//2:patch_size//2+self.output_patch_size//2 + 1, 81 | patch_size//2-self.output_patch_size//2:patch_size//2+self.output_patch_size//2 + 1,] 82 | 83 | return x_fbp -------------------------------------------------------------------------------- /filter_models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CNN(torch.nn.Module): 6 | def __init__(self, 7 | fitler_size=3, 8 | hidden_channels = 3, 9 | hidden_layers= 3, 10 | padding_mode = 'reflect', 11 | nonlinearity = nn.LeakyReLU()): 12 | super(CNN, self).__init__() 13 | self.fitler_size = fitler_size 14 | self.hidden_channels = hidden_channels 15 | self.hidden_layers = hidden_layers 16 | self.padding_mode = padding_mode 17 | self.non_linearity = nonlinearity 18 | #self.residual = residual 19 | self.convInput = torch.nn.Conv2d(1,self.hidden_channels,self.fitler_size,padding='same', padding_mode=self.padding_mode) 20 | self.convHidden = nn.ModuleList() 21 | for i in range(self.hidden_layers): 22 | self.convHidden.append(torch.nn.Conv2d(self.hidden_channels , 23 | self.hidden_channels , 24 | self.fitler_size, 25 | padding='same', padding_mode=self.padding_mode)) 26 | 27 | self.convOutput = torch.nn.Conv2d(self.hidden_channels,1,self.fitler_size,padding='same', padding_mode=self.padding_mode) 28 | 29 | def forward(self,x): 30 | """ 31 | x = N_projectionxNxN 32 | """ 33 | input = x[:,None].clone() 34 | x = self.convInput(x[:,None]) 35 | 36 | for i in range(self.hidden_layers): 37 | x = self.convHidden[i](x) 38 | x = self.non_linearity(x) 39 | 40 | x = self.convOutput(x) 41 | #if self.residual: 42 | # x = x + input 43 | return x.squeeze() 44 | 45 | # class CNN(torch.nn.Module): 46 | # def __init__(self, 47 | # fitler_size=3, 48 | # hidden_channels = 3, 49 | # hidden_layers= 3, 50 | # padding_mode = 'reflect', 51 | # nonlinearity = nn.LeakyReLU(), 52 | # residual = False): 53 | # super(CNN, self).__init__() 54 | # self.fitler_size = fitler_size 55 | # self.hidden_channels = hidden_channels 56 | # self.hidden_layers = hidden_layers 57 | # self.padding_mode = padding_mode 58 | # self.non_linearity = nonlinearity 59 | # self.residual = residual 60 | # self.convInput = torch.nn.Conv2d(1,self.hidden_channels,self.fitler_size,padding='same', padding_mode=self.padding_mode) 61 | # self.convHidden = nn.ModuleList() 62 | # for i in range(self.hidden_layers): 63 | # self.convHidden.append(torch.nn.Conv2d(self.hidden_channels , 64 | # self.hidden_channels , 65 | # self.fitler_size, 66 | # padding='same', padding_mode=self.padding_mode)) 67 | 68 | # self.convOutput = torch.nn.Conv2d(self.hidden_channels,1,self.fitler_size,padding='same', padding_mode=self.padding_mode) 69 | 70 | # def forward(self,x): 71 | # """ 72 | # x = N_projectionxNxN 73 | # """ 74 | # input = x[:,None].clone() 75 | # x = self.convInput(x[:,None]) 76 | 77 | # for i in range(self.hidden_layers): 78 | # x = self.convHidden[i](x) 79 | # x = self.non_linearity(x) 80 | 81 | # x = self.convOutput(x) 82 | # if self.residual: 83 | # x = x + input 84 | # return x.squeeze() -------------------------------------------------------------------------------- /filter_models/vector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains the vector model for the filter. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from skimage.transform.radon_transform import _get_fourier_filter 8 | 9 | class VectorModel(nn.Module): 10 | def __init__(self, init: str, size: int, symmetric: bool = True): 11 | super(VectorModel, self).__init__() 12 | self.size = size 13 | self.init = init 14 | self.symmetric = symmetric 15 | 16 | if self.init == 'ones': 17 | self.vector = nn.Parameter(torch.ones(size)) 18 | else: 19 | vector = _get_fourier_filter(2*size, init)[:size,0] 20 | self.vector = nn.Parameter(torch.tensor(vector, dtype=torch.float32)) 21 | 22 | def forward(self, x: int): 23 | if self.symmetric: 24 | return torch.cat([self.vector,self.vector.flipud()]) 25 | return self.vector 26 | 27 | 28 | # class VectorModel_symmetric(nn.Module): 29 | # """ 30 | # Corrected symmetric version 31 | # """ 32 | # def __init__(self, init: str, size: int): 33 | # super(VectorModel_symmetric, self).__init__() 34 | # self.size = size 35 | # self.init = init 36 | 37 | # if self.init == 'ones': 38 | # self.vector = nn.Parameter(torch.ones(size+1)) 39 | # else: 40 | # vector = _get_fourier_filter(2*size, init)[:size+1,0] 41 | # self.vector = nn.Parameter(torch.tensor(vector, dtype=torch.float32)) 42 | 43 | # def forward(self, x: int): 44 | # return torch.cat([self.vector,self.vector.flipud()[1:-1]]) 45 | 46 | 47 | class VectorModel_symmetric(nn.Module): 48 | """ 49 | Corrected symmetric version 50 | """ 51 | def __init__(self, init: str, size: int, linear_filter: bool = False): 52 | super(VectorModel_symmetric, self).__init__() 53 | self.size = size 54 | self.init = init 55 | self.linear_filter = linear_filter 56 | 57 | if self.init == 'ones': 58 | self.vector = nn.Parameter(torch.ones(size+1)) 59 | else: 60 | vector = _get_fourier_filter(2*size, init)[:size+1,0] 61 | self.vector = nn.Parameter(torch.tensor(vector, dtype=torch.float32)) 62 | 63 | def forward(self, x: int): 64 | if self.linear_filter: 65 | filter_value = torch.cat([self.vector,self.vector.flipud()[1:-1]]) 66 | filter_value = torch.fft.fftshift(torch.fft.ifft(filter_value)) 67 | filter_value[:self.size//2] = 0 68 | filter_value[-self.size//2:] = 0 69 | filter_value = torch.fft.fft(torch.fft.ifftshift(filter_value)) 70 | return filter_value 71 | else: 72 | return torch.cat([self.vector,self.vector.flipud()[1:-1]]) 73 | 74 | 75 | 76 | class VectorModel_real(nn.Module): 77 | """ 78 | Corrected symmetric version 79 | """ 80 | def __init__(self, init: str, size: int): 81 | super(VectorModel_real, self).__init__() 82 | self.size = size 83 | self.init = init 84 | 85 | if self.init == 'ones': 86 | self.vector = nn.Parameter(torch.ones(size+1)) 87 | else: 88 | vector = torch.tensor(_get_fourier_filter(size, init)[:,0], dtype=torch.float32) 89 | vector = torch.fft.ifft(vector).real 90 | self.vector = nn.Parameter(vector) 91 | 92 | def forward(self, x: int): 93 | """ 94 | x: size of the projection along one dimension 95 | Note that output should be twice the size of x, as we we use double the size of the FFT 96 | """ 97 | 98 | response_real = torch.zeros(2*x,dtype=torch.float32).to(self.vector.device) 99 | response_size = min(self.size,x) 100 | response_real[:response_size//2] = self.vector[:response_size//2] 101 | response_real[-response_size//2:] = self.vector[-response_size//2:] 102 | return torch.fft.fft(response_real) 103 | 104 | -------------------------------------------------------------------------------- /super.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to load the traned model and reconstruct the volumes present in the yaml file 3 | """ 4 | 5 | 6 | 7 | 8 | import os 9 | import yaml 10 | import argparse 11 | from evaluator import Evaluator 12 | import mrcfile 13 | import torch 14 | import numpy as np 15 | 16 | args = argparse.ArgumentParser(description="Load the trained model and reconstruct the volumes present in the yaml file") 17 | 18 | args.add_argument("--config", type=str, help="Path to the yaml file") 19 | 20 | 21 | 22 | 23 | if __name__ == "__main__": 24 | # Load the yaml file 25 | 26 | args = args.parse_args() 27 | with open(args.config, "r") as file: 28 | config = yaml.load(file, Loader=yaml.FullLoader) 29 | 30 | device = config["device"] 31 | model_path = config["model_dir"] 32 | batch_size = config["batch_size"] 33 | downsample = config["downsample_projections"] 34 | N3 = config["N3"] 35 | save_dir = config["save_dir"] 36 | save_name = config["save_name"] 37 | angles = np.loadtxt(config["angle_file"]) 38 | 39 | # check if the parameter is preset 40 | if "num_workers" in config: 41 | num_workers = config["num_workers"] 42 | print('num_workers:',num_workers) 43 | else: 44 | num_workers = 0 45 | 46 | 47 | if type(device) is int: 48 | multi_gpu = False 49 | else: 50 | # mult_gpu = True 51 | # GPUS = device 52 | # device = GPUS[0] 53 | GPUS = [] 54 | for i in range(torch.cuda.device_count()): 55 | try: 56 | torch.cuda.get_device_properties(i) 57 | GPUS.append(i) 58 | except AssertionError: 59 | pass 60 | if len(GPUS)>1: 61 | print("Using multiple GPUs") 62 | multi_gpu = True 63 | device = GPUS[0] 64 | else: 65 | multi_gpu = False 66 | device = GPUS[0] 67 | 68 | print("Using GPUs: ", GPUS) 69 | 70 | 71 | eval = Evaluator(model_path = model_path , device = device) 72 | 73 | proj_path = config["proj_file"] 74 | projection = mrcfile.open(proj_path, permissive=True).data 75 | projection = projection - np.mean(projection) 76 | projection = projection/np.std(projection) 77 | 78 | 79 | 80 | 81 | if downsample: 82 | downsample_factor = config["downsample_factor"] 83 | anti_alias = config["anti_alias"] 84 | proj_ds_set = [] 85 | for proj in projection: 86 | proj_t = torch.tensor(proj,device = device,dtype =torch.float32) 87 | proj_ds = torch.nn.functional.interpolate(proj_t[None,None] , 88 | scale_factor=downsample_factor, 89 | align_corners=True, 90 | antialias = anti_alias, 91 | mode='bicubic').squeeze() 92 | proj_ds_set.append(proj_ds.cpu().numpy()) 93 | 94 | projection = np.array(proj_ds_set) 95 | 96 | 97 | 98 | # Zero pad the projections to make them square 99 | N1 = projection.shape[1] 100 | N2 = projection.shape[2] 101 | 102 | if N1>N2: 103 | pad = (N1-N2)//2 104 | projection = np.pad(projection,((0,0),(0,0),(pad,pad))) 105 | elif N2>N1: 106 | pad = (N2-N1)//2 107 | projection = np.pad(projection,((0,0),(pad,pad),(0,0))) 108 | 109 | if N3 > int(max(N1,N2)): 110 | print("Changed value of N3 to be same as max(N1,N2)") 111 | N3 = int(max(N1,N2)) 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | if not os.path.exists(save_dir): 120 | os.makedirs(save_dir) 121 | 122 | 123 | if multi_gpu: 124 | vol = eval.reconstruct(projection = projection, 125 | angles= angles, 126 | N3 = N3, 127 | N3_scale = 0.5, 128 | batch_size = batch_size, 129 | num_workers=num_workers, 130 | gpu_ids= GPUS) 131 | else: 132 | vol = eval.reconstruct(projection = projection, 133 | angles= angles, 134 | N3 = N3, 135 | N3_scale = 0.5, 136 | batch_size = batch_size, 137 | num_workers=num_workers) 138 | 139 | vol = np.moveaxis(vol,2,0) 140 | if N1 > N2: 141 | vol = vol[:,:,pad:-pad] 142 | elif N2 > N1: 143 | vol = vol[:,pad:-pad] 144 | 145 | save_path = os.path.join(save_dir,save_name) 146 | 147 | out = mrcfile.new(save_path,overwrite = True) 148 | out.set_data(vol.astype(np.float32)) 149 | out.close() 150 | -------------------------------------------------------------------------------- /super-list.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to load the traned model and reconstruct the volumes present in the yaml file 3 | """ 4 | 5 | 6 | 7 | 8 | import os 9 | import yaml 10 | import argparse 11 | from evaluator import Evaluator 12 | import mrcfile 13 | import torch 14 | import numpy as np 15 | 16 | args = argparse.ArgumentParser(description="Load the trained model and reconstruct the volumes present in the yaml file") 17 | 18 | args.add_argument("--config", type=str, help="Path to the yaml file") 19 | 20 | 21 | 22 | 23 | if __name__ == "__main__": 24 | # Load the yaml file 25 | 26 | args = args.parse_args() 27 | with open(args.config, "r") as file: 28 | config = yaml.load(file, Loader=yaml.FullLoader) 29 | 30 | device = config["device"] 31 | model_path = config["model_dir"] 32 | batch_size = config["batch_size"] 33 | downsample = config["downsample_projections"] 34 | N3_list = config["N3"] 35 | save_dir = config["save_dir"] 36 | save_name_list = config["save_name"] 37 | angles_file_list = config["angle_file"] 38 | 39 | 40 | # check if the parameter is preset 41 | if "num_workers" in config: 42 | num_workers = config["num_workers"] 43 | print('num_workers:',num_workers) 44 | else: 45 | num_workers = 0 46 | 47 | 48 | if type(device) is int: 49 | multi_gpu = False 50 | else: 51 | # mult_gpu = True 52 | # GPUS = device 53 | # device = GPUS[0] 54 | GPUS = [] 55 | for i in range(torch.cuda.device_count()): 56 | try: 57 | torch.cuda.get_device_properties(i) 58 | GPUS.append(i) 59 | except AssertionError: 60 | pass 61 | if len(GPUS)>1: 62 | print("Using multiple GPUs") 63 | multi_gpu = True 64 | device = GPUS[0] 65 | else: 66 | multi_gpu = False 67 | device = GPUS[0] 68 | 69 | print("Using GPUs: ", GPUS) 70 | 71 | 72 | eval = Evaluator(model_path = model_path , device = device) 73 | 74 | proj_path_list = config["proj_file"] 75 | 76 | 77 | for i, proj_path in enumerate(proj_path_list): 78 | angles = np.loadtxt(angles_file_list[i]) 79 | N3 = N3_list[i] 80 | save_name = save_name_list[i] 81 | projection = mrcfile.open(proj_path, permissive=True).data 82 | projection = projection - np.mean(projection) 83 | projection = projection/np.std(projection) 84 | 85 | 86 | 87 | 88 | if downsample: 89 | downsample_factor = config["downsample_factor"] 90 | anti_alias = config["anti_alias"] 91 | proj_ds_set = [] 92 | for proj in projection: 93 | proj_t = torch.tensor(proj,device = device,dtype =torch.float32) 94 | proj_ds = torch.nn.functional.interpolate(proj_t[None,None] , 95 | scale_factor=downsample_factor, 96 | align_corners=True, 97 | antialias = anti_alias, 98 | mode='bicubic').squeeze() 99 | proj_ds_set.append(proj_ds.cpu().numpy()) 100 | 101 | projection = np.array(proj_ds_set) 102 | 103 | 104 | 105 | # Zero pad the projections to make them square 106 | N1 = projection.shape[1] 107 | N2 = projection.shape[2] 108 | 109 | if N1>N2: 110 | pad = (N1-N2)//2 111 | projection = np.pad(projection,((0,0),(0,0),(pad,pad))) 112 | elif N2>N1: 113 | pad = (N2-N1)//2 114 | projection = np.pad(projection,((0,0),(pad,pad),(0,0))) 115 | 116 | if N3 > int(max(N1,N2)): 117 | print("Changed value of N3 to be same as max(N1,N2)") 118 | N3 = int(max(N1,N2)) 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | if not os.path.exists(save_dir): 127 | os.makedirs(save_dir) 128 | 129 | 130 | if multi_gpu: 131 | vol = eval.reconstruct(projection = projection, 132 | angles= angles, 133 | N3 = N3, 134 | N3_scale = 0.5, 135 | batch_size = batch_size, 136 | num_workers=num_workers, 137 | gpu_ids= GPUS) 138 | else: 139 | vol = eval.reconstruct(projection = projection, 140 | angles= angles, 141 | N3 = N3, 142 | N3_scale = 0.5, 143 | batch_size = batch_size, 144 | num_workers=num_workers) 145 | 146 | vol = np.moveaxis(vol,2,0) 147 | if N1 > N2: 148 | vol = vol[:,:,pad:-pad] 149 | elif N2 > N1: 150 | vol = vol[:,pad:-pad] 151 | 152 | save_path = os.path.join(save_dir,save_name) 153 | 154 | out = mrcfile.new(save_path,overwrite = True) 155 | out.set_data(vol.astype(np.float32)) 156 | out.close() 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end localized deep learning for Cryo-ET 2 | Official repo for CryoLithe ([paper](https://arxiv.org/abs/2501.15246)) 3 | 4 | CryoLithe is a supervised machine learning method to directly reconstruct the tomogram from aligned cryo-ET tilt series. The methods is trained on real measurements using FBP+cryo-CARE+IsoNet reconstructions as the reference. The network exploits the imaging geometry to extract small patches from the tilt series to recover the volume. Thus it is practically robust to various data distirbutions. The method provides FBP+cryo-CARE+IsoNet type of reconstructions in a fraction of the time. 5 | 6 | 7 | ## Updates 8 | - 18.04.2025: 9 | - New models that can recover the volume from arbitrary number of tilt series. 10 | - Update the pytorch version to 2.6.0 (the code was tested with 2.6.0) 11 | - Update the README file to include the new models and the new requirements. 12 | - Multi-GPU infernence is now supported. 13 | - 05.08.2025 14 | - Added support to reconstruct a list of volumes from a single yaml file. 15 | - Added a new script `super-list.py` to run the model on a list of projections. 16 | - 05.09.2025 17 | - Added new trained models that were trained on a larger dataset. 18 | 19 | ## Installation 20 | Download the repository using the command: 21 | ```bash 22 | git clone git@github.com:swing-research/CryoLithe.git 23 | ``` 24 | 25 | 26 | Create a new conda environment using the command: 27 | ```bash 28 | conda create -n CryoLithe python=3.9 29 | ``` 30 | 31 | 32 | 33 | Activate the environment using the command: 34 | ```bash 35 | conda activate CryoLithe 36 | ``` 37 | Install PyTorch 2.6 (or a compatible version). The code was tested with PyTorch 2.6 38 | ```bash 39 | pip3 install torch torchvision torchaudio 40 | ``` 41 | 42 | 43 | Install the required packages using the command: 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | ## Downloading the trained models 50 | The models are stored in switchdrive and can be downloaded using the provided download.sh script: 51 | ```bash 52 | bash download.sh 53 | ``` 54 | 55 | This will download the trained models and place them in the `trained_models` directory. It should contain the following files: 56 | - `checkpoint.pth` - the trained model 57 | - `config.json` - the configuration file used to train the model contains the model architecture and hyperparameters 58 | 59 | Currently, we provide two models: 60 | - 'cryolithe_pixel' - trained to recover the volume one pixel at a time 61 | - 'cryolithe_21' - trained to recover the wavelet coefficients of the volume. Uses patches of size 21x21 from the projections. 62 | - 'cryolithe_11' - trained to recover the wavelet coefficients of the volume. Uses patches of size 11x11 from the projections. The model is a bit faster but the reconstructions are slightly worse than the 21x21 model. 63 | 64 | **Note**: The wavelet model is 8x faster than the sliceset model. However, the reconstruction looks sligthly low resolution compared to the 65 | sliceset model 66 | ## Running the model 67 | 68 | We are actively working on extending the model to support arbitrary tilt series, which will be released in an upcoming update. 69 | 70 | The script 'super.py' is used to run the trained model on any new projection of choice. The script requires a configuration file that contains the necessary information to run the model. 71 | The configuration file is a yaml file that contains the following fields: 72 | - 'model_dir' - path to the directory containing the trained model 73 | - 'proj_file' - path to the projection file 74 | - 'angle_file' - path to the angles file 75 | - 'save_dir' - path to the directory where the output will be saved 76 | - 'save_name' - name of the output volume 77 | - 'device' - device to run the model on (cpu or cuda) 78 | - 'downsample_projections' - Whether to downsample the projections or not 79 | - 'downsample_factor' - factor by which to downsample the volume 80 | - 'anti_alias' - whether to apply anti-aliasing to the projections or not 81 | - 'N3' - The size of the volume along the z-axis 82 | - 'batch_size' - batch size to use when running the model 83 | 84 | 85 | The script can be run using the following command: 86 | ```bash 87 | python3 super.py --config 88 | ``` 89 | 90 | A sample yaml file is provided as 'ribo80.yaml' which contains the necessary information to run the model on the ribosome dataset. 91 | 92 | ## Running the model on the ribosome dataset 93 | 94 | Download the ribosome dataset using the provided script: 95 | ```bash 96 | bash download_ribosome.sh 97 | ``` 98 | This will download the ribosome dataset and place it in the `data` directory. The dataset contains the following files: 99 | - `projections.mrcs` - the projections of the ribosome dataset 100 | - `angles.tlt` - the angles of the projections 101 | 102 | The data is downloaded from the EMPIAR 10045 dataset and is a subset of the full dataset. 103 | 104 | To run the script, use the following command: 105 | ```bash 106 | python3 super.py --config ribo80.yaml 107 | ``` 108 | 109 | ## Using the Wavelet Model 110 | Run the script using the following command: 111 | ```bash 112 | python3 super.py --config ribo80_wavelet.yaml 113 | ``` 114 | 115 | ## Running the model on a list of projections 116 | The script `super-list.py` is used to run the trained model on a list of projections. Additionally, we provide a yaml file that can run the model on a list of projections. In the `ribo80_list.yaml` file, you can specify multiple projection files, angle files, save names and N3 values for each projection. The script will then process each set of files in the list and save the corresponding volumes. Note that in the example yaml file, we are running the model on the same data twice, but you can modify it to have different projection data. You need to change the following fields in the yaml file: 117 | - `proj_file` - list of paths to the projection files 118 | - `angles_file` - list of paths to the angles files 119 | - `save_name` - list of names for the output volumes 120 | - `N3` - list of sizes for the volumes along the z-axis 121 | 122 | You can run the script using the following command: 123 | ```bash 124 | python3 super-list.py --config ribo80_list.yaml 125 | ``` 126 | 127 | ## Downloading the older models 128 | If you want to use the older models, you can download them using the provided download_old.sh script: 129 | ```bashbash 130 | bash download_old.sh 131 | ``` 132 | 133 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load the model and run on the given data 3 | """ 4 | from ml_collections import config_dict 5 | import json 6 | import torch 7 | from models import get_model, model_wrapper 8 | from utils.utils import custom_ramp_fft 9 | from tqdm import tqdm 10 | import numpy as np 11 | from torch.utils.data import DataLoader 12 | from utils.wavelet_utils import wavelet_multilevel_decomposition, wavelet_multilevel_reconstruction 13 | 14 | 15 | 16 | class Evaluator: 17 | def __init__(self, model_path , device): 18 | 19 | configs = config_dict.ConfigDict(json.load(open(model_path + '/config.json'))) 20 | self.configs = configs 21 | self.device = device 22 | 23 | self.n_projections = configs.data.n_projections 24 | 25 | checkpoint = torch.load(model_path + '/checkpoint.pth',map_location=torch.device(device), weights_only=False) 26 | 27 | 28 | model = get_model(n_projections = self.n_projections, **configs.model).to(device) 29 | model.load_state_dict(checkpoint['model_state_dict']) 30 | 31 | 32 | if configs.filter_projections: 33 | ramp = checkpoint['ramp'] 34 | if ramp is not None: 35 | ramp = ramp.to(device) 36 | else: 37 | ramp = None 38 | 39 | if configs.training.learn_patch_scale: 40 | patch_scale = checkpoint['patch_scale'].to(device) 41 | patch_scale = patch_scale.detach() 42 | else: 43 | patch_scale = torch.tensor([configs.training.patch_scale_init]).to(device) 44 | 45 | 46 | if configs.use_2D_fitlers or configs.use_2D_filters: 47 | filter_2D =checkpoint['filter_2D'].to(device) 48 | else: 49 | filter_2D = None 50 | 51 | 52 | self.model = model 53 | self.ramp = ramp 54 | self.patch_scale = patch_scale 55 | self.filter_2D = filter_2D 56 | 57 | 58 | def generate_points(self,n1,n2,n3): 59 | """ 60 | Generate points in 3D space 61 | """ 62 | 63 | n = max(n1,n2,n3) 64 | scale = np.ones(3) 65 | scale[0] = n1/n 66 | scale[1] = n2/n 67 | scale[2] = n3/n 68 | 69 | x_index = torch.linspace(-1,1,n1) 70 | x_index = x_index*scale[0] 71 | y_index = torch.linspace(-1,1,n2) 72 | y_index = y_index*scale[1] 73 | z_index = torch.linspace(-1,1,n3) 74 | z_index = z_index*scale[2] 75 | 76 | zz,yy,xx = torch.meshgrid(z_index,y_index,x_index) 77 | points = torch.stack([zz,yy,xx],dim=3).reshape(-1,3) 78 | 79 | return points,scale 80 | 81 | 82 | def pre_process(self, projection, angles, N3_scale = 0.5, N3 =None): 83 | """ 84 | projections: projections of shape (n_projections, N_1, N_2) 85 | angles: angles of shape (n_projections) in degrees 86 | """ 87 | 88 | 89 | projection_filt = self.filter_projections(projection) 90 | angles_t = torch.tensor(angles, dtype=torch.float32, device=self.device)*torch.pi/180 91 | 92 | N1,N2 = projection_filt.shape[-2], projection_filt.shape[-1] 93 | if N3 is None: 94 | N3 = int(max(N1,N2)*N3_scale) 95 | 96 | vol_dummy = torch.randn(N1,N2,N3,dtype=torch.float32,device='cpu') 97 | 98 | 99 | if self.configs.training.use_wavelet_trainer: 100 | vol_wavelet = wavelet_multilevel_decomposition(vol_dummy, 101 | self.configs.training.wavelet, 102 | levels = self.configs.training.wavelet_levels) 103 | vol_lp = vol_wavelet[0] 104 | N1,N2,N3 = vol_lp.shape 105 | 106 | 107 | points,scale = self.generate_points(N1,N2,N3) 108 | 109 | return projection_filt, angles_t, points, vol_dummy, N1,N2,N3,scale 110 | 111 | 112 | def reconstruct(self, projection, 113 | angles, 114 | N3_scale = 0.5, 115 | batch_size = int(4e4) , 116 | N3 = None, 117 | num_workers =4, 118 | gpu_ids = None): 119 | """ 120 | projections: projections of shape (n_projections, N_1, N_2) 121 | angles: angles of shape (n_projections) in degrees 122 | """ 123 | 124 | 125 | projection_filt, angles_t, points, vol_dummy, N1,N2,N3,scale = self.pre_process(projection, 126 | angles, 127 | N3_scale = N3_scale, 128 | N3 = N3) 129 | 130 | modl_wrapper = model_wrapper(self.model, 131 | projections = projection_filt, 132 | angles = angles_t, 133 | volume_dummy= vol_dummy, 134 | patch_scale = self.patch_scale, 135 | scale = scale, 136 | configs = self.configs) 137 | 138 | modl_wrapper = modl_wrapper.to(self.device).half() 139 | if gpu_ids is not None: 140 | modl_wrapper = torch.nn.DataParallel(modl_wrapper, device_ids = gpu_ids) 141 | point_loader = DataLoader(points.half(),shuffle=False,batch_size=batch_size,num_workers=num_workers) 142 | 143 | 144 | with torch.no_grad(): 145 | modl_wrapper.eval() 146 | vol_est_set = [] 147 | for points in tqdm(point_loader): 148 | if gpu_ids is None: 149 | points = points.to(self.device) 150 | vol_est = modl_wrapper(points).permute(1,0).cpu().numpy() 151 | vol_est_set.append(vol_est) 152 | v_est_set_np = np.moveaxis(np.moveaxis(np.concatenate(vol_est_set,axis=1).reshape(-1,N3,N1,N2),1,-1),2,1) 153 | 154 | if self.configs.training.use_wavelet_trainer: 155 | v_est_set_t = torch.tensor(v_est_set_np, dtype=torch.float32, device='cpu') 156 | vol_est_rec = wavelet_multilevel_reconstruction(v_est_set_t, 157 | wavelet= self.configs.training.wavelet).cpu().numpy() 158 | else: 159 | vol_est_rec = v_est_set_np[0] 160 | 161 | return vol_est_rec 162 | 163 | 164 | def filter_projections(self,projections): 165 | """ 166 | Filter the projections 167 | """ 168 | 169 | if type(projections) == list: 170 | proj_filt = [] 171 | 172 | for proj in projections: 173 | proj_filt.append(self.filter_single_projection(proj)) 174 | else: 175 | proj_filt = self.filter_single_projection(projections) 176 | 177 | return proj_filt 178 | 179 | 180 | def filter_single_projection(self,projection): 181 | """ 182 | Filter a single projection 183 | """ 184 | proj_t = torch.tensor(projection, dtype=torch.float32, device=self.device) 185 | with torch.no_grad(): 186 | if self.configs.use_2D_filters: 187 | proj_t = self.filter_2D(proj_t) 188 | 189 | if self.configs.filter_projections: 190 | ramp_filt = self.ramp(proj_t.shape[-1]) 191 | proj_t = custom_ramp_fft(proj_t, ramp_filt, use_splits= True) 192 | 193 | return proj_t 194 | 195 | 196 | 197 | 198 | 199 | if __name__ == '__main__': 200 | import matplotlib.pyplot as plt 201 | import numpy as np 202 | import mrcfile 203 | import torch 204 | DEVICE = 1 205 | eval = Evaluator(model_path = './trained_models/real_trained/' , device = DEVICE) 206 | 207 | path = '/home/kishor0000/Work/cryoET/ET_data_supervised/10045-80S/IS002_291013_005.mrcs' 208 | proj = mrcfile.open(path).data 209 | proj = proj - np.mean(proj) 210 | proj = proj/np.std(proj) 211 | 212 | 213 | angles = np.loadtxt('/home/kishor0000/Work/cryoET/ET_data_supervised/10045-80S/angle_5.rawtlt') 214 | 215 | DOWNSAMPLE = True 216 | DOWNSAMPLE_FACTOR = 0.25 217 | proj_ds_set = [] 218 | if DOWNSAMPLE: 219 | for proj in proj: 220 | proj_t = torch.tensor(proj,device = DEVICE,dtype =torch.float32) 221 | proj_ds = torch.nn.functional.interpolate(proj_t[None,None] , 222 | scale_factor=DOWNSAMPLE_FACTOR, 223 | align_corners=False, 224 | antialias = True, 225 | mode='bicubic').squeeze() 226 | proj_ds_set.append(proj_ds.cpu().numpy()) 227 | 228 | proj_real = np.array(proj_ds_set) 229 | 230 | 231 | op = eval.orthogonal_reconstruction(proj_real,angles,512) 232 | -------------------------------------------------------------------------------- /models/sliceSet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set model for combining the informations from the patches 3 | """ 4 | 5 | 6 | 7 | 8 | 9 | 10 | from torch import nn 11 | import torch 12 | from models.standardmlp import standardMLP 13 | 14 | 15 | 16 | 17 | 18 | # Acts along the slices 19 | class RadonSet(nn.Module): 20 | def __init__(self, input = 21, 21 | output = 1, 22 | set_input =512, 23 | set_output = None, 24 | transformer_positional_encoding_size = 128, 25 | use_learned_positional_encoding = False, 26 | transformer_positional_encoding_base = 1000, 27 | transformer_positional_encoding_add_angle : bool = False, 28 | transformer_positional_encoding_mult_angle: bool = False, 29 | set_hidden_size = 1024, 30 | set_num_layers = 3, 31 | set_skip_connection = False, 32 | set_bias = False, 33 | transformer_avg_pooling = True, 34 | mlp_hidden_size=512, 35 | mlp_num_layers=2, 36 | mlp_skip_connection=False, 37 | mlp_bias = False, 38 | bias = True): 39 | 40 | super(RadonSet, self).__init__() 41 | 42 | self.up_layer = nn.Linear(input, set_input, bias=bias) 43 | self.transformer_positional_encoding_size = transformer_positional_encoding_size 44 | self.transformer_positional_encoding_base = transformer_positional_encoding_base 45 | self.transformer_positional_encoding_add_angle = transformer_positional_encoding_add_angle 46 | self.transformer_avg_pooling = transformer_avg_pooling 47 | self.use_learned_positional_encoding = use_learned_positional_encoding 48 | self.transformer_positional_encoding_mult_angle = transformer_positional_encoding_mult_angle 49 | if set_output is None: 50 | set_output = set_input 51 | 52 | 53 | self.setMlP = standardMLP(input_size = set_input, 54 | output_size = set_output, 55 | mlp_hidden = set_hidden_size, 56 | mlp_layers = set_num_layers, 57 | skip_connection = set_skip_connection, 58 | bias = set_bias) 59 | 60 | 61 | 62 | self.mlp = standardMLP(input_size = set_output, 63 | output_size = output, 64 | mlp_hidden = mlp_hidden_size, 65 | mlp_layers = mlp_num_layers, 66 | skip_connection = mlp_skip_connection, 67 | bias = mlp_bias) 68 | 69 | def forward(self, x, angle): 70 | x = self.up_layer(x) 71 | if self.transformer_positional_encoding_add_angle: 72 | x = x + angle 73 | elif self.transformer_positional_encoding_mult_angle: 74 | x = x*angle 75 | else: 76 | #print('concat angle') 77 | x = torch.cat((x,angle),2) 78 | 79 | x = self.setMlP(x) 80 | 81 | # Average pooling 82 | if self.transformer_avg_pooling: 83 | x = torch.mean(x, dim=1) 84 | else: 85 | x = x[:,0] 86 | x = self.mlp(x) 87 | return x 88 | 89 | 90 | 91 | class SliceSet(nn.Module): 92 | """ 93 | Separate transformer for each slice and the combine the output 94 | """ 95 | def __init__(self, n_projections: int, 96 | mlp_output: int, 97 | patch_size: int, 98 | slice_index: int = 2, 99 | compare_index: int = 3, 100 | learn_residual: bool = False, 101 | set_input : int = 512, 102 | set_output : int = None, 103 | set_hidden_size = 1024, 104 | set_num_layers = 3, 105 | set_skip_connection = False, 106 | set_bias = False, 107 | radon_bias = True, 108 | learned_positional_encoding = False, 109 | learned_positional_encoding_use_softmax = False, 110 | slice_transformer_avg_pooling : bool = True, 111 | slice_transformer_positional_encoding_size : int = 128, 112 | slice_transformer_positional_encoding_base : int = 1000, 113 | sice_transformer_transformer_positional_encoding_add_angle : bool = False, 114 | sice_transformer_positional_encoding_mult_angle: bool = False, 115 | slice_mlp_hidden_size : int = 512, 116 | slice_mlp_num_layers : int = 3, 117 | slice_mlp_skip_connection : bool = False, 118 | slice_mlp_bias : bool = False, 119 | combine_mlp_layers : int = 5, 120 | combine_mlp_hidden : int = 256, 121 | combine_dropout : int = 0, 122 | combine_batch_norm : bool = False, 123 | combine_learn_residual : bool = False, 124 | combine_skip_connection : bool = False, 125 | combine_mlp_bias : bool = False, 126 | output_size : int = 1): 127 | super(SliceSet, self).__init__() 128 | 129 | self.n_projections = n_projections 130 | self.mlp_output = mlp_output 131 | self.patch_size = patch_size 132 | self.slice_index = slice_index 133 | self.learn_residual = learn_residual 134 | self.compare_index = compare_index 135 | self.slice_transformers = nn.ModuleList() 136 | self.use_learned_positional_encoding = learned_positional_encoding 137 | self.slice_transformer_positional_encoding_size = slice_transformer_positional_encoding_size 138 | self.slice_transformer_positional_encoding_base = slice_transformer_positional_encoding_base 139 | self.learned_positional_encoding_use_softmax = learned_positional_encoding_use_softmax 140 | 141 | if self.learned_positional_encoding_use_softmax: 142 | self.pos_nonlinear = nn.Softmax(dim=-1) 143 | else: 144 | self.pos_nonlinear = nn.Identity() 145 | 146 | 147 | if self.use_learned_positional_encoding: 148 | self.pos_encoder = nn.Linear(1, slice_transformer_positional_encoding_size, bias=radon_bias) 149 | 150 | 151 | for i in range(patch_size): 152 | self.slice_transformers.append(RadonSet(input = self.patch_size, 153 | output = self.mlp_output, 154 | set_input = set_input, 155 | set_output = set_output, 156 | set_hidden_size = set_hidden_size, 157 | set_num_layers = set_num_layers, 158 | set_skip_connection = set_skip_connection, 159 | set_bias = set_bias, 160 | use_learned_positional_encoding = learned_positional_encoding, 161 | transformer_positional_encoding_size = slice_transformer_positional_encoding_size, 162 | transformer_positional_encoding_base = slice_transformer_positional_encoding_base, 163 | transformer_positional_encoding_add_angle = sice_transformer_transformer_positional_encoding_add_angle, 164 | transformer_positional_encoding_mult_angle = sice_transformer_positional_encoding_mult_angle, 165 | mlp_hidden_size= slice_mlp_hidden_size, 166 | mlp_num_layers= slice_mlp_num_layers, 167 | mlp_skip_connection= slice_mlp_skip_connection, 168 | mlp_bias = slice_mlp_bias, 169 | transformer_avg_pooling = slice_transformer_avg_pooling, 170 | bias = radon_bias)) 171 | 172 | self.combination_mlp = standardMLP(input_size = self.patch_size*self.mlp_output, 173 | output_size = output_size, 174 | mlp_hidden = combine_mlp_hidden, 175 | mlp_layers = combine_mlp_layers, 176 | batch_norm = combine_batch_norm, 177 | dropout = combine_dropout, 178 | learn_residual=combine_learn_residual, 179 | skip_connection = combine_skip_connection, 180 | bias = combine_mlp_bias) 181 | 182 | 183 | def forward(self, x,angles): 184 | """ 185 | Forward pass of the network 186 | """ 187 | if self.learn_residual: 188 | mid_pix = x[:,:,self.patch_size//2,self.patch_size//2] 189 | mid_pix_sum = torch.mean(mid_pix,1) 190 | x = x.permute(0,self.slice_index,1,self.compare_index).contiguous() 191 | 192 | if self.use_learned_positional_encoding: 193 | angles = self.pos_nonlinear(self.pos_encoder(angles.unsqueeze(-1))) 194 | else: 195 | angles = self.angle_encoding(angles) 196 | x = [self.slice_transformers[i](x[:,i],angles) for i in range(self.patch_size)] 197 | x = torch.cat(x,1) 198 | x = self.combination_mlp(x) 199 | if self.learn_residual: 200 | x = x + mid_pix_sum.unsqueeze(1) 201 | return x 202 | 203 | def angle_encoding(self, angle): 204 | """ 205 | Angle encoding using sin and cos 206 | angle: tensor of shape (batch_size, n_projections) 207 | output: tensor of shape (batch_size, n_projections, self.transformer_positional_encoding_size) 208 | """ 209 | 210 | d = self.slice_transformer_positional_encoding_size 211 | base_frequency = self.slice_transformer_positional_encoding_base 212 | 213 | samples = angle.shape[0] 214 | 215 | pe = torch.zeros(samples, angle.shape[1], d, device = angle.device, dtype = angle.dtype) 216 | 217 | for i in range(d//2): 218 | pe[:,:,2*i] = torch.sin(angle[0,:]*(180/torch.pi)/(base_frequency**(2*i/d))) 219 | pe[:,:,2*i+1] = torch.cos(angle[0,:]*(180/torch.pi)/(base_frequency**(2*i/d))) 220 | 221 | return pe 222 | 223 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script containing utility functions 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | from typing import List 8 | import torch.nn.functional as F 9 | from typing import List, Union 10 | 11 | 12 | ALIGN_CORNERS = True 13 | 14 | 15 | def generate_projections_location(points: torch.FloatTensor, angles: torch.FloatTensor): 16 | """ 17 | Generate the points in the local coordinates of a plane passing through the origin 18 | tilted by angle along y axis 19 | points: (n_points,3) tensor of points 20 | angles: (n_projections) tensor of angles 21 | 22 | Note: the projections is only for single axis tilts along the y axis 23 | tilts between (-pi/2,pi/2) are valid pi/2 is not valid 24 | """ 25 | 26 | n_projections = len(angles) 27 | n_points = len(points) 28 | 29 | normals = torch.zeros((n_projections,3), dtype = points.dtype).to(points.device) 30 | 31 | normals[:,0] = torch.cos(angles) 32 | normals[:,1] = -torch.sin(angles) 33 | 34 | projection = points@ normals.T 35 | points_proj = points[:,None,:] - projection[:,:,None] * normals[None,:,:] 36 | points_proj[:,:,1] = points_proj[:,:,1]/torch.cos(angles)[None,:] 37 | local_coords = points_proj[:,:,1:] 38 | 39 | return local_coords.permute(1,0,2) 40 | 41 | 42 | 43 | 44 | 45 | 46 | def volume_sampler(volume,volume_location): 47 | """ 48 | Uses the sampling function: 49 | volume : HxWxD 50 | volume_location: (N_points,3) 51 | outputs = N_points 52 | """ 53 | # volume_location = volume_location[:,[2,1,0]] 54 | # return interpolate_grid_3d(volume[None],volume_location/2+0.5,CUBIC_B_SPLINE_MATRIX.to(volume.device)).squeeze() 55 | 56 | 57 | 58 | return torch.nn.functional.grid_sample(volume.unsqueeze(0).unsqueeze(0), 59 | volume_location.unsqueeze(0).unsqueeze(0).unsqueeze(0), 60 | mode='bilinear', 61 | padding_mode='zeros', 62 | align_corners=ALIGN_CORNERS).squeeze() 63 | 64 | 65 | def image_sampler(image,image_location): 66 | """ 67 | image: HXW 68 | img_location: N_points,2 69 | outputs: N_points 70 | """ 71 | # image_location = image_location[:,[1,0]] 72 | # return interpolate_grid_2d(image[None],image_location/2+0.5,CUBIC_B_SPLINE_MATRIX.to(image.device)).squeeze() 73 | 74 | 75 | 76 | return torch.nn.functional.grid_sample(image.unsqueeze(0).unsqueeze(0), 77 | image_location.unsqueeze(0).unsqueeze(0), 78 | mode='bilinear', 79 | padding_mode='zeros', 80 | align_corners=ALIGN_CORNERS).squeeze() 81 | 82 | 83 | 84 | def generate_patches_from_volume_location(volume_location : torch.Tensor, 85 | vol: Union[torch.Tensor, List[torch.Tensor]], 86 | projections: Union[torch.Tensor, List[torch.Tensor]], 87 | angles:Union[torch.Tensor, List[torch.Tensor]], 88 | patch_size: int, 89 | scale: torch.Tensor = None, 90 | patch_scale: torch.Tensor = 1, 91 | discrete_sampling: bool = False, 92 | scaled_patches: bool = False): 93 | 94 | """ 95 | using the volume location and projections generate the projections patches 96 | volume_location: (n_points,3) tensor of points between -1 and 1 97 | Note: This now works for rectangular volumes 98 | TODO: make vol optional 99 | 100 | 101 | """ 102 | 103 | device = volume_location.device 104 | dtype = volume_location.dtype 105 | 106 | if type(projections) == list: 107 | n_1 = projections[0].shape[1] 108 | n_2 = projections[0].shape[2] 109 | n_projections = projections[0].shape[0] 110 | else: 111 | n_1 = projections.shape[1] 112 | n_2 = projections.shape[2] 113 | n_projections = projections.shape[0] 114 | #projection_size = projections.shape[1] 115 | n_points = volume_location.shape[0] 116 | 117 | proj_scale = torch.tensor([n_1,n_2],device=device) 118 | proj_scale = proj_scale/max(proj_scale) 119 | vol_samples = None 120 | 121 | 122 | if ALIGN_CORNERS: 123 | x_patch = (torch.linspace(-1,1,n_2,device=device,dtype = dtype)[:patch_size]) 124 | x_patch = x_patch - x_patch[patch_size//2] 125 | y_patch = (torch.linspace(-1,1,n_1,device=device,dtype = dtype)[:patch_size]) 126 | y_patch = y_patch - y_patch[patch_size//2] 127 | else: 128 | x_patch = (torch.arange(-patch_size//2+1,patch_size//2+1, device=device , dtype = dtype)*2/n_2) 129 | y_patch = (torch.arange(-patch_size//2+1,patch_size//2+1,device=device,dtype = dtype)*2/n_1) 130 | 131 | xx_pathc, yy_pathc = torch.meshgrid(x_patch, y_patch, indexing='xy') 132 | points_patch = torch.zeros((patch_size*patch_size,2),device=device,dtype = dtype) 133 | points_patch[:,0] = xx_pathc.flatten() 134 | points_patch[:,1] = yy_pathc.flatten() 135 | 136 | 137 | if type(angles) == list: 138 | projection_locations = [] 139 | for i,angle_vals in enumerate(angles): 140 | projection_centers = generate_projections_location(volume_location,angle_vals) 141 | projection_centers[:,:,0] = projection_centers[:,:,0]/proj_scale[1] 142 | projection_centers[:,:,1] = projection_centers[:,:,1]/proj_scale[0] 143 | if scaled_patches: 144 | # Compute the scaling value 145 | n_curr = projections[i].shape[-1] 146 | p_scale = n_2/n_curr 147 | projection_locations.append(projection_centers.unsqueeze(2) + patch_scale*points_patch.unsqueeze(0).unsqueeze(0)*p_scale) 148 | else: 149 | projection_locations.append(projection_centers.unsqueeze(2) + patch_scale*points_patch.unsqueeze(0).unsqueeze(0)) 150 | else: 151 | projection_centers = generate_projections_location(volume_location,angles) 152 | projection_centers[:,:,0] = projection_centers[:,:,0]/proj_scale[1] 153 | projection_centers[:,:,1] = projection_centers[:,:,1]/proj_scale[0] 154 | # Generate patch coordinates 155 | #print(projection_centers) 156 | projection_locations = projection_centers.unsqueeze(2) + patch_scale*points_patch.unsqueeze(0).unsqueeze(0) 157 | 158 | 159 | 160 | 161 | if type(projections) == list: 162 | 163 | patches_list = [] 164 | 165 | for index, projections_i in enumerate(projections): 166 | if type(projection_locations) == list: 167 | n_projections = len(angles[index]) 168 | patches = torch.zeros((n_projections,n_points,patch_size,patch_size),device=device,dtype = dtype) 169 | projection_locations_current = projection_locations[index] 170 | else: 171 | patches = torch.zeros((n_projections,n_points,patch_size,patch_size),device=device,dtype = dtype) 172 | projection_locations_current =projection_locations 173 | 174 | for i in range(n_projections): 175 | i_patch_points = projection_locations_current[i].reshape(-1,2) 176 | pp = image_sampler(projections_i[i],i_patch_points) 177 | # pp = torch.nn.functional.grid_sample(projections_i[i].unsqueeze(0).unsqueeze(0), 178 | # i_patch_points.unsqueeze( 179 | # 0).unsqueeze(0),mode='bilinear',padding_mode='zeros',align_corners=ALIGN_CORNERS).squeeze() 180 | 181 | patches[i,:,:,:] = pp.reshape(n_points,patch_size,patch_size) 182 | patches = patches.permute(1,0,2,3) 183 | 184 | patches_list.append(patches) 185 | 186 | return vol_samples,patches_list 187 | else: 188 | 189 | patches = torch.zeros((n_projections,n_points,patch_size,patch_size),device=device,dtype = dtype) 190 | 191 | for i in range(n_projections): 192 | i_patch_points = projection_locations[i].reshape(-1,2) 193 | pp = image_sampler(projections[i],i_patch_points) 194 | # pp = torch.nn.functional.grid_sample(projections[i].unsqueeze(0).unsqueeze(0), 195 | # i_patch_points.unsqueeze( 196 | # 0).unsqueeze(0),mode='bilinear',padding_mode='zeros',align_corners=ALIGN_CORNERS).squeeze() 197 | 198 | # if i == 30: 199 | # #print(i_patch_points) 200 | # #print(projections[i].max()) 201 | 202 | # #print(i_patch_points.shape) 203 | # #pts = torch.cat([pp.unsqueeze(1),i_patch_points],dim=1) 204 | # #print(pts) 205 | 206 | patches[i,:,:,:] = pp.reshape(n_points,patch_size,patch_size) 207 | 208 | patches = patches.permute(1,0,2,3) 209 | return vol_samples,patches 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | def custom_ramp_fft(x,t_cust, use_splits = False): 218 | """ 219 | ramp filtering using torch fft 220 | x: (n_projections, N , N) tensor 221 | t_cust: (2*N) tensor 222 | use_splits: use the fft for each projections separately so that it can be used in low GPU cards 223 | """ 224 | if use_splits: 225 | projection_filtered = [] 226 | for proj in x: 227 | projection_fft = torch.fft.fftn(proj, dim=(-1), s = t_cust.shape[0]) 228 | projection_fft = projection_fft*t_cust[None,:] 229 | projection_filtered.append(torch.fft.ifftn(projection_fft, dim=(-1), s =t_cust.shape[0]).real[:,0:x.shape[-1]]) 230 | projection_filtered = torch.stack(projection_filtered,dim=0) 231 | else: 232 | projection_fft = torch.fft.fftn(x, dim=(-1), s = t_cust.shape[0]) 233 | projection_fft = projection_fft*t_cust[None,None,:] 234 | projection_filtered = torch.fft.ifftn(projection_fft, dim=(-1), s = t_cust.shape[0]).real[:,:,0:x.shape[-1]] 235 | return projection_filtered 236 | 237 | 238 | 239 | 240 | def vol_normalize(vol,min_Val,max_Val): 241 | vol = (vol- np.min(vol))/(np.max(vol)-np.min(vol)) 242 | vol = vol*(max_Val-min_Val)+min_Val 243 | return vol 244 | 245 | 246 | 247 | 248 | 249 | def downsample_anti_aliasing(vol_t, scale=0.5): 250 | # anti-aliasing in the frequency domain 251 | vol_fft = torch.fft.fftn(vol_t) 252 | vol_shape = ((torch.tensor(vol_t.shape)//2)*scale).int() 253 | vol_fft[vol_shape[0]:-vol_shape[0], :, :] = 0 254 | vol_fft[:, vol_shape[1]:-vol_shape[1], :] = 0 255 | vol_fft[:, :, vol_shape[2]:-vol_shape[2]] = 0 256 | vol_filtered = torch.fft.ifftn(vol_fft).real 257 | 258 | del vol_fft 259 | 260 | vol_downsampled = F.interpolate(vol_filtered[None,None], scale_factor=scale, mode='trilinear', align_corners=False) 261 | return vol_downsampled[0,0] 262 | 263 | 264 | 265 | 266 | 267 | 268 | def normalize_numpy(inp, batchwise=True): 269 | """ 270 | Normalize the input by substracting the mean and divided by the std. 271 | 272 | INPUT: 273 | -inp, (B,*n): input numpy array 274 | -batchwise, bool: if True, normalize each batch elements differently 275 | """ 276 | if batchwise: 277 | s = np.std(inp,axis=(1,2),keepdims=True) 278 | out = (inp - np.mean(inp,axis=0,keepdims=True))/s 279 | else: 280 | s = inp.std() 281 | if s!=0: 282 | out = (inp - inp.mean())/s 283 | return out 284 | 285 | 286 | def normalize_torch(inp, batchwise=True): 287 | """ 288 | Normalize the input by substracting the mean and divided by the std. 289 | 290 | INPUT: 291 | -inp, (B,*n): inputdd torch tensor 292 | -batchwise, bool: if True, normalize each batch elements differently 293 | """ 294 | if batchwise: 295 | s = torch.std(inp,dim=(1,2),keepdim=True) 296 | out = (inp - torch.mean(inp,dim=0,keepdim=True))/s 297 | else: 298 | s = inp.std() 299 | if s!=0: 300 | out = (inp - inp.mean())/s 301 | return out 302 | 303 | def SNR(x_ref,x): 304 | dif = np.sum((x_ref-x)**2) 305 | nref = np.sum(x_ref**2) 306 | res=10*np.log10((nref+1e-16)/(dif+1e-16)) 307 | return res 308 | 309 | def torch_filter(img: torch.Tensor,img_filter: torch.Tensor, padding = 0): 310 | """ 311 | Filter the image using the filter using conv2d 312 | img: (N,N) tensor or (B,N,N) tensor (batched image 313 | img_filter: (Nf,Nf) tensor 314 | """ 315 | if len(img.shape) == 2: 316 | return F.conv2d(img.unsqueeze(0).unsqueeze(0),img_filter.unsqueeze(0).unsqueeze(0),padding=padding).squeeze(0).squeeze(0) 317 | else: 318 | return F.conv2d(img.unsqueeze(1),img_filter.unsqueeze(0).unsqueeze(0),padding=padding).squeeze() 319 | 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /models/slicemlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | slice mlp model for tomography reconstruction 3 | """ 4 | from typing import Union, List 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.standardmlp import standardMLP 9 | 10 | 11 | class sliceMLP(nn.Module): 12 | """ 13 | MLP along the slices of the projection and an mlp to combine the slices 14 | """ 15 | def __init__(self, n_projections: int, 16 | mlp_output: int, 17 | patch_size: int, 18 | mlp_layers : int, 19 | mlp_hidden : int, 20 | dropout : int = 0, 21 | batch_norm : bool = False, 22 | learn_residual : bool = False, 23 | skip_connection : bool = False, 24 | slice_index: int = 2, 25 | compare_index: int = 3): 26 | super(sliceMLP, self).__init__() 27 | #TODO: Add skip connection 28 | 29 | self.n_projections = n_projections 30 | self.mlp_output = mlp_output 31 | self.patch_size = patch_size 32 | self.dropout = dropout 33 | self.mlp_layers = mlp_layers 34 | self.mlp_hidden = mlp_hidden 35 | self.slice_index = slice_index 36 | self.compare_index = compare_index 37 | self.batch_norm = batch_norm 38 | self.learn_residual = learn_residual 39 | self.skip_connection = skip_connection 40 | 41 | self.slice_mlp= standardMLP(input_size = self.patch_size*self.n_projections, 42 | output_size = self.mlp_output, 43 | mlp_hidden = self.mlp_hidden, 44 | mlp_layers = self.mlp_layers, 45 | batch_norm = self.batch_norm, 46 | dropout = self.dropout, 47 | learn_residual = self.learn_residual, 48 | skip_connection = self.skip_connection) 49 | 50 | self.combination_mlp = standardMLP(input_size = self.patch_size*self.mlp_output, 51 | output_size = 1, 52 | mlp_hidden = self.mlp_hidden, 53 | mlp_layers = self.mlp_layers, 54 | batch_norm = self.batch_norm, 55 | dropout = self.dropout, 56 | learn_residual=self.learn_residual, 57 | skip_connection = self.skip_connection) 58 | 59 | 60 | def forward(self, x): 61 | """ 62 | Forward pass of the network 63 | """ 64 | if self.learn_residual: 65 | mid_pix = x[:,:,self.patch_size//2,self.patch_size//2] 66 | mid_pix_sum = torch.mean(mid_pix,1) 67 | x = x.permute(0,self.slice_index,1,self.compare_index).contiguous() 68 | x = x.reshape(-1,self.patch_size*self.n_projections) 69 | x = self.slice_mlp(x).reshape(-1,self.patch_size,self.mlp_output).reshape(-1,self.patch_size*self.mlp_output) 70 | 71 | x = self.combination_mlp(x) 72 | if self.learn_residual: 73 | x = x + mid_pix_sum.unsqueeze(1) 74 | return x 75 | 76 | 77 | class sliceMlp_v2(nn.Module): 78 | """ 79 | MLP along the slices of the projection and an mlp to combine the slices with a different architecture for 80 | encoding and combining 81 | """ 82 | def __init__(self, n_projections: int, 83 | mlp_output: int, 84 | patch_size: int, 85 | slice_index: int = 2, 86 | compare_index: int = 3, 87 | learn_residual: bool = False, 88 | slice_mlp_layers : int = 5, 89 | slice_mlp_hidden : int = 256, 90 | slice_dropout : int = 0, 91 | slice_batch_norm : bool = False, 92 | slice_learn_residual : bool = False, 93 | slice_skip_connection : bool = False, 94 | combine_mlp_layers : int = 5, 95 | combine_mlp_hidden : int = 256, 96 | combine_dropout : int = 0, 97 | combine_batch_norm : bool = False, 98 | combine_learn_residual : bool = False, 99 | combine_skip_connection : bool = False, 100 | output_size : int = 1): 101 | super(sliceMlp_v2, self).__init__() 102 | 103 | self.n_projections = n_projections 104 | self.mlp_output = mlp_output 105 | self.patch_size = patch_size 106 | self.slice_index = slice_index 107 | self.learn_residual = learn_residual 108 | self.compare_index = compare_index 109 | 110 | self.slice_mlp= standardMLP(input_size = self.patch_size*self.n_projections, 111 | output_size = self.mlp_output, 112 | mlp_hidden = slice_mlp_hidden, 113 | mlp_layers = slice_mlp_layers, 114 | batch_norm = slice_batch_norm, 115 | dropout = slice_dropout, 116 | learn_residual = slice_learn_residual, 117 | skip_connection = slice_skip_connection) 118 | 119 | self.combination_mlp = standardMLP(input_size = self.patch_size*self.mlp_output, 120 | output_size = output_size, 121 | mlp_hidden = combine_mlp_hidden, 122 | mlp_layers = combine_mlp_layers, 123 | batch_norm = combine_batch_norm, 124 | dropout = combine_dropout, 125 | learn_residual=combine_learn_residual, 126 | skip_connection = combine_skip_connection) 127 | 128 | 129 | def forward(self, x): 130 | """ 131 | Forward pass of the network 132 | """ 133 | if self.learn_residual: 134 | mid_pix = x[:,:,self.patch_size//2,self.patch_size//2] 135 | mid_pix_sum = torch.mean(mid_pix,1) 136 | 137 | 138 | x = x.permute(0,self.slice_index,1,self.compare_index).contiguous() 139 | x = x.reshape(-1,self.patch_size*self.n_projections) 140 | x = self.slice_mlp(x).reshape(-1,self.patch_size,self.mlp_output).reshape(-1,self.patch_size*self.mlp_output) 141 | 142 | x = self.combination_mlp(x) 143 | if self.learn_residual: 144 | x = x + mid_pix_sum.unsqueeze(1) 145 | return x 146 | 147 | class sliceMlp_mulitnet(nn.Module): 148 | """ 149 | Separate mlp for each slice of the patch and then combine the output of the mlp with another mlp 150 | """ 151 | def __init__(self, n_projections: int, 152 | mlp_output: int, 153 | patch_size: int, 154 | slice_index: int = 2, 155 | compare_index: int = 3, 156 | learn_residual: bool = False, 157 | slice_mlp_layers : int = 5, 158 | slice_mlp_hidden : int = 256, 159 | slice_dropout : int = 0, 160 | slice_batch_norm : bool = False, 161 | slice_learn_residual : bool = False, 162 | slice_skip_connection : bool = False, 163 | slice_bias : bool = True, 164 | combine_mlp_layers : int = 5, 165 | combine_mlp_hidden : int = 256, 166 | combine_dropout : int = 0, 167 | combine_batch_norm : bool = False, 168 | combine_learn_residual : bool = False, 169 | combine_skip_connection : bool = False, 170 | combine_bias : bool = True, 171 | output_size : int = 1): 172 | super(sliceMlp_mulitnet, self).__init__() 173 | 174 | self.n_projections = n_projections 175 | self.mlp_output = mlp_output 176 | self.patch_size = patch_size 177 | self.slice_index = slice_index 178 | self.learn_residual = learn_residual 179 | self.compare_index = compare_index 180 | self.slice_mlps = nn.ModuleList() 181 | 182 | for i in range(patch_size): 183 | self.slice_mlps.append(standardMLP(input_size = self.patch_size*self.n_projections, 184 | output_size = self.mlp_output, 185 | mlp_hidden = slice_mlp_hidden, 186 | mlp_layers = slice_mlp_layers, 187 | batch_norm = slice_batch_norm, 188 | dropout = slice_dropout, 189 | learn_residual = slice_learn_residual, 190 | skip_connection = slice_skip_connection, 191 | bias = slice_bias)) 192 | 193 | self.combination_mlp = standardMLP(input_size = self.patch_size*self.mlp_output, 194 | output_size = output_size, 195 | mlp_hidden = combine_mlp_hidden, 196 | mlp_layers = combine_mlp_layers, 197 | batch_norm = combine_batch_norm, 198 | dropout = combine_dropout, 199 | learn_residual=combine_learn_residual, 200 | skip_connection = combine_skip_connection, 201 | bias = combine_bias) 202 | 203 | 204 | def forward(self, x): 205 | """ 206 | Forward pass of the network 207 | """ 208 | if self.learn_residual: 209 | mid_pix = x[:,:,self.patch_size//2,self.patch_size//2] 210 | mid_pix_sum = torch.mean(mid_pix,1) 211 | x = x.permute(0,self.slice_index,1,self.compare_index).contiguous() 212 | x = x.reshape(-1,self.patch_size,self.patch_size*self.n_projections) 213 | x = [self.slice_mlps[i](x[:,i]) for i in range(self.patch_size)] 214 | x = torch.cat(x,1) 215 | x = self.combination_mlp(x) 216 | if self.learn_residual: 217 | x = x + mid_pix_sum.unsqueeze(1) 218 | #print(x.shape) 219 | return x 220 | 221 | 222 | class sliceMlp_multinet_multiproj(nn.Module): 223 | """ 224 | Separate mlp for each slice of the patch and then combine the output of the mlp with another mlp 225 | """ 226 | def __init__(self, n_projections: int, 227 | mlp_output: int, 228 | patch_size: int, 229 | slice_index: int = 2, 230 | compare_index: int = 3, 231 | n_series: int = 2, 232 | learn_residual: bool = False, 233 | slice_mlp_layers : int = 5, 234 | slice_mlp_hidden : int = 256, 235 | slice_dropout : int = 0, 236 | slice_batch_norm : bool = False, 237 | slice_learn_residual : bool = False, 238 | slice_skip_connection : bool = False, 239 | combine_mlp_layers : int = 5, 240 | combine_mlp_hidden : int = 256, 241 | combine_dropout : int = 0, 242 | combine_batch_norm : bool = False, 243 | combine_learn_residual : bool = False, 244 | combine_skip_connection : bool = False, 245 | output_size : int = 1): 246 | super(sliceMlp_multinet_multiproj, self).__init__() 247 | 248 | self.n_projections = n_projections 249 | self.mlp_output = mlp_output 250 | self.patch_size = patch_size 251 | self.slice_index = slice_index 252 | self.learn_residual = learn_residual 253 | self.compare_index = compare_index 254 | self.n_series = n_series 255 | self.slice_mlps = nn.ModuleList() 256 | 257 | for i in range(patch_size*n_series): 258 | self.slice_mlps.append(standardMLP(input_size = self.patch_size*self.n_projections, 259 | output_size = self.mlp_output, 260 | mlp_hidden = slice_mlp_hidden, 261 | mlp_layers = slice_mlp_layers, 262 | batch_norm = slice_batch_norm, 263 | dropout = slice_dropout, 264 | learn_residual = slice_learn_residual, 265 | skip_connection = slice_skip_connection)) 266 | 267 | self.combination_mlp = standardMLP(input_size = self.patch_size*self.mlp_output*self.n_series, 268 | output_size = output_size, 269 | mlp_hidden = combine_mlp_hidden, 270 | mlp_layers = combine_mlp_layers, 271 | batch_norm = combine_batch_norm, 272 | dropout = combine_dropout, 273 | learn_residual=combine_learn_residual, 274 | skip_connection = combine_skip_connection) 275 | 276 | 277 | def forward(self, x): 278 | """ 279 | Forward pass of the network 280 | """ 281 | if self.learn_residual: 282 | mid_pix = x[:,:self.n_projections,self.patch_size//2,self.patch_size//2] 283 | mid_pix_sum = torch.mean(mid_pix,1) 284 | 285 | 286 | x = x.permute(0,self.slice_index,1,self.compare_index).contiguous() 287 | x_op = [] 288 | for i in range(self.n_series): 289 | x_sub = x[:,:,i*(self.n_projections):(i+1)*(self.n_projections)] 290 | x_sub = x_sub.reshape(-1,self.patch_size,self.patch_size*self.n_projections) 291 | x_sub = [self.slice_mlps[i](x_sub[:,i]) for i in range(self.patch_size)] 292 | x_sub = torch.cat(x_sub,1) 293 | x_op.append(x_sub) 294 | x_op = torch.cat(x_op,1) 295 | x_op = self.combination_mlp(x_op) 296 | if self.learn_residual: 297 | x_op = x_op + mid_pix_sum.unsqueeze(1) 298 | #print(x.shape) 299 | return x_op 300 | 301 | 302 | 303 | 304 | --------------------------------------------------------------------------------