├── DeepCAD_pytorch ├── results │ └── results.txt ├── pth │ ├── ModelForPytorch │ │ └── DownloadedModel │ └── README.md ├── datasets │ ├── DataForPytorch │ │ └── DownloadedData │ └── README.md ├── network.py ├── script.py ├── utils.py ├── train.py ├── test.py ├── buildingblocks.py ├── model_3DUnet.py └── data_process.py ├── DeepCAD_Fiji ├── DeepCAD_java │ ├── DeepCAD_java.txt │ └── README.md ├── DeepCAD_tensorflow │ ├── datasets │ │ └── datasets.txt │ ├── results │ │ └── results.txt │ ├── DeepCAD_model │ │ └── DeepCAD_model.txt │ ├── script.py │ ├── README.md │ ├── utils.py │ ├── basic_ops.py │ ├── network.py │ ├── test_pb.py │ ├── main.py │ └── data_process.py ├── DeepCAD_Fiji_plugin │ ├── DeepCAD-0.3.0.jar │ ├── DeepCAD-0.3.6.jar │ └── README.md └── README.md ├── images ├── fiji.png ├── logo.PNG ├── soma.png ├── dendrite.png ├── parameter.png ├── schematic.png └── cross-system.png ├── dataset └── README.md ├── README.md └── LICENSE /DeepCAD_pytorch/results/results.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_java/DeepCAD_java.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/datasets/datasets.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/results/results.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/pth/ModelForPytorch/DownloadedModel: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/datasets/DataForPytorch/DownloadedData: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/DeepCAD_model/DeepCAD_model.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_java/README.md: -------------------------------------------------------------------------------- 1 | # Java source code of DeepCAD Fiji plugin 2 | -------------------------------------------------------------------------------- /images/fiji.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/fiji.png -------------------------------------------------------------------------------- /images/logo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/logo.PNG -------------------------------------------------------------------------------- /images/soma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/soma.png -------------------------------------------------------------------------------- /images/dendrite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/dendrite.png -------------------------------------------------------------------------------- /images/parameter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/parameter.png -------------------------------------------------------------------------------- /images/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/schematic.png -------------------------------------------------------------------------------- /images/cross-system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/cross-system.png -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.0.jar -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.6.jar -------------------------------------------------------------------------------- /DeepCAD_pytorch/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data download guide 2 | Download the demo data(.tif file) [[DataForPytorch](https://drive.google.com/drive/folders/1w9v1SrEkmvZal5LH79HloHhz6VXSPfI_)] and put it into the *./DataForPytorch* folder. 3 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/pth/README.md: -------------------------------------------------------------------------------- 1 | # Model download guide 2 | Download the pre-trained model(.pth file and .yaml file) [[ModelForPytorch](https://drive.google.com/drive/folders/12LEFsAopTolaRyRpJtFpzOYH3tBZMGUP)] and put it into the *./ModelForPytorch* folder. 3 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_Fiji_plugin/README.md: -------------------------------------------------------------------------------- 1 | ## Packaged plugin file (.jar) for Fiji 2 | ### Please download and install the latest version. 3 | 4 | ________________________________________________________________________________________________________________________ 5 | ### Logs 6 | - 0.3.6 - 2020-11-17 7 | 8 | Fixed an error for language display. 9 | 10 | - 0.3.0 - 2020-11-17 11 | 12 | Initial relese. 13 | 14 | 15 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | # for train 5 | os.system('python main2.py --GPU 0 --img_h 64 --img_w 64 --img_s 320 --train_epochs 30 --datasets_folder DataForPytorch --normalize_factor 1 --lr 0.00005 --train_datasets_size 10000') 6 | 7 | # for test 8 | os.system('python test_pb2.py --GPU 3 --denoise_model pb_unet3d_10AMP_0.3_0001_20201108-2139 \ 9 | --datasets_folder 10AMP_0.3_0001 --model_name 25_1000') 10 | 11 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/network.py: -------------------------------------------------------------------------------- 1 | from model_3DUnet import UNet3D 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class Network_3D_Unet(nn.Module): 7 | def __init__(self, UNet_type = '3DUNet', in_channels=1, out_channels=1, final_sigmoid = True): 8 | super(Network_3D_Unet, self).__init__() 9 | 10 | self.in_channels = in_channels 11 | self.out_channels = out_channels 12 | self.final_sigmoid = final_sigmoid 13 | 14 | if UNet_type == '3DUNet': 15 | self.Generator = UNet3D( in_channels = in_channels, 16 | out_channels = out_channels, 17 | final_sigmoid = final_sigmoid) 18 | 19 | def forward(self, x): 20 | fake_x = self.Generator(x) 21 | return fake_x 22 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | 5 | flag = sys.argv[1] 6 | 7 | if flag == 'train': 8 | # for train 9 | os.system('python train.py --datasets_folder DataForPytorch --lr 0.00005 \ 10 | --img_h 150 --img_w 150 --img_s 150 --gap_h 60 --gap_w 60 --gap_s 60 \ 11 | --n_epochs 20 --GPU 0 --normalize_factor 1 \ 12 | --train_datasets_size 1200 --select_img_num 10000') 13 | 14 | if flag == 'test': 15 | # for test 16 | os.system('python test.py --denoise_model ModelForPytorch \ 17 | --datasets_folder DataForPytorch \ 18 | --test_datasize 6000') 19 | 20 | if flag == 'all': 21 | # train and then test 22 | os.system('python train.py --datasets_folder DataForPytorch --lr 0.00005 \ 23 | --img_h 150 --img_w 150 --img_s 150 --gap_h 60 --gap_w 60 --gap_s 60 \ 24 | --n_epochs 20 --GPU 0 --normalize_factor 1 \ 25 | --train_datasets_size 1200 --select_img_num 10000') 26 | 27 | os.system('python test.py --denoise_model ModelForPytorch \ 28 | --datasets_folder DataForPytorch \ 29 | --test_datasize 6000') 30 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/README.md: -------------------------------------------------------------------------------- 1 | ## Tensorflow implementation compatible with Fiji plugin 2 | 3 | ### Directory structure 4 | ``` 5 | DeepCAD_tensorflow #Tensorflow implementation compatible with Fiji plugin# 6 | |---basic_ops.py 7 | |---train.py 8 | |---network.py 9 | |---script.py 10 | |---test_pb.py 11 | |---data_process.py 12 | |---datasets 13 | |---|---qwd_7 #project_name# 14 | |---|---|---train_raw.tif #raw data for train# 15 | |---DeepCAD_model 16 | |---|---qwd_7_20201115-0913 17 | |---|---|--- #pb model# 18 | |---results 19 | |---|---#Results of training process and final test# 20 | ``` 21 | 22 | ### Environment 23 | 24 | * Ubuntu 16.04 25 | * Python 3.6 26 | * Tensorflow 1.4.0 27 | * NVIDIA GPU + CUDA 28 | 29 | ### Environment configuration 30 | 31 | Open the terminal of ubuntu system. 32 | 33 | * Create anaconda environment 34 | 35 | ``` 36 | $ conda create -n tensorflow python=3.6 37 | ``` 38 | 39 | * Install Tensorflow 40 | 41 | ``` 42 | $ source activate tensorflow 43 | $ pip install tensorflow-gpu==1.4.0 44 | ``` 45 | 46 | ### Training 47 | 48 | ``` 49 | $ source activate tensorflow 50 | $ python main.py --GPU 0 --img_h 64 --img_w 64 --img_s 320 --train_epochs 30 --datasets_folder DataForPytorch --normalize_factor 1 --lr 0.00005 --train_datasets_size 1000 51 | ``` 52 | 53 | Parameters can be modified as required. 54 | 55 | ``` 56 | $ python main.py --GPU #GPU index# --img_h #stack height# --img_w #stack width# --img_s #stack length# --train_epochs #training epoch number# --datasets_folder #project name# 57 | ``` 58 | 59 | The pre-trained model is saved at *DeepCAD_Fiji/DeepCAD_tensorflow/DeepCAD_model/*. 60 | 61 | #### Test 62 | 63 | Run the script.py (test part) to begin your test. Parameters saved in the .yaml file will be automatically loaded. 64 | 65 | ``` 66 | $ source activate pytorch 67 | $ python test_pb.py --GPU 3 --denoise_model ModelForTestPlugin --datasets_folder DataForPytorch --model_name 25_1000 --test_datasize 500 68 | ``` 69 | 70 | Parameters can be modified as required. 71 | 72 | ``` 73 | $ os.system('python test.py --denoise_model #model name# --test_datasize #the number of images used for test#') 74 | ``` 75 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/README.md: -------------------------------------------------------------------------------- 1 | # DeepCAD Fiji plugin 2 | 3 | To ameliorate the difficulty of using our deep self-supervised learning-based method, we developed a user-friendly Fiji plugin, which is easy to install and convenient to use (has been tested on a Windows desktop with Intel i9 CPU and 128G RAM). Researchers without expertise in computer science and machine learning can manage it in a very short time. 4 | 5 | 6 | 7 | ### Install Fiji plugin 8 | To avoid unnecessary troubles, the following steps are recommended for installation: 9 | 1. Download and install Fiji from the [[Fiji download page](https://imagej.net/Fiji/Downloads)]. Install the CSBDeep dependency following the steps at [[CSBDeep in Fiji – Installation](https://github.com/CSBDeep/CSBDeep_website/wiki/CSBDeep-in-Fiji-%E2%80%93-Installation)] 10 | 2. Download the packaged plugin file (.jar) from [[DeepCAD_Fiji/DeepCAD_Fiji_plugin](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji/DeepCAD_Fiji_plugin)]. 11 | 3. Install the plugin via **Fiji > Plugins > Install**. 12 | 13 | We provide lightweight data for test. Please download the demo data (.tif) [[DataForTestPlugin](https://drive.google.com/drive/folders/1JVbuCwIxRKr4_NNOD7fY61NnVeCA2UnP)] and the pre-trained model (.zip) [[ModelForTestPlugin](https://drive.google.com/drive/folders/14wSuMFhWKxW5Oq93GHxTsGixpB3T4lOL)]. 14 | 15 | ### Use Fiji plugin 16 | 17 | 1. Open Fiji. 18 | 2. Open the calcium imaging stack to be denoised. 19 | 3. Open the plugin at **Plugins > DeepCAD**. The six parameters will be shown on the panel (with default values and no changes are required unless necessary). 20 | 4. Specify the pre-trained model using the '*Browse*' button (select the .zip file). 21 | 5. Click ‘OK’ and the denoised result will be displayed in another window after processing (processing time depends on the data size). 22 | 23 | 24 | 25 | ### Train a customized model for your microscope 26 | 27 | Since imaging systems and experiment conditions varies, a customized DeepCAD model trained on specified data is recommended for optimal performance. A Tensorflow implementation of DeepCAD compatible with the plugin is made publicly accessible at *[DeepCAD_Fiji/DeepCAD_tensorflow](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji/DeepCAD_tensorflow)*. 28 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import shutil 5 | import sys 6 | import io 7 | 8 | import h5py 9 | from PIL import Image 10 | import numpy as np 11 | import scipy.sparse as sparse 12 | import matplotlib.pyplot as plt 13 | import uuid 14 | import warnings 15 | import pylab 16 | import cv2 17 | import yaml 18 | 19 | def save_yaml(opt, yaml_name): 20 | para = {} 21 | para["train_epochs"] = opt.train_epochs 22 | para["datasets_folder"] = opt.datasets_folder 23 | para["GPU"] = opt.GPU 24 | para["img_s"] = opt.img_s 25 | para["img_w"] = opt.img_w 26 | para["img_h"] = opt.img_h 27 | para["gap_h"] = opt.gap_h 28 | para["gap_w"] = opt.gap_w 29 | para["gap_s"] = opt.gap_s 30 | para["lr"] = opt.lr 31 | para["normalize_factor"] = opt.normalize_factor 32 | with open(yaml_name, 'w') as f: 33 | data = yaml.dump(para, f) 34 | 35 | 36 | def read_yaml(opt, yaml_name): 37 | with open(yaml_name) as f: 38 | para = yaml.load(f, Loader=yaml.FullLoader) 39 | print(para) 40 | opt.datasets_folder = para["datasets_folder"] 41 | opt.GPU = para["GPU"] 42 | opt.img_s = para["img_s"] 43 | opt.img_w = para["img_w"] 44 | opt.img_h = para["img_h"] 45 | # opt.gap_h = para["gap_h"] 46 | # opt.gap_w = para["gap_w"] 47 | # opt.gap_s = para["gap_s"] 48 | opt.normalize_factor = para["normalize_factor"] 49 | 50 | def name2index(opt, input_name, num_h, num_w, num_s): 51 | # print(input_name) 52 | name_list = input_name.split('_') 53 | # print(name_list) 54 | z_part = name_list[-1] 55 | # print(z_part) 56 | y_part = name_list[-2] 57 | # print(y_part) 58 | x_part = name_list[-3] 59 | # print(x_part) 60 | z_index = int(z_part.replace('z','')) 61 | y_index = int(y_part.replace('y','')) 62 | x_index = int(x_part.replace('x','')) 63 | # print("x_index ---> ",x_index,"y_index ---> ", y_index,"z_index ---> ", z_index) 64 | 65 | cut_w = (opt.img_w - opt.gap_w)/2 66 | cut_h = (opt.img_h - opt.gap_h)/2 67 | cut_s = (opt.img_s - opt.gap_s)/2 68 | # print("z_index ---> ",cut_w, "cut_h ---> ",cut_h, "cut_s ---> ",cut_s) 69 | if x_index == 0: 70 | stack_start_w = x_index*opt.gap_w 71 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w 72 | patch_start_w = 0 73 | patch_end_w = opt.img_w-cut_w 74 | elif x_index == num_w-1: 75 | stack_start_w = x_index*opt.gap_w+cut_w 76 | stack_end_w = x_index*opt.gap_w+opt.img_w 77 | patch_start_w = cut_w 78 | patch_end_w = opt.img_w 79 | else: 80 | stack_start_w = x_index*opt.gap_w+cut_w 81 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w 82 | patch_start_w = cut_w 83 | patch_end_w = opt.img_w-cut_w 84 | 85 | if y_index == 0: 86 | stack_start_h = y_index*opt.gap_h 87 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h 88 | patch_start_h = 0 89 | patch_end_h = opt.img_h-cut_h 90 | elif y_index == num_h-1: 91 | stack_start_h = y_index*opt.gap_h+cut_h 92 | stack_end_h = y_index*opt.gap_h+opt.img_h 93 | patch_start_h = cut_h 94 | patch_end_h = opt.img_h 95 | else: 96 | stack_start_h = y_index*opt.gap_h+cut_h 97 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h 98 | patch_start_h = cut_h 99 | patch_end_h = opt.img_h-cut_h 100 | 101 | if z_index == 0: 102 | stack_start_s = z_index*opt.gap_s 103 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s 104 | patch_start_s = 0 105 | patch_end_s = opt.img_s-cut_s 106 | elif z_index == num_s-1: 107 | stack_start_s = z_index*opt.gap_s+cut_s 108 | stack_end_s = z_index*opt.gap_s+opt.img_s 109 | patch_start_s = cut_s 110 | patch_end_s = opt.img_s 111 | else: 112 | stack_start_s = z_index*opt.gap_s+cut_s 113 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s 114 | patch_start_s = cut_s 115 | patch_end_s = opt.img_s-cut_s 116 | return int(stack_start_w) ,int(stack_end_w) ,int(patch_start_w) ,int(patch_end_w) ,\ 117 | int(stack_start_h) ,int(stack_end_h) ,int(patch_start_h) ,int(patch_end_h), \ 118 | int(stack_start_s) ,int(stack_end_s) ,int(patch_start_s) ,int(patch_end_s) 119 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/basic_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras.layers import UpSampling3D 3 | import numpy as np 4 | """This script defines basic operations. 5 | """ 6 | def Pool3d(inputs, name): 7 | layer = tf.layers.max_pooling3d(inputs=inputs, 8 | pool_size=2, 9 | strides=2, 10 | name=name, 11 | padding='same') 12 | print(name,'.get_shape -----> ',str(layer.get_shape())) 13 | return layer 14 | 15 | 16 | def Deconv3D(inputs, filters, name): 17 | layer = tf.layers.conv3d_transpose(inputs=inputs, 18 | filters=filters, 19 | kernel_size=2, 20 | strides=2, 21 | padding='same', 22 | use_bias=True, 23 | name=name, 24 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.05, seed=1)) 25 | print(name,'.get_shape -----> ',str(layer.get_shape())) 26 | return layer 27 | 28 | 29 | def Deconv3D_upsample(inputs, filters, name): 30 | print(name,'.inputs -----> ',str(inputs.get_shape())) 31 | kernel = tf.constant(1.0, shape=[2,2,2,filters,filters]) 32 | inputs_shape = inputs.get_shape().as_list() 33 | layer = tf.nn.conv3d_transpose(value=inputs, filter=kernel, 34 | output_shape=[inputs_shape[0], inputs_shape[1]*2, inputs_shape[2]*2, inputs_shape[3]*2, filters], 35 | strides=[1, 2, 2, 2, 1], 36 | padding="SAME") 37 | print(name,'.get_shape -----> ',str(layer.get_shape())) 38 | return layer 39 | 40 | 41 | def Conv3D(inputs, filters, name): 42 | inputs_shape = inputs.get_shape().as_list() 43 | fan_in =3*3*inputs_shape[-1]*filters 44 | std = np.sqrt(2) / np.sqrt(fan_in) 45 | layer = tf.layers.conv3d(inputs=inputs, 46 | filters=filters, 47 | kernel_size=3, 48 | strides=1, 49 | padding='same', 50 | use_bias=True, 51 | name=name, 52 | kernel_initializer=tf.truncated_normal_initializer(stddev=std)) 53 | print(name,'.get_shape -----> ',str(layer.get_shape()),' std -----> ',std) 54 | return layer 55 | 56 | def ReLU(inputs, name): 57 | layer = tf.nn.relu(inputs, name) 58 | print(name,'.get_shape -----> ',str(layer.get_shape())) 59 | return layer 60 | 61 | def leak_ReLU(inputs, name): 62 | layer = tf.nn.leaky_relu(inputs, alpha=0.1, name=name) #tf.nn.relu(inputs, name) 63 | print(name,'.get_shape -----> ',str(layer.get_shape())) 64 | return layer 65 | 66 | def Deconv3D_keras(inputs): 67 | layer = UpSampling3D(size=2)(inputs) 68 | print('.get_shape -----> ',str(layer.get_shape())) 69 | return layer 70 | 71 | def up_conv3d(input, conv_filter_size, num_input_channels, num_filters, feature_map_size, feature_map_len, train=True, padding='SAME',relu=True): 72 | # num_input_channels 73 | # num_filters 74 | # feature_map_size 75 | # feature_map_len 76 | weights = create_weights(shape=[conv_filter_size, conv_filter_size, conv_filter_size, num_filters, num_input_channels]) 77 | biases = create_biases(num_filters) 78 | if train: 79 | batch_size_0 = 1 #batch_size 80 | else: 81 | batch_size_0 = 1 82 | layer = tf.nn.conv3d_transpose(value=input, filter=weights, 83 | output_shape=[batch_size_0, feature_map_size, feature_map_size, feature_map_len, num_filters], 84 | strides=[1, 2, 2, 2, 1], 85 | padding=padding) 86 | layer += biases 87 | if relu: 88 | layer = tf.nn.relu(layer) 89 | return layer 90 | 91 | 92 | 93 | def BN_ReLU(inputs, training, name): 94 | """Performs a batch normalization followed by a ReLU6.""" 95 | inputs = tf.layers.batch_normalization(inputs=inputs, 96 | axis=-1, 97 | momentum=0.997, 98 | epsilon=1e-5, 99 | center=True, 100 | scale=True, 101 | training=training, 102 | fused=True) 103 | return tf.nn.relu(inputs) 104 | 105 | def GN_ReLU(inputs, name): 106 | """Performs a batch normalization followed by a ReLU6.""" 107 | inputs = group_norm(inputs, name=name) 108 | return tf.nn.relu(inputs) 109 | 110 | def GN_leakReLU(inputs, name): 111 | """Performs a batch normalization followed by a ReLU6.""" 112 | # inputs = group_norm(inputs, name=name) 113 | inputs = tf.nn.relu(inputs) #tf.nn.relu(inputs, name) 114 | layer = group_norm(inputs, name=name) 115 | return layer 116 | 117 | def group_norm(x, name, G=8, eps=1e-5, scope='group_norm') : 118 | with tf.variable_scope(scope+name, reuse=tf.AUTO_REUSE) : 119 | N, H, W, S, C = x.get_shape().as_list() 120 | G = min(G, C) 121 | # [N, H, W, G, C // G] 122 | x = tf.reshape(x, [N, H, W, S, G, C // G]) 123 | mean, var = tf.nn.moments(x, [1, 2, 3, 5], keep_dims=True) 124 | x = (x - mean) / tf.sqrt(var + eps) 125 | 126 | # print(' -----> GroupNorm ',x.get_shape()) 127 | gamma = tf.get_variable('gamma_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(1)) 128 | beta = tf.get_variable('beta_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(0)) 129 | x = tf.reshape(x, [N, H, W, S, C]) * gamma + beta 130 | return x -------------------------------------------------------------------------------- /DeepCAD_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import shutil 5 | import sys 6 | import io 7 | 8 | import h5py 9 | from PIL import Image 10 | import numpy as np 11 | import scipy.sparse as sparse 12 | import torch 13 | import matplotlib.pyplot as plt 14 | import uuid 15 | from sklearn.decomposition import PCA 16 | import warnings 17 | import pylab 18 | import cv2 19 | import yaml 20 | ######################################################################################################################## 21 | plt.ioff() 22 | plt.switch_backend('agg') 23 | 24 | ######################################################################################################################## 25 | def create_feature_maps(init_channel_number, number_of_fmaps): 26 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)] 27 | 28 | def save_yaml(opt, yaml_name): 29 | para = {'epoch':0, 30 | 'n_epochs':0, 31 | 'datasets_folder':0, 32 | 'GPU':0, 33 | 'output_dir':0, 34 | 'batch_size':0, 35 | 'img_s':0, 36 | 'img_w':0, 37 | 'img_h':0, 38 | 'gap_h':0, 39 | 'gap_w':0, 40 | 'gap_s':0, 41 | 'lr':0, 42 | 'b1':0, 43 | 'b2':0, 44 | 'normalize_factor':0} 45 | para["epoch"] = opt.epoch 46 | para["n_epochs"] = opt.n_epochs 47 | para["datasets_folder"] = opt.datasets_folder 48 | para["GPU"] = opt.GPU 49 | para["output_dir"] = opt.output_dir 50 | para["batch_size"] = opt.batch_size 51 | para["img_s"] = opt.img_s 52 | para["img_w"] = opt.img_w 53 | para["img_h"] = opt.img_h 54 | para["gap_h"] = opt.gap_h 55 | para["gap_w"] = opt.gap_w 56 | para["gap_s"] = opt.gap_s 57 | para["lr"] = opt.lr 58 | para["b1"] = opt.b1 59 | para["b2"] = opt.b2 60 | para["normalize_factor"] = opt.normalize_factor 61 | para["datasets_path"] = opt.datasets_path 62 | para["train_datasets_size"] = opt.train_datasets_size 63 | with open(yaml_name, 'w') as f: 64 | data = yaml.dump(para, f) 65 | 66 | 67 | def read_yaml(opt, yaml_name): 68 | with open(yaml_name) as f: 69 | para = yaml.load(f, Loader=yaml.FullLoader) 70 | print(para) 71 | opt.epoch = para["epoch"] 72 | opt.n_epochspara = ["n_epochs"] 73 | # opt.datasets_folder = para["datasets_folder"] 74 | opt.output_dir = para["output_dir"] 75 | opt.batch_size = para["batch_size"] 76 | # opt.img_s = para["img_s"] 77 | # opt.img_w = para["img_w"] 78 | # opt.img_h = para["img_h"] 79 | # opt.gap_h = para["gap_h"] 80 | # opt.gap_w = para["gap_w"] 81 | # opt.gap_s = para["gap_s"] 82 | opt.lr = para["lr"] 83 | opt.b1 = para["b1"] 84 | para["b2"] = opt.b2 85 | para["normalize_factor"] = opt.normalize_factor 86 | 87 | 88 | def name2index(opt, input_name, num_h, num_w, num_s): 89 | # print(input_name) 90 | name_list = input_name.split('_') 91 | # print(name_list) 92 | z_part = name_list[-1] 93 | # print(z_part) 94 | y_part = name_list[-2] 95 | # print(y_part) 96 | x_part = name_list[-3] 97 | # print(x_part) 98 | z_index = int(z_part.replace('z','')) 99 | y_index = int(y_part.replace('y','')) 100 | x_index = int(x_part.replace('x','')) 101 | # print("x_index ---> ",x_index,"y_index ---> ", y_index,"z_index ---> ", z_index) 102 | 103 | cut_w = (opt.img_w - opt.gap_w)/2 104 | cut_h = (opt.img_h - opt.gap_h)/2 105 | cut_s = (opt.img_s - opt.gap_s)/2 106 | # print("z_index ---> ",cut_w, "cut_h ---> ",cut_h, "cut_s ---> ",cut_s) 107 | if x_index == 0: 108 | stack_start_w = x_index*opt.gap_w 109 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w 110 | patch_start_w = 0 111 | patch_end_w = opt.img_w-cut_w 112 | elif x_index == num_w-1: 113 | stack_start_w = x_index*opt.gap_w+cut_w 114 | stack_end_w = x_index*opt.gap_w+opt.img_w 115 | patch_start_w = cut_w 116 | patch_end_w = opt.img_w 117 | else: 118 | stack_start_w = x_index*opt.gap_w+cut_w 119 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w 120 | patch_start_w = cut_w 121 | patch_end_w = opt.img_w-cut_w 122 | 123 | if y_index == 0: 124 | stack_start_h = y_index*opt.gap_h 125 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h 126 | patch_start_h = 0 127 | patch_end_h = opt.img_h-cut_h 128 | elif y_index == num_h-1: 129 | stack_start_h = y_index*opt.gap_h+cut_h 130 | stack_end_h = y_index*opt.gap_h+opt.img_h 131 | patch_start_h = cut_h 132 | patch_end_h = opt.img_h 133 | else: 134 | stack_start_h = y_index*opt.gap_h+cut_h 135 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h 136 | patch_start_h = cut_h 137 | patch_end_h = opt.img_h-cut_h 138 | 139 | if z_index == 0: 140 | stack_start_s = z_index*opt.gap_s 141 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s 142 | patch_start_s = 0 143 | patch_end_s = opt.img_s-cut_s 144 | elif z_index == num_s-1: 145 | stack_start_s = z_index*opt.gap_s+cut_s 146 | stack_end_s = z_index*opt.gap_s+opt.img_s 147 | patch_start_s = cut_s 148 | patch_end_s = opt.img_s 149 | else: 150 | stack_start_s = z_index*opt.gap_s+cut_s 151 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s 152 | patch_start_s = cut_s 153 | patch_end_s = opt.img_s-cut_s 154 | return int(stack_start_w) ,int(stack_end_w) ,int(patch_start_w) ,int(patch_end_w) ,\ 155 | int(stack_start_h) ,int(stack_end_h) ,int(patch_start_h) ,int(patch_end_h), \ 156 | int(stack_start_s) ,int(stack_end_s) ,int(patch_start_s) ,int(patch_end_s) 157 | 158 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # Data download guide 2 | 3 | The data used for training and validation of DeepCAD are made publicly available here. These data were captured by our customized two-photon microscope with two strictly synchronized detection path. The signal intensity of the high-SNR path is 10-fold higher than that of the low-SNR path. We provided 14 groups of recordings with various imaging depths, excitation power, and cell structures. All data are listed in the table below. You can download these data directly by clicking the `hyperlinks` appended in the 'Power' column (Warning: ~4.5 GB each). 4 | 5 | |No. |FOV (V×H)a |Frame rate | Imaging depthb |Powerc|AMPd|Structures | 6 | |:----:| ---- |:----: | :----: | :----: |:----: | :----: | 7 | |1 |550×575 μm |30 Hz |80 μm |[66](https://zenodo.org/record/8079069/files/1_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/1_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_0.3power_1AMP.zip?download=1) mW |1 |Dendrite | 8 | |2 |550×575 μm |30 Hz |100 μm |[66](https://zenodo.org/record/8079069/files/2_ZOOM1.3_550Vx575H_FOV_30Hz_100umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/2_ZOOM1.3_550Vx575H_FOV_30Hz_100umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma and dendrite| 9 | |3 |550×575 μm |30 Hz |120 μm |[66](https://zenodo.org/record/8079069/files/3_ZOOM1.3_550Vx575H_FOV_30Hz_120umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/3_ZOOM1.3_550Vx575H_FOV_30Hz_120umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma | 10 | |4 |550×575 μm |30 Hz |140 μm |[66](https://zenodo.org/record/8079069/files/4_ZOOM1.3_550Vx575H_FOV_30Hz_140umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/4_ZOOM1.3_550Vx575H_FOV_30Hz_140umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma | 11 | |5 |550×575 μm |30 Hz |160 μm |[66](https://zenodo.org/record/8079069/files/5_ZOOM1.3_550Vx575H_FOV_30Hz_160umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/5_ZOOM1.3_550Vx575H_FOV_30Hz_160umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma | 12 | |6 |550×575 μm |30 Hz |180 μm |[99](https://zenodo.org/record/8080615/files/6_ZOOM1.3_550Vx575H_FOV_30Hz_180umdepth_0.3power_1AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/6_ZOOM1.3_550Vx575H_FOV_30Hz_180umdepth_0.4power_1AMP.zip?download=1) mW |1 |Soma | 13 | |7 |550×575 μm |30 Hz |200 μm |[99](https://zenodo.org/record/8080615/files/7_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/7_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.4power_10AMP.zip?download=1) mW |10 |Soma | 14 | |8 |550×575 μm |30 Hz |200 μm |[99](https://zenodo.org/record/8080615/files/8_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/8_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.4power_10AMP.zip?download=1) mW |10 |Soma | 15 | |9 |550×575 μm |30 Hz |150 μm |[66](https://zenodo.org/record/8080615/files/9_ZOOM1.3_550Vx575H_FOV_30Hz_150umdepth_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8080615/files/9_ZOOM1.3_550Vx575H_FOV_30Hz_150umdepth_0.3power_10AMP.zip?download=1) mW |10 |Soma | 16 | |10 |550×575 μm |30 Hz |170 μm |[99](https://zenodo.org/record/8080615/files/10_ZOOM1.3_550Vx575H_FOV_30Hz_170umdepth_0.3power_10AMP.zip?download=1) mW |10 |Soma | 17 | |11 |550×575 μm |30 Hz |80 μm |[66](https://zenodo.org/record/8079117/files/11_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_dendrite_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8079117/files/11_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_dendrite_0.3power_10AMP.zip?download=1) mW |10 |Dendrite | 18 | |12 |550×575 μm |30 Hz |110 μm |[66](https://zenodo.org/record/8079117/files/12_ZOOM1.3_550Vx575H_FOV_30Hz_110umdepth_somadendrite_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8079117/files/12_ZOOM1.3_550Vx575H_FOV_30Hz_110umdepth_somadendrite_0.3power_10AMP.zip?download=1) mW |10 |Soma and dendrite| 19 | |13 |550×575 μm |30 Hz |185 μm |[99](https://zenodo.org/record/8079117/files/13_ZOOM1.3_550Vx575H_FOV_30Hz_185umdepth_soma_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8079117/files/13_ZOOM1.3_550Vx575H_FOV_30Hz_185umdepth_soma_0.4power_10AMP.zip?download=1) mW |10 |Soma | 20 | |14 |550×575 μm |30 Hz |210 μm |[99](https://zenodo.org/record/8079117/files/14_ZOOM1.3_550Vx575H_FOV_30Hz_210umdepth_soma_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8079117/files/14_ZOOM1.3_550Vx575H_FOV_30Hz_210umdepth_soma_0.4power_10AMP.zip?download=1) mW |10 |Soma | 21 | ``` 22 | a. FOV: field-of-view; V: vertical length; H: horizontal length. 23 | b. Depth: imaging depth below the pia mater. 24 | c. Two different excitation powers were used in each experiment for data diversity. 25 | d. AMP: the amplifier gain of the two PMTs. 26 | ``` 27 | 28 | ## Citing the data 29 | If you use this data, please cite our paper: 30 | 31 | Li, X., Zhang, G., Wu, J. et al. Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising. Nat Methods (2021). [https://doi.org/10.1038/s41592-021-01225-0](https://www.nature.com/articles/s41592-021-01225-0) 32 | 33 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def get_weight(shape, gain=np.sqrt(2)): 6 | fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] 7 | std = gain / np.sqrt(fan_in) # He init 8 | w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) 9 | return w 10 | 11 | 12 | def apply_bias(x): 13 | b = tf.get_variable('bias', shape=[x.shape[4]], initializer=tf.initializers.zeros()) 14 | b = tf.cast(b, x.dtype) 15 | if len(x.shape) == 2: 16 | return x + b 17 | return x + tf.reshape(b, [1, 1, 1, 1, -1]) 18 | 19 | def conv3d_bias(x, fmaps, kernel, gain=np.sqrt(2)): 20 | assert kernel >= 1 and kernel % 2 == 1 21 | w = get_weight([kernel, kernel, kernel, x.shape[4].value, fmaps], gain=gain) 22 | w = tf.cast(w, x.dtype) 23 | return apply_bias(tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NDHWC')) 24 | 25 | def conv3d(x, fmaps, kernel, gain=np.sqrt(2)): 26 | assert kernel >= 1 and kernel % 2 == 1 27 | w = get_weight([kernel, kernel, kernel, x.shape[4].value, fmaps], gain=gain) 28 | w = tf.cast(w, x.dtype) 29 | return tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NDHWC') 30 | 31 | def maxpool3d(x, k=2): 32 | ksize = [1, k, k, k, 1] 33 | return tf.nn.max_pool3d(x, ksize=ksize, strides=ksize, padding='SAME', data_format='NDHWC') 34 | 35 | def upscale3d(x, factor=2): 36 | assert isinstance(factor, int) and factor >= 1 37 | if factor == 1: return x 38 | with tf.variable_scope('Upscale3D'): 39 | s = x.shape 40 | x = tf.reshape(x, [-1, s[1], 1, s[2], 1, s[3], 1, s[4]]) 41 | x = tf.tile(x, [1, 1, factor, 1, factor, 1, factor, 1]) 42 | x = tf.reshape(x, [-1, s[1] * factor, s[2] * factor, s[3] * factor, s[4]]) 43 | return x 44 | 45 | def conv_lr_bias(name, x, fmaps): 46 | with tf.variable_scope(name): 47 | return tf.nn.leaky_relu(conv3d_bias(x, fmaps, 3), alpha=0.1) 48 | 49 | def conv_r_bias(name, x, fmaps): 50 | with tf.variable_scope(name): 51 | return tf.nn.relu(conv3d_bias(x, fmaps, 3)) 52 | 53 | def conv_r(name, x, fmaps): 54 | with tf.variable_scope(name): 55 | return tf.nn.relu(conv3d(x, fmaps, 3)) 56 | 57 | def final_conv(name, x, fmaps, gain): 58 | with tf.variable_scope(name): 59 | return apply_bias(conv3d(x, fmaps, 1, gain)) 60 | ''' 61 | def output_block_layer(inputs): 62 | w = tf.Variable(tf.truncated_normal([1, 1, 1, 64, 1], stddev=0.01), name='end_con3d') 63 | x = tf.nn.conv3d(input=inputs, 64 | filter=w, 65 | strides=[1,1,1,1,1], 66 | padding='SAME') 67 | b = tf.get_variable('bias', shape=[x.shape[4]], initializer=tf.initializers.zeros()) 68 | output = tf.Variable(tf.ones(shape=x.shape), name='output') 69 | output = x + b 70 | print('output -----> ',output.get_shape()) 71 | return output 72 | ''' 73 | def output_block_layer(inputs): 74 | w = tf.Variable(tf.truncated_normal([1, 1, 1, 64, 1], stddev=0.01), name='end_con3d') 75 | output = tf.nn.conv3d(input=inputs, 76 | filter=w, 77 | strides=[1,1,1,1,1], 78 | padding='SAME', 79 | name='output') 80 | return output 81 | 82 | def group_norm(x, name, G=8, eps=1e-5, scope='group_norm') : 83 | with tf.variable_scope(scope+name, reuse=tf.AUTO_REUSE) : 84 | N, H, W, S, C = x.get_shape().as_list() 85 | G = min(G, C) 86 | # [N, H, W, G, C // G] 87 | x = tf.reshape(x, [N, H, W, S, G, C // G]) 88 | mean, var = tf.nn.moments(x, [1, 2, 3, 5], keep_dims=True) 89 | x = (x - mean) / tf.sqrt(var + eps) 90 | 91 | gamma = tf.get_variable('gamma_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(1)) 92 | beta = tf.get_variable('beta_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(0)) 93 | x = tf.reshape(x, [N, H, W, S, C]) * gamma + beta 94 | return x 95 | #################################################################################################################### 96 | def autoencoder(x, width=256, height=256, length=256, **_kwargs): 97 | x.set_shape([1, height, width, length, 1]) 98 | 99 | skips = [x] 100 | 101 | n = x 102 | n = conv_r('enc_conv0', n, 32) 103 | n = group_norm(n, 'gn_enc_conv0') 104 | n = conv_r('enc_conv0b', n, 64) 105 | n = group_norm(n, 'gn_enc_conv0b') 106 | skips.append(n) 107 | 108 | n = maxpool3d(n) 109 | n = conv_r('enc_conv1', n, 64) 110 | n = group_norm(n, 'gn_enc_conv1') 111 | n = conv_r('enc_conv1b', n, 128) 112 | n = group_norm(n, 'gn_enc_conv1b') 113 | skips.append(n) 114 | 115 | n = maxpool3d(n) 116 | n = conv_r('enc_conv2', n, 128) 117 | n = group_norm(n, 'gn_enc_conv2') 118 | n = conv_r('enc_conv2b', n, 256) 119 | n = group_norm(n, 'gn_enc_conv2b') 120 | skips.append(n) 121 | 122 | n = maxpool3d(n) 123 | n = conv_r('enc_conv3', n, 256) 124 | n = group_norm(n, 'gn_enc_conv3') 125 | n = conv_r('enc_conv3b', n, 512) 126 | n = group_norm(n, 'gn_enc_conv3b') 127 | #----------------------------------------------- 128 | n = upscale3d(n) 129 | # print('upscale1 -----> ',str(n.get_shape())) 130 | n = tf.concat([n, skips.pop()], axis=-1) 131 | # print('upscale1 -----> ',str(n.get_shape())) 132 | n = conv_r('dec_conv4', n, 256) 133 | n = group_norm(n, 'gn_dec_conv4') 134 | n = conv_r('dec_conv4b', n, 256) 135 | n = group_norm(n, 'gn_dec_conv4b') 136 | 137 | n = upscale3d(n) 138 | # print('upscale2 -----> ',str(n.get_shape())) 139 | n = tf.concat([n, skips.pop()], axis=-1) 140 | # print('upscale2 -----> ',str(n.get_shape())) 141 | n = conv_r('dec_conv3', n, 128) 142 | n = group_norm(n, 'gn_dec_conv3') 143 | n = conv_r('dec_conv3b', n, 128) 144 | n = group_norm(n, 'gn_dec_conv3b') 145 | 146 | n = upscale3d(n) 147 | # print('upscale3 -----> ',str(n.get_shape())) 148 | n = tf.concat([n, skips.pop()], axis=-1) 149 | # print('upscale3 -----> ',str(n.get_shape())) 150 | n = conv_r('dec_conv2', n, 64) 151 | n = group_norm(n, 'gn_dec_conv2') 152 | n = conv_r('dec_conv2b', n, 64) 153 | n = group_norm(n, 'gn_dec_conv2b') 154 | 155 | #output = final_conv('final_conv', n, 1, gain=1.0) 156 | output = output_block_layer(n) 157 | return output 158 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/test_pb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.models import load_model 3 | import argparse 4 | import os 5 | import tifffile as tiff 6 | import time 7 | import datetime 8 | import random 9 | from skimage import io 10 | import tensorflow as tf 11 | import logging 12 | import time 13 | from data_process import test_preprocess_lessMemory 14 | from utils import read_yaml, name2index 15 | import math 16 | 17 | 18 | def main(args): 19 | # train_3d_new(args) 20 | test(args) 21 | 22 | def test(args): 23 | tf.reset_default_graph() 24 | model_path = args.CBSDeep_model_folder+'//'+args.denoise_model 25 | # print(list(os.walk(model_path, topdown=False))[-1]) 26 | # print(list(os.walk(model_path, topdown=False))[-1][-1][0]) 27 | # print(list(os.walk(model_path, topdown=False))[-1][-2]) 28 | model_list = list(os.walk(model_path, topdown=False))[-1][-2] 29 | yaml_name = list(os.walk(model_path, topdown=False))[-1][-1][0] 30 | print(yaml_name) 31 | read_yaml(args, model_path+'//'+yaml_name) 32 | print('hhhhh ----->',args) 33 | 34 | name_list, noise_img, coordinate_list = test_preprocess_lessMemory(args) 35 | num_h = (math.floor((noise_img.shape[1]-args.img_h)/args.gap_h)+1) 36 | num_w = (math.floor((noise_img.shape[2]-args.img_w)/args.gap_w)+1) 37 | num_s = (math.floor((noise_img.shape[0]-args.img_s)/args.gap_s)+1) 38 | 39 | TIME = args.datasets_folder+'_'+args.denoise_model #+'_'+args.model_name #+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M") 40 | results_path = args.results_folder+'//'+'unet3d_'+TIME+'//' 41 | if not os.path.exists(args.results_folder): 42 | os.mkdir(args.results_folder) 43 | if not os.path.exists(results_path): 44 | os.mkdir(results_path) 45 | 46 | model_name = args.model_name 47 | print('model_name -----> ',model_name) 48 | output_path = results_path + '//' + model_name 49 | if not os.path.exists(output_path): 50 | os.mkdir(output_path) 51 | 52 | output_graph_path = args.CBSDeep_model_folder+'//'+args.denoise_model+'//'+model_name+'//' 53 | print('output_graph_path -----> ',output_graph_path) 54 | start_time=time.time() 55 | sess = tf.Session() 56 | with sess.as_default(): 57 | # sess.run(tf.global_variables_initializer()) 58 | meta_graph_def = tf.saved_model.loader.load(sess, ['3D_N2N'], output_graph_path) 59 | signature = meta_graph_def.signature_def 60 | in_tensor_name = signature['my_signature'].inputs['input0'].name 61 | out_tensor_name = signature['my_signature'].outputs['output0'].name 62 | input = sess.graph.get_tensor_by_name(in_tensor_name) 63 | output = sess.graph.get_tensor_by_name(out_tensor_name) 64 | # sess.run(tf.global_variables_initializer()) 65 | ''' 66 | variable_names = [v.name for v in tf.trainable_variables()] 67 | values = sess.run(variable_names) 68 | for k,v in zip(variable_names, values): 69 | if len(v.shape)==5: 70 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0][0][0][0][0]) 71 | if len(v.shape)==1: 72 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0]) 73 | ''' 74 | denoise_img = np.zeros(noise_img.shape) 75 | input_img = np.zeros(noise_img.shape) 76 | for index in range(len(name_list)): 77 | input_name = name_list[index] 78 | single_coordinate = coordinate_list[name_list[index]] 79 | init_h = single_coordinate['init_h'] 80 | end_h = single_coordinate['end_h'] 81 | init_w = single_coordinate['init_w'] 82 | end_w = single_coordinate['end_w'] 83 | init_s = single_coordinate['init_s'] 84 | end_s = single_coordinate['end_s'] 85 | noise_patch1 = noise_img[init_s:end_s,init_h:end_h,init_w:end_w] 86 | train_input = np.expand_dims(np.expand_dims(noise_patch1.transpose(1,2,0), 3),0) 87 | # print('train_input -----> ',train_input.shape) 88 | data_name = name_list[index] 89 | train_output = sess.run(output, feed_dict={input: train_input}) 90 | 91 | train_input = np.squeeze(train_input).transpose(2,0,1) 92 | train_output = np.squeeze(train_output).transpose(2,0,1) 93 | stack_start_w ,stack_end_w ,patch_start_w ,patch_end_w ,\ 94 | stack_start_h ,stack_end_h ,patch_start_h ,patch_end_h ,\ 95 | stack_start_s ,stack_end_s ,patch_start_s ,patch_end_s = name2index(args, input_name, num_h, num_w, num_s) 96 | 97 | aaaa = train_output[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w] 98 | bbbb = train_input[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w] 99 | 100 | denoise_img[stack_start_s:stack_end_s, stack_start_w:stack_end_w, stack_start_h:stack_end_h] \ 101 | = train_output[patch_start_s:patch_end_s, patch_start_w:patch_end_w, patch_start_h:patch_end_h]*(np.sum(bbbb)/np.sum(aaaa))**0.5 102 | input_img[stack_start_s:stack_end_s, stack_start_w:stack_end_w, stack_start_h:stack_end_h] \ 103 | = train_input[patch_start_s:patch_end_s, patch_start_w:patch_end_w, patch_start_h:patch_end_h] 104 | # print('output_img shape -----> ',output_img.shape) 105 | 106 | ''' 107 | output_img = denoise_img.squeeze().astype(np.float32)*args.normalize_factor 108 | output_img = np.clip(output_img, 0, 65535).astype('uint16') 109 | result_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_output.tif' 110 | io.imsave(result_name, output_img.transpose(2,0,1)) 111 | ''' 112 | 113 | # if index % 100 == 0: 114 | # print('denoise_img ---> ',denoise_img.max(),'---> ',denoise_img.min()) 115 | # print('input_img ---> ',input_img.max(),'---> ',input_img.min()) 116 | output_img = denoise_img.squeeze().astype(np.float32)*args.normalize_factor 117 | output_img = output_img-output_img.min() 118 | output_img = output_img/output_img.max()*65535 119 | output_img = np.clip(output_img, 0, 65535).astype('uint16') 120 | output_img = output_img-output_img.min() 121 | input_img = input_img.squeeze().astype(np.float32)*args.normalize_factor 122 | input_img = np.clip(input_img, 0, 65535).astype('uint16') 123 | result_name = output_path + '//' + 'output_'+model_name+'.tif' 124 | input_name = output_path + '//' + 'input_'+model_name+'.tif' 125 | io.imsave(result_name, output_img) 126 | io.imsave(input_name, input_img) 127 | 128 | 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--img_h', type=int, default=64, help='the height of patch stack') 134 | parser.add_argument('--img_w', type=int, default=64, help='the width of patch stack') 135 | parser.add_argument('--img_s', type=int, default=464, help='the image number of patch stack') 136 | parser.add_argument('--img_c', type=int, default=1, help='the channel of image') 137 | parser.add_argument('--gap_h', type=int, default=56, help='actions: train or predict') 138 | parser.add_argument('--gap_w', type=int, default=56, help='actions: train or predict') 139 | parser.add_argument('--gap_s', type=int, default=144, help='actions: train or predict') 140 | parser.add_argument('--normalize_factor', type=int, default=1, help='actions: train or predict') 141 | parser.add_argument('--datasets_folder', type=str, default='test2', help='actions: train or predict') 142 | parser.add_argument('--model_name', type=str, default='test2', help='actions: train or predict') 143 | parser.add_argument('--model_folder', type=str, default='log', help='actions: train or predict') 144 | parser.add_argument('--model_epoch', type=int, default=0, help='actions: train or predict') 145 | parser.add_argument('--CBSDeep_model_folder', type=str, default='DeepCAD_model', help='actions: train or predict') 146 | parser.add_argument('--results_folder', type=str, default='results', help='actions: train or predict') 147 | parser.add_argument('--GPU', type=int, default=3, help='the index of GPU you will use for computation') 148 | 149 | parser.add_argument('--denoise_model', type=str, default='unet3d_test2_20200924-1707', help='actions: train or predict') 150 | parser.add_argument('--test_datasize', type=int, default=512, help='epoch for denoising') 151 | args = parser.parse_args() 152 | print('hhhhh ----->',args) 153 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU) 154 | main(args) 155 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.models import load_model 3 | import argparse 4 | import os 5 | import tifffile as tiff 6 | import time 7 | import datetime 8 | import random 9 | from skimage import io 10 | from network import autoencoder 11 | import tensorflow as tf 12 | import logging 13 | import time 14 | from utils import save_yaml 15 | from data_process import train_preprocess_lessMemory, shuffle_datasets_lessMemory,train_preprocess_lessMemoryMulStacks 16 | 17 | 18 | def main(args): 19 | train(args) 20 | 21 | def train(args): 22 | TIME = args.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M") 23 | DeepCAD_model_path = args.DeepCAD_model_folder+'//'+'pb_unet3d_'+TIME+'//' 24 | if not os.path.exists(args.DeepCAD_model_folder): 25 | os.mkdir(args.DeepCAD_model_folder) 26 | if not os.path.exists(DeepCAD_model_path): 27 | os.mkdir(DeepCAD_model_path) 28 | yaml_name = DeepCAD_model_path+'//para.yaml' 29 | save_yaml(args, yaml_name) 30 | results_path = args.results_folder+'//'+'unet3d_'+TIME+'//' 31 | if not os.path.exists(args.results_folder): 32 | os.mkdir(args.results_folder) 33 | if not os.path.exists(results_path): 34 | os.mkdir(results_path) 35 | 36 | name_list, noise_img, coordinate_list = train_preprocess_lessMemoryMulStacks(args) 37 | data_size = len(name_list) 38 | 39 | sess = tf.Session() 40 | input_shape = [1, args.img_h, args.img_w, args.img_s, args.img_c] 41 | input = tf.placeholder(tf.float32, shape=input_shape, name='input') 42 | # output = tf.placeholder(tf.float32, shape=input_shape, name='output') 43 | output_GT = tf.placeholder(tf.float32, shape=input_shape, name='output_GT') 44 | # net = Network(training = args.is_training) 45 | output = autoencoder(input, height=args.img_h, width=args.img_w, length=args.img_s) 46 | 47 | L2_loss = tf.reduce_mean(tf.square(output - output_GT)) 48 | L1_loss = tf.reduce_sum(tf.losses.absolute_difference(output, output_GT)) 49 | loss = tf.add(L1_loss, L2_loss) 50 | optimizer = tf.train.AdamOptimizer(learning_rate=args.lr) 51 | train_step = optimizer.minimize(loss) 52 | start_time=time.time() 53 | with sess.as_default(): 54 | sess.run(tf.global_variables_initializer()) 55 | for i in range(args.train_epochs): 56 | name_list = shuffle_datasets_lessMemory(name_list) 57 | for index in range(data_size): 58 | single_coordinate = coordinate_list[name_list[index]] 59 | init_h = single_coordinate['init_h'] 60 | end_h = single_coordinate['end_h'] 61 | init_w = single_coordinate['init_w'] 62 | end_w = single_coordinate['end_w'] 63 | init_s = single_coordinate['init_s'] 64 | end_s = single_coordinate['end_s'] 65 | noise_patch1 = noise_img[init_s:end_s:2,init_h:end_h,init_w:end_w] 66 | noise_patch2 = noise_img[init_s+1:end_s:2,init_h:end_h,init_w:end_w] 67 | train_input = np.expand_dims(np.expand_dims(noise_patch1.transpose(1,2,0), 3),0) 68 | train_GT = np.expand_dims(np.expand_dims(noise_patch2.transpose(1,2,0), 3),0) 69 | # print(train_input.shape) 70 | data_name = name_list[index] 71 | sess.run(train_step, feed_dict={input: train_input, output_GT: train_GT}) 72 | 73 | if index % 100 == 0: 74 | output_img, L1_loss_va, L2_loss_va = sess.run([output, L1_loss, L2_loss], feed_dict={input: train_input, output_GT: train_GT}) 75 | print('--- Epoch ',i,' --- Step ',index,'/',data_size,' --- L1_loss ', L1_loss_va,' --- L2_loss ', L2_loss_va,' --- Time ',(time.time()-start_time)) 76 | print('train_input ---> ',train_input.max(),'---> ',train_input.min()) 77 | print('output_img ---> ',output_img.max(),'---> ',output_img.min()) 78 | train_input = train_input.squeeze().astype(np.float32)*args.normalize_factor 79 | train_GT = train_GT.squeeze().astype(np.float32)*args.normalize_factor 80 | output_img = output_img.squeeze().astype(np.float32)*args.normalize_factor 81 | train_input = np.clip(train_input, 0, 65535).astype('uint16') 82 | train_GT = np.clip(train_GT, 0, 65535).astype('uint16') 83 | output_img = np.clip(output_img, 0, 65535).astype('uint16') 84 | result_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_output.tif' 85 | noise_img1_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_noise1.tif' 86 | noise_img2_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_noise2.tif' 87 | io.imsave(result_name, output_img.transpose(2,0,1)) 88 | io.imsave(noise_img1_name, train_input.transpose(2,0,1)) 89 | io.imsave(noise_img2_name, train_GT.transpose(2,0,1)) 90 | ''' 91 | variable_names = [v.name for v in tf.trainable_variables()] 92 | values = sess.run(variable_names) 93 | for k,v in zip(variable_names, values): 94 | if len(v.shape)==5: 95 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0][0][0][0][0]) 96 | if len(v.shape)==1: 97 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0]) 98 | ''' 99 | ''' 100 | aaaaa=0 101 | for op in tf.get_default_graph().get_operations(): 102 | aaaaa=aaaaa+1 103 | if aaaaa<50: 104 | # print('-----> ',op.name) 105 | print('-----> ',op.values()) 106 | ''' 107 | 108 | if index % 1000 == 0: 109 | DeepCAD_model_name=DeepCAD_model_path+'//'+str(i)+'_'+str(index)+'//' 110 | builder = tf.saved_model.builder.SavedModelBuilder(DeepCAD_model_name) 111 | input0 = {'input0': tf.saved_model.utils.build_tensor_info(input)} 112 | output0 = {'output0': tf.saved_model.utils.build_tensor_info(output)} 113 | method_name = tf.saved_model.signature_constants.PREDICT_METHOD_NAME 114 | my_signature = tf.saved_model.signature_def_utils.build_signature_def(input0, output0, method_name) 115 | builder.add_meta_graph_and_variables(sess, ["3D_N2N"], signature_def_map={'my_signature': my_signature}) 116 | builder.add_meta_graph(["3D_N2N"], signature_def_map={'my_signature': my_signature}) 117 | builder.save() 118 | 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--img_h', type=int, default=64, help='the height of patch stack') 125 | parser.add_argument('--img_w', type=int, default=64, help='the width of patch stack') 126 | parser.add_argument('--img_s', type=int, default=320, help='the image number of patch stack') 127 | parser.add_argument('--img_c', type=int, default=1, help='the channel of image') 128 | parser.add_argument('--gap_h', type=int, default=56, help='the height of patch gap') 129 | parser.add_argument('--gap_w', type=int, default=56, help='the width of patch gap') 130 | parser.add_argument('--gap_s', type=int, default=128, help='the image number of patch gap') 131 | parser.add_argument('--normalize_factor', type=int, default=1, help='Image normalization factor') 132 | parser.add_argument('--train_epochs', type=int, default=30, help='train epochs') 133 | parser.add_argument('--datasets_path', type=str, default='datasets', help="the name of your project") 134 | parser.add_argument('--datasets_folder', type=str, default='3', help='the folders for datasets') 135 | parser.add_argument('--DeepCAD_model_folder', type=str, default='DeepCAD_model', help='the folders for DeepCAD(pb) model') 136 | parser.add_argument('--results_folder', type=str, default='results', help='the folders for results') 137 | parser.add_argument('--GPU', type=int, default=3, help='the index of GPU you will use for computation') 138 | parser.add_argument('--is_training', type=bool, default=True, help='train or test') 139 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate') 140 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='actions: train or predict') 141 | parser.add_argument('--select_img_num', type=int, default=6000, help='actions: train or predict') 142 | args = parser.parse_args() 143 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU) 144 | main(args) 145 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | import time 8 | import datetime 9 | import sys 10 | import math 11 | import scipy.io as scio 12 | from network import Network_3D_Unet 13 | from tensorboardX import SummaryWriter 14 | import numpy as np 15 | from data_process import shuffle_datasets, train_preprocess_lessMemoryMulStacks, shuffle_datasets_lessMemory 16 | from utils import save_yaml 17 | from skimage import io 18 | ############################################################################################################################################# 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") 21 | parser.add_argument("--n_epochs", type=int, default=100, help="number of training epochs") 22 | parser.add_argument('--cuda', action='store_true', help='use GPU computation') 23 | parser.add_argument('--GPU', type=int, default=0, help="the index of GPU you will use for computation") 24 | 25 | parser.add_argument('--batch_size', type=int, default=1, help="batch size") 26 | parser.add_argument('--img_s', type=int, default=150, help="the slices of image sequence") 27 | parser.add_argument('--img_w', type=int, default=150, help="the width of image sequence") 28 | parser.add_argument('--img_h', type=int, default=150, help="the height of image sequence") 29 | parser.add_argument('--gap_s', type=int, default=60, help='the slices of image gap') 30 | parser.add_argument('--gap_w', type=int, default=90, help='the width of image gap') 31 | parser.add_argument('--gap_h', type=int, default=90, help='the height of image gap') 32 | 33 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate') 34 | parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1") 35 | parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2") 36 | parser.add_argument('--normalize_factor', type=int, default=65535, help='normalize factor') 37 | 38 | parser.add_argument('--output_dir', type=str, default='./results', help="output directory") 39 | parser.add_argument('--datasets_folder', type=str, default='DataForPytorch', help="A folder containing files for training") 40 | parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path") 41 | parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path") 42 | parser.add_argument('--select_img_num', type=int, default=6000, help='select the number of images') 43 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='datasets size for training') 44 | opt = parser.parse_args() 45 | 46 | print('the parameter of your training ----->') 47 | print(opt) 48 | ######################################################################################################################## 49 | if not os.path.exists(opt.output_dir): 50 | os.mkdir(opt.output_dir) 51 | current_time = opt.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M") 52 | output_path = opt.output_dir + '/' + current_time 53 | pth_path = 'pth//'+ current_time 54 | if not os.path.exists(output_path): 55 | os.mkdir(output_path) 56 | if not os.path.exists(pth_path): 57 | os.mkdir(pth_path) 58 | 59 | yaml_name = pth_path+'//para.yaml' 60 | save_yaml(opt, yaml_name) 61 | 62 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU) 63 | batch_size = opt.batch_size 64 | lr = opt.lr 65 | 66 | name_list, noise_img, coordinate_list = train_preprocess_lessMemoryMulStacks(opt) 67 | # print('name_list -----> ',name_list) 68 | ######################################################################################################################## 69 | L1_pixelwise = torch.nn.L1Loss() 70 | L2_pixelwise = torch.nn.MSELoss() 71 | ######################################################################################################################## 72 | denoise_generator = Network_3D_Unet(in_channels = 1, 73 | out_channels = 1, 74 | final_sigmoid = True) 75 | if torch.cuda.is_available(): 76 | print('Using GPU.') 77 | denoise_generator.cuda() 78 | L2_pixelwise.cuda() 79 | L1_pixelwise.cuda() 80 | ######################################################################################################################## 81 | optimizer_G = torch.optim.Adam( denoise_generator.parameters(), 82 | lr=opt.lr, betas=(opt.b1, opt.b2)) 83 | ######################################################################################################################## 84 | 85 | cuda = True if torch.cuda.is_available() else False 86 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 87 | prev_time = time.time() 88 | ######################################################################################################################## 89 | time_start=time.time() 90 | for epoch in range(opt.epoch, opt.n_epochs): 91 | name_list = shuffle_datasets_lessMemory(name_list) 92 | # print('name list -----> ',name_list) 93 | #################################################################################################################### 94 | for index in range(len(name_list)): 95 | single_coordinate = coordinate_list[name_list[index]] 96 | init_h = single_coordinate['init_h'] 97 | end_h = single_coordinate['end_h'] 98 | init_w = single_coordinate['init_w'] 99 | end_w = single_coordinate['end_w'] 100 | init_s = single_coordinate['init_s'] 101 | end_s = single_coordinate['end_s'] 102 | noise_patch1 = noise_img[init_s:end_s:2,init_h:end_h,init_w:end_w] 103 | noise_patch2 = noise_img[init_s+1:end_s:2,init_h:end_h,init_w:end_w] 104 | real_A = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch1, 3),0)).cuda() 105 | real_A = real_A.permute([0,4,1,2,3]) 106 | real_B = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch2, 3),0)).cuda() 107 | real_B = real_B.permute([0,4,1,2,3]) 108 | # print('real_A shape -----> ',real_A.shape) 109 | # print('real_B shape -----> ',real_B.shape) 110 | input_name = name_list[index] 111 | real_A = Variable(real_A) 112 | fake_B = denoise_generator(real_A) 113 | # Pixel-wise loss 114 | L1_loss = L1_pixelwise(fake_B, real_B) 115 | L2_loss = L2_pixelwise(fake_B, real_B) 116 | ################################################################################################################ 117 | optimizer_G.zero_grad() 118 | # Total loss 119 | Total_loss = 0.5*L1_loss + 0.5*L2_loss 120 | Total_loss.backward() 121 | optimizer_G.step() 122 | ################################################################################################################ 123 | batches_done = epoch * len(name_list) + index 124 | batches_left = opt.n_epochs * len(name_list) - batches_done 125 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 126 | prev_time = time.time() 127 | ################################################################################################################ 128 | if index%50 == 0: 129 | time_end=time.time() 130 | print('time cost',time_end-time_start,'s \n') 131 | sys.stdout.write( 132 | "\r[Epoch %d/%d] [Batch %d/%d] [Total loss: %f, L1 Loss: %f, L2 Loss: %f] ETA: %s" 133 | % ( 134 | epoch, 135 | opt.n_epochs, 136 | index, 137 | len(name_list), 138 | Total_loss.item(), 139 | L1_loss.item(), 140 | L2_loss.item(), 141 | time_left, 142 | ) 143 | ) 144 | ################################################################################################################ 145 | # if (epoch+1)%1 == 0: 146 | # torch.save(denoise_generator.state_dict(), pth_path + '//G_' + str(epoch) + '.pth') 147 | if (index+1)%300 == 0: 148 | torch.save(denoise_generator.state_dict(), pth_path + '//G_' + str(epoch) +'_'+ str(index) + '.pth') 149 | if (epoch+1)%1 == 0: 150 | output_img = fake_B.cpu().detach().numpy() 151 | train_GT = real_B.cpu().detach().numpy() 152 | train_input = real_A.cpu().detach().numpy() 153 | image_name = input_name 154 | 155 | train_input = train_input.squeeze().astype(np.float32)*opt.normalize_factor 156 | train_GT = train_GT.squeeze().astype(np.float32)*opt.normalize_factor 157 | output_img = output_img.squeeze().astype(np.float32)*opt.normalize_factor 158 | train_input = np.clip(train_input, 0, 65535).astype('uint16') 159 | train_GT = np.clip(train_GT, 0, 65535).astype('uint16') 160 | output_img = np.clip(output_img, 0, 65535).astype('uint16') 161 | result_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_output.tif' 162 | noise_img1_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_noise1.tif' 163 | noise_img2_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_noise2.tif' 164 | io.imsave(result_name, output_img) 165 | io.imsave(noise_img1_name, train_input) 166 | io.imsave(noise_img2_name, train_GT) 167 | 168 | 169 | torch.save(denoise_generator.state_dict(), pth_path +'//G_' + str(opt.n_epochs) + '.pth') 170 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | import time 8 | import datetime 9 | import sys 10 | import math 11 | import scipy.io as scio 12 | from network import Network_3D_Unet 13 | from tensorboardX import SummaryWriter 14 | import numpy as np 15 | from utils import save_yaml, read_yaml, name2index 16 | from data_process import shuffle_datasets, train_preprocess, test_preprocess, test_preprocess_lessMemory,test_preprocess_lessMemoryNoTail 17 | from skimage import io 18 | ############################################################################################################################################# 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") 21 | parser.add_argument("--n_epochs", type=int, default=100, help="number of training epochs") 22 | parser.add_argument('--cuda', action='store_true', help='use GPU computation') 23 | parser.add_argument('--GPU', type=int, default=0, help="the index of GPU you will use for computation") 24 | 25 | parser.add_argument('--batch_size', type=int, default=1, help="batch size") 26 | parser.add_argument('--img_s', type=int, default=150, help="the slices of image sequence") 27 | parser.add_argument('--img_w', type=int, default=150, help="the width of image sequence") 28 | parser.add_argument('--img_h', type=int, default=150, help="the height of image sequence") 29 | parser.add_argument('--gap_s', type=int, default=60, help='the slices of image gap') 30 | parser.add_argument('--gap_w', type=int, default=90, help='the width of image gap') 31 | parser.add_argument('--gap_h', type=int, default=90, help='the height of image gap') 32 | 33 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate') 34 | parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1") 35 | parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2") 36 | parser.add_argument('--normalize_factor', type=int, default=65535, help='normalize factor') 37 | 38 | parser.add_argument('--output_dir', type=str, default='./results', help="output directory") 39 | parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path") 40 | parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path") 41 | parser.add_argument('--datasets_folder', type=str, default='DataForPytorch', help="A folder containing files to be tested") 42 | parser.add_argument('--denoise_model', type=str, default='ModelForPytorch', help='A folder containing models to be tested') 43 | parser.add_argument('--test_datasize', type=int, default=6000, help='dataset size to be tested') 44 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='datasets size for training') 45 | 46 | opt = parser.parse_args() 47 | print('the parameter of your training ----->') 48 | print(opt) 49 | ######################################################################################################################## 50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU) 51 | model_path = opt.pth_path+'//'+opt.denoise_model 52 | # print(model_path) 53 | model_list = list(os.walk(model_path, topdown=False))[-1][-1] 54 | # print(model_list) 55 | 56 | for i in range(len(model_list)): 57 | aaa = model_list[i] 58 | if '.yaml' in aaa: 59 | yaml_name = model_list[i] 60 | print(yaml_name) 61 | read_yaml(opt, model_path+'//'+yaml_name) 62 | # print(opt.datasets_folder) 63 | 64 | name_list, noise_img, coordinate_list= test_preprocess_lessMemoryNoTail(opt) 65 | # name_list, noise_img, coordinate_list = test_preprocess_lessMemory(opt) 66 | # trainX = np.expand_dims(np.array(train_raw),4) 67 | num_h = (math.floor((noise_img.shape[1]-opt.img_h)/opt.gap_h)+1) 68 | num_w = (math.floor((noise_img.shape[2]-opt.img_w)/opt.gap_w)+1) 69 | num_s = (math.floor((noise_img.shape[0]-opt.img_s)/opt.gap_s)+1) 70 | # print(num_h, num_w, num_s) 71 | # print(coordinate_list) 72 | 73 | if not os.path.exists(opt.output_dir): 74 | os.mkdir(opt.output_dir) 75 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M") 76 | output_path1 = opt.output_dir + '//' + opt.datasets_folder + '_' + current_time + '_' + opt.denoise_model 77 | if not os.path.exists(output_path1): 78 | os.mkdir(output_path1) 79 | 80 | yaml_name = output_path1+'//para.yaml' 81 | save_yaml(opt, yaml_name) 82 | denoise_generator = Network_3D_Unet(in_channels = 1, 83 | out_channels = 1, 84 | final_sigmoid = True) 85 | if torch.cuda.is_available(): 86 | print('Using GPU.') 87 | for pth_index in range(len(model_list)): 88 | aaa = model_list[pth_index] 89 | if '.pth' in aaa: 90 | pth_name = model_list[pth_index] 91 | output_path = output_path1 + '//' + pth_name.replace('.pth','') 92 | if not os.path.exists(output_path): 93 | os.mkdir(output_path) 94 | denoise_generator.load_state_dict(torch.load(opt.pth_path+'//'+opt.denoise_model+'//'+pth_name)) 95 | 96 | denoise_generator.cuda() 97 | prev_time = time.time() 98 | time_start=time.time() 99 | denoise_img = np.zeros(noise_img.shape) 100 | input_img = np.zeros(noise_img.shape) 101 | for index in range(len(name_list)): 102 | single_coordinate = coordinate_list[name_list[index]] 103 | init_h = single_coordinate['init_h'] 104 | end_h = single_coordinate['end_h'] 105 | init_w = single_coordinate['init_w'] 106 | end_w = single_coordinate['end_w'] 107 | init_s = single_coordinate['init_s'] 108 | end_s = single_coordinate['end_s'] 109 | noise_patch = noise_img[init_s:end_s,init_h:end_h,init_w:end_w] 110 | # print(noise_patch.shape) 111 | real_A = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch, 3),0)).cuda() 112 | # print('real_A -----> ',real_A.shape) 113 | real_A = real_A.permute([0,4,1,2,3]) 114 | input_name = name_list[index] 115 | print(' input_name -----> ',input_name) 116 | print(' single_coordinate -----> ',single_coordinate) 117 | print('real_A -----> ',real_A.shape) 118 | real_A = Variable(real_A) 119 | fake_B = denoise_generator(real_A) 120 | ################################################################################################################ 121 | # Determine approximate time left 122 | batches_done = index 123 | batches_left = 1 * len(name_list) - batches_done 124 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 125 | prev_time = time.time() 126 | prev_time = time.time() 127 | ################################################################################################################ 128 | if index%1 == 0: 129 | time_end=time.time() 130 | time_cost=datetime.timedelta(seconds= (time_end - time_start)) 131 | sys.stdout.write("\r [Batch %d/%d] [Time Left: %s] [Time Cost: %s]" 132 | % (index, 133 | len(name_list), 134 | time_left, 135 | time_cost,)) 136 | ################################################################################################################ 137 | output_image = np.squeeze(fake_B.cpu().detach().numpy()) 138 | raw_image = np.squeeze(real_A.cpu().detach().numpy()) 139 | stack_start_w = int(single_coordinate['stack_start_w']) 140 | stack_end_w = int(single_coordinate['stack_end_w']) 141 | patch_start_w = int(single_coordinate['patch_start_w']) 142 | patch_end_w = int(single_coordinate['patch_end_w']) 143 | 144 | stack_start_h = int(single_coordinate['stack_start_h']) 145 | stack_end_h = int(single_coordinate['stack_end_h']) 146 | patch_start_h = int(single_coordinate['patch_start_h']) 147 | patch_end_h = int(single_coordinate['patch_end_h']) 148 | 149 | stack_start_s = int(single_coordinate['stack_start_s']) 150 | stack_end_s = int(single_coordinate['stack_end_s']) 151 | patch_start_s = int(single_coordinate['patch_start_s']) 152 | patch_end_s = int(single_coordinate['patch_end_s']) 153 | 154 | aaaa = output_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w] 155 | bbbb = raw_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w] 156 | 157 | denoise_img[stack_start_s:stack_end_s, stack_start_h:stack_end_h, stack_start_w:stack_end_w] \ 158 | = output_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]*(np.sum(bbbb)/np.sum(aaaa))**0.5 159 | input_img[stack_start_s:stack_end_s, stack_start_h:stack_end_h, stack_start_w:stack_end_w] \ 160 | = raw_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w] 161 | 162 | # del noise_img 163 | output_img = denoise_img.squeeze().astype(np.float32)*opt.normalize_factor 164 | del denoise_img 165 | # output_img = output_img1[0:raw_noise_img.shape[0],0:raw_noise_img.shape[1],0:raw_noise_img.shape[2]] 166 | output_img = output_img-output_img.min() 167 | output_img = output_img/output_img.max()*65535 168 | output_img = np.clip(output_img, 0, 65535).astype('uint16') 169 | output_img = output_img-output_img.min() 170 | # output_img = output_img.astype('uint16') 171 | input_img = input_img.squeeze().astype(np.float32)*opt.normalize_factor 172 | # input_img = input_img1[0:raw_noise_img.shape[0],0:raw_noise_img.shape[1],0:raw_noise_img.shape[2]] 173 | input_img = np.clip(input_img, 0, 65535).astype('uint16') 174 | result_name = output_path + '//' +pth_name.replace('.pth','')+ '_output.tif' 175 | input_name = output_path + '//' +pth_name.replace('.pth','')+ '_input.tif' 176 | io.imsave(result_name, output_img) 177 | io.imsave(input_name, input_img) 178 | 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepCAD: Deep self-supervised learning for calcium imaging denoising 2 | 3 | 4 | 5 | ### :triangular_flag_on_post: [[New version released! :DeepCAD-RT]](https://github.com/cabooster/DeepCAD-RT) 6 | 7 | ## Contents 8 | 9 | - [Overview](#overview) 10 | - [Directory structure](#directory-structure) 11 | - [Pytorch code](#pytorch-code) 12 | - [Fiji plugin](#fiji-plugin) 13 | - [Results](#results) 14 | - [License](./LICENSE) 15 | - [Citation](#citation) 16 | 17 | ## Overview 18 | 19 | 20 | 21 | Calcium imaging is inherently susceptible to detection noise especially when imaging with high frame rate or under low excitation dosage. However, calcium transients are highly dynamic, non-repetitive activities and a firing pattern cannot be captured twice. Clean images for supervised training of deep neural networks are not accessible. Here, we present DeepCAD, a **deep** self-supervised learning-based method for **ca**lcium imaging **d**enoising. Using our method, detection noise can be effectively removed and the accuracy of neuron extraction and spike inference can be highly improved. 22 | 23 | DeepCAD is based on the insight that a deep learning network for image denoising can achieve satisfactory convergence even the target image used for training is another corrupted sampling of the same scene [[paper link]](https://arxiv.org/abs/1803.04189). We explored the temporal redundancy of calcium imaging and found that any two consecutive frames can be regarded as two independent samplings of the same underlying firing pattern. A single low-SNR stack is sufficient to be a complete training set for DeepCAD. Furthermore, to boost its performance on 3D temporal stacks, the input and output data are designed to be 3D volumes rather than 2D frames to fully incorporate the abundant information along time axis. 24 | 25 | For more details, please see the companion paper where the method first appeared: 26 | ["*Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising*".](https://www.nature.com/articles/s41592-021-01225-0) 27 | 28 | ## Directory structure 29 | 30 | ``` 31 | DeepCAD 32 | |---DeepCAD_pytorch #Pytorch implementation of DeepCAD# 33 | |---|---train.py 34 | |---|---test.py 35 | |---|---script.py 36 | |---|---network.py 37 | |---|---model_3DUnet.py 38 | |---|---data_process.py 39 | |---|---buildingblocks.py 40 | |---|---utils.py 41 | |---|---datasets 42 | |---|---|---DataForPytorch #project_name# 43 | |---|---|---|---data.tif 44 | |---|---pth 45 | |---|---|---ModelForPytorch 46 | |---|---|---|---model.pth 47 | |---|---results 48 | |---|---|--- # Intermediate and final results# 49 | |---DeepCAD_Fiji 50 | |---|---DeepCAD_Fiji_plugin 51 | |---|---|---DeepCAD-0.3.0 #executable jar file/.jar# 52 | |---|---DeepCAD_java #java source code of DeepCAD Fiji plugin# 53 | |---|---DeepCAD_tensorflow #Tensorflow implementation compatible with Fiji plugin# 54 | ``` 55 | - **DeepCAD_pytorch** is the [Pytorch](https://pytorch.org/) implementation of DeepCAD. 56 | - **DeepCAD_Fiji** is a user-friendly [Fiji](https://imagej.net/Fiji) plugin. This plugin is easy to install and convenient to use. Researchers without expertise in computer science and machine learning can learn to use it in a very short time. 57 | - **DeepCAD_Fiji_plugin** contains the executable .jar file that can be installed on Fiji. 58 | - **DeepCAD_java** is the java source code of our Fiji plugin based on [CSBDeep](https://csbdeep.bioimagecomputing.com). 59 | - **DeepCAD_tensorflow** is the [Tensorflow](https://www.tensorflow.org/) implementation of DeepCAD, which is used for training models compatible with the Fiji plugin. 60 | 61 | ## Pytorch code 62 | 63 | PyTorch code is the recommended implementation of DeepCAD. 64 | 65 | ### :triangular_flag_on_post: [[New version released]](https://github.com/cabooster/DeepCAD-RT) 66 | **An upgraded version of this PyTorch code has been released and managed at another repository [[new code link]](https://github.com/cabooster/DeepCAD-RT), with several new features such as much faster processing speed, low memory cost, improved pre- and post processing, multi-GPU acceleration, more stable performance, etc**. 67 | 68 | ### Environment 69 | 70 | * Ubuntu 16.04 71 | * Python 3.6 72 | * Pytorch >= 1.3.1 73 | * NVIDIA GPU (24 GB Memory) + CUDA 74 | 75 | ### Environment configuration 76 | 77 | * Create a virtual environment and install Pytorch. In the 4th step, please select the correct Pytorch version that matches your CUDA version from https://pytorch.org/get-started/previous-versions/ 78 | 79 | ``` 80 | $ conda create -n deepcad python=3.6 81 | $ source activate deepcad 82 | $ pip install torch==1.3.1 83 | $ conda install pytorch torchvision cudatoolkit -c pytorch 84 | ``` 85 | 86 | * Install other dependencies 87 | 88 | ``` 89 | $ conda install -c anaconda matplotlib opencv scikit-learn scikit-image 90 | $ conda install -c conda-forge h5py pyyaml tensorboardx tifffile 91 | ``` 92 | ### Download the source code 93 | 94 | ``` 95 | $ git clone git://github.com/cabooster/DeepCAD 96 | $ cd DeepCAD/DeepCAD_pytorch/ 97 | ``` 98 | 99 | ### Training 100 | 101 | Download the demo data(.tif file) [[DataForPytorch](https://drive.google.com/drive/folders/1w9v1SrEkmvZal5LH79HloHhz6VXSPfI_)] and put it into *DeepCAD_pytorch/datasets/DataForPytorch.*. 102 | 103 | Run the **script.py** to start training. 104 | 105 | ``` 106 | $ source activate deepcad 107 | $ python script.py train 108 | ``` 109 | 110 | Parameters can be modified as required in **script.py**. If your GPU is running out of memory, you can use smaller `img_h`, `img_w`, `img_s` and `gap_h`, `gap_h`, `gap_s`. 111 | 112 | ``` 113 | $ os.system('python train.py --datasets_folder --img_h --img_w --img_s --gap_h --gap_w --gap_s --n_epochs --GPU --normalize_factor --train_datasets_size --select_img_num') 114 | 115 | @parameters 116 | --datasets_folder: the folder containing your training data (one or more stacks) 117 | --img_h, --img_w, --img_s: patch size in three dimensions 118 | --gap_h, --gap_w, --gap_s: the spacing to extract training patches from the input stack(s) 119 | --n_epochs: the number of training epochs 120 | --GPU: specify the GPU used for training 121 | --lr: learning rate, please use the default value 122 | --normalize_factor: a constant for image normalization 123 | --training_datasets_size: the number of patches you extracted for training 124 | --select_img_num: the number of slices used for training. 125 | ``` 126 | 127 | ### Test 128 | 129 | Download our pre-trained model (.pth file and .yaml file) [[ModelForPytorch](https://drive.google.com/drive/folders/12LEFsAopTolaRyRpJtFpzOYH3tBZMGUP)] and put it into *DeepCAD_pytorch/pth/ModelForPytorch*. 130 | 131 | Run the **script.py** to start the test process. Parameters saved in the .yaml file will be automatically loaded. If your GPU is running out of memory, you can use smaller `img_h`, `img_w`, `img_s` and `gap_h`, `gap_h`, `gap_s`. 132 | 133 | ``` 134 | $ source activate deepcad 135 | $ python script.py test 136 | ``` 137 | 138 | Parameters can be modified as required in **script.py**. All models in the `--denoise_model` folder will be tested and manual inspection should be made for **model screening**. 139 | 140 | ``` 141 | $ os.system('python test.py --denoise_model --datasets_folder --test_datasize') 142 | 143 | @parameters 144 | --denoise_model: the folder containing all the pre-trained models. 145 | --datasets_folder: the folder containing the testing data (one or more stacks). 146 | --test_datasize: the number of frames used for testing 147 | --img_h, --img_w, --img_s: patch size in three dimensions 148 | --gap_h, --gap_w, --gap_s: the spacing to extract test patches from the input stack(s) 149 | ``` 150 | 151 | ## Fiji plugin 152 | 153 | To ameliorate the difficulty of using our deep self-supervised learning-based method, we developed a user-friendly Fiji plugin, which is easy to install and convenient to use (has been tested on a Windows desktop with Intel i9 CPU and 128G RAM). Researchers without expertise in computer science and machine learning can manage it in a very short time. **Tutorials** on installing and using the plugin has been moved to [**this page**](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji). 154 | 155 | 156 | 157 | 158 | ## Results 159 | 160 | ### 1. The performance of DeepCAD on denoising two-photon calcium imaging of neurite activities. 161 | 162 | 163 | 164 | ### 2. The performance of DeepCAD on denoising two-photon calcium imaging of large neuronal populations. 165 | 166 | 167 | 168 | ### 3. Cross-system validation. 169 | 170 | 171 | 172 | Denoising performance of DeepCAD on three two-photon laser-scanning microscopes (2PLSMs) with different system setups. **Our system** was equipped with alkali PMTs (PMT1001, Thorlabs) and a 25×/1.05 NA commercial objective (XLPLN25XWMP2, Olympus). The **standard 2PLSM** was equipped with a GaAsP PMT (H10770PA-40, Hamamatsu) and a 25×/1.05 NA commercial objective (XLPLN25XWMP2, Olympus). The **two-photon mesoscope** was equipped with a GaAsP PMT (H11706-40, Hamamatsu) and a 2.3×/0.6 NA custom objective. The same pre-trained model was used for processing these data. 173 | 174 | ## Citation 175 | 176 | If you use this code please cite the companion paper where the original method appeared: 177 | 178 | Li, X., Zhang, G., Wu, J. et al. Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising. Nat Methods (2021). [https://doi.org/10.1038/s41592-021-01225-0](https://www.nature.com/articles/s41592-021-01225-0) 179 | 180 | ``` 181 | @article{li2021reinforcing, 182 | title={Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising}, 183 | author={Li, Xinyang and Zhang, Guoxun and Wu, Jiamin and Zhang, Yuanlong and Zhao, Zhifeng and Lin, Xing and Qiao, Hui and Xie, Hao and Wang, Haoqian and Fang, Lu and others}, 184 | journal={Nature Methods}, 185 | pages={1--6}, 186 | year={2021}, 187 | publisher={Nature Publishing Group} 188 | } 189 | ``` 190 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/buildingblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 7 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 8 | 9 | 10 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 11 | """ 12 | Create a list of modules with together constitute a single conv layer with non-linearity 13 | and optional batchnorm/groupnorm. 14 | 15 | Args: 16 | in_channels (int): number of input channels 17 | out_channels (int): number of output channels 18 | order (string): order of things, e.g. 19 | 'cr' -> conv + ReLU 20 | 'crg' -> conv + ReLU + groupnorm 21 | 'cl' -> conv + LeakyReLU 22 | 'ce' -> conv + ELU 23 | num_groups (int): number of groups for the GroupNorm 24 | padding (int): add zero-padding to the input 25 | 26 | Return: 27 | list of tuple (name, module) 28 | """ 29 | assert 'c' in order, "Conv layer MUST be present" 30 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 31 | 32 | modules = [] 33 | for i, char in enumerate(order): 34 | if char == 'r': 35 | modules.append(('ReLU', nn.ReLU(inplace=True))) 36 | elif char == 'l': 37 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 38 | elif char == 'e': 39 | modules.append(('ELU', nn.ELU(inplace=True))) 40 | elif char == 'c': 41 | # add learnable bias only in the absence of gatchnorm/groupnorm 42 | bias = not ('g' in order or 'b' in order) 43 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 44 | elif char == 'g': 45 | is_before_conv = i < order.index('c') 46 | assert not is_before_conv, 'GroupNorm MUST go after the Conv3d' 47 | # number of groups must be less or equal the number of channels 48 | if out_channels < num_groups: 49 | num_groups = out_channels 50 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels))) 51 | elif char == 'b': 52 | is_before_conv = i < order.index('c') 53 | if is_before_conv: 54 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 55 | else: 56 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 57 | else: 58 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 59 | 60 | return modules 61 | 62 | 63 | class SingleConv(nn.Sequential): 64 | """ 65 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 66 | of operations can be specified via the `order` parameter 67 | 68 | Args: 69 | in_channels (int): number of input channels 70 | out_channels (int): number of output channels 71 | kernel_size (int): size of the convolving kernel 72 | order (string): determines the order of layers, e.g. 73 | 'cr' -> conv + ReLU 74 | 'crg' -> conv + ReLU + groupnorm 75 | 'cl' -> conv + LeakyReLU 76 | 'ce' -> conv + ELU 77 | num_groups (int): number of groups for the GroupNorm 78 | """ 79 | 80 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cr', num_groups=8, padding=1): 81 | super(SingleConv, self).__init__() 82 | 83 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 84 | self.add_module(name, module) 85 | 86 | 87 | class DoubleConv(nn.Sequential): 88 | """ 89 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 90 | We use (Conv3d+ReLU+GroupNorm3d) by default. 91 | This can be changed however by providing the 'order' argument, e.g. in order 92 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 93 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 94 | as (H_in, W_in), so that you don't have to crop in the decoder path. 95 | 96 | Args: 97 | in_channels (int): number of input channels 98 | out_channels (int): number of output channels 99 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 100 | kernel_size (int): size of the convolving kernel 101 | order (string): determines the order of layers, e.g. 102 | 'cr' -> conv + ReLU 103 | 'crg' -> conv + ReLU + groupnorm 104 | 'cl' -> conv + LeakyReLU 105 | 'ce' -> conv + ELU 106 | num_groups (int): number of groups for the GroupNorm 107 | """ 108 | 109 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='cr', num_groups=8): 110 | super(DoubleConv, self).__init__() 111 | if encoder: 112 | # we're in the encoder path 113 | conv1_in_channels = in_channels 114 | conv1_out_channels = out_channels // 2 115 | if conv1_out_channels < in_channels: 116 | conv1_out_channels = in_channels 117 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 118 | else: 119 | # we're in the decoder path, decrease the number of channels in the 1st convolution 120 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 121 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 122 | 123 | # conv1 124 | self.add_module('SingleConv1', 125 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 126 | # conv2 127 | self.add_module('SingleConv2', 128 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 129 | 130 | 131 | class ExtResNetBlock(nn.Module): 132 | """ 133 | Basic UNet block consisting of a SingleConv followed by the residual block. 134 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 135 | of output channels is compatible with the residual block that follows. 136 | This block can be used instead of standard DoubleConv in the Encoder module. 137 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 138 | 139 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 140 | """ 141 | 142 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 143 | super(ExtResNetBlock, self).__init__() 144 | 145 | # first convolution 146 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 147 | # residual block 148 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 149 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 150 | n_order = order 151 | for c in 'rel': 152 | n_order = n_order.replace(c, '') 153 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 154 | num_groups=num_groups) 155 | 156 | # create non-linearity separately 157 | if 'l' in order: 158 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 159 | elif 'e' in order: 160 | self.non_linearity = nn.ELU(inplace=True) 161 | else: 162 | self.non_linearity = nn.ReLU(inplace=True) 163 | 164 | def forward(self, x): 165 | # apply first convolution and save the output as a residual 166 | out = self.conv1(x) 167 | residual = out 168 | 169 | # residual block 170 | out = self.conv2(out) 171 | out = self.conv3(out) 172 | 173 | out += residual 174 | out = self.non_linearity(out) 175 | 176 | return out 177 | 178 | 179 | class Encoder(nn.Module): 180 | """ 181 | A single module from the encoder path consisting of the optional max 182 | pooling layer (one may specify the MaxPool kernel_size to be different 183 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 184 | (make sure to use complementary scale_factor in the decoder path) followed by 185 | a DoubleConv module. 186 | Args: 187 | in_channels (int): number of input channels 188 | out_channels (int): number of output channels 189 | conv_kernel_size (int): size of the convolving kernel 190 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 191 | pool_kernel_size (tuple): the size of the window to take a max over 192 | pool_type (str): pooling layer: 'max' or 'avg' 193 | basic_module(nn.Module): either ResNetBlock or DoubleConv 194 | conv_layer_order (string): determines the order of layers 195 | in `DoubleConv` module. See `DoubleConv` for more info. 196 | num_groups (int): number of groups for the GroupNorm 197 | """ 198 | 199 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 200 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='cr', 201 | num_groups=8): 202 | super(Encoder, self).__init__() 203 | assert pool_type in ['max', 'avg'] 204 | if apply_pooling: 205 | if pool_type == 'max': 206 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 207 | else: 208 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 209 | else: 210 | self.pooling = None 211 | 212 | self.basic_module = basic_module(in_channels, out_channels, 213 | encoder=True, 214 | kernel_size=conv_kernel_size, 215 | order=conv_layer_order, 216 | num_groups=num_groups) 217 | 218 | def forward(self, x): 219 | if self.pooling is not None: 220 | x = self.pooling(x) 221 | x = self.basic_module(x) 222 | return x 223 | 224 | 225 | class Decoder(nn.Module): 226 | """ 227 | A single module for decoder path consisting of the upsample layer 228 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv 229 | module. 230 | Args: 231 | in_channels (int): number of input channels 232 | out_channels (int): number of output channels 233 | kernel_size (int): size of the convolving kernel 234 | scale_factor (tuple): used as the multiplier for the image H/W/D in 235 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 236 | from the corresponding encoder 237 | basic_module(nn.Module): either ResNetBlock or DoubleConv 238 | conv_layer_order (string): determines the order of layers 239 | in `DoubleConv` module. See `DoubleConv` for more info. 240 | num_groups (int): number of groups for the GroupNorm 241 | """ 242 | 243 | def __init__(self, in_channels, out_channels, kernel_size=3, 244 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='cr', num_groups=8): 245 | super(Decoder, self).__init__() 246 | if basic_module == DoubleConv: 247 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling 248 | self.upsample = None 249 | else: 250 | # otherwise use ConvTranspose3d (bear in mind your GPU memory) 251 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 252 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 253 | # also scale the number of channels from in_channels to out_channels so that summation joining 254 | # works correctly 255 | self.upsample = nn.ConvTranspose3d(in_channels, 256 | out_channels, 257 | kernel_size=kernel_size, 258 | stride=scale_factor, 259 | padding=1, 260 | output_padding=1) 261 | # adapt the number of in_channels for the ExtResNetBlock 262 | in_channels = out_channels 263 | 264 | self.basic_module = basic_module(in_channels, out_channels, 265 | encoder=False, 266 | kernel_size=kernel_size, 267 | order=conv_layer_order, 268 | num_groups=num_groups) 269 | 270 | def forward(self, encoder_features, x): 271 | if self.upsample is None: 272 | # use nearest neighbor interpolation and concatenation joining 273 | output_size = encoder_features.size()[2:] 274 | x = F.interpolate(x, size=output_size, mode='nearest') 275 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension 276 | x = torch.cat((encoder_features, x), dim=1) 277 | else: 278 | # use ConvTranspose3d and summation joining 279 | x = self.upsample(x) 280 | x += encoder_features 281 | 282 | x = self.basic_module(x) 283 | return x 284 | 285 | 286 | class FinalConv(nn.Sequential): 287 | """ 288 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 289 | which reduces the number of channels to 'out_channels'. 290 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 291 | We use (Conv3d+ReLU+GroupNorm3d) by default. 292 | This can be change however by providing the 'order' argument, e.g. in order 293 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 294 | Args: 295 | in_channels (int): number of input channels 296 | out_channels (int): number of output channels 297 | kernel_size (int): size of the convolving kernel 298 | order (string): determines the order of layers, e.g. 299 | 'cr' -> conv + ReLU 300 | 'crg' -> conv + ReLU + groupnorm 301 | num_groups (int): number of groups for the GroupNorm 302 | """ 303 | 304 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cr', num_groups=8): 305 | super(FinalConv, self).__init__() 306 | 307 | # conv1 308 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 309 | 310 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 311 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 312 | self.add_module('final_conv', final_conv) 313 | -------------------------------------------------------------------------------- /DeepCAD_Fiji/DeepCAD_tensorflow/data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import tifffile as tiff 5 | import time 6 | import datetime 7 | import random 8 | from skimage import io 9 | import logging 10 | import math 11 | 12 | def train_preprocess_lessMemoryMulStacks(args): 13 | img_h = args.img_h 14 | img_w = args.img_w 15 | img_s2 = args.img_s*2 16 | gap_h = args.gap_h 17 | gap_w = args.gap_w 18 | gap_s2 = args.gap_s*2 19 | im_folder = args.datasets_path+'//'+args.datasets_folder 20 | 21 | name_list = [] 22 | # train_raw = [] 23 | coordinate_list={} 24 | 25 | print('list(os.walk(im_folder, topdown=False)) -----> ',list(os.walk(im_folder, topdown=False))) 26 | stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1]) 27 | print('stack_num -----> ',stack_num) 28 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 29 | # print('im_name -----> ',im_name) 30 | im_dir = im_folder+'//'+im_name 31 | noise_im = tiff.imread(im_dir) 32 | if noise_im.shape[0]>args.select_img_num: 33 | noise_im = noise_im[0:args.select_img_num,:,:] 34 | gap_s2 = get_gap_s(args, noise_im, stack_num) 35 | # print('noise_im shape -----> ',noise_im.shape) 36 | # print('noise_im max -----> ',noise_im.max()) 37 | # print('noise_im min -----> ',noise_im.min()) 38 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 39 | 40 | whole_w = noise_im.shape[2] 41 | whole_h = noise_im.shape[1] 42 | whole_s = noise_im.shape[0] 43 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 44 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 45 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 46 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 47 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 48 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 49 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 50 | init_h = gap_h*x 51 | end_h = gap_h*x + img_h 52 | init_w = gap_w*y 53 | end_w = gap_w*y + img_w 54 | init_s = gap_s2*z 55 | end_s = gap_s2*z + img_s2 56 | single_coordinate['init_h'] = init_h 57 | single_coordinate['end_h'] = end_h 58 | single_coordinate['init_w'] = init_w 59 | single_coordinate['end_w'] = end_w 60 | single_coordinate['init_s'] = init_s 61 | single_coordinate['end_s'] = end_s 62 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 63 | patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 64 | # train_raw.append(noise_patch1.transpose(1,2,0)) 65 | name_list.append(patch_name) 66 | # print(' single_coordinate -----> ',single_coordinate) 67 | coordinate_list[patch_name] = single_coordinate 68 | return name_list, noise_im, coordinate_list 69 | 70 | def get_gap_s(args, img, stack_num): 71 | whole_w = img.shape[2] 72 | whole_h = img.shape[1] 73 | whole_s = img.shape[0] 74 | print('whole_w -----> ',whole_w) 75 | print('whole_h -----> ',whole_h) 76 | print('whole_s -----> ',whole_s) 77 | w_num = math.floor((whole_w-args.img_w)/args.gap_w)+1 78 | h_num = math.floor((whole_h-args.img_h)/args.gap_h)+1 79 | s_num = math.ceil(args.train_datasets_size/w_num/h_num/stack_num) 80 | print('w_num -----> ',w_num) 81 | print('h_num -----> ',h_num) 82 | print('s_num -----> ',s_num) 83 | gap_s = math.floor((whole_s-args.img_s*2)/(s_num-1)) 84 | print('gap_s -----> ',gap_s) 85 | return gap_s 86 | 87 | def shuffle_datasets(train_raw, train_GT, name_list): 88 | index_list = list(range(0, len(name_list))) 89 | # print('index_list -----> ',index_list) 90 | random.shuffle(index_list) 91 | random_index_list = index_list 92 | # print('index_list -----> ',index_list) 93 | new_name_list = list(range(0, len(name_list))) 94 | train_raw = np.array(train_raw) 95 | # print('train_raw shape -----> ',train_raw.shape) 96 | train_GT = np.array(train_GT) 97 | # print('train_GT shape -----> ',train_GT.shape) 98 | new_train_raw = train_raw 99 | new_train_GT = train_GT 100 | for i in range(0,len(random_index_list)): 101 | # print('i -----> ',i) 102 | new_train_raw[i,:,:,:] = train_raw[random_index_list[i],:,:,:] 103 | new_train_GT[i,:,:,:] = train_GT[random_index_list[i],:,:,:] 104 | new_name_list[i] = name_list[random_index_list[i]] 105 | # new_train_raw = np.expand_dims(new_train_raw, 4) 106 | # new_train_GT = np.expand_dims(new_train_GT, 4) 107 | return new_train_raw, new_train_GT, new_name_list 108 | 109 | def train_preprocess(args): 110 | img_h = args.img_h 111 | img_w = args.img_w 112 | img_s2 = args.img_s*2 113 | gap_h = args.gap_h 114 | gap_w = args.gap_w 115 | gap_s2 = args.gap_s*2 116 | im_folder = 'datasets//'+args.datasets_folder 117 | 118 | name_list = [] 119 | train_raw = [] 120 | train_GT = [] 121 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 122 | print('im_name -----> ',im_name) 123 | im_dir = im_folder+'//'+im_name 124 | noise_im = tiff.imread(im_dir) 125 | print('noise_im shape -----> ',noise_im.shape) 126 | # print('noise_im max -----> ',noise_im.max()) 127 | # print('noise_im min -----> ',noise_im.min()) 128 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 129 | 130 | whole_w = noise_im.shape[2] 131 | whole_h = noise_im.shape[1] 132 | whole_s = noise_im.shape[0] 133 | print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 134 | print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 135 | print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 136 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 137 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 138 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 139 | init_h = gap_h*x 140 | end_h = gap_h*x + img_h 141 | init_w = gap_w*y 142 | end_w = gap_w*y + img_w 143 | init_s = gap_s2*z 144 | end_s = gap_s2*z + img_s2 145 | noise_patch1 = noise_im[init_s:end_s:2,init_h:end_h,init_w:end_w] 146 | noise_patch2 = noise_im[init_s+1:end_s:2,init_h:end_h,init_w:end_w] 147 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 148 | train_raw.append(noise_patch1.transpose(1,2,0)) 149 | train_GT.append(noise_patch2.transpose(1,2,0)) 150 | name_list.append(patch_name) 151 | return train_raw, train_GT, name_list, noise_im 152 | 153 | def test_preprocess(args): 154 | img_h = args.img_h 155 | img_w = args.img_w 156 | img_s2 = args.img_s 157 | gap_h = args.gap_h 158 | gap_w = args.gap_w 159 | gap_s2 = args.gap_s 160 | im_folder = 'datasets//'+args.datasets_folder 161 | 162 | name_list = [] 163 | train_raw = [] 164 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 165 | # print('im_name -----> ',im_name) 166 | im_dir = im_folder+'//'+im_name 167 | noise_im = tiff.imread(im_dir) 168 | # print('noise_im shape -----> ',noise_im.shape) 169 | # print('noise_im max -----> ',noise_im.max()) 170 | # print('noise_im min -----> ',noise_im.min()) 171 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 172 | 173 | whole_w = noise_im.shape[2] 174 | whole_h = noise_im.shape[1] 175 | whole_s = noise_im.shape[0] 176 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 177 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 178 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 179 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 180 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 181 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 182 | init_h = gap_h*x 183 | end_h = gap_h*x + img_h 184 | init_w = gap_w*y 185 | end_w = gap_w*y + img_w 186 | init_s = gap_s2*z 187 | end_s = gap_s2*z + img_s2 188 | noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 189 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 190 | train_raw.append(noise_patch1.transpose(1,2,0)) 191 | name_list.append(patch_name) 192 | return train_raw, name_list, noise_im 193 | 194 | def test_preprocess_lessMemory (args): 195 | img_h = args.img_h 196 | img_w = args.img_w 197 | img_s2 = args.img_s 198 | gap_h = args.gap_h 199 | gap_w = args.gap_w 200 | gap_s2 = args.gap_s 201 | im_folder = 'datasets//'+args.datasets_folder 202 | 203 | name_list = [] 204 | # train_raw = [] 205 | coordinate_list={} 206 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 207 | # print('im_name -----> ',im_name) 208 | im_dir = im_folder+'//'+im_name 209 | noise_im = tiff.imread(im_dir) 210 | # print('noise_im shape -----> ',noise_im.shape) 211 | # print('noise_im max -----> ',noise_im.max()) 212 | # print('noise_im min -----> ',noise_im.min()) 213 | if noise_im.shape[0]>args.test_datasize: 214 | noise_im = noise_im[0:args.test_datasize,:,:] 215 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 216 | 217 | whole_w = noise_im.shape[2] 218 | whole_h = noise_im.shape[1] 219 | whole_s = noise_im.shape[0] 220 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 221 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 222 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 223 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 224 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 225 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 226 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 227 | init_h = gap_h*x 228 | end_h = gap_h*x + img_h 229 | init_w = gap_w*y 230 | end_w = gap_w*y + img_w 231 | init_s = gap_s2*z 232 | end_s = gap_s2*z + img_s2 233 | single_coordinate['init_h'] = init_h 234 | single_coordinate['end_h'] = end_h 235 | single_coordinate['init_w'] = init_w 236 | single_coordinate['end_w'] = end_w 237 | single_coordinate['init_s'] = init_s 238 | single_coordinate['end_s'] = end_s 239 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 240 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 241 | # train_raw.append(noise_patch1.transpose(1,2,0)) 242 | name_list.append(patch_name) 243 | # print(' single_coordinate -----> ',single_coordinate) 244 | coordinate_list[patch_name] = single_coordinate 245 | return name_list, noise_im, coordinate_list 246 | 247 | def train_preprocess_lessMemory(args): 248 | img_h = args.img_h 249 | img_w = args.img_w 250 | img_s2 = args.img_s*2 251 | gap_h = args.gap_h 252 | gap_w = args.gap_w 253 | gap_s2 = args.gap_s*2 254 | im_folder = 'datasets//'+args.datasets_folder 255 | 256 | name_list = [] 257 | # train_raw = [] 258 | coordinate_list={} 259 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 260 | # print('im_name -----> ',im_name) 261 | im_dir = im_folder+'//'+im_name 262 | noise_im = tiff.imread(im_dir) 263 | # print('noise_im shape -----> ',noise_im.shape) 264 | # print('noise_im max -----> ',noise_im.max()) 265 | # print('noise_im min -----> ',noise_im.min()) 266 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 267 | 268 | whole_w = noise_im.shape[2] 269 | whole_h = noise_im.shape[1] 270 | whole_s = noise_im.shape[0] 271 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 272 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 273 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 274 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 275 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 276 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 277 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 278 | init_h = gap_h*x 279 | end_h = gap_h*x + img_h 280 | init_w = gap_w*y 281 | end_w = gap_w*y + img_w 282 | init_s = gap_s2*z 283 | end_s = gap_s2*z + img_s2 284 | single_coordinate['init_h'] = init_h 285 | single_coordinate['end_h'] = end_h 286 | single_coordinate['init_w'] = init_w 287 | single_coordinate['end_w'] = end_w 288 | single_coordinate['init_s'] = init_s 289 | single_coordinate['end_s'] = end_s 290 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 291 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 292 | # train_raw.append(noise_patch1.transpose(1,2,0)) 293 | name_list.append(patch_name) 294 | # print(' single_coordinate -----> ',single_coordinate) 295 | coordinate_list[patch_name] = single_coordinate 296 | return name_list, noise_im, coordinate_list 297 | 298 | 299 | def shuffle_datasets_lessMemory(name_list): 300 | index_list = list(range(0, len(name_list))) 301 | # print('index_list -----> ',index_list) 302 | random.shuffle(index_list) 303 | random_index_list = index_list 304 | # print('index_list -----> ',index_list) 305 | new_name_list = list(range(0, len(name_list))) 306 | for i in range(0,len(random_index_list)): 307 | new_name_list[i] = name_list[random_index_list[i]] 308 | return new_name_list -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/model_3DUnet.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from buildingblocks import Encoder, Decoder, FinalConv, DoubleConv, ExtResNetBlock, SingleConv 7 | from utils import create_feature_maps 8 | 9 | 10 | class UNet3D(nn.Module): 11 | """ 12 | 3DUnet model from 13 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 14 | `. 15 | 16 | Args: 17 | in_channels (int): number of input channels 18 | out_channels (int): number of output segmentation masks; 19 | Note that that the of out_channels might correspond to either 20 | different semantic classes or to different binary segmentation mask. 21 | It's up to the user of the class to interpret the out_channels and 22 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 23 | or BCEWithLogitsLoss (two-class) respectively) 24 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 25 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 26 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 27 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 28 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 29 | layer_order (string): determines the order of layers 30 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 31 | See `SingleConv` for more info 32 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 33 | num_groups (int): number of groups for the GroupNorm 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=64, layer_order='cr', num_groups=8, 37 | **kwargs): 38 | super(UNet3D, self).__init__() 39 | 40 | if isinstance(f_maps, int): 41 | # use 4 levels in the encoder path as suggested in the paper 42 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4) 43 | 44 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 45 | # uses DoubleConv as a basic_module for the Encoder 46 | encoders = [] 47 | for i, out_feature_num in enumerate(f_maps): 48 | if i == 0: 49 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 50 | conv_layer_order=layer_order, num_groups=num_groups) 51 | else: 52 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 53 | conv_layer_order=layer_order, num_groups=num_groups) 54 | encoders.append(encoder) 55 | 56 | self.encoders = nn.ModuleList(encoders) 57 | 58 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 59 | # uses DoubleConv as a basic_module for the Decoder 60 | decoders = [] 61 | reversed_f_maps = list(reversed(f_maps)) 62 | for i in range(len(reversed_f_maps) - 1): 63 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 64 | out_feature_num = reversed_f_maps[i + 1] 65 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 66 | conv_layer_order=layer_order, num_groups=num_groups) 67 | decoders.append(decoder) 68 | 69 | self.decoders = nn.ModuleList(decoders) 70 | 71 | # in the last layer a 1×1 convolution reduces the number of output 72 | # channels to the number of labels 73 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 74 | 75 | if final_sigmoid: 76 | self.final_activation = nn.Sigmoid() 77 | else: 78 | self.final_activation = nn.Softmax(dim=1) 79 | 80 | def forward(self, x): 81 | # encoder part 82 | encoders_features = [] 83 | for encoder in self.encoders: 84 | x = encoder(x) 85 | # reverse the encoder outputs to be aligned with the decoder 86 | encoders_features.insert(0, x) 87 | 88 | # remove the last encoder's output from the list 89 | # !!remember: it's the 1st in the list 90 | encoders_features = encoders_features[1:] 91 | 92 | # decoder part 93 | for decoder, encoder_features in zip(self.decoders, encoders_features): 94 | # pass the output from the corresponding encoder and the output 95 | # of the previous decoder 96 | x = decoder(encoder_features, x) 97 | 98 | x = self.final_conv(x) 99 | 100 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 101 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 102 | # if not self.training: 103 | # x = self.final_activation(x) 104 | 105 | return x 106 | 107 | 108 | class ResidualUNet3D(nn.Module): 109 | """ 110 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 111 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead 112 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 113 | 114 | Args: 115 | in_channels (int): number of input channels 116 | out_channels (int): number of output segmentation masks; 117 | Note that that the of out_channels might correspond to either 118 | different semantic classes or to different binary segmentation mask. 119 | It's up to the user of the class to interpret the out_channels and 120 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 121 | or BCELoss (two-class) respectively) 122 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 123 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5 124 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 125 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 126 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 127 | conv_layer_order (string): determines the order of layers 128 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 129 | See `SingleConv` for more info 130 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 131 | num_groups (int): number of groups for the GroupNorm 132 | skip_final_activation (bool): if True, skips the final normalization layer (sigmoid/softmax) and returns the 133 | logits directly 134 | """ 135 | 136 | def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=32, conv_layer_order='cge', num_groups=8, 137 | skip_final_activation=False, **kwargs): 138 | super(ResidualUNet3D, self).__init__() 139 | 140 | if isinstance(f_maps, int): 141 | # use 5 levels in the encoder path as suggested in the paper 142 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 143 | 144 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 145 | # uses ExtResNetBlock as a basic_module for the Encoder 146 | encoders = [] 147 | for i, out_feature_num in enumerate(f_maps): 148 | if i == 0: 149 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock, 150 | conv_layer_order=conv_layer_order, num_groups=num_groups) 151 | else: 152 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock, 153 | conv_layer_order=conv_layer_order, num_groups=num_groups) 154 | encoders.append(encoder) 155 | 156 | self.encoders = nn.ModuleList(encoders) 157 | 158 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 159 | # uses ExtResNetBlock as a basic_module for the Decoder 160 | decoders = [] 161 | reversed_f_maps = list(reversed(f_maps)) 162 | for i in range(len(reversed_f_maps) - 1): 163 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock, 164 | conv_layer_order=conv_layer_order, num_groups=num_groups) 165 | decoders.append(decoder) 166 | 167 | self.decoders = nn.ModuleList(decoders) 168 | 169 | # in the last layer a 1×1 convolution reduces the number of output 170 | # channels to the number of labels 171 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 172 | 173 | if not skip_final_activation: 174 | if final_sigmoid: 175 | self.final_activation = nn.Sigmoid() 176 | else: 177 | self.final_activation = nn.Softmax(dim=1) 178 | else: 179 | self.final_activation = None 180 | 181 | def forward(self, x): 182 | # encoder part 183 | encoders_features = [] 184 | for encoder in self.encoders: 185 | x = encoder(x) 186 | # reverse the encoder outputs to be aligned with the decoder 187 | encoders_features.insert(0, x) 188 | 189 | # remove the last encoder's output from the list 190 | # !!remember: it's the 1st in the list 191 | encoders_features = encoders_features[1:] 192 | 193 | # decoder part 194 | for decoder, encoder_features in zip(self.decoders, encoders_features): 195 | # pass the output from the corresponding encoder and the output 196 | # of the previous decoder 197 | x = decoder(encoder_features, x) 198 | 199 | x = self.final_conv(x) 200 | 201 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 202 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 203 | if not self.training and self.final_activation is not None: 204 | x = self.final_activation(x) 205 | 206 | return x 207 | 208 | 209 | class Noise2NoiseUNet3D(nn.Module): 210 | """ 211 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 212 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead 213 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 214 | 215 | Args: 216 | in_channels (int): number of input channels 217 | out_channels (int): number of output segmentation masks; 218 | Note that that the of out_channels might correspond to either 219 | different semantic classes or to different binary segmentation mask. 220 | It's up to the user of the class to interpret the out_channels and 221 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 222 | or BCELoss (two-class) respectively) 223 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 224 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5 225 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 226 | num_groups (int): number of groups for the GroupNorm 227 | """ 228 | 229 | def __init__(self, in_channels, out_channels, f_maps=16, num_groups=8, **kwargs): 230 | super(Noise2NoiseUNet3D, self).__init__() 231 | 232 | # Use LeakyReLU activation everywhere except the last layer 233 | conv_layer_order = 'clg' 234 | 235 | if isinstance(f_maps, int): 236 | # use 5 levels in the encoder path as suggested in the paper 237 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 238 | 239 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 240 | # uses DoubleConv as a basic_module for the Encoder 241 | encoders = [] 242 | for i, out_feature_num in enumerate(f_maps): 243 | if i == 0: 244 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 245 | conv_layer_order=conv_layer_order, num_groups=num_groups) 246 | else: 247 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 248 | conv_layer_order=conv_layer_order, num_groups=num_groups) 249 | encoders.append(encoder) 250 | 251 | self.encoders = nn.ModuleList(encoders) 252 | 253 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 254 | # uses DoubleConv as a basic_module for the Decoder 255 | decoders = [] 256 | reversed_f_maps = list(reversed(f_maps)) 257 | for i in range(len(reversed_f_maps) - 1): 258 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 259 | out_feature_num = reversed_f_maps[i + 1] 260 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 261 | conv_layer_order=conv_layer_order, num_groups=num_groups) 262 | decoders.append(decoder) 263 | 264 | self.decoders = nn.ModuleList(decoders) 265 | 266 | # 1x1x1 conv + simple ReLU in the final convolution 267 | self.final_conv = SingleConv(f_maps[0], out_channels, kernel_size=1, order='cr', padding=0) 268 | 269 | def forward(self, x): 270 | # encoder part 271 | encoders_features = [] 272 | for encoder in self.encoders: 273 | x = encoder(x) 274 | # reverse the encoder outputs to be aligned with the decoder 275 | encoders_features.insert(0, x) 276 | 277 | # remove the last encoder's output from the list 278 | # !!remember: it's the 1st in the list 279 | encoders_features = encoders_features[1:] 280 | 281 | # decoder part 282 | for decoder, encoder_features in zip(self.decoders, encoders_features): 283 | # pass the output from the corresponding encoder and the output 284 | # of the previous decoder 285 | x = decoder(encoder_features, x) 286 | 287 | x = self.final_conv(x) 288 | 289 | return x 290 | 291 | 292 | def get_model(config): 293 | def _model_class(class_name): 294 | m = importlib.import_module('unet3d.model') #向a模块中导入c.py中的对象 295 | clazz = getattr(m, class_name) #getattr() 函数用于返回一个对象属性值。 296 | return clazz 297 | 298 | assert 'model' in config, 'Could not find model configuration' 299 | model_config = config['model'] 300 | model_class = _model_class(model_config['name']) 301 | return model_class(**model_config) 302 | 303 | 304 | ###############################################Supervised Tags 3DUnet################################################### 305 | 306 | class TagsUNet3D(nn.Module): 307 | """ 308 | Supervised tags 3DUnet 309 | Args: 310 | in_channels (int): number of input channels 311 | out_channels (int): number of output channels; since most often we're trying to learn 312 | 3D unit vectors we use 3 as a default value 313 | output_heads (int): number of output heads from the network, each head corresponds to different 314 | semantic tag/direction to be learned 315 | conv_layer_order (string): determines the order of layers 316 | in `DoubleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 317 | See `DoubleConv` for more info 318 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 319 | """ 320 | 321 | def __init__(self, in_channels, out_channels=3, output_heads=1, conv_layer_order='crg', init_channel_number=32, 322 | **kwargs): 323 | super(TagsUNet3D, self).__init__() 324 | 325 | # number of groups for the GroupNorm 326 | num_groups = min(init_channel_number // 2, 32) 327 | 328 | # encoder path consist of 4 subsequent Encoder modules 329 | # the number of features maps is the same as in the paper 330 | self.encoders = nn.ModuleList([ 331 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order=conv_layer_order, 332 | num_groups=num_groups), 333 | Encoder(init_channel_number, 2 * init_channel_number, conv_layer_order=conv_layer_order, 334 | num_groups=num_groups), 335 | Encoder(2 * init_channel_number, 4 * init_channel_number, conv_layer_order=conv_layer_order, 336 | num_groups=num_groups), 337 | Encoder(4 * init_channel_number, 8 * init_channel_number, conv_layer_order=conv_layer_order, 338 | num_groups=num_groups) 339 | ]) 340 | 341 | self.decoders = nn.ModuleList([ 342 | Decoder(4 * init_channel_number + 8 * init_channel_number, 4 * init_channel_number, 343 | conv_layer_order=conv_layer_order, num_groups=num_groups), 344 | Decoder(2 * init_channel_number + 4 * init_channel_number, 2 * init_channel_number, 345 | conv_layer_order=conv_layer_order, num_groups=num_groups), 346 | Decoder(init_channel_number + 2 * init_channel_number, init_channel_number, 347 | conv_layer_order=conv_layer_order, num_groups=num_groups) 348 | ]) 349 | 350 | self.final_heads = nn.ModuleList( 351 | [FinalConv(init_channel_number, out_channels, num_groups=num_groups) for _ in 352 | range(output_heads)]) 353 | 354 | def forward(self, x): 355 | # encoder part 356 | encoders_features = [] 357 | for encoder in self.encoders: 358 | x = encoder(x) 359 | # reverse the encoder outputs to be aligned with the decoder 360 | encoders_features.insert(0, x) 361 | 362 | # remove the last encoder's output from the list 363 | # !!remember: it's the 1st in the list 364 | encoders_features = encoders_features[1:] 365 | 366 | # decoder part 367 | for decoder, encoder_features in zip(self.decoders, encoders_features): 368 | # pass the output from the corresponding encoder and the output 369 | # of the previous decoder 370 | x = decoder(encoder_features, x) 371 | 372 | # apply final layer per each output head 373 | tags = [final_head(x) for final_head in self.final_heads] 374 | 375 | # normalize directions with L2 norm 376 | return [tag / torch.norm(tag, p=2, dim=1).detach().clamp(min=1e-8) for tag in tags] 377 | 378 | 379 | ################################################Distance transform 3DUNet############################################## 380 | class DistanceTransformUNet3D(nn.Module): 381 | """ 382 | Predict Distance Transform to the boundary signal based on the output from the Tags3DUnet. Fore training use either: 383 | 1. PixelWiseCrossEntropyLoss if the distance transform is quantized (classification) 384 | 2. MSELoss if the distance transform is continuous (regression) 385 | Args: 386 | in_channels (int): number of input channels 387 | out_channels (int): number of output segmentation masks; 388 | Note that that the of out_channels might correspond to either 389 | different semantic classes or to different binary segmentation mask. 390 | It's up to the user of the class to interpret the out_channels and 391 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 392 | or BCELoss (two-class) respectively) 393 | final_sigmoid (bool): 'sigmoid'/'softmax' whether element-wise nn.Sigmoid or nn.Softmax should be applied after 394 | the final 1x1 convolution 395 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 396 | """ 397 | 398 | def __init__(self, in_channels, out_channels, final_sigmoid, init_channel_number=32, **kwargs): 399 | super(DistanceTransformUNet3D, self).__init__() 400 | 401 | # number of groups for the GroupNorm 402 | num_groups = min(init_channel_number // 2, 32) 403 | 404 | # encoder path consist of 4 subsequent Encoder modules 405 | # the number of features maps is the same as in the paper 406 | self.encoders = nn.ModuleList([ 407 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order='crg', 408 | num_groups=num_groups), 409 | Encoder(init_channel_number, 2 * init_channel_number, pool_type='avg', conv_layer_order='crg', 410 | num_groups=num_groups) 411 | ]) 412 | 413 | self.decoders = nn.ModuleList([ 414 | Decoder(3 * init_channel_number, init_channel_number, conv_layer_order='crg', num_groups=num_groups) 415 | ]) 416 | 417 | # in the last layer a 1×1 convolution reduces the number of output 418 | # channels to the number of labels 419 | self.final_conv = nn.Conv3d(init_channel_number, out_channels, 1) 420 | 421 | if final_sigmoid: 422 | self.final_activation = nn.Sigmoid() 423 | else: 424 | self.final_activation = nn.Softmax(dim=1) 425 | 426 | def forward(self, inputs): 427 | # allow multiple heads 428 | if isinstance(inputs, list) or isinstance(inputs, tuple): 429 | x = torch.cat(inputs, dim=1) 430 | else: 431 | x = inputs 432 | 433 | # encoder part 434 | encoders_features = [] 435 | for encoder in self.encoders: 436 | x = encoder(x) 437 | # reverse the encoder outputs to be aligned with the decoder 438 | encoders_features.insert(0, x) 439 | 440 | # remove the last encoder's output from the list 441 | # !!remember: it's the 1st in the list 442 | encoders_features = encoders_features[1:] 443 | 444 | # decoder part 445 | for decoder, encoder_features in zip(self.decoders, encoders_features): 446 | # pass the output from the corresponding encoder and the output 447 | # of the previous decoder 448 | x = decoder(encoder_features, x) 449 | 450 | # apply final 1x1 convolution 451 | x = self.final_conv(x) 452 | 453 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 454 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 455 | if not self.training: 456 | x = self.final_activation(x) 457 | 458 | return x 459 | 460 | 461 | class EndToEndDTUNet3D(nn.Module): 462 | def __init__(self, tags_in_channels, tags_out_channels, tags_output_heads, tags_init_channel_number, 463 | dt_in_channels, dt_out_channels, dt_final_sigmoid, dt_init_channel_number, 464 | tags_net_path=None, dt_net_path=None, **kwargs): 465 | super(EndToEndDTUNet3D, self).__init__() 466 | 467 | self.tags_net = TagsUNet3D(tags_in_channels, tags_out_channels, tags_output_heads, 468 | init_channel_number=tags_init_channel_number) 469 | if tags_net_path is not None: 470 | # load pre-trained TagsUNet3D 471 | self.tags_net = self._load_net(tags_net_path, self.tags_net) 472 | 473 | self.dt_net = DistanceTransformUNet3D(dt_in_channels, dt_out_channels, dt_final_sigmoid, 474 | init_channel_number=dt_init_channel_number) 475 | if dt_net_path is not None: 476 | # load pre-trained DistanceTransformUNet3D 477 | self.dt_net = self._load_net(dt_net_path, self.dt_net) 478 | 479 | @staticmethod 480 | def _load_net(checkpoint_path, model): 481 | state = torch.load(checkpoint_path) 482 | model.load_state_dict(state['model_state_dict']) 483 | return model 484 | 485 | def forward(self, x): 486 | x = self.tags_net(x) 487 | return self.dt_net(x) 488 | -------------------------------------------------------------------------------- /DeepCAD_pytorch/data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import tifffile as tiff 5 | import time 6 | import datetime 7 | import random 8 | from skimage import io 9 | import logging 10 | import math 11 | 12 | 13 | def shuffle_datasets(train_raw, train_GT, name_list): 14 | index_list = list(range(0, len(name_list))) 15 | # print('index_list -----> ',index_list) 16 | random.shuffle(index_list) 17 | random_index_list = index_list 18 | # print('index_list -----> ',index_list) 19 | new_name_list = list(range(0, len(name_list))) 20 | train_raw = np.array(train_raw) 21 | # print('train_raw shape -----> ',train_raw.shape) 22 | train_GT = np.array(train_GT) 23 | # print('train_GT shape -----> ',train_GT.shape) 24 | new_train_raw = train_raw 25 | new_train_GT = train_GT 26 | for i in range(0,len(random_index_list)): 27 | # print('i -----> ',i) 28 | new_train_raw[i,:,:,:] = train_raw[random_index_list[i],:,:,:] 29 | new_train_GT[i,:,:,:] = train_GT[random_index_list[i],:,:,:] 30 | new_name_list[i] = name_list[random_index_list[i]] 31 | # new_train_raw = np.expand_dims(new_train_raw, 4) 32 | # new_train_GT = np.expand_dims(new_train_GT, 4) 33 | return new_train_raw, new_train_GT, new_name_list 34 | 35 | def train_preprocess(args): 36 | img_h = args.img_h 37 | img_w = args.img_w 38 | img_s2 = args.img_s*2 39 | gap_h = args.gap_h 40 | gap_w = args.gap_w 41 | gap_s2 = args.gap_s*2 42 | im_folder = 'datasets//'+args.datasets_folder 43 | 44 | name_list = [] 45 | train_raw = [] 46 | train_GT = [] 47 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 48 | print('im_name -----> ',im_name) 49 | im_dir = im_folder+'//'+im_name 50 | noise_im = tiff.imread(im_dir) 51 | print('noise_im shape -----> ',noise_im.shape) 52 | # print('noise_im max -----> ',noise_im.max()) 53 | # print('noise_im min -----> ',noise_im.min()) 54 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 55 | 56 | whole_w = noise_im.shape[2] 57 | whole_h = noise_im.shape[1] 58 | whole_s = noise_im.shape[0] 59 | print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 60 | print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 61 | print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 62 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 63 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 64 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 65 | init_h = gap_h*x 66 | end_h = gap_h*x + img_h 67 | init_w = gap_w*y 68 | end_w = gap_w*y + img_w 69 | init_s = gap_s2*z 70 | end_s = gap_s2*z + img_s2 71 | noise_patch1 = noise_im[init_s:end_s:2,init_h:end_h,init_w:end_w] 72 | noise_patch2 = noise_im[init_s+1:end_s:2,init_h:end_h,init_w:end_w] 73 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 74 | train_raw.append(noise_patch1.transpose(1,2,0)) 75 | train_GT.append(noise_patch2.transpose(1,2,0)) 76 | name_list.append(patch_name) 77 | return train_raw, train_GT, name_list, noise_im 78 | 79 | def test_preprocess(args): 80 | img_h = args.img_h 81 | img_w = args.img_w 82 | img_s2 = args.img_s 83 | gap_h = args.gap_h 84 | gap_w = args.gap_w 85 | gap_s2 = args.gap_s 86 | im_folder = 'datasets//'+args.datasets_folder 87 | 88 | name_list = [] 89 | train_raw = [] 90 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 91 | # print('im_name -----> ',im_name) 92 | im_dir = im_folder+'//'+im_name 93 | noise_im = tiff.imread(im_dir) 94 | # print('noise_im shape -----> ',noise_im.shape) 95 | # print('noise_im max -----> ',noise_im.max()) 96 | # print('noise_im min -----> ',noise_im.min()) 97 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 98 | 99 | whole_w = noise_im.shape[2] 100 | whole_h = noise_im.shape[1] 101 | whole_s = noise_im.shape[0] 102 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 103 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 104 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 105 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 106 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 107 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 108 | init_h = gap_h*x 109 | end_h = gap_h*x + img_h 110 | init_w = gap_w*y 111 | end_w = gap_w*y + img_w 112 | init_s = gap_s2*z 113 | end_s = gap_s2*z + img_s2 114 | noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 115 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 116 | train_raw.append(noise_patch1.transpose(1,2,0)) 117 | name_list.append(patch_name) 118 | return train_raw, name_list, noise_im 119 | 120 | def test_preprocess_lessMemory (args): 121 | img_h = args.img_h 122 | img_w = args.img_w 123 | img_s2 = args.img_s 124 | gap_h = args.gap_h 125 | gap_w = args.gap_w 126 | gap_s2 = args.gap_s 127 | im_folder = 'datasets//'+args.datasets_folder 128 | 129 | name_list = [] 130 | # train_raw = [] 131 | coordinate_list={} 132 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 133 | # print('im_name -----> ',im_name) 134 | im_dir = im_folder+'//'+im_name 135 | noise_im = tiff.imread(im_dir) 136 | # print('noise_im shape -----> ',noise_im.shape) 137 | # print('noise_im max -----> ',noise_im.max()) 138 | # print('noise_im min -----> ',noise_im.min()) 139 | if noise_im.shape[0]>args.test_datasize: 140 | noise_im = noise_im[0:args.test_datasize,:,:] 141 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 142 | 143 | whole_w = noise_im.shape[2] 144 | whole_h = noise_im.shape[1] 145 | whole_s = noise_im.shape[0] 146 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 147 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 148 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 149 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 150 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 151 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 152 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 153 | init_h = gap_h*x 154 | end_h = gap_h*x + img_h 155 | init_w = gap_w*y 156 | end_w = gap_w*y + img_w 157 | init_s = gap_s2*z 158 | end_s = gap_s2*z + img_s2 159 | single_coordinate['init_h'] = init_h 160 | single_coordinate['end_h'] = end_h 161 | single_coordinate['init_w'] = init_w 162 | single_coordinate['end_w'] = end_w 163 | single_coordinate['init_s'] = init_s 164 | single_coordinate['end_s'] = end_s 165 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 166 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 167 | # train_raw.append(noise_patch1.transpose(1,2,0)) 168 | name_list.append(patch_name) 169 | # print(' single_coordinate -----> ',single_coordinate) 170 | coordinate_list[patch_name] = single_coordinate 171 | return name_list, noise_im, coordinate_list 172 | 173 | def get_gap_s(args, img, stack_num): 174 | whole_w = img.shape[2] 175 | whole_h = img.shape[1] 176 | whole_s = img.shape[0] 177 | print('whole_w -----> ',whole_w) 178 | print('whole_h -----> ',whole_h) 179 | print('whole_s -----> ',whole_s) 180 | w_num = math.floor((whole_w-args.img_w)/args.gap_w)+1 181 | h_num = math.floor((whole_h-args.img_h)/args.gap_h)+1 182 | s_num = math.ceil(args.train_datasets_size/w_num/h_num/stack_num) 183 | print('w_num -----> ',w_num) 184 | print('h_num -----> ',h_num) 185 | print('s_num -----> ',s_num) 186 | gap_s = math.floor((whole_s-args.img_s*2)/(s_num-1)) 187 | print('gap_s -----> ',gap_s) 188 | return gap_s 189 | 190 | def train_preprocess_lessMemory(args): 191 | img_h = args.img_h 192 | img_w = args.img_w 193 | img_s2 = args.img_s*2 194 | gap_h = args.gap_h 195 | gap_w = args.gap_w 196 | gap_s2 = args.gap_s*2 197 | im_folder = args.datasets_path+'//'+args.datasets_folder 198 | 199 | name_list = [] 200 | # train_raw = [] 201 | coordinate_list={} 202 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 203 | # print('im_name -----> ',im_name) 204 | im_dir = im_folder+'//'+im_name 205 | noise_im = tiff.imread(im_dir) 206 | if noise_im.shape[0]>args.select_img_num: 207 | noise_im = noise_im[0:args.select_img_num,:,:] 208 | gap_s2 = get_gap_s(args, noise_im) 209 | # print('noise_im shape -----> ',noise_im.shape) 210 | # print('noise_im max -----> ',noise_im.max()) 211 | # print('noise_im min -----> ',noise_im.min()) 212 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 213 | 214 | whole_w = noise_im.shape[2] 215 | whole_h = noise_im.shape[1] 216 | whole_s = noise_im.shape[0] 217 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 218 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 219 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 220 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 221 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 222 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 223 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 224 | init_h = gap_h*x 225 | end_h = gap_h*x + img_h 226 | init_w = gap_w*y 227 | end_w = gap_w*y + img_w 228 | init_s = gap_s2*z 229 | end_s = gap_s2*z + img_s2 230 | single_coordinate['init_h'] = init_h 231 | single_coordinate['end_h'] = end_h 232 | single_coordinate['init_w'] = init_w 233 | single_coordinate['end_w'] = end_w 234 | single_coordinate['init_s'] = init_s 235 | single_coordinate['end_s'] = end_s 236 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 237 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 238 | # train_raw.append(noise_patch1.transpose(1,2,0)) 239 | name_list.append(patch_name) 240 | # print(' single_coordinate -----> ',single_coordinate) 241 | coordinate_list[patch_name] = single_coordinate 242 | return name_list, noise_im, coordinate_list 243 | 244 | def train_preprocess_lessMemoryMulStacks(args): 245 | img_h = args.img_h 246 | img_w = args.img_w 247 | img_s2 = args.img_s*2 248 | gap_h = args.gap_h 249 | gap_w = args.gap_w 250 | gap_s2 = args.gap_s*2 251 | im_folder = args.datasets_path+'//'+args.datasets_folder 252 | 253 | name_list = [] 254 | # train_raw = [] 255 | coordinate_list={} 256 | 257 | print('list(os.walk(im_folder, topdown=False)) -----> ',list(os.walk(im_folder, topdown=False))) 258 | stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1]) 259 | print('stack_num -----> ',stack_num) 260 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 261 | # print('im_name -----> ',im_name) 262 | im_dir = im_folder+'//'+im_name 263 | noise_im = tiff.imread(im_dir) 264 | if noise_im.shape[0]>args.select_img_num: 265 | noise_im = noise_im[0:args.select_img_num,:,:] 266 | gap_s2 = get_gap_s(args, noise_im, stack_num) 267 | # print('noise_im shape -----> ',noise_im.shape) 268 | # print('noise_im max -----> ',noise_im.max()) 269 | # print('noise_im min -----> ',noise_im.min()) 270 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 271 | 272 | whole_w = noise_im.shape[2] 273 | whole_h = noise_im.shape[1] 274 | whole_s = noise_im.shape[0] 275 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 276 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 277 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 278 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 279 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 280 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 281 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 282 | init_h = gap_h*x 283 | end_h = gap_h*x + img_h 284 | init_w = gap_w*y 285 | end_w = gap_w*y + img_w 286 | init_s = gap_s2*z 287 | end_s = gap_s2*z + img_s2 288 | single_coordinate['init_h'] = init_h 289 | single_coordinate['end_h'] = end_h 290 | single_coordinate['init_w'] = init_w 291 | single_coordinate['end_w'] = end_w 292 | single_coordinate['init_s'] = init_s 293 | single_coordinate['end_s'] = end_s 294 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 295 | patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 296 | # train_raw.append(noise_patch1.transpose(1,2,0)) 297 | name_list.append(patch_name) 298 | # print(' single_coordinate -----> ',single_coordinate) 299 | coordinate_list[patch_name] = single_coordinate 300 | return name_list, noise_im, coordinate_list 301 | 302 | def shuffle_datasets_lessMemory(name_list): 303 | index_list = list(range(0, len(name_list))) 304 | # print('index_list -----> ',index_list) 305 | random.shuffle(index_list) 306 | random_index_list = index_list 307 | # print('index_list -----> ',index_list) 308 | new_name_list = list(range(0, len(name_list))) 309 | for i in range(0,len(random_index_list)): 310 | new_name_list[i] = name_list[random_index_list[i]] 311 | return new_name_list 312 | 313 | def test_preprocess_lessMemoryPadding (args): 314 | img_h = args.img_h 315 | img_w = args.img_w 316 | img_s2 = args.img_s 317 | gap_h = args.gap_h 318 | gap_w = args.gap_w 319 | gap_s2 = args.gap_s 320 | im_folder = 'datasets//'+args.datasets_folder 321 | 322 | name_list = [] 323 | # train_raw = [] 324 | coordinate_list={} 325 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 326 | # print('im_name -----> ',im_name) 327 | im_dir = im_folder+'//'+im_name 328 | raw_noise_im = tiff.imread(im_dir) 329 | if raw_noise_im.shape[0]>args.test_datasize: 330 | raw_noise_im = raw_noise_im[0:args.test_datasize,:,:] 331 | raw_noise_im = (raw_noise_im-raw_noise_im.min()).astype(np.float32)/args.normalize_factor 332 | 333 | print('raw_noise_im shape -----> ',raw_noise_im.shape) 334 | noise_im_w = math.ceil((raw_noise_im.shape[2]-img_w)/gap_w)*gap_w+img_w 335 | noise_im_h = math.ceil((raw_noise_im.shape[1]-img_h)/gap_h)*gap_h+img_h 336 | noise_im_s = math.ceil((raw_noise_im.shape[0]-img_s2)/gap_s2)*gap_s2+img_s2 337 | noise_im = np.zeros([noise_im_s,noise_im_h,noise_im_w]) 338 | noise_im[0:raw_noise_im.shape[0], 0:raw_noise_im.shape[1], 0:raw_noise_im.shape[2]]=raw_noise_im 339 | noise_im = noise_im.astype(np.float32) 340 | print('noise_im shape -----> ',noise_im.shape) 341 | # print('noise_im max -----> ',noise_im.max()) 342 | # print('noise_im min -----> ',noise_im.min()) 343 | 344 | whole_w = noise_im.shape[2] 345 | whole_h = noise_im.shape[1] 346 | whole_s = noise_im.shape[0] 347 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 348 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 349 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 350 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)): 351 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)): 352 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)): 353 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 354 | init_h = gap_h*x 355 | end_h = gap_h*x + img_h 356 | init_w = gap_w*y 357 | end_w = gap_w*y + img_w 358 | init_s = gap_s2*z 359 | end_s = gap_s2*z + img_s2 360 | single_coordinate['init_h'] = init_h 361 | single_coordinate['end_h'] = end_h 362 | single_coordinate['init_w'] = init_w 363 | single_coordinate['end_w'] = end_w 364 | single_coordinate['init_s'] = init_s 365 | single_coordinate['end_s'] = end_s 366 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 367 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 368 | # train_raw.append(noise_patch1.transpose(1,2,0)) 369 | name_list.append(patch_name) 370 | # print(' single_coordinate -----> ',single_coordinate) 371 | coordinate_list[patch_name] = single_coordinate 372 | return name_list, noise_im, coordinate_list, raw_noise_im 373 | 374 | def test_preprocess_lessMemoryNoTail (args): 375 | img_h = args.img_h 376 | img_w = args.img_w 377 | img_s2 = args.img_s 378 | gap_h = args.gap_h 379 | gap_w = args.gap_w 380 | gap_s2 = args.gap_s 381 | cut_w = (img_w - gap_w)/2 382 | cut_h = (img_h - gap_h)/2 383 | cut_s = (img_s2 - gap_s2)/2 384 | im_folder = args.datasets_path+'//'+args.datasets_folder 385 | 386 | name_list = [] 387 | # train_raw = [] 388 | coordinate_list={} 389 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]: 390 | # print('im_name -----> ',im_name) 391 | im_dir = im_folder+'//'+im_name 392 | noise_im = tiff.imread(im_dir) 393 | # print('noise_im shape -----> ',noise_im.shape) 394 | # print('noise_im max -----> ',noise_im.max()) 395 | # print('noise_im min -----> ',noise_im.min()) 396 | if noise_im.shape[0]>args.test_datasize: 397 | noise_im = noise_im[0:args.test_datasize,:,:] 398 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor 399 | 400 | whole_w = noise_im.shape[2] 401 | whole_h = noise_im.shape[1] 402 | whole_s = noise_im.shape[0] 403 | 404 | num_w = math.ceil((whole_w-img_w+gap_w)/gap_w) 405 | num_h = math.ceil((whole_h-img_h+gap_h)/gap_h) 406 | num_s = math.ceil((whole_s-img_s2+gap_s2)/gap_s2) 407 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h)) 408 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w)) 409 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2)) 410 | for x in range(0,num_h): 411 | for y in range(0,num_w): 412 | for z in range(0,num_s): 413 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0} 414 | if x != (num_h-1): 415 | init_h = gap_h*x 416 | end_h = gap_h*x + img_h 417 | elif x == (num_h-1): 418 | init_h = whole_h - img_h 419 | end_h = whole_h 420 | 421 | if y != (num_w-1): 422 | init_w = gap_w*y 423 | end_w = gap_w*y + img_w 424 | elif y == (num_w-1): 425 | init_w = whole_w - img_w 426 | end_w = whole_w 427 | 428 | if z != (num_s-1): 429 | init_s = gap_s2*z 430 | end_s = gap_s2*z + img_s2 431 | elif z == (num_s-1): 432 | init_s = whole_s - img_s2 433 | end_s = whole_s 434 | single_coordinate['init_h'] = init_h 435 | single_coordinate['end_h'] = end_h 436 | single_coordinate['init_w'] = init_w 437 | single_coordinate['end_w'] = end_w 438 | single_coordinate['init_s'] = init_s 439 | single_coordinate['end_s'] = end_s 440 | 441 | if y == 0: 442 | single_coordinate['stack_start_w'] = y*gap_w 443 | single_coordinate['stack_end_w'] = y*gap_w+img_w-cut_w 444 | single_coordinate['patch_start_w'] = 0 445 | single_coordinate['patch_end_w'] = img_w-cut_w 446 | elif y == num_w-1: 447 | single_coordinate['stack_start_w'] = whole_w-img_w+cut_w 448 | single_coordinate['stack_end_w'] = whole_w 449 | single_coordinate['patch_start_w'] = cut_w 450 | single_coordinate['patch_end_w'] = img_w 451 | else: 452 | single_coordinate['stack_start_w'] = y*gap_w+cut_w 453 | single_coordinate['stack_end_w'] = y*gap_w+img_w-cut_w 454 | single_coordinate['patch_start_w'] = cut_w 455 | single_coordinate['patch_end_w'] = img_w-cut_w 456 | 457 | if x == 0: 458 | single_coordinate['stack_start_h'] = x*gap_h 459 | single_coordinate['stack_end_h'] = x*gap_h+img_h-cut_h 460 | single_coordinate['patch_start_h'] = 0 461 | single_coordinate['patch_end_h'] = img_h-cut_h 462 | elif x == num_h-1: 463 | single_coordinate['stack_start_h'] = whole_h-img_h+cut_h 464 | single_coordinate['stack_end_h'] = whole_h 465 | single_coordinate['patch_start_h'] = cut_h 466 | single_coordinate['patch_end_h'] = img_h 467 | else: 468 | single_coordinate['stack_start_h'] = x*gap_h+cut_h 469 | single_coordinate['stack_end_h'] = x*gap_h+img_h-cut_h 470 | single_coordinate['patch_start_h'] = cut_h 471 | single_coordinate['patch_end_h'] = img_h-cut_h 472 | 473 | if z == 0: 474 | single_coordinate['stack_start_s'] = z*gap_s2 475 | single_coordinate['stack_end_s'] = z*gap_s2+img_s2-cut_s 476 | single_coordinate['patch_start_s'] = 0 477 | single_coordinate['patch_end_s'] = img_s2-cut_s 478 | elif z == num_s-1: 479 | single_coordinate['stack_start_s'] = whole_s-img_s2+cut_s 480 | single_coordinate['stack_end_s'] = whole_s 481 | single_coordinate['patch_start_s'] = cut_s 482 | single_coordinate['patch_end_s'] = img_s2 483 | else: 484 | single_coordinate['stack_start_s'] = z*gap_s2+cut_s 485 | single_coordinate['stack_end_s'] = z*gap_s2+img_s2-cut_s 486 | single_coordinate['patch_start_s'] = cut_s 487 | single_coordinate['patch_end_s'] = img_s2-cut_s 488 | 489 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w] 490 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z) 491 | # train_raw.append(noise_patch1.transpose(1,2,0)) 492 | name_list.append(patch_name) 493 | # print(' single_coordinate -----> ',single_coordinate) 494 | coordinate_list[patch_name] = single_coordinate 495 | return name_list, noise_im, coordinate_list 496 | 497 | 498 | # stack_start_w ,stack_end_w ,patch_start_w ,patch_end_w , 499 | # stack_start_h ,stack_end_h ,patch_start_h ,patch_end_h , 500 | # stack_start_s ,stack_end_s ,patch_start_s ,patch_end_s 501 | 502 | --------------------------------------------------------------------------------