├── DeepCAD_pytorch
├── results
│ └── results.txt
├── pth
│ ├── ModelForPytorch
│ │ └── DownloadedModel
│ └── README.md
├── datasets
│ ├── DataForPytorch
│ │ └── DownloadedData
│ └── README.md
├── network.py
├── script.py
├── utils.py
├── train.py
├── test.py
├── buildingblocks.py
├── model_3DUnet.py
└── data_process.py
├── DeepCAD_Fiji
├── DeepCAD_java
│ ├── DeepCAD_java.txt
│ └── README.md
├── DeepCAD_tensorflow
│ ├── datasets
│ │ └── datasets.txt
│ ├── results
│ │ └── results.txt
│ ├── DeepCAD_model
│ │ └── DeepCAD_model.txt
│ ├── script.py
│ ├── README.md
│ ├── utils.py
│ ├── basic_ops.py
│ ├── network.py
│ ├── test_pb.py
│ ├── main.py
│ └── data_process.py
├── DeepCAD_Fiji_plugin
│ ├── DeepCAD-0.3.0.jar
│ ├── DeepCAD-0.3.6.jar
│ └── README.md
└── README.md
├── images
├── fiji.png
├── logo.PNG
├── soma.png
├── dendrite.png
├── parameter.png
├── schematic.png
└── cross-system.png
├── dataset
└── README.md
├── README.md
└── LICENSE
/DeepCAD_pytorch/results/results.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_java/DeepCAD_java.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/datasets/datasets.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/results/results.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/pth/ModelForPytorch/DownloadedModel:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/datasets/DataForPytorch/DownloadedData:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/DeepCAD_model/DeepCAD_model.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_java/README.md:
--------------------------------------------------------------------------------
1 | # Java source code of DeepCAD Fiji plugin
2 |
--------------------------------------------------------------------------------
/images/fiji.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/fiji.png
--------------------------------------------------------------------------------
/images/logo.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/logo.PNG
--------------------------------------------------------------------------------
/images/soma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/soma.png
--------------------------------------------------------------------------------
/images/dendrite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/dendrite.png
--------------------------------------------------------------------------------
/images/parameter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/parameter.png
--------------------------------------------------------------------------------
/images/schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/schematic.png
--------------------------------------------------------------------------------
/images/cross-system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/images/cross-system.png
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.0.jar
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cabooster/DeepCAD/HEAD/DeepCAD_Fiji/DeepCAD_Fiji_plugin/DeepCAD-0.3.6.jar
--------------------------------------------------------------------------------
/DeepCAD_pytorch/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Data download guide
2 | Download the demo data(.tif file) [[DataForPytorch](https://drive.google.com/drive/folders/1w9v1SrEkmvZal5LH79HloHhz6VXSPfI_)] and put it into the *./DataForPytorch* folder.
3 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/pth/README.md:
--------------------------------------------------------------------------------
1 | # Model download guide
2 | Download the pre-trained model(.pth file and .yaml file) [[ModelForPytorch](https://drive.google.com/drive/folders/12LEFsAopTolaRyRpJtFpzOYH3tBZMGUP)] and put it into the *./ModelForPytorch* folder.
3 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_Fiji_plugin/README.md:
--------------------------------------------------------------------------------
1 | ## Packaged plugin file (.jar) for Fiji
2 | ### Please download and install the latest version.
3 |
4 | ________________________________________________________________________________________________________________________
5 | ### Logs
6 | - 0.3.6 - 2020-11-17
7 |
8 | Fixed an error for language display.
9 |
10 | - 0.3.0 - 2020-11-17
11 |
12 | Initial relese.
13 |
14 |
15 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | # for train
5 | os.system('python main2.py --GPU 0 --img_h 64 --img_w 64 --img_s 320 --train_epochs 30 --datasets_folder DataForPytorch --normalize_factor 1 --lr 0.00005 --train_datasets_size 10000')
6 |
7 | # for test
8 | os.system('python test_pb2.py --GPU 3 --denoise_model pb_unet3d_10AMP_0.3_0001_20201108-2139 \
9 | --datasets_folder 10AMP_0.3_0001 --model_name 25_1000')
10 |
11 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/network.py:
--------------------------------------------------------------------------------
1 | from model_3DUnet import UNet3D
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch
5 |
6 | class Network_3D_Unet(nn.Module):
7 | def __init__(self, UNet_type = '3DUNet', in_channels=1, out_channels=1, final_sigmoid = True):
8 | super(Network_3D_Unet, self).__init__()
9 |
10 | self.in_channels = in_channels
11 | self.out_channels = out_channels
12 | self.final_sigmoid = final_sigmoid
13 |
14 | if UNet_type == '3DUNet':
15 | self.Generator = UNet3D( in_channels = in_channels,
16 | out_channels = out_channels,
17 | final_sigmoid = final_sigmoid)
18 |
19 | def forward(self, x):
20 | fake_x = self.Generator(x)
21 | return fake_x
22 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import sys
4 |
5 | flag = sys.argv[1]
6 |
7 | if flag == 'train':
8 | # for train
9 | os.system('python train.py --datasets_folder DataForPytorch --lr 0.00005 \
10 | --img_h 150 --img_w 150 --img_s 150 --gap_h 60 --gap_w 60 --gap_s 60 \
11 | --n_epochs 20 --GPU 0 --normalize_factor 1 \
12 | --train_datasets_size 1200 --select_img_num 10000')
13 |
14 | if flag == 'test':
15 | # for test
16 | os.system('python test.py --denoise_model ModelForPytorch \
17 | --datasets_folder DataForPytorch \
18 | --test_datasize 6000')
19 |
20 | if flag == 'all':
21 | # train and then test
22 | os.system('python train.py --datasets_folder DataForPytorch --lr 0.00005 \
23 | --img_h 150 --img_w 150 --img_s 150 --gap_h 60 --gap_w 60 --gap_s 60 \
24 | --n_epochs 20 --GPU 0 --normalize_factor 1 \
25 | --train_datasets_size 1200 --select_img_num 10000')
26 |
27 | os.system('python test.py --denoise_model ModelForPytorch \
28 | --datasets_folder DataForPytorch \
29 | --test_datasize 6000')
30 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/README.md:
--------------------------------------------------------------------------------
1 | ## Tensorflow implementation compatible with Fiji plugin
2 |
3 | ### Directory structure
4 | ```
5 | DeepCAD_tensorflow #Tensorflow implementation compatible with Fiji plugin#
6 | |---basic_ops.py
7 | |---train.py
8 | |---network.py
9 | |---script.py
10 | |---test_pb.py
11 | |---data_process.py
12 | |---datasets
13 | |---|---qwd_7 #project_name#
14 | |---|---|---train_raw.tif #raw data for train#
15 | |---DeepCAD_model
16 | |---|---qwd_7_20201115-0913
17 | |---|---|--- #pb model#
18 | |---results
19 | |---|---#Results of training process and final test#
20 | ```
21 |
22 | ### Environment
23 |
24 | * Ubuntu 16.04
25 | * Python 3.6
26 | * Tensorflow 1.4.0
27 | * NVIDIA GPU + CUDA
28 |
29 | ### Environment configuration
30 |
31 | Open the terminal of ubuntu system.
32 |
33 | * Create anaconda environment
34 |
35 | ```
36 | $ conda create -n tensorflow python=3.6
37 | ```
38 |
39 | * Install Tensorflow
40 |
41 | ```
42 | $ source activate tensorflow
43 | $ pip install tensorflow-gpu==1.4.0
44 | ```
45 |
46 | ### Training
47 |
48 | ```
49 | $ source activate tensorflow
50 | $ python main.py --GPU 0 --img_h 64 --img_w 64 --img_s 320 --train_epochs 30 --datasets_folder DataForPytorch --normalize_factor 1 --lr 0.00005 --train_datasets_size 1000
51 | ```
52 |
53 | Parameters can be modified as required.
54 |
55 | ```
56 | $ python main.py --GPU #GPU index# --img_h #stack height# --img_w #stack width# --img_s #stack length# --train_epochs #training epoch number# --datasets_folder #project name#
57 | ```
58 |
59 | The pre-trained model is saved at *DeepCAD_Fiji/DeepCAD_tensorflow/DeepCAD_model/*.
60 |
61 | #### Test
62 |
63 | Run the script.py (test part) to begin your test. Parameters saved in the .yaml file will be automatically loaded.
64 |
65 | ```
66 | $ source activate pytorch
67 | $ python test_pb.py --GPU 3 --denoise_model ModelForTestPlugin --datasets_folder DataForPytorch --model_name 25_1000 --test_datasize 500
68 | ```
69 |
70 | Parameters can be modified as required.
71 |
72 | ```
73 | $ os.system('python test.py --denoise_model #model name# --test_datasize #the number of images used for test#')
74 | ```
75 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/README.md:
--------------------------------------------------------------------------------
1 | # DeepCAD Fiji plugin
2 |
3 | To ameliorate the difficulty of using our deep self-supervised learning-based method, we developed a user-friendly Fiji plugin, which is easy to install and convenient to use (has been tested on a Windows desktop with Intel i9 CPU and 128G RAM). Researchers without expertise in computer science and machine learning can manage it in a very short time.
4 |
5 |
6 |
7 | ### Install Fiji plugin
8 | To avoid unnecessary troubles, the following steps are recommended for installation:
9 | 1. Download and install Fiji from the [[Fiji download page](https://imagej.net/Fiji/Downloads)]. Install the CSBDeep dependency following the steps at [[CSBDeep in Fiji – Installation](https://github.com/CSBDeep/CSBDeep_website/wiki/CSBDeep-in-Fiji-%E2%80%93-Installation)]
10 | 2. Download the packaged plugin file (.jar) from [[DeepCAD_Fiji/DeepCAD_Fiji_plugin](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji/DeepCAD_Fiji_plugin)].
11 | 3. Install the plugin via **Fiji > Plugins > Install**.
12 |
13 | We provide lightweight data for test. Please download the demo data (.tif) [[DataForTestPlugin](https://drive.google.com/drive/folders/1JVbuCwIxRKr4_NNOD7fY61NnVeCA2UnP)] and the pre-trained model (.zip) [[ModelForTestPlugin](https://drive.google.com/drive/folders/14wSuMFhWKxW5Oq93GHxTsGixpB3T4lOL)].
14 |
15 | ### Use Fiji plugin
16 |
17 | 1. Open Fiji.
18 | 2. Open the calcium imaging stack to be denoised.
19 | 3. Open the plugin at **Plugins > DeepCAD**. The six parameters will be shown on the panel (with default values and no changes are required unless necessary).
20 | 4. Specify the pre-trained model using the '*Browse*' button (select the .zip file).
21 | 5. Click ‘OK’ and the denoised result will be displayed in another window after processing (processing time depends on the data size).
22 |
23 |
24 |
25 | ### Train a customized model for your microscope
26 |
27 | Since imaging systems and experiment conditions varies, a customized DeepCAD model trained on specified data is recommended for optimal performance. A Tensorflow implementation of DeepCAD compatible with the plugin is made publicly accessible at *[DeepCAD_Fiji/DeepCAD_tensorflow](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji/DeepCAD_tensorflow)*.
28 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import os
4 | import shutil
5 | import sys
6 | import io
7 |
8 | import h5py
9 | from PIL import Image
10 | import numpy as np
11 | import scipy.sparse as sparse
12 | import matplotlib.pyplot as plt
13 | import uuid
14 | import warnings
15 | import pylab
16 | import cv2
17 | import yaml
18 |
19 | def save_yaml(opt, yaml_name):
20 | para = {}
21 | para["train_epochs"] = opt.train_epochs
22 | para["datasets_folder"] = opt.datasets_folder
23 | para["GPU"] = opt.GPU
24 | para["img_s"] = opt.img_s
25 | para["img_w"] = opt.img_w
26 | para["img_h"] = opt.img_h
27 | para["gap_h"] = opt.gap_h
28 | para["gap_w"] = opt.gap_w
29 | para["gap_s"] = opt.gap_s
30 | para["lr"] = opt.lr
31 | para["normalize_factor"] = opt.normalize_factor
32 | with open(yaml_name, 'w') as f:
33 | data = yaml.dump(para, f)
34 |
35 |
36 | def read_yaml(opt, yaml_name):
37 | with open(yaml_name) as f:
38 | para = yaml.load(f, Loader=yaml.FullLoader)
39 | print(para)
40 | opt.datasets_folder = para["datasets_folder"]
41 | opt.GPU = para["GPU"]
42 | opt.img_s = para["img_s"]
43 | opt.img_w = para["img_w"]
44 | opt.img_h = para["img_h"]
45 | # opt.gap_h = para["gap_h"]
46 | # opt.gap_w = para["gap_w"]
47 | # opt.gap_s = para["gap_s"]
48 | opt.normalize_factor = para["normalize_factor"]
49 |
50 | def name2index(opt, input_name, num_h, num_w, num_s):
51 | # print(input_name)
52 | name_list = input_name.split('_')
53 | # print(name_list)
54 | z_part = name_list[-1]
55 | # print(z_part)
56 | y_part = name_list[-2]
57 | # print(y_part)
58 | x_part = name_list[-3]
59 | # print(x_part)
60 | z_index = int(z_part.replace('z',''))
61 | y_index = int(y_part.replace('y',''))
62 | x_index = int(x_part.replace('x',''))
63 | # print("x_index ---> ",x_index,"y_index ---> ", y_index,"z_index ---> ", z_index)
64 |
65 | cut_w = (opt.img_w - opt.gap_w)/2
66 | cut_h = (opt.img_h - opt.gap_h)/2
67 | cut_s = (opt.img_s - opt.gap_s)/2
68 | # print("z_index ---> ",cut_w, "cut_h ---> ",cut_h, "cut_s ---> ",cut_s)
69 | if x_index == 0:
70 | stack_start_w = x_index*opt.gap_w
71 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w
72 | patch_start_w = 0
73 | patch_end_w = opt.img_w-cut_w
74 | elif x_index == num_w-1:
75 | stack_start_w = x_index*opt.gap_w+cut_w
76 | stack_end_w = x_index*opt.gap_w+opt.img_w
77 | patch_start_w = cut_w
78 | patch_end_w = opt.img_w
79 | else:
80 | stack_start_w = x_index*opt.gap_w+cut_w
81 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w
82 | patch_start_w = cut_w
83 | patch_end_w = opt.img_w-cut_w
84 |
85 | if y_index == 0:
86 | stack_start_h = y_index*opt.gap_h
87 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h
88 | patch_start_h = 0
89 | patch_end_h = opt.img_h-cut_h
90 | elif y_index == num_h-1:
91 | stack_start_h = y_index*opt.gap_h+cut_h
92 | stack_end_h = y_index*opt.gap_h+opt.img_h
93 | patch_start_h = cut_h
94 | patch_end_h = opt.img_h
95 | else:
96 | stack_start_h = y_index*opt.gap_h+cut_h
97 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h
98 | patch_start_h = cut_h
99 | patch_end_h = opt.img_h-cut_h
100 |
101 | if z_index == 0:
102 | stack_start_s = z_index*opt.gap_s
103 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s
104 | patch_start_s = 0
105 | patch_end_s = opt.img_s-cut_s
106 | elif z_index == num_s-1:
107 | stack_start_s = z_index*opt.gap_s+cut_s
108 | stack_end_s = z_index*opt.gap_s+opt.img_s
109 | patch_start_s = cut_s
110 | patch_end_s = opt.img_s
111 | else:
112 | stack_start_s = z_index*opt.gap_s+cut_s
113 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s
114 | patch_start_s = cut_s
115 | patch_end_s = opt.img_s-cut_s
116 | return int(stack_start_w) ,int(stack_end_w) ,int(patch_start_w) ,int(patch_end_w) ,\
117 | int(stack_start_h) ,int(stack_end_h) ,int(patch_start_h) ,int(patch_end_h), \
118 | int(stack_start_s) ,int(stack_end_s) ,int(patch_start_s) ,int(patch_end_s)
119 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/basic_ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from keras.layers import UpSampling3D
3 | import numpy as np
4 | """This script defines basic operations.
5 | """
6 | def Pool3d(inputs, name):
7 | layer = tf.layers.max_pooling3d(inputs=inputs,
8 | pool_size=2,
9 | strides=2,
10 | name=name,
11 | padding='same')
12 | print(name,'.get_shape -----> ',str(layer.get_shape()))
13 | return layer
14 |
15 |
16 | def Deconv3D(inputs, filters, name):
17 | layer = tf.layers.conv3d_transpose(inputs=inputs,
18 | filters=filters,
19 | kernel_size=2,
20 | strides=2,
21 | padding='same',
22 | use_bias=True,
23 | name=name,
24 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.05, seed=1))
25 | print(name,'.get_shape -----> ',str(layer.get_shape()))
26 | return layer
27 |
28 |
29 | def Deconv3D_upsample(inputs, filters, name):
30 | print(name,'.inputs -----> ',str(inputs.get_shape()))
31 | kernel = tf.constant(1.0, shape=[2,2,2,filters,filters])
32 | inputs_shape = inputs.get_shape().as_list()
33 | layer = tf.nn.conv3d_transpose(value=inputs, filter=kernel,
34 | output_shape=[inputs_shape[0], inputs_shape[1]*2, inputs_shape[2]*2, inputs_shape[3]*2, filters],
35 | strides=[1, 2, 2, 2, 1],
36 | padding="SAME")
37 | print(name,'.get_shape -----> ',str(layer.get_shape()))
38 | return layer
39 |
40 |
41 | def Conv3D(inputs, filters, name):
42 | inputs_shape = inputs.get_shape().as_list()
43 | fan_in =3*3*inputs_shape[-1]*filters
44 | std = np.sqrt(2) / np.sqrt(fan_in)
45 | layer = tf.layers.conv3d(inputs=inputs,
46 | filters=filters,
47 | kernel_size=3,
48 | strides=1,
49 | padding='same',
50 | use_bias=True,
51 | name=name,
52 | kernel_initializer=tf.truncated_normal_initializer(stddev=std))
53 | print(name,'.get_shape -----> ',str(layer.get_shape()),' std -----> ',std)
54 | return layer
55 |
56 | def ReLU(inputs, name):
57 | layer = tf.nn.relu(inputs, name)
58 | print(name,'.get_shape -----> ',str(layer.get_shape()))
59 | return layer
60 |
61 | def leak_ReLU(inputs, name):
62 | layer = tf.nn.leaky_relu(inputs, alpha=0.1, name=name) #tf.nn.relu(inputs, name)
63 | print(name,'.get_shape -----> ',str(layer.get_shape()))
64 | return layer
65 |
66 | def Deconv3D_keras(inputs):
67 | layer = UpSampling3D(size=2)(inputs)
68 | print('.get_shape -----> ',str(layer.get_shape()))
69 | return layer
70 |
71 | def up_conv3d(input, conv_filter_size, num_input_channels, num_filters, feature_map_size, feature_map_len, train=True, padding='SAME',relu=True):
72 | # num_input_channels
73 | # num_filters
74 | # feature_map_size
75 | # feature_map_len
76 | weights = create_weights(shape=[conv_filter_size, conv_filter_size, conv_filter_size, num_filters, num_input_channels])
77 | biases = create_biases(num_filters)
78 | if train:
79 | batch_size_0 = 1 #batch_size
80 | else:
81 | batch_size_0 = 1
82 | layer = tf.nn.conv3d_transpose(value=input, filter=weights,
83 | output_shape=[batch_size_0, feature_map_size, feature_map_size, feature_map_len, num_filters],
84 | strides=[1, 2, 2, 2, 1],
85 | padding=padding)
86 | layer += biases
87 | if relu:
88 | layer = tf.nn.relu(layer)
89 | return layer
90 |
91 |
92 |
93 | def BN_ReLU(inputs, training, name):
94 | """Performs a batch normalization followed by a ReLU6."""
95 | inputs = tf.layers.batch_normalization(inputs=inputs,
96 | axis=-1,
97 | momentum=0.997,
98 | epsilon=1e-5,
99 | center=True,
100 | scale=True,
101 | training=training,
102 | fused=True)
103 | return tf.nn.relu(inputs)
104 |
105 | def GN_ReLU(inputs, name):
106 | """Performs a batch normalization followed by a ReLU6."""
107 | inputs = group_norm(inputs, name=name)
108 | return tf.nn.relu(inputs)
109 |
110 | def GN_leakReLU(inputs, name):
111 | """Performs a batch normalization followed by a ReLU6."""
112 | # inputs = group_norm(inputs, name=name)
113 | inputs = tf.nn.relu(inputs) #tf.nn.relu(inputs, name)
114 | layer = group_norm(inputs, name=name)
115 | return layer
116 |
117 | def group_norm(x, name, G=8, eps=1e-5, scope='group_norm') :
118 | with tf.variable_scope(scope+name, reuse=tf.AUTO_REUSE) :
119 | N, H, W, S, C = x.get_shape().as_list()
120 | G = min(G, C)
121 | # [N, H, W, G, C // G]
122 | x = tf.reshape(x, [N, H, W, S, G, C // G])
123 | mean, var = tf.nn.moments(x, [1, 2, 3, 5], keep_dims=True)
124 | x = (x - mean) / tf.sqrt(var + eps)
125 |
126 | # print(' -----> GroupNorm ',x.get_shape())
127 | gamma = tf.get_variable('gamma_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(1))
128 | beta = tf.get_variable('beta_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(0))
129 | x = tf.reshape(x, [N, H, W, S, C]) * gamma + beta
130 | return x
--------------------------------------------------------------------------------
/DeepCAD_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import os
4 | import shutil
5 | import sys
6 | import io
7 |
8 | import h5py
9 | from PIL import Image
10 | import numpy as np
11 | import scipy.sparse as sparse
12 | import torch
13 | import matplotlib.pyplot as plt
14 | import uuid
15 | from sklearn.decomposition import PCA
16 | import warnings
17 | import pylab
18 | import cv2
19 | import yaml
20 | ########################################################################################################################
21 | plt.ioff()
22 | plt.switch_backend('agg')
23 |
24 | ########################################################################################################################
25 | def create_feature_maps(init_channel_number, number_of_fmaps):
26 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)]
27 |
28 | def save_yaml(opt, yaml_name):
29 | para = {'epoch':0,
30 | 'n_epochs':0,
31 | 'datasets_folder':0,
32 | 'GPU':0,
33 | 'output_dir':0,
34 | 'batch_size':0,
35 | 'img_s':0,
36 | 'img_w':0,
37 | 'img_h':0,
38 | 'gap_h':0,
39 | 'gap_w':0,
40 | 'gap_s':0,
41 | 'lr':0,
42 | 'b1':0,
43 | 'b2':0,
44 | 'normalize_factor':0}
45 | para["epoch"] = opt.epoch
46 | para["n_epochs"] = opt.n_epochs
47 | para["datasets_folder"] = opt.datasets_folder
48 | para["GPU"] = opt.GPU
49 | para["output_dir"] = opt.output_dir
50 | para["batch_size"] = opt.batch_size
51 | para["img_s"] = opt.img_s
52 | para["img_w"] = opt.img_w
53 | para["img_h"] = opt.img_h
54 | para["gap_h"] = opt.gap_h
55 | para["gap_w"] = opt.gap_w
56 | para["gap_s"] = opt.gap_s
57 | para["lr"] = opt.lr
58 | para["b1"] = opt.b1
59 | para["b2"] = opt.b2
60 | para["normalize_factor"] = opt.normalize_factor
61 | para["datasets_path"] = opt.datasets_path
62 | para["train_datasets_size"] = opt.train_datasets_size
63 | with open(yaml_name, 'w') as f:
64 | data = yaml.dump(para, f)
65 |
66 |
67 | def read_yaml(opt, yaml_name):
68 | with open(yaml_name) as f:
69 | para = yaml.load(f, Loader=yaml.FullLoader)
70 | print(para)
71 | opt.epoch = para["epoch"]
72 | opt.n_epochspara = ["n_epochs"]
73 | # opt.datasets_folder = para["datasets_folder"]
74 | opt.output_dir = para["output_dir"]
75 | opt.batch_size = para["batch_size"]
76 | # opt.img_s = para["img_s"]
77 | # opt.img_w = para["img_w"]
78 | # opt.img_h = para["img_h"]
79 | # opt.gap_h = para["gap_h"]
80 | # opt.gap_w = para["gap_w"]
81 | # opt.gap_s = para["gap_s"]
82 | opt.lr = para["lr"]
83 | opt.b1 = para["b1"]
84 | para["b2"] = opt.b2
85 | para["normalize_factor"] = opt.normalize_factor
86 |
87 |
88 | def name2index(opt, input_name, num_h, num_w, num_s):
89 | # print(input_name)
90 | name_list = input_name.split('_')
91 | # print(name_list)
92 | z_part = name_list[-1]
93 | # print(z_part)
94 | y_part = name_list[-2]
95 | # print(y_part)
96 | x_part = name_list[-3]
97 | # print(x_part)
98 | z_index = int(z_part.replace('z',''))
99 | y_index = int(y_part.replace('y',''))
100 | x_index = int(x_part.replace('x',''))
101 | # print("x_index ---> ",x_index,"y_index ---> ", y_index,"z_index ---> ", z_index)
102 |
103 | cut_w = (opt.img_w - opt.gap_w)/2
104 | cut_h = (opt.img_h - opt.gap_h)/2
105 | cut_s = (opt.img_s - opt.gap_s)/2
106 | # print("z_index ---> ",cut_w, "cut_h ---> ",cut_h, "cut_s ---> ",cut_s)
107 | if x_index == 0:
108 | stack_start_w = x_index*opt.gap_w
109 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w
110 | patch_start_w = 0
111 | patch_end_w = opt.img_w-cut_w
112 | elif x_index == num_w-1:
113 | stack_start_w = x_index*opt.gap_w+cut_w
114 | stack_end_w = x_index*opt.gap_w+opt.img_w
115 | patch_start_w = cut_w
116 | patch_end_w = opt.img_w
117 | else:
118 | stack_start_w = x_index*opt.gap_w+cut_w
119 | stack_end_w = x_index*opt.gap_w+opt.img_w-cut_w
120 | patch_start_w = cut_w
121 | patch_end_w = opt.img_w-cut_w
122 |
123 | if y_index == 0:
124 | stack_start_h = y_index*opt.gap_h
125 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h
126 | patch_start_h = 0
127 | patch_end_h = opt.img_h-cut_h
128 | elif y_index == num_h-1:
129 | stack_start_h = y_index*opt.gap_h+cut_h
130 | stack_end_h = y_index*opt.gap_h+opt.img_h
131 | patch_start_h = cut_h
132 | patch_end_h = opt.img_h
133 | else:
134 | stack_start_h = y_index*opt.gap_h+cut_h
135 | stack_end_h = y_index*opt.gap_h+opt.img_h-cut_h
136 | patch_start_h = cut_h
137 | patch_end_h = opt.img_h-cut_h
138 |
139 | if z_index == 0:
140 | stack_start_s = z_index*opt.gap_s
141 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s
142 | patch_start_s = 0
143 | patch_end_s = opt.img_s-cut_s
144 | elif z_index == num_s-1:
145 | stack_start_s = z_index*opt.gap_s+cut_s
146 | stack_end_s = z_index*opt.gap_s+opt.img_s
147 | patch_start_s = cut_s
148 | patch_end_s = opt.img_s
149 | else:
150 | stack_start_s = z_index*opt.gap_s+cut_s
151 | stack_end_s = z_index*opt.gap_s+opt.img_s-cut_s
152 | patch_start_s = cut_s
153 | patch_end_s = opt.img_s-cut_s
154 | return int(stack_start_w) ,int(stack_end_w) ,int(patch_start_w) ,int(patch_end_w) ,\
155 | int(stack_start_h) ,int(stack_end_h) ,int(patch_start_h) ,int(patch_end_h), \
156 | int(stack_start_s) ,int(stack_end_s) ,int(patch_start_s) ,int(patch_end_s)
157 |
158 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | # Data download guide
2 |
3 | The data used for training and validation of DeepCAD are made publicly available here. These data were captured by our customized two-photon microscope with two strictly synchronized detection path. The signal intensity of the high-SNR path is 10-fold higher than that of the low-SNR path. We provided 14 groups of recordings with various imaging depths, excitation power, and cell structures. All data are listed in the table below. You can download these data directly by clicking the `hyperlinks` appended in the 'Power' column (Warning: ~4.5 GB each).
4 |
5 | |No. |FOV (V×H)a |Frame rate | Imaging depthb |Powerc|AMPd|Structures |
6 | |:----:| ---- |:----: | :----: | :----: |:----: | :----: |
7 | |1 |550×575 μm |30 Hz |80 μm |[66](https://zenodo.org/record/8079069/files/1_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/1_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_0.3power_1AMP.zip?download=1) mW |1 |Dendrite |
8 | |2 |550×575 μm |30 Hz |100 μm |[66](https://zenodo.org/record/8079069/files/2_ZOOM1.3_550Vx575H_FOV_30Hz_100umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/2_ZOOM1.3_550Vx575H_FOV_30Hz_100umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma and dendrite|
9 | |3 |550×575 μm |30 Hz |120 μm |[66](https://zenodo.org/record/8079069/files/3_ZOOM1.3_550Vx575H_FOV_30Hz_120umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/3_ZOOM1.3_550Vx575H_FOV_30Hz_120umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma |
10 | |4 |550×575 μm |30 Hz |140 μm |[66](https://zenodo.org/record/8079069/files/4_ZOOM1.3_550Vx575H_FOV_30Hz_140umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/4_ZOOM1.3_550Vx575H_FOV_30Hz_140umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma |
11 | |5 |550×575 μm |30 Hz |160 μm |[66](https://zenodo.org/record/8079069/files/5_ZOOM1.3_550Vx575H_FOV_30Hz_160umdepth_0.2power_1AMP.zip?download=1)/[99](https://zenodo.org/record/8079069/files/5_ZOOM1.3_550Vx575H_FOV_30Hz_160umdepth_0.3power_1AMP.zip?download=1) mW |1 |Soma |
12 | |6 |550×575 μm |30 Hz |180 μm |[99](https://zenodo.org/record/8080615/files/6_ZOOM1.3_550Vx575H_FOV_30Hz_180umdepth_0.3power_1AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/6_ZOOM1.3_550Vx575H_FOV_30Hz_180umdepth_0.4power_1AMP.zip?download=1) mW |1 |Soma |
13 | |7 |550×575 μm |30 Hz |200 μm |[99](https://zenodo.org/record/8080615/files/7_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/7_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.4power_10AMP.zip?download=1) mW |10 |Soma |
14 | |8 |550×575 μm |30 Hz |200 μm |[99](https://zenodo.org/record/8080615/files/8_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8080615/files/8_ZOOM1.3_550Vx575H_FOV_30Hz_200umdepth_0.4power_10AMP.zip?download=1) mW |10 |Soma |
15 | |9 |550×575 μm |30 Hz |150 μm |[66](https://zenodo.org/record/8080615/files/9_ZOOM1.3_550Vx575H_FOV_30Hz_150umdepth_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8080615/files/9_ZOOM1.3_550Vx575H_FOV_30Hz_150umdepth_0.3power_10AMP.zip?download=1) mW |10 |Soma |
16 | |10 |550×575 μm |30 Hz |170 μm |[99](https://zenodo.org/record/8080615/files/10_ZOOM1.3_550Vx575H_FOV_30Hz_170umdepth_0.3power_10AMP.zip?download=1) mW |10 |Soma |
17 | |11 |550×575 μm |30 Hz |80 μm |[66](https://zenodo.org/record/8079117/files/11_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_dendrite_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8079117/files/11_ZOOM1.3_550Vx575H_FOV_30Hz_80umdepth_dendrite_0.3power_10AMP.zip?download=1) mW |10 |Dendrite |
18 | |12 |550×575 μm |30 Hz |110 μm |[66](https://zenodo.org/record/8079117/files/12_ZOOM1.3_550Vx575H_FOV_30Hz_110umdepth_somadendrite_0.2power_10AMP.zip?download=1)/[99](https://zenodo.org/record/8079117/files/12_ZOOM1.3_550Vx575H_FOV_30Hz_110umdepth_somadendrite_0.3power_10AMP.zip?download=1) mW |10 |Soma and dendrite|
19 | |13 |550×575 μm |30 Hz |185 μm |[99](https://zenodo.org/record/8079117/files/13_ZOOM1.3_550Vx575H_FOV_30Hz_185umdepth_soma_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8079117/files/13_ZOOM1.3_550Vx575H_FOV_30Hz_185umdepth_soma_0.4power_10AMP.zip?download=1) mW |10 |Soma |
20 | |14 |550×575 μm |30 Hz |210 μm |[99](https://zenodo.org/record/8079117/files/14_ZOOM1.3_550Vx575H_FOV_30Hz_210umdepth_soma_0.3power_10AMP.zip?download=1)/[132](https://zenodo.org/record/8079117/files/14_ZOOM1.3_550Vx575H_FOV_30Hz_210umdepth_soma_0.4power_10AMP.zip?download=1) mW |10 |Soma |
21 | ```
22 | a. FOV: field-of-view; V: vertical length; H: horizontal length.
23 | b. Depth: imaging depth below the pia mater.
24 | c. Two different excitation powers were used in each experiment for data diversity.
25 | d. AMP: the amplifier gain of the two PMTs.
26 | ```
27 |
28 | ## Citing the data
29 | If you use this data, please cite our paper:
30 |
31 | Li, X., Zhang, G., Wu, J. et al. Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising. Nat Methods (2021). [https://doi.org/10.1038/s41592-021-01225-0](https://www.nature.com/articles/s41592-021-01225-0)
32 |
33 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def get_weight(shape, gain=np.sqrt(2)):
6 | fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
7 | std = gain / np.sqrt(fan_in) # He init
8 | w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))
9 | return w
10 |
11 |
12 | def apply_bias(x):
13 | b = tf.get_variable('bias', shape=[x.shape[4]], initializer=tf.initializers.zeros())
14 | b = tf.cast(b, x.dtype)
15 | if len(x.shape) == 2:
16 | return x + b
17 | return x + tf.reshape(b, [1, 1, 1, 1, -1])
18 |
19 | def conv3d_bias(x, fmaps, kernel, gain=np.sqrt(2)):
20 | assert kernel >= 1 and kernel % 2 == 1
21 | w = get_weight([kernel, kernel, kernel, x.shape[4].value, fmaps], gain=gain)
22 | w = tf.cast(w, x.dtype)
23 | return apply_bias(tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NDHWC'))
24 |
25 | def conv3d(x, fmaps, kernel, gain=np.sqrt(2)):
26 | assert kernel >= 1 and kernel % 2 == 1
27 | w = get_weight([kernel, kernel, kernel, x.shape[4].value, fmaps], gain=gain)
28 | w = tf.cast(w, x.dtype)
29 | return tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NDHWC')
30 |
31 | def maxpool3d(x, k=2):
32 | ksize = [1, k, k, k, 1]
33 | return tf.nn.max_pool3d(x, ksize=ksize, strides=ksize, padding='SAME', data_format='NDHWC')
34 |
35 | def upscale3d(x, factor=2):
36 | assert isinstance(factor, int) and factor >= 1
37 | if factor == 1: return x
38 | with tf.variable_scope('Upscale3D'):
39 | s = x.shape
40 | x = tf.reshape(x, [-1, s[1], 1, s[2], 1, s[3], 1, s[4]])
41 | x = tf.tile(x, [1, 1, factor, 1, factor, 1, factor, 1])
42 | x = tf.reshape(x, [-1, s[1] * factor, s[2] * factor, s[3] * factor, s[4]])
43 | return x
44 |
45 | def conv_lr_bias(name, x, fmaps):
46 | with tf.variable_scope(name):
47 | return tf.nn.leaky_relu(conv3d_bias(x, fmaps, 3), alpha=0.1)
48 |
49 | def conv_r_bias(name, x, fmaps):
50 | with tf.variable_scope(name):
51 | return tf.nn.relu(conv3d_bias(x, fmaps, 3))
52 |
53 | def conv_r(name, x, fmaps):
54 | with tf.variable_scope(name):
55 | return tf.nn.relu(conv3d(x, fmaps, 3))
56 |
57 | def final_conv(name, x, fmaps, gain):
58 | with tf.variable_scope(name):
59 | return apply_bias(conv3d(x, fmaps, 1, gain))
60 | '''
61 | def output_block_layer(inputs):
62 | w = tf.Variable(tf.truncated_normal([1, 1, 1, 64, 1], stddev=0.01), name='end_con3d')
63 | x = tf.nn.conv3d(input=inputs,
64 | filter=w,
65 | strides=[1,1,1,1,1],
66 | padding='SAME')
67 | b = tf.get_variable('bias', shape=[x.shape[4]], initializer=tf.initializers.zeros())
68 | output = tf.Variable(tf.ones(shape=x.shape), name='output')
69 | output = x + b
70 | print('output -----> ',output.get_shape())
71 | return output
72 | '''
73 | def output_block_layer(inputs):
74 | w = tf.Variable(tf.truncated_normal([1, 1, 1, 64, 1], stddev=0.01), name='end_con3d')
75 | output = tf.nn.conv3d(input=inputs,
76 | filter=w,
77 | strides=[1,1,1,1,1],
78 | padding='SAME',
79 | name='output')
80 | return output
81 |
82 | def group_norm(x, name, G=8, eps=1e-5, scope='group_norm') :
83 | with tf.variable_scope(scope+name, reuse=tf.AUTO_REUSE) :
84 | N, H, W, S, C = x.get_shape().as_list()
85 | G = min(G, C)
86 | # [N, H, W, G, C // G]
87 | x = tf.reshape(x, [N, H, W, S, G, C // G])
88 | mean, var = tf.nn.moments(x, [1, 2, 3, 5], keep_dims=True)
89 | x = (x - mean) / tf.sqrt(var + eps)
90 |
91 | gamma = tf.get_variable('gamma_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(1))
92 | beta = tf.get_variable('beta_'+name, [1, 1, 1, 1, C], initializer=tf.constant_initializer(0))
93 | x = tf.reshape(x, [N, H, W, S, C]) * gamma + beta
94 | return x
95 | ####################################################################################################################
96 | def autoencoder(x, width=256, height=256, length=256, **_kwargs):
97 | x.set_shape([1, height, width, length, 1])
98 |
99 | skips = [x]
100 |
101 | n = x
102 | n = conv_r('enc_conv0', n, 32)
103 | n = group_norm(n, 'gn_enc_conv0')
104 | n = conv_r('enc_conv0b', n, 64)
105 | n = group_norm(n, 'gn_enc_conv0b')
106 | skips.append(n)
107 |
108 | n = maxpool3d(n)
109 | n = conv_r('enc_conv1', n, 64)
110 | n = group_norm(n, 'gn_enc_conv1')
111 | n = conv_r('enc_conv1b', n, 128)
112 | n = group_norm(n, 'gn_enc_conv1b')
113 | skips.append(n)
114 |
115 | n = maxpool3d(n)
116 | n = conv_r('enc_conv2', n, 128)
117 | n = group_norm(n, 'gn_enc_conv2')
118 | n = conv_r('enc_conv2b', n, 256)
119 | n = group_norm(n, 'gn_enc_conv2b')
120 | skips.append(n)
121 |
122 | n = maxpool3d(n)
123 | n = conv_r('enc_conv3', n, 256)
124 | n = group_norm(n, 'gn_enc_conv3')
125 | n = conv_r('enc_conv3b', n, 512)
126 | n = group_norm(n, 'gn_enc_conv3b')
127 | #-----------------------------------------------
128 | n = upscale3d(n)
129 | # print('upscale1 -----> ',str(n.get_shape()))
130 | n = tf.concat([n, skips.pop()], axis=-1)
131 | # print('upscale1 -----> ',str(n.get_shape()))
132 | n = conv_r('dec_conv4', n, 256)
133 | n = group_norm(n, 'gn_dec_conv4')
134 | n = conv_r('dec_conv4b', n, 256)
135 | n = group_norm(n, 'gn_dec_conv4b')
136 |
137 | n = upscale3d(n)
138 | # print('upscale2 -----> ',str(n.get_shape()))
139 | n = tf.concat([n, skips.pop()], axis=-1)
140 | # print('upscale2 -----> ',str(n.get_shape()))
141 | n = conv_r('dec_conv3', n, 128)
142 | n = group_norm(n, 'gn_dec_conv3')
143 | n = conv_r('dec_conv3b', n, 128)
144 | n = group_norm(n, 'gn_dec_conv3b')
145 |
146 | n = upscale3d(n)
147 | # print('upscale3 -----> ',str(n.get_shape()))
148 | n = tf.concat([n, skips.pop()], axis=-1)
149 | # print('upscale3 -----> ',str(n.get_shape()))
150 | n = conv_r('dec_conv2', n, 64)
151 | n = group_norm(n, 'gn_dec_conv2')
152 | n = conv_r('dec_conv2b', n, 64)
153 | n = group_norm(n, 'gn_dec_conv2b')
154 |
155 | #output = final_conv('final_conv', n, 1, gain=1.0)
156 | output = output_block_layer(n)
157 | return output
158 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/test_pb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.models import load_model
3 | import argparse
4 | import os
5 | import tifffile as tiff
6 | import time
7 | import datetime
8 | import random
9 | from skimage import io
10 | import tensorflow as tf
11 | import logging
12 | import time
13 | from data_process import test_preprocess_lessMemory
14 | from utils import read_yaml, name2index
15 | import math
16 |
17 |
18 | def main(args):
19 | # train_3d_new(args)
20 | test(args)
21 |
22 | def test(args):
23 | tf.reset_default_graph()
24 | model_path = args.CBSDeep_model_folder+'//'+args.denoise_model
25 | # print(list(os.walk(model_path, topdown=False))[-1])
26 | # print(list(os.walk(model_path, topdown=False))[-1][-1][0])
27 | # print(list(os.walk(model_path, topdown=False))[-1][-2])
28 | model_list = list(os.walk(model_path, topdown=False))[-1][-2]
29 | yaml_name = list(os.walk(model_path, topdown=False))[-1][-1][0]
30 | print(yaml_name)
31 | read_yaml(args, model_path+'//'+yaml_name)
32 | print('hhhhh ----->',args)
33 |
34 | name_list, noise_img, coordinate_list = test_preprocess_lessMemory(args)
35 | num_h = (math.floor((noise_img.shape[1]-args.img_h)/args.gap_h)+1)
36 | num_w = (math.floor((noise_img.shape[2]-args.img_w)/args.gap_w)+1)
37 | num_s = (math.floor((noise_img.shape[0]-args.img_s)/args.gap_s)+1)
38 |
39 | TIME = args.datasets_folder+'_'+args.denoise_model #+'_'+args.model_name #+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M")
40 | results_path = args.results_folder+'//'+'unet3d_'+TIME+'//'
41 | if not os.path.exists(args.results_folder):
42 | os.mkdir(args.results_folder)
43 | if not os.path.exists(results_path):
44 | os.mkdir(results_path)
45 |
46 | model_name = args.model_name
47 | print('model_name -----> ',model_name)
48 | output_path = results_path + '//' + model_name
49 | if not os.path.exists(output_path):
50 | os.mkdir(output_path)
51 |
52 | output_graph_path = args.CBSDeep_model_folder+'//'+args.denoise_model+'//'+model_name+'//'
53 | print('output_graph_path -----> ',output_graph_path)
54 | start_time=time.time()
55 | sess = tf.Session()
56 | with sess.as_default():
57 | # sess.run(tf.global_variables_initializer())
58 | meta_graph_def = tf.saved_model.loader.load(sess, ['3D_N2N'], output_graph_path)
59 | signature = meta_graph_def.signature_def
60 | in_tensor_name = signature['my_signature'].inputs['input0'].name
61 | out_tensor_name = signature['my_signature'].outputs['output0'].name
62 | input = sess.graph.get_tensor_by_name(in_tensor_name)
63 | output = sess.graph.get_tensor_by_name(out_tensor_name)
64 | # sess.run(tf.global_variables_initializer())
65 | '''
66 | variable_names = [v.name for v in tf.trainable_variables()]
67 | values = sess.run(variable_names)
68 | for k,v in zip(variable_names, values):
69 | if len(v.shape)==5:
70 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0][0][0][0][0])
71 | if len(v.shape)==1:
72 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0])
73 | '''
74 | denoise_img = np.zeros(noise_img.shape)
75 | input_img = np.zeros(noise_img.shape)
76 | for index in range(len(name_list)):
77 | input_name = name_list[index]
78 | single_coordinate = coordinate_list[name_list[index]]
79 | init_h = single_coordinate['init_h']
80 | end_h = single_coordinate['end_h']
81 | init_w = single_coordinate['init_w']
82 | end_w = single_coordinate['end_w']
83 | init_s = single_coordinate['init_s']
84 | end_s = single_coordinate['end_s']
85 | noise_patch1 = noise_img[init_s:end_s,init_h:end_h,init_w:end_w]
86 | train_input = np.expand_dims(np.expand_dims(noise_patch1.transpose(1,2,0), 3),0)
87 | # print('train_input -----> ',train_input.shape)
88 | data_name = name_list[index]
89 | train_output = sess.run(output, feed_dict={input: train_input})
90 |
91 | train_input = np.squeeze(train_input).transpose(2,0,1)
92 | train_output = np.squeeze(train_output).transpose(2,0,1)
93 | stack_start_w ,stack_end_w ,patch_start_w ,patch_end_w ,\
94 | stack_start_h ,stack_end_h ,patch_start_h ,patch_end_h ,\
95 | stack_start_s ,stack_end_s ,patch_start_s ,patch_end_s = name2index(args, input_name, num_h, num_w, num_s)
96 |
97 | aaaa = train_output[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]
98 | bbbb = train_input[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]
99 |
100 | denoise_img[stack_start_s:stack_end_s, stack_start_w:stack_end_w, stack_start_h:stack_end_h] \
101 | = train_output[patch_start_s:patch_end_s, patch_start_w:patch_end_w, patch_start_h:patch_end_h]*(np.sum(bbbb)/np.sum(aaaa))**0.5
102 | input_img[stack_start_s:stack_end_s, stack_start_w:stack_end_w, stack_start_h:stack_end_h] \
103 | = train_input[patch_start_s:patch_end_s, patch_start_w:patch_end_w, patch_start_h:patch_end_h]
104 | # print('output_img shape -----> ',output_img.shape)
105 |
106 | '''
107 | output_img = denoise_img.squeeze().astype(np.float32)*args.normalize_factor
108 | output_img = np.clip(output_img, 0, 65535).astype('uint16')
109 | result_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_output.tif'
110 | io.imsave(result_name, output_img.transpose(2,0,1))
111 | '''
112 |
113 | # if index % 100 == 0:
114 | # print('denoise_img ---> ',denoise_img.max(),'---> ',denoise_img.min())
115 | # print('input_img ---> ',input_img.max(),'---> ',input_img.min())
116 | output_img = denoise_img.squeeze().astype(np.float32)*args.normalize_factor
117 | output_img = output_img-output_img.min()
118 | output_img = output_img/output_img.max()*65535
119 | output_img = np.clip(output_img, 0, 65535).astype('uint16')
120 | output_img = output_img-output_img.min()
121 | input_img = input_img.squeeze().astype(np.float32)*args.normalize_factor
122 | input_img = np.clip(input_img, 0, 65535).astype('uint16')
123 | result_name = output_path + '//' + 'output_'+model_name+'.tif'
124 | input_name = output_path + '//' + 'input_'+model_name+'.tif'
125 | io.imsave(result_name, output_img)
126 | io.imsave(input_name, input_img)
127 |
128 |
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument('--img_h', type=int, default=64, help='the height of patch stack')
134 | parser.add_argument('--img_w', type=int, default=64, help='the width of patch stack')
135 | parser.add_argument('--img_s', type=int, default=464, help='the image number of patch stack')
136 | parser.add_argument('--img_c', type=int, default=1, help='the channel of image')
137 | parser.add_argument('--gap_h', type=int, default=56, help='actions: train or predict')
138 | parser.add_argument('--gap_w', type=int, default=56, help='actions: train or predict')
139 | parser.add_argument('--gap_s', type=int, default=144, help='actions: train or predict')
140 | parser.add_argument('--normalize_factor', type=int, default=1, help='actions: train or predict')
141 | parser.add_argument('--datasets_folder', type=str, default='test2', help='actions: train or predict')
142 | parser.add_argument('--model_name', type=str, default='test2', help='actions: train or predict')
143 | parser.add_argument('--model_folder', type=str, default='log', help='actions: train or predict')
144 | parser.add_argument('--model_epoch', type=int, default=0, help='actions: train or predict')
145 | parser.add_argument('--CBSDeep_model_folder', type=str, default='DeepCAD_model', help='actions: train or predict')
146 | parser.add_argument('--results_folder', type=str, default='results', help='actions: train or predict')
147 | parser.add_argument('--GPU', type=int, default=3, help='the index of GPU you will use for computation')
148 |
149 | parser.add_argument('--denoise_model', type=str, default='unet3d_test2_20200924-1707', help='actions: train or predict')
150 | parser.add_argument('--test_datasize', type=int, default=512, help='epoch for denoising')
151 | args = parser.parse_args()
152 | print('hhhhh ----->',args)
153 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU)
154 | main(args)
155 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.models import load_model
3 | import argparse
4 | import os
5 | import tifffile as tiff
6 | import time
7 | import datetime
8 | import random
9 | from skimage import io
10 | from network import autoencoder
11 | import tensorflow as tf
12 | import logging
13 | import time
14 | from utils import save_yaml
15 | from data_process import train_preprocess_lessMemory, shuffle_datasets_lessMemory,train_preprocess_lessMemoryMulStacks
16 |
17 |
18 | def main(args):
19 | train(args)
20 |
21 | def train(args):
22 | TIME = args.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M")
23 | DeepCAD_model_path = args.DeepCAD_model_folder+'//'+'pb_unet3d_'+TIME+'//'
24 | if not os.path.exists(args.DeepCAD_model_folder):
25 | os.mkdir(args.DeepCAD_model_folder)
26 | if not os.path.exists(DeepCAD_model_path):
27 | os.mkdir(DeepCAD_model_path)
28 | yaml_name = DeepCAD_model_path+'//para.yaml'
29 | save_yaml(args, yaml_name)
30 | results_path = args.results_folder+'//'+'unet3d_'+TIME+'//'
31 | if not os.path.exists(args.results_folder):
32 | os.mkdir(args.results_folder)
33 | if not os.path.exists(results_path):
34 | os.mkdir(results_path)
35 |
36 | name_list, noise_img, coordinate_list = train_preprocess_lessMemoryMulStacks(args)
37 | data_size = len(name_list)
38 |
39 | sess = tf.Session()
40 | input_shape = [1, args.img_h, args.img_w, args.img_s, args.img_c]
41 | input = tf.placeholder(tf.float32, shape=input_shape, name='input')
42 | # output = tf.placeholder(tf.float32, shape=input_shape, name='output')
43 | output_GT = tf.placeholder(tf.float32, shape=input_shape, name='output_GT')
44 | # net = Network(training = args.is_training)
45 | output = autoencoder(input, height=args.img_h, width=args.img_w, length=args.img_s)
46 |
47 | L2_loss = tf.reduce_mean(tf.square(output - output_GT))
48 | L1_loss = tf.reduce_sum(tf.losses.absolute_difference(output, output_GT))
49 | loss = tf.add(L1_loss, L2_loss)
50 | optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
51 | train_step = optimizer.minimize(loss)
52 | start_time=time.time()
53 | with sess.as_default():
54 | sess.run(tf.global_variables_initializer())
55 | for i in range(args.train_epochs):
56 | name_list = shuffle_datasets_lessMemory(name_list)
57 | for index in range(data_size):
58 | single_coordinate = coordinate_list[name_list[index]]
59 | init_h = single_coordinate['init_h']
60 | end_h = single_coordinate['end_h']
61 | init_w = single_coordinate['init_w']
62 | end_w = single_coordinate['end_w']
63 | init_s = single_coordinate['init_s']
64 | end_s = single_coordinate['end_s']
65 | noise_patch1 = noise_img[init_s:end_s:2,init_h:end_h,init_w:end_w]
66 | noise_patch2 = noise_img[init_s+1:end_s:2,init_h:end_h,init_w:end_w]
67 | train_input = np.expand_dims(np.expand_dims(noise_patch1.transpose(1,2,0), 3),0)
68 | train_GT = np.expand_dims(np.expand_dims(noise_patch2.transpose(1,2,0), 3),0)
69 | # print(train_input.shape)
70 | data_name = name_list[index]
71 | sess.run(train_step, feed_dict={input: train_input, output_GT: train_GT})
72 |
73 | if index % 100 == 0:
74 | output_img, L1_loss_va, L2_loss_va = sess.run([output, L1_loss, L2_loss], feed_dict={input: train_input, output_GT: train_GT})
75 | print('--- Epoch ',i,' --- Step ',index,'/',data_size,' --- L1_loss ', L1_loss_va,' --- L2_loss ', L2_loss_va,' --- Time ',(time.time()-start_time))
76 | print('train_input ---> ',train_input.max(),'---> ',train_input.min())
77 | print('output_img ---> ',output_img.max(),'---> ',output_img.min())
78 | train_input = train_input.squeeze().astype(np.float32)*args.normalize_factor
79 | train_GT = train_GT.squeeze().astype(np.float32)*args.normalize_factor
80 | output_img = output_img.squeeze().astype(np.float32)*args.normalize_factor
81 | train_input = np.clip(train_input, 0, 65535).astype('uint16')
82 | train_GT = np.clip(train_GT, 0, 65535).astype('uint16')
83 | output_img = np.clip(output_img, 0, 65535).astype('uint16')
84 | result_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_output.tif'
85 | noise_img1_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_noise1.tif'
86 | noise_img2_name = results_path+str(i)+'_'+str(index)+'_'+data_name+'_noise2.tif'
87 | io.imsave(result_name, output_img.transpose(2,0,1))
88 | io.imsave(noise_img1_name, train_input.transpose(2,0,1))
89 | io.imsave(noise_img2_name, train_GT.transpose(2,0,1))
90 | '''
91 | variable_names = [v.name for v in tf.trainable_variables()]
92 | values = sess.run(variable_names)
93 | for k,v in zip(variable_names, values):
94 | if len(v.shape)==5:
95 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0][0][0][0][0])
96 | if len(v.shape)==1:
97 | print("Variable: ", k, "Shape: ", v.shape,"value: ",v[0])
98 | '''
99 | '''
100 | aaaaa=0
101 | for op in tf.get_default_graph().get_operations():
102 | aaaaa=aaaaa+1
103 | if aaaaa<50:
104 | # print('-----> ',op.name)
105 | print('-----> ',op.values())
106 | '''
107 |
108 | if index % 1000 == 0:
109 | DeepCAD_model_name=DeepCAD_model_path+'//'+str(i)+'_'+str(index)+'//'
110 | builder = tf.saved_model.builder.SavedModelBuilder(DeepCAD_model_name)
111 | input0 = {'input0': tf.saved_model.utils.build_tensor_info(input)}
112 | output0 = {'output0': tf.saved_model.utils.build_tensor_info(output)}
113 | method_name = tf.saved_model.signature_constants.PREDICT_METHOD_NAME
114 | my_signature = tf.saved_model.signature_def_utils.build_signature_def(input0, output0, method_name)
115 | builder.add_meta_graph_and_variables(sess, ["3D_N2N"], signature_def_map={'my_signature': my_signature})
116 | builder.add_meta_graph(["3D_N2N"], signature_def_map={'my_signature': my_signature})
117 | builder.save()
118 |
119 |
120 |
121 |
122 | if __name__ == '__main__':
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument('--img_h', type=int, default=64, help='the height of patch stack')
125 | parser.add_argument('--img_w', type=int, default=64, help='the width of patch stack')
126 | parser.add_argument('--img_s', type=int, default=320, help='the image number of patch stack')
127 | parser.add_argument('--img_c', type=int, default=1, help='the channel of image')
128 | parser.add_argument('--gap_h', type=int, default=56, help='the height of patch gap')
129 | parser.add_argument('--gap_w', type=int, default=56, help='the width of patch gap')
130 | parser.add_argument('--gap_s', type=int, default=128, help='the image number of patch gap')
131 | parser.add_argument('--normalize_factor', type=int, default=1, help='Image normalization factor')
132 | parser.add_argument('--train_epochs', type=int, default=30, help='train epochs')
133 | parser.add_argument('--datasets_path', type=str, default='datasets', help="the name of your project")
134 | parser.add_argument('--datasets_folder', type=str, default='3', help='the folders for datasets')
135 | parser.add_argument('--DeepCAD_model_folder', type=str, default='DeepCAD_model', help='the folders for DeepCAD(pb) model')
136 | parser.add_argument('--results_folder', type=str, default='results', help='the folders for results')
137 | parser.add_argument('--GPU', type=int, default=3, help='the index of GPU you will use for computation')
138 | parser.add_argument('--is_training', type=bool, default=True, help='train or test')
139 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
140 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='actions: train or predict')
141 | parser.add_argument('--select_img_num', type=int, default=6000, help='actions: train or predict')
142 | args = parser.parse_args()
143 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU)
144 | main(args)
145 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.utils.data import DataLoader
6 | import argparse
7 | import time
8 | import datetime
9 | import sys
10 | import math
11 | import scipy.io as scio
12 | from network import Network_3D_Unet
13 | from tensorboardX import SummaryWriter
14 | import numpy as np
15 | from data_process import shuffle_datasets, train_preprocess_lessMemoryMulStacks, shuffle_datasets_lessMemory
16 | from utils import save_yaml
17 | from skimage import io
18 | #############################################################################################################################################
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
21 | parser.add_argument("--n_epochs", type=int, default=100, help="number of training epochs")
22 | parser.add_argument('--cuda', action='store_true', help='use GPU computation')
23 | parser.add_argument('--GPU', type=int, default=0, help="the index of GPU you will use for computation")
24 |
25 | parser.add_argument('--batch_size', type=int, default=1, help="batch size")
26 | parser.add_argument('--img_s', type=int, default=150, help="the slices of image sequence")
27 | parser.add_argument('--img_w', type=int, default=150, help="the width of image sequence")
28 | parser.add_argument('--img_h', type=int, default=150, help="the height of image sequence")
29 | parser.add_argument('--gap_s', type=int, default=60, help='the slices of image gap')
30 | parser.add_argument('--gap_w', type=int, default=90, help='the width of image gap')
31 | parser.add_argument('--gap_h', type=int, default=90, help='the height of image gap')
32 |
33 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
34 | parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1")
35 | parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2")
36 | parser.add_argument('--normalize_factor', type=int, default=65535, help='normalize factor')
37 |
38 | parser.add_argument('--output_dir', type=str, default='./results', help="output directory")
39 | parser.add_argument('--datasets_folder', type=str, default='DataForPytorch', help="A folder containing files for training")
40 | parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path")
41 | parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path")
42 | parser.add_argument('--select_img_num', type=int, default=6000, help='select the number of images')
43 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='datasets size for training')
44 | opt = parser.parse_args()
45 |
46 | print('the parameter of your training ----->')
47 | print(opt)
48 | ########################################################################################################################
49 | if not os.path.exists(opt.output_dir):
50 | os.mkdir(opt.output_dir)
51 | current_time = opt.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d-%H%M")
52 | output_path = opt.output_dir + '/' + current_time
53 | pth_path = 'pth//'+ current_time
54 | if not os.path.exists(output_path):
55 | os.mkdir(output_path)
56 | if not os.path.exists(pth_path):
57 | os.mkdir(pth_path)
58 |
59 | yaml_name = pth_path+'//para.yaml'
60 | save_yaml(opt, yaml_name)
61 |
62 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU)
63 | batch_size = opt.batch_size
64 | lr = opt.lr
65 |
66 | name_list, noise_img, coordinate_list = train_preprocess_lessMemoryMulStacks(opt)
67 | # print('name_list -----> ',name_list)
68 | ########################################################################################################################
69 | L1_pixelwise = torch.nn.L1Loss()
70 | L2_pixelwise = torch.nn.MSELoss()
71 | ########################################################################################################################
72 | denoise_generator = Network_3D_Unet(in_channels = 1,
73 | out_channels = 1,
74 | final_sigmoid = True)
75 | if torch.cuda.is_available():
76 | print('Using GPU.')
77 | denoise_generator.cuda()
78 | L2_pixelwise.cuda()
79 | L1_pixelwise.cuda()
80 | ########################################################################################################################
81 | optimizer_G = torch.optim.Adam( denoise_generator.parameters(),
82 | lr=opt.lr, betas=(opt.b1, opt.b2))
83 | ########################################################################################################################
84 |
85 | cuda = True if torch.cuda.is_available() else False
86 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
87 | prev_time = time.time()
88 | ########################################################################################################################
89 | time_start=time.time()
90 | for epoch in range(opt.epoch, opt.n_epochs):
91 | name_list = shuffle_datasets_lessMemory(name_list)
92 | # print('name list -----> ',name_list)
93 | ####################################################################################################################
94 | for index in range(len(name_list)):
95 | single_coordinate = coordinate_list[name_list[index]]
96 | init_h = single_coordinate['init_h']
97 | end_h = single_coordinate['end_h']
98 | init_w = single_coordinate['init_w']
99 | end_w = single_coordinate['end_w']
100 | init_s = single_coordinate['init_s']
101 | end_s = single_coordinate['end_s']
102 | noise_patch1 = noise_img[init_s:end_s:2,init_h:end_h,init_w:end_w]
103 | noise_patch2 = noise_img[init_s+1:end_s:2,init_h:end_h,init_w:end_w]
104 | real_A = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch1, 3),0)).cuda()
105 | real_A = real_A.permute([0,4,1,2,3])
106 | real_B = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch2, 3),0)).cuda()
107 | real_B = real_B.permute([0,4,1,2,3])
108 | # print('real_A shape -----> ',real_A.shape)
109 | # print('real_B shape -----> ',real_B.shape)
110 | input_name = name_list[index]
111 | real_A = Variable(real_A)
112 | fake_B = denoise_generator(real_A)
113 | # Pixel-wise loss
114 | L1_loss = L1_pixelwise(fake_B, real_B)
115 | L2_loss = L2_pixelwise(fake_B, real_B)
116 | ################################################################################################################
117 | optimizer_G.zero_grad()
118 | # Total loss
119 | Total_loss = 0.5*L1_loss + 0.5*L2_loss
120 | Total_loss.backward()
121 | optimizer_G.step()
122 | ################################################################################################################
123 | batches_done = epoch * len(name_list) + index
124 | batches_left = opt.n_epochs * len(name_list) - batches_done
125 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
126 | prev_time = time.time()
127 | ################################################################################################################
128 | if index%50 == 0:
129 | time_end=time.time()
130 | print('time cost',time_end-time_start,'s \n')
131 | sys.stdout.write(
132 | "\r[Epoch %d/%d] [Batch %d/%d] [Total loss: %f, L1 Loss: %f, L2 Loss: %f] ETA: %s"
133 | % (
134 | epoch,
135 | opt.n_epochs,
136 | index,
137 | len(name_list),
138 | Total_loss.item(),
139 | L1_loss.item(),
140 | L2_loss.item(),
141 | time_left,
142 | )
143 | )
144 | ################################################################################################################
145 | # if (epoch+1)%1 == 0:
146 | # torch.save(denoise_generator.state_dict(), pth_path + '//G_' + str(epoch) + '.pth')
147 | if (index+1)%300 == 0:
148 | torch.save(denoise_generator.state_dict(), pth_path + '//G_' + str(epoch) +'_'+ str(index) + '.pth')
149 | if (epoch+1)%1 == 0:
150 | output_img = fake_B.cpu().detach().numpy()
151 | train_GT = real_B.cpu().detach().numpy()
152 | train_input = real_A.cpu().detach().numpy()
153 | image_name = input_name
154 |
155 | train_input = train_input.squeeze().astype(np.float32)*opt.normalize_factor
156 | train_GT = train_GT.squeeze().astype(np.float32)*opt.normalize_factor
157 | output_img = output_img.squeeze().astype(np.float32)*opt.normalize_factor
158 | train_input = np.clip(train_input, 0, 65535).astype('uint16')
159 | train_GT = np.clip(train_GT, 0, 65535).astype('uint16')
160 | output_img = np.clip(output_img, 0, 65535).astype('uint16')
161 | result_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_output.tif'
162 | noise_img1_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_noise1.tif'
163 | noise_img2_name = output_path + '/' + str(epoch) + '_' + str(index) + '_' + input_name+'_noise2.tif'
164 | io.imsave(result_name, output_img)
165 | io.imsave(noise_img1_name, train_input)
166 | io.imsave(noise_img2_name, train_GT)
167 |
168 |
169 | torch.save(denoise_generator.state_dict(), pth_path +'//G_' + str(opt.n_epochs) + '.pth')
170 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.utils.data import DataLoader
6 | import argparse
7 | import time
8 | import datetime
9 | import sys
10 | import math
11 | import scipy.io as scio
12 | from network import Network_3D_Unet
13 | from tensorboardX import SummaryWriter
14 | import numpy as np
15 | from utils import save_yaml, read_yaml, name2index
16 | from data_process import shuffle_datasets, train_preprocess, test_preprocess, test_preprocess_lessMemory,test_preprocess_lessMemoryNoTail
17 | from skimage import io
18 | #############################################################################################################################################
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
21 | parser.add_argument("--n_epochs", type=int, default=100, help="number of training epochs")
22 | parser.add_argument('--cuda', action='store_true', help='use GPU computation')
23 | parser.add_argument('--GPU', type=int, default=0, help="the index of GPU you will use for computation")
24 |
25 | parser.add_argument('--batch_size', type=int, default=1, help="batch size")
26 | parser.add_argument('--img_s', type=int, default=150, help="the slices of image sequence")
27 | parser.add_argument('--img_w', type=int, default=150, help="the width of image sequence")
28 | parser.add_argument('--img_h', type=int, default=150, help="the height of image sequence")
29 | parser.add_argument('--gap_s', type=int, default=60, help='the slices of image gap')
30 | parser.add_argument('--gap_w', type=int, default=90, help='the width of image gap')
31 | parser.add_argument('--gap_h', type=int, default=90, help='the height of image gap')
32 |
33 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
34 | parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1")
35 | parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2")
36 | parser.add_argument('--normalize_factor', type=int, default=65535, help='normalize factor')
37 |
38 | parser.add_argument('--output_dir', type=str, default='./results', help="output directory")
39 | parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path")
40 | parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path")
41 | parser.add_argument('--datasets_folder', type=str, default='DataForPytorch', help="A folder containing files to be tested")
42 | parser.add_argument('--denoise_model', type=str, default='ModelForPytorch', help='A folder containing models to be tested')
43 | parser.add_argument('--test_datasize', type=int, default=6000, help='dataset size to be tested')
44 | parser.add_argument('--train_datasets_size', type=int, default=1000, help='datasets size for training')
45 |
46 | opt = parser.parse_args()
47 | print('the parameter of your training ----->')
48 | print(opt)
49 | ########################################################################################################################
50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU)
51 | model_path = opt.pth_path+'//'+opt.denoise_model
52 | # print(model_path)
53 | model_list = list(os.walk(model_path, topdown=False))[-1][-1]
54 | # print(model_list)
55 |
56 | for i in range(len(model_list)):
57 | aaa = model_list[i]
58 | if '.yaml' in aaa:
59 | yaml_name = model_list[i]
60 | print(yaml_name)
61 | read_yaml(opt, model_path+'//'+yaml_name)
62 | # print(opt.datasets_folder)
63 |
64 | name_list, noise_img, coordinate_list= test_preprocess_lessMemoryNoTail(opt)
65 | # name_list, noise_img, coordinate_list = test_preprocess_lessMemory(opt)
66 | # trainX = np.expand_dims(np.array(train_raw),4)
67 | num_h = (math.floor((noise_img.shape[1]-opt.img_h)/opt.gap_h)+1)
68 | num_w = (math.floor((noise_img.shape[2]-opt.img_w)/opt.gap_w)+1)
69 | num_s = (math.floor((noise_img.shape[0]-opt.img_s)/opt.gap_s)+1)
70 | # print(num_h, num_w, num_s)
71 | # print(coordinate_list)
72 |
73 | if not os.path.exists(opt.output_dir):
74 | os.mkdir(opt.output_dir)
75 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")
76 | output_path1 = opt.output_dir + '//' + opt.datasets_folder + '_' + current_time + '_' + opt.denoise_model
77 | if not os.path.exists(output_path1):
78 | os.mkdir(output_path1)
79 |
80 | yaml_name = output_path1+'//para.yaml'
81 | save_yaml(opt, yaml_name)
82 | denoise_generator = Network_3D_Unet(in_channels = 1,
83 | out_channels = 1,
84 | final_sigmoid = True)
85 | if torch.cuda.is_available():
86 | print('Using GPU.')
87 | for pth_index in range(len(model_list)):
88 | aaa = model_list[pth_index]
89 | if '.pth' in aaa:
90 | pth_name = model_list[pth_index]
91 | output_path = output_path1 + '//' + pth_name.replace('.pth','')
92 | if not os.path.exists(output_path):
93 | os.mkdir(output_path)
94 | denoise_generator.load_state_dict(torch.load(opt.pth_path+'//'+opt.denoise_model+'//'+pth_name))
95 |
96 | denoise_generator.cuda()
97 | prev_time = time.time()
98 | time_start=time.time()
99 | denoise_img = np.zeros(noise_img.shape)
100 | input_img = np.zeros(noise_img.shape)
101 | for index in range(len(name_list)):
102 | single_coordinate = coordinate_list[name_list[index]]
103 | init_h = single_coordinate['init_h']
104 | end_h = single_coordinate['end_h']
105 | init_w = single_coordinate['init_w']
106 | end_w = single_coordinate['end_w']
107 | init_s = single_coordinate['init_s']
108 | end_s = single_coordinate['end_s']
109 | noise_patch = noise_img[init_s:end_s,init_h:end_h,init_w:end_w]
110 | # print(noise_patch.shape)
111 | real_A = torch.from_numpy(np.expand_dims(np.expand_dims(noise_patch, 3),0)).cuda()
112 | # print('real_A -----> ',real_A.shape)
113 | real_A = real_A.permute([0,4,1,2,3])
114 | input_name = name_list[index]
115 | print(' input_name -----> ',input_name)
116 | print(' single_coordinate -----> ',single_coordinate)
117 | print('real_A -----> ',real_A.shape)
118 | real_A = Variable(real_A)
119 | fake_B = denoise_generator(real_A)
120 | ################################################################################################################
121 | # Determine approximate time left
122 | batches_done = index
123 | batches_left = 1 * len(name_list) - batches_done
124 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
125 | prev_time = time.time()
126 | prev_time = time.time()
127 | ################################################################################################################
128 | if index%1 == 0:
129 | time_end=time.time()
130 | time_cost=datetime.timedelta(seconds= (time_end - time_start))
131 | sys.stdout.write("\r [Batch %d/%d] [Time Left: %s] [Time Cost: %s]"
132 | % (index,
133 | len(name_list),
134 | time_left,
135 | time_cost,))
136 | ################################################################################################################
137 | output_image = np.squeeze(fake_B.cpu().detach().numpy())
138 | raw_image = np.squeeze(real_A.cpu().detach().numpy())
139 | stack_start_w = int(single_coordinate['stack_start_w'])
140 | stack_end_w = int(single_coordinate['stack_end_w'])
141 | patch_start_w = int(single_coordinate['patch_start_w'])
142 | patch_end_w = int(single_coordinate['patch_end_w'])
143 |
144 | stack_start_h = int(single_coordinate['stack_start_h'])
145 | stack_end_h = int(single_coordinate['stack_end_h'])
146 | patch_start_h = int(single_coordinate['patch_start_h'])
147 | patch_end_h = int(single_coordinate['patch_end_h'])
148 |
149 | stack_start_s = int(single_coordinate['stack_start_s'])
150 | stack_end_s = int(single_coordinate['stack_end_s'])
151 | patch_start_s = int(single_coordinate['patch_start_s'])
152 | patch_end_s = int(single_coordinate['patch_end_s'])
153 |
154 | aaaa = output_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]
155 | bbbb = raw_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]
156 |
157 | denoise_img[stack_start_s:stack_end_s, stack_start_h:stack_end_h, stack_start_w:stack_end_w] \
158 | = output_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]*(np.sum(bbbb)/np.sum(aaaa))**0.5
159 | input_img[stack_start_s:stack_end_s, stack_start_h:stack_end_h, stack_start_w:stack_end_w] \
160 | = raw_image[patch_start_s:patch_end_s, patch_start_h:patch_end_h, patch_start_w:patch_end_w]
161 |
162 | # del noise_img
163 | output_img = denoise_img.squeeze().astype(np.float32)*opt.normalize_factor
164 | del denoise_img
165 | # output_img = output_img1[0:raw_noise_img.shape[0],0:raw_noise_img.shape[1],0:raw_noise_img.shape[2]]
166 | output_img = output_img-output_img.min()
167 | output_img = output_img/output_img.max()*65535
168 | output_img = np.clip(output_img, 0, 65535).astype('uint16')
169 | output_img = output_img-output_img.min()
170 | # output_img = output_img.astype('uint16')
171 | input_img = input_img.squeeze().astype(np.float32)*opt.normalize_factor
172 | # input_img = input_img1[0:raw_noise_img.shape[0],0:raw_noise_img.shape[1],0:raw_noise_img.shape[2]]
173 | input_img = np.clip(input_img, 0, 65535).astype('uint16')
174 | result_name = output_path + '//' +pth_name.replace('.pth','')+ '_output.tif'
175 | input_name = output_path + '//' +pth_name.replace('.pth','')+ '_input.tif'
176 | io.imsave(result_name, output_img)
177 | io.imsave(input_name, input_img)
178 |
179 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepCAD: Deep self-supervised learning for calcium imaging denoising
2 |
3 |
4 |
5 | ### :triangular_flag_on_post: [[New version released! :DeepCAD-RT]](https://github.com/cabooster/DeepCAD-RT)
6 |
7 | ## Contents
8 |
9 | - [Overview](#overview)
10 | - [Directory structure](#directory-structure)
11 | - [Pytorch code](#pytorch-code)
12 | - [Fiji plugin](#fiji-plugin)
13 | - [Results](#results)
14 | - [License](./LICENSE)
15 | - [Citation](#citation)
16 |
17 | ## Overview
18 |
19 |
20 |
21 | Calcium imaging is inherently susceptible to detection noise especially when imaging with high frame rate or under low excitation dosage. However, calcium transients are highly dynamic, non-repetitive activities and a firing pattern cannot be captured twice. Clean images for supervised training of deep neural networks are not accessible. Here, we present DeepCAD, a **deep** self-supervised learning-based method for **ca**lcium imaging **d**enoising. Using our method, detection noise can be effectively removed and the accuracy of neuron extraction and spike inference can be highly improved.
22 |
23 | DeepCAD is based on the insight that a deep learning network for image denoising can achieve satisfactory convergence even the target image used for training is another corrupted sampling of the same scene [[paper link]](https://arxiv.org/abs/1803.04189). We explored the temporal redundancy of calcium imaging and found that any two consecutive frames can be regarded as two independent samplings of the same underlying firing pattern. A single low-SNR stack is sufficient to be a complete training set for DeepCAD. Furthermore, to boost its performance on 3D temporal stacks, the input and output data are designed to be 3D volumes rather than 2D frames to fully incorporate the abundant information along time axis.
24 |
25 | For more details, please see the companion paper where the method first appeared:
26 | ["*Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising*".](https://www.nature.com/articles/s41592-021-01225-0)
27 |
28 | ## Directory structure
29 |
30 | ```
31 | DeepCAD
32 | |---DeepCAD_pytorch #Pytorch implementation of DeepCAD#
33 | |---|---train.py
34 | |---|---test.py
35 | |---|---script.py
36 | |---|---network.py
37 | |---|---model_3DUnet.py
38 | |---|---data_process.py
39 | |---|---buildingblocks.py
40 | |---|---utils.py
41 | |---|---datasets
42 | |---|---|---DataForPytorch #project_name#
43 | |---|---|---|---data.tif
44 | |---|---pth
45 | |---|---|---ModelForPytorch
46 | |---|---|---|---model.pth
47 | |---|---results
48 | |---|---|--- # Intermediate and final results#
49 | |---DeepCAD_Fiji
50 | |---|---DeepCAD_Fiji_plugin
51 | |---|---|---DeepCAD-0.3.0 #executable jar file/.jar#
52 | |---|---DeepCAD_java #java source code of DeepCAD Fiji plugin#
53 | |---|---DeepCAD_tensorflow #Tensorflow implementation compatible with Fiji plugin#
54 | ```
55 | - **DeepCAD_pytorch** is the [Pytorch](https://pytorch.org/) implementation of DeepCAD.
56 | - **DeepCAD_Fiji** is a user-friendly [Fiji](https://imagej.net/Fiji) plugin. This plugin is easy to install and convenient to use. Researchers without expertise in computer science and machine learning can learn to use it in a very short time.
57 | - **DeepCAD_Fiji_plugin** contains the executable .jar file that can be installed on Fiji.
58 | - **DeepCAD_java** is the java source code of our Fiji plugin based on [CSBDeep](https://csbdeep.bioimagecomputing.com).
59 | - **DeepCAD_tensorflow** is the [Tensorflow](https://www.tensorflow.org/) implementation of DeepCAD, which is used for training models compatible with the Fiji plugin.
60 |
61 | ## Pytorch code
62 |
63 | PyTorch code is the recommended implementation of DeepCAD.
64 |
65 | ### :triangular_flag_on_post: [[New version released]](https://github.com/cabooster/DeepCAD-RT)
66 | **An upgraded version of this PyTorch code has been released and managed at another repository [[new code link]](https://github.com/cabooster/DeepCAD-RT), with several new features such as much faster processing speed, low memory cost, improved pre- and post processing, multi-GPU acceleration, more stable performance, etc**.
67 |
68 | ### Environment
69 |
70 | * Ubuntu 16.04
71 | * Python 3.6
72 | * Pytorch >= 1.3.1
73 | * NVIDIA GPU (24 GB Memory) + CUDA
74 |
75 | ### Environment configuration
76 |
77 | * Create a virtual environment and install Pytorch. In the 4th step, please select the correct Pytorch version that matches your CUDA version from https://pytorch.org/get-started/previous-versions/
78 |
79 | ```
80 | $ conda create -n deepcad python=3.6
81 | $ source activate deepcad
82 | $ pip install torch==1.3.1
83 | $ conda install pytorch torchvision cudatoolkit -c pytorch
84 | ```
85 |
86 | * Install other dependencies
87 |
88 | ```
89 | $ conda install -c anaconda matplotlib opencv scikit-learn scikit-image
90 | $ conda install -c conda-forge h5py pyyaml tensorboardx tifffile
91 | ```
92 | ### Download the source code
93 |
94 | ```
95 | $ git clone git://github.com/cabooster/DeepCAD
96 | $ cd DeepCAD/DeepCAD_pytorch/
97 | ```
98 |
99 | ### Training
100 |
101 | Download the demo data(.tif file) [[DataForPytorch](https://drive.google.com/drive/folders/1w9v1SrEkmvZal5LH79HloHhz6VXSPfI_)] and put it into *DeepCAD_pytorch/datasets/DataForPytorch.*.
102 |
103 | Run the **script.py** to start training.
104 |
105 | ```
106 | $ source activate deepcad
107 | $ python script.py train
108 | ```
109 |
110 | Parameters can be modified as required in **script.py**. If your GPU is running out of memory, you can use smaller `img_h`, `img_w`, `img_s` and `gap_h`, `gap_h`, `gap_s`.
111 |
112 | ```
113 | $ os.system('python train.py --datasets_folder --img_h --img_w --img_s --gap_h --gap_w --gap_s --n_epochs --GPU --normalize_factor --train_datasets_size --select_img_num')
114 |
115 | @parameters
116 | --datasets_folder: the folder containing your training data (one or more stacks)
117 | --img_h, --img_w, --img_s: patch size in three dimensions
118 | --gap_h, --gap_w, --gap_s: the spacing to extract training patches from the input stack(s)
119 | --n_epochs: the number of training epochs
120 | --GPU: specify the GPU used for training
121 | --lr: learning rate, please use the default value
122 | --normalize_factor: a constant for image normalization
123 | --training_datasets_size: the number of patches you extracted for training
124 | --select_img_num: the number of slices used for training.
125 | ```
126 |
127 | ### Test
128 |
129 | Download our pre-trained model (.pth file and .yaml file) [[ModelForPytorch](https://drive.google.com/drive/folders/12LEFsAopTolaRyRpJtFpzOYH3tBZMGUP)] and put it into *DeepCAD_pytorch/pth/ModelForPytorch*.
130 |
131 | Run the **script.py** to start the test process. Parameters saved in the .yaml file will be automatically loaded. If your GPU is running out of memory, you can use smaller `img_h`, `img_w`, `img_s` and `gap_h`, `gap_h`, `gap_s`.
132 |
133 | ```
134 | $ source activate deepcad
135 | $ python script.py test
136 | ```
137 |
138 | Parameters can be modified as required in **script.py**. All models in the `--denoise_model` folder will be tested and manual inspection should be made for **model screening**.
139 |
140 | ```
141 | $ os.system('python test.py --denoise_model --datasets_folder --test_datasize')
142 |
143 | @parameters
144 | --denoise_model: the folder containing all the pre-trained models.
145 | --datasets_folder: the folder containing the testing data (one or more stacks).
146 | --test_datasize: the number of frames used for testing
147 | --img_h, --img_w, --img_s: patch size in three dimensions
148 | --gap_h, --gap_w, --gap_s: the spacing to extract test patches from the input stack(s)
149 | ```
150 |
151 | ## Fiji plugin
152 |
153 | To ameliorate the difficulty of using our deep self-supervised learning-based method, we developed a user-friendly Fiji plugin, which is easy to install and convenient to use (has been tested on a Windows desktop with Intel i9 CPU and 128G RAM). Researchers without expertise in computer science and machine learning can manage it in a very short time. **Tutorials** on installing and using the plugin has been moved to [**this page**](https://github.com/cabooster/DeepCAD/tree/master/DeepCAD_Fiji).
154 |
155 |
156 |
157 |
158 | ## Results
159 |
160 | ### 1. The performance of DeepCAD on denoising two-photon calcium imaging of neurite activities.
161 |
162 |
163 |
164 | ### 2. The performance of DeepCAD on denoising two-photon calcium imaging of large neuronal populations.
165 |
166 |
167 |
168 | ### 3. Cross-system validation.
169 |
170 |
171 |
172 | Denoising performance of DeepCAD on three two-photon laser-scanning microscopes (2PLSMs) with different system setups. **Our system** was equipped with alkali PMTs (PMT1001, Thorlabs) and a 25×/1.05 NA commercial objective (XLPLN25XWMP2, Olympus). The **standard 2PLSM** was equipped with a GaAsP PMT (H10770PA-40, Hamamatsu) and a 25×/1.05 NA commercial objective (XLPLN25XWMP2, Olympus). The **two-photon mesoscope** was equipped with a GaAsP PMT (H11706-40, Hamamatsu) and a 2.3×/0.6 NA custom objective. The same pre-trained model was used for processing these data.
173 |
174 | ## Citation
175 |
176 | If you use this code please cite the companion paper where the original method appeared:
177 |
178 | Li, X., Zhang, G., Wu, J. et al. Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising. Nat Methods (2021). [https://doi.org/10.1038/s41592-021-01225-0](https://www.nature.com/articles/s41592-021-01225-0)
179 |
180 | ```
181 | @article{li2021reinforcing,
182 | title={Reinforcing neuron extraction and spike inference in calcium imaging using deep self-supervised denoising},
183 | author={Li, Xinyang and Zhang, Guoxun and Wu, Jiamin and Zhang, Yuanlong and Zhao, Zhifeng and Lin, Xing and Qiao, Hui and Xie, Hao and Wang, Haoqian and Fang, Lu and others},
184 | journal={Nature Methods},
185 | pages={1--6},
186 | year={2021},
187 | publisher={Nature Publishing Group}
188 | }
189 | ```
190 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/buildingblocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
7 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
8 |
9 |
10 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
11 | """
12 | Create a list of modules with together constitute a single conv layer with non-linearity
13 | and optional batchnorm/groupnorm.
14 |
15 | Args:
16 | in_channels (int): number of input channels
17 | out_channels (int): number of output channels
18 | order (string): order of things, e.g.
19 | 'cr' -> conv + ReLU
20 | 'crg' -> conv + ReLU + groupnorm
21 | 'cl' -> conv + LeakyReLU
22 | 'ce' -> conv + ELU
23 | num_groups (int): number of groups for the GroupNorm
24 | padding (int): add zero-padding to the input
25 |
26 | Return:
27 | list of tuple (name, module)
28 | """
29 | assert 'c' in order, "Conv layer MUST be present"
30 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
31 |
32 | modules = []
33 | for i, char in enumerate(order):
34 | if char == 'r':
35 | modules.append(('ReLU', nn.ReLU(inplace=True)))
36 | elif char == 'l':
37 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
38 | elif char == 'e':
39 | modules.append(('ELU', nn.ELU(inplace=True)))
40 | elif char == 'c':
41 | # add learnable bias only in the absence of gatchnorm/groupnorm
42 | bias = not ('g' in order or 'b' in order)
43 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
44 | elif char == 'g':
45 | is_before_conv = i < order.index('c')
46 | assert not is_before_conv, 'GroupNorm MUST go after the Conv3d'
47 | # number of groups must be less or equal the number of channels
48 | if out_channels < num_groups:
49 | num_groups = out_channels
50 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)))
51 | elif char == 'b':
52 | is_before_conv = i < order.index('c')
53 | if is_before_conv:
54 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
55 | else:
56 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
57 | else:
58 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
59 |
60 | return modules
61 |
62 |
63 | class SingleConv(nn.Sequential):
64 | """
65 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
66 | of operations can be specified via the `order` parameter
67 |
68 | Args:
69 | in_channels (int): number of input channels
70 | out_channels (int): number of output channels
71 | kernel_size (int): size of the convolving kernel
72 | order (string): determines the order of layers, e.g.
73 | 'cr' -> conv + ReLU
74 | 'crg' -> conv + ReLU + groupnorm
75 | 'cl' -> conv + LeakyReLU
76 | 'ce' -> conv + ELU
77 | num_groups (int): number of groups for the GroupNorm
78 | """
79 |
80 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cr', num_groups=8, padding=1):
81 | super(SingleConv, self).__init__()
82 |
83 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
84 | self.add_module(name, module)
85 |
86 |
87 | class DoubleConv(nn.Sequential):
88 | """
89 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
90 | We use (Conv3d+ReLU+GroupNorm3d) by default.
91 | This can be changed however by providing the 'order' argument, e.g. in order
92 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
93 | Use padded convolutions to make sure that the output (H_out, W_out) is the same
94 | as (H_in, W_in), so that you don't have to crop in the decoder path.
95 |
96 | Args:
97 | in_channels (int): number of input channels
98 | out_channels (int): number of output channels
99 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
100 | kernel_size (int): size of the convolving kernel
101 | order (string): determines the order of layers, e.g.
102 | 'cr' -> conv + ReLU
103 | 'crg' -> conv + ReLU + groupnorm
104 | 'cl' -> conv + LeakyReLU
105 | 'ce' -> conv + ELU
106 | num_groups (int): number of groups for the GroupNorm
107 | """
108 |
109 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='cr', num_groups=8):
110 | super(DoubleConv, self).__init__()
111 | if encoder:
112 | # we're in the encoder path
113 | conv1_in_channels = in_channels
114 | conv1_out_channels = out_channels // 2
115 | if conv1_out_channels < in_channels:
116 | conv1_out_channels = in_channels
117 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
118 | else:
119 | # we're in the decoder path, decrease the number of channels in the 1st convolution
120 | conv1_in_channels, conv1_out_channels = in_channels, out_channels
121 | conv2_in_channels, conv2_out_channels = out_channels, out_channels
122 |
123 | # conv1
124 | self.add_module('SingleConv1',
125 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
126 | # conv2
127 | self.add_module('SingleConv2',
128 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
129 |
130 |
131 | class ExtResNetBlock(nn.Module):
132 | """
133 | Basic UNet block consisting of a SingleConv followed by the residual block.
134 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
135 | of output channels is compatible with the residual block that follows.
136 | This block can be used instead of standard DoubleConv in the Encoder module.
137 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf
138 |
139 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
140 | """
141 |
142 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
143 | super(ExtResNetBlock, self).__init__()
144 |
145 | # first convolution
146 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
147 | # residual block
148 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
149 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
150 | n_order = order
151 | for c in 'rel':
152 | n_order = n_order.replace(c, '')
153 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
154 | num_groups=num_groups)
155 |
156 | # create non-linearity separately
157 | if 'l' in order:
158 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
159 | elif 'e' in order:
160 | self.non_linearity = nn.ELU(inplace=True)
161 | else:
162 | self.non_linearity = nn.ReLU(inplace=True)
163 |
164 | def forward(self, x):
165 | # apply first convolution and save the output as a residual
166 | out = self.conv1(x)
167 | residual = out
168 |
169 | # residual block
170 | out = self.conv2(out)
171 | out = self.conv3(out)
172 |
173 | out += residual
174 | out = self.non_linearity(out)
175 |
176 | return out
177 |
178 |
179 | class Encoder(nn.Module):
180 | """
181 | A single module from the encoder path consisting of the optional max
182 | pooling layer (one may specify the MaxPool kernel_size to be different
183 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic
184 | (make sure to use complementary scale_factor in the decoder path) followed by
185 | a DoubleConv module.
186 | Args:
187 | in_channels (int): number of input channels
188 | out_channels (int): number of output channels
189 | conv_kernel_size (int): size of the convolving kernel
190 | apply_pooling (bool): if True use MaxPool3d before DoubleConv
191 | pool_kernel_size (tuple): the size of the window to take a max over
192 | pool_type (str): pooling layer: 'max' or 'avg'
193 | basic_module(nn.Module): either ResNetBlock or DoubleConv
194 | conv_layer_order (string): determines the order of layers
195 | in `DoubleConv` module. See `DoubleConv` for more info.
196 | num_groups (int): number of groups for the GroupNorm
197 | """
198 |
199 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
200 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='cr',
201 | num_groups=8):
202 | super(Encoder, self).__init__()
203 | assert pool_type in ['max', 'avg']
204 | if apply_pooling:
205 | if pool_type == 'max':
206 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
207 | else:
208 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
209 | else:
210 | self.pooling = None
211 |
212 | self.basic_module = basic_module(in_channels, out_channels,
213 | encoder=True,
214 | kernel_size=conv_kernel_size,
215 | order=conv_layer_order,
216 | num_groups=num_groups)
217 |
218 | def forward(self, x):
219 | if self.pooling is not None:
220 | x = self.pooling(x)
221 | x = self.basic_module(x)
222 | return x
223 |
224 |
225 | class Decoder(nn.Module):
226 | """
227 | A single module for decoder path consisting of the upsample layer
228 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv
229 | module.
230 | Args:
231 | in_channels (int): number of input channels
232 | out_channels (int): number of output channels
233 | kernel_size (int): size of the convolving kernel
234 | scale_factor (tuple): used as the multiplier for the image H/W/D in
235 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
236 | from the corresponding encoder
237 | basic_module(nn.Module): either ResNetBlock or DoubleConv
238 | conv_layer_order (string): determines the order of layers
239 | in `DoubleConv` module. See `DoubleConv` for more info.
240 | num_groups (int): number of groups for the GroupNorm
241 | """
242 |
243 | def __init__(self, in_channels, out_channels, kernel_size=3,
244 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='cr', num_groups=8):
245 | super(Decoder, self).__init__()
246 | if basic_module == DoubleConv:
247 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling
248 | self.upsample = None
249 | else:
250 | # otherwise use ConvTranspose3d (bear in mind your GPU memory)
251 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder
252 | # (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
253 | # also scale the number of channels from in_channels to out_channels so that summation joining
254 | # works correctly
255 | self.upsample = nn.ConvTranspose3d(in_channels,
256 | out_channels,
257 | kernel_size=kernel_size,
258 | stride=scale_factor,
259 | padding=1,
260 | output_padding=1)
261 | # adapt the number of in_channels for the ExtResNetBlock
262 | in_channels = out_channels
263 |
264 | self.basic_module = basic_module(in_channels, out_channels,
265 | encoder=False,
266 | kernel_size=kernel_size,
267 | order=conv_layer_order,
268 | num_groups=num_groups)
269 |
270 | def forward(self, encoder_features, x):
271 | if self.upsample is None:
272 | # use nearest neighbor interpolation and concatenation joining
273 | output_size = encoder_features.size()[2:]
274 | x = F.interpolate(x, size=output_size, mode='nearest')
275 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension
276 | x = torch.cat((encoder_features, x), dim=1)
277 | else:
278 | # use ConvTranspose3d and summation joining
279 | x = self.upsample(x)
280 | x += encoder_features
281 |
282 | x = self.basic_module(x)
283 | return x
284 |
285 |
286 | class FinalConv(nn.Sequential):
287 | """
288 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
289 | which reduces the number of channels to 'out_channels'.
290 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
291 | We use (Conv3d+ReLU+GroupNorm3d) by default.
292 | This can be change however by providing the 'order' argument, e.g. in order
293 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
294 | Args:
295 | in_channels (int): number of input channels
296 | out_channels (int): number of output channels
297 | kernel_size (int): size of the convolving kernel
298 | order (string): determines the order of layers, e.g.
299 | 'cr' -> conv + ReLU
300 | 'crg' -> conv + ReLU + groupnorm
301 | num_groups (int): number of groups for the GroupNorm
302 | """
303 |
304 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cr', num_groups=8):
305 | super(FinalConv, self).__init__()
306 |
307 | # conv1
308 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
309 |
310 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels
311 | final_conv = nn.Conv3d(in_channels, out_channels, 1)
312 | self.add_module('final_conv', final_conv)
313 |
--------------------------------------------------------------------------------
/DeepCAD_Fiji/DeepCAD_tensorflow/data_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import os
4 | import tifffile as tiff
5 | import time
6 | import datetime
7 | import random
8 | from skimage import io
9 | import logging
10 | import math
11 |
12 | def train_preprocess_lessMemoryMulStacks(args):
13 | img_h = args.img_h
14 | img_w = args.img_w
15 | img_s2 = args.img_s*2
16 | gap_h = args.gap_h
17 | gap_w = args.gap_w
18 | gap_s2 = args.gap_s*2
19 | im_folder = args.datasets_path+'//'+args.datasets_folder
20 |
21 | name_list = []
22 | # train_raw = []
23 | coordinate_list={}
24 |
25 | print('list(os.walk(im_folder, topdown=False)) -----> ',list(os.walk(im_folder, topdown=False)))
26 | stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1])
27 | print('stack_num -----> ',stack_num)
28 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
29 | # print('im_name -----> ',im_name)
30 | im_dir = im_folder+'//'+im_name
31 | noise_im = tiff.imread(im_dir)
32 | if noise_im.shape[0]>args.select_img_num:
33 | noise_im = noise_im[0:args.select_img_num,:,:]
34 | gap_s2 = get_gap_s(args, noise_im, stack_num)
35 | # print('noise_im shape -----> ',noise_im.shape)
36 | # print('noise_im max -----> ',noise_im.max())
37 | # print('noise_im min -----> ',noise_im.min())
38 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
39 |
40 | whole_w = noise_im.shape[2]
41 | whole_h = noise_im.shape[1]
42 | whole_s = noise_im.shape[0]
43 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
44 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
45 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
46 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
47 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
48 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
49 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
50 | init_h = gap_h*x
51 | end_h = gap_h*x + img_h
52 | init_w = gap_w*y
53 | end_w = gap_w*y + img_w
54 | init_s = gap_s2*z
55 | end_s = gap_s2*z + img_s2
56 | single_coordinate['init_h'] = init_h
57 | single_coordinate['end_h'] = end_h
58 | single_coordinate['init_w'] = init_w
59 | single_coordinate['end_w'] = end_w
60 | single_coordinate['init_s'] = init_s
61 | single_coordinate['end_s'] = end_s
62 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
63 | patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
64 | # train_raw.append(noise_patch1.transpose(1,2,0))
65 | name_list.append(patch_name)
66 | # print(' single_coordinate -----> ',single_coordinate)
67 | coordinate_list[patch_name] = single_coordinate
68 | return name_list, noise_im, coordinate_list
69 |
70 | def get_gap_s(args, img, stack_num):
71 | whole_w = img.shape[2]
72 | whole_h = img.shape[1]
73 | whole_s = img.shape[0]
74 | print('whole_w -----> ',whole_w)
75 | print('whole_h -----> ',whole_h)
76 | print('whole_s -----> ',whole_s)
77 | w_num = math.floor((whole_w-args.img_w)/args.gap_w)+1
78 | h_num = math.floor((whole_h-args.img_h)/args.gap_h)+1
79 | s_num = math.ceil(args.train_datasets_size/w_num/h_num/stack_num)
80 | print('w_num -----> ',w_num)
81 | print('h_num -----> ',h_num)
82 | print('s_num -----> ',s_num)
83 | gap_s = math.floor((whole_s-args.img_s*2)/(s_num-1))
84 | print('gap_s -----> ',gap_s)
85 | return gap_s
86 |
87 | def shuffle_datasets(train_raw, train_GT, name_list):
88 | index_list = list(range(0, len(name_list)))
89 | # print('index_list -----> ',index_list)
90 | random.shuffle(index_list)
91 | random_index_list = index_list
92 | # print('index_list -----> ',index_list)
93 | new_name_list = list(range(0, len(name_list)))
94 | train_raw = np.array(train_raw)
95 | # print('train_raw shape -----> ',train_raw.shape)
96 | train_GT = np.array(train_GT)
97 | # print('train_GT shape -----> ',train_GT.shape)
98 | new_train_raw = train_raw
99 | new_train_GT = train_GT
100 | for i in range(0,len(random_index_list)):
101 | # print('i -----> ',i)
102 | new_train_raw[i,:,:,:] = train_raw[random_index_list[i],:,:,:]
103 | new_train_GT[i,:,:,:] = train_GT[random_index_list[i],:,:,:]
104 | new_name_list[i] = name_list[random_index_list[i]]
105 | # new_train_raw = np.expand_dims(new_train_raw, 4)
106 | # new_train_GT = np.expand_dims(new_train_GT, 4)
107 | return new_train_raw, new_train_GT, new_name_list
108 |
109 | def train_preprocess(args):
110 | img_h = args.img_h
111 | img_w = args.img_w
112 | img_s2 = args.img_s*2
113 | gap_h = args.gap_h
114 | gap_w = args.gap_w
115 | gap_s2 = args.gap_s*2
116 | im_folder = 'datasets//'+args.datasets_folder
117 |
118 | name_list = []
119 | train_raw = []
120 | train_GT = []
121 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
122 | print('im_name -----> ',im_name)
123 | im_dir = im_folder+'//'+im_name
124 | noise_im = tiff.imread(im_dir)
125 | print('noise_im shape -----> ',noise_im.shape)
126 | # print('noise_im max -----> ',noise_im.max())
127 | # print('noise_im min -----> ',noise_im.min())
128 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
129 |
130 | whole_w = noise_im.shape[2]
131 | whole_h = noise_im.shape[1]
132 | whole_s = noise_im.shape[0]
133 | print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
134 | print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
135 | print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
136 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
137 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
138 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
139 | init_h = gap_h*x
140 | end_h = gap_h*x + img_h
141 | init_w = gap_w*y
142 | end_w = gap_w*y + img_w
143 | init_s = gap_s2*z
144 | end_s = gap_s2*z + img_s2
145 | noise_patch1 = noise_im[init_s:end_s:2,init_h:end_h,init_w:end_w]
146 | noise_patch2 = noise_im[init_s+1:end_s:2,init_h:end_h,init_w:end_w]
147 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
148 | train_raw.append(noise_patch1.transpose(1,2,0))
149 | train_GT.append(noise_patch2.transpose(1,2,0))
150 | name_list.append(patch_name)
151 | return train_raw, train_GT, name_list, noise_im
152 |
153 | def test_preprocess(args):
154 | img_h = args.img_h
155 | img_w = args.img_w
156 | img_s2 = args.img_s
157 | gap_h = args.gap_h
158 | gap_w = args.gap_w
159 | gap_s2 = args.gap_s
160 | im_folder = 'datasets//'+args.datasets_folder
161 |
162 | name_list = []
163 | train_raw = []
164 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
165 | # print('im_name -----> ',im_name)
166 | im_dir = im_folder+'//'+im_name
167 | noise_im = tiff.imread(im_dir)
168 | # print('noise_im shape -----> ',noise_im.shape)
169 | # print('noise_im max -----> ',noise_im.max())
170 | # print('noise_im min -----> ',noise_im.min())
171 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
172 |
173 | whole_w = noise_im.shape[2]
174 | whole_h = noise_im.shape[1]
175 | whole_s = noise_im.shape[0]
176 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
177 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
178 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
179 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
180 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
181 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
182 | init_h = gap_h*x
183 | end_h = gap_h*x + img_h
184 | init_w = gap_w*y
185 | end_w = gap_w*y + img_w
186 | init_s = gap_s2*z
187 | end_s = gap_s2*z + img_s2
188 | noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
189 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
190 | train_raw.append(noise_patch1.transpose(1,2,0))
191 | name_list.append(patch_name)
192 | return train_raw, name_list, noise_im
193 |
194 | def test_preprocess_lessMemory (args):
195 | img_h = args.img_h
196 | img_w = args.img_w
197 | img_s2 = args.img_s
198 | gap_h = args.gap_h
199 | gap_w = args.gap_w
200 | gap_s2 = args.gap_s
201 | im_folder = 'datasets//'+args.datasets_folder
202 |
203 | name_list = []
204 | # train_raw = []
205 | coordinate_list={}
206 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
207 | # print('im_name -----> ',im_name)
208 | im_dir = im_folder+'//'+im_name
209 | noise_im = tiff.imread(im_dir)
210 | # print('noise_im shape -----> ',noise_im.shape)
211 | # print('noise_im max -----> ',noise_im.max())
212 | # print('noise_im min -----> ',noise_im.min())
213 | if noise_im.shape[0]>args.test_datasize:
214 | noise_im = noise_im[0:args.test_datasize,:,:]
215 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
216 |
217 | whole_w = noise_im.shape[2]
218 | whole_h = noise_im.shape[1]
219 | whole_s = noise_im.shape[0]
220 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
221 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
222 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
223 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
224 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
225 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
226 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
227 | init_h = gap_h*x
228 | end_h = gap_h*x + img_h
229 | init_w = gap_w*y
230 | end_w = gap_w*y + img_w
231 | init_s = gap_s2*z
232 | end_s = gap_s2*z + img_s2
233 | single_coordinate['init_h'] = init_h
234 | single_coordinate['end_h'] = end_h
235 | single_coordinate['init_w'] = init_w
236 | single_coordinate['end_w'] = end_w
237 | single_coordinate['init_s'] = init_s
238 | single_coordinate['end_s'] = end_s
239 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
240 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
241 | # train_raw.append(noise_patch1.transpose(1,2,0))
242 | name_list.append(patch_name)
243 | # print(' single_coordinate -----> ',single_coordinate)
244 | coordinate_list[patch_name] = single_coordinate
245 | return name_list, noise_im, coordinate_list
246 |
247 | def train_preprocess_lessMemory(args):
248 | img_h = args.img_h
249 | img_w = args.img_w
250 | img_s2 = args.img_s*2
251 | gap_h = args.gap_h
252 | gap_w = args.gap_w
253 | gap_s2 = args.gap_s*2
254 | im_folder = 'datasets//'+args.datasets_folder
255 |
256 | name_list = []
257 | # train_raw = []
258 | coordinate_list={}
259 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
260 | # print('im_name -----> ',im_name)
261 | im_dir = im_folder+'//'+im_name
262 | noise_im = tiff.imread(im_dir)
263 | # print('noise_im shape -----> ',noise_im.shape)
264 | # print('noise_im max -----> ',noise_im.max())
265 | # print('noise_im min -----> ',noise_im.min())
266 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
267 |
268 | whole_w = noise_im.shape[2]
269 | whole_h = noise_im.shape[1]
270 | whole_s = noise_im.shape[0]
271 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
272 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
273 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
274 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
275 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
276 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
277 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
278 | init_h = gap_h*x
279 | end_h = gap_h*x + img_h
280 | init_w = gap_w*y
281 | end_w = gap_w*y + img_w
282 | init_s = gap_s2*z
283 | end_s = gap_s2*z + img_s2
284 | single_coordinate['init_h'] = init_h
285 | single_coordinate['end_h'] = end_h
286 | single_coordinate['init_w'] = init_w
287 | single_coordinate['end_w'] = end_w
288 | single_coordinate['init_s'] = init_s
289 | single_coordinate['end_s'] = end_s
290 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
291 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
292 | # train_raw.append(noise_patch1.transpose(1,2,0))
293 | name_list.append(patch_name)
294 | # print(' single_coordinate -----> ',single_coordinate)
295 | coordinate_list[patch_name] = single_coordinate
296 | return name_list, noise_im, coordinate_list
297 |
298 |
299 | def shuffle_datasets_lessMemory(name_list):
300 | index_list = list(range(0, len(name_list)))
301 | # print('index_list -----> ',index_list)
302 | random.shuffle(index_list)
303 | random_index_list = index_list
304 | # print('index_list -----> ',index_list)
305 | new_name_list = list(range(0, len(name_list)))
306 | for i in range(0,len(random_index_list)):
307 | new_name_list[i] = name_list[random_index_list[i]]
308 | return new_name_list
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 2, June 1991
3 |
4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6 | Everyone is permitted to copy and distribute verbatim copies
7 | of this license document, but changing it is not allowed.
8 |
9 | Preamble
10 |
11 | The licenses for most software are designed to take away your
12 | freedom to share and change it. By contrast, the GNU General Public
13 | License is intended to guarantee your freedom to share and change free
14 | software--to make sure the software is free for all its users. This
15 | General Public License applies to most of the Free Software
16 | Foundation's software and to any other program whose authors commit to
17 | using it. (Some other Free Software Foundation software is covered by
18 | the GNU Lesser General Public License instead.) You can apply it to
19 | your programs, too.
20 |
21 | When we speak of free software, we are referring to freedom, not
22 | price. Our General Public Licenses are designed to make sure that you
23 | have the freedom to distribute copies of free software (and charge for
24 | this service if you wish), that you receive source code or can get it
25 | if you want it, that you can change the software or use pieces of it
26 | in new free programs; and that you know you can do these things.
27 |
28 | To protect your rights, we need to make restrictions that forbid
29 | anyone to deny you these rights or to ask you to surrender the rights.
30 | These restrictions translate to certain responsibilities for you if you
31 | distribute copies of the software, or if you modify it.
32 |
33 | For example, if you distribute copies of such a program, whether
34 | gratis or for a fee, you must give the recipients all the rights that
35 | you have. You must make sure that they, too, receive or can get the
36 | source code. And you must show them these terms so they know their
37 | rights.
38 |
39 | We protect your rights with two steps: (1) copyright the software, and
40 | (2) offer you this license which gives you legal permission to copy,
41 | distribute and/or modify the software.
42 |
43 | Also, for each author's protection and ours, we want to make certain
44 | that everyone understands that there is no warranty for this free
45 | software. If the software is modified by someone else and passed on, we
46 | want its recipients to know that what they have is not the original, so
47 | that any problems introduced by others will not reflect on the original
48 | authors' reputations.
49 |
50 | Finally, any free program is threatened constantly by software
51 | patents. We wish to avoid the danger that redistributors of a free
52 | program will individually obtain patent licenses, in effect making the
53 | program proprietary. To prevent this, we have made it clear that any
54 | patent must be licensed for everyone's free use or not licensed at all.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | GNU GENERAL PUBLIC LICENSE
60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61 |
62 | 0. This License applies to any program or other work which contains
63 | a notice placed by the copyright holder saying it may be distributed
64 | under the terms of this General Public License. The "Program", below,
65 | refers to any such program or work, and a "work based on the Program"
66 | means either the Program or any derivative work under copyright law:
67 | that is to say, a work containing the Program or a portion of it,
68 | either verbatim or with modifications and/or translated into another
69 | language. (Hereinafter, translation is included without limitation in
70 | the term "modification".) Each licensee is addressed as "you".
71 |
72 | Activities other than copying, distribution and modification are not
73 | covered by this License; they are outside its scope. The act of
74 | running the Program is not restricted, and the output from the Program
75 | is covered only if its contents constitute a work based on the
76 | Program (independent of having been made by running the Program).
77 | Whether that is true depends on what the Program does.
78 |
79 | 1. You may copy and distribute verbatim copies of the Program's
80 | source code as you receive it, in any medium, provided that you
81 | conspicuously and appropriately publish on each copy an appropriate
82 | copyright notice and disclaimer of warranty; keep intact all the
83 | notices that refer to this License and to the absence of any warranty;
84 | and give any other recipients of the Program a copy of this License
85 | along with the Program.
86 |
87 | You may charge a fee for the physical act of transferring a copy, and
88 | you may at your option offer warranty protection in exchange for a fee.
89 |
90 | 2. You may modify your copy or copies of the Program or any portion
91 | of it, thus forming a work based on the Program, and copy and
92 | distribute such modifications or work under the terms of Section 1
93 | above, provided that you also meet all of these conditions:
94 |
95 | a) You must cause the modified files to carry prominent notices
96 | stating that you changed the files and the date of any change.
97 |
98 | b) You must cause any work that you distribute or publish, that in
99 | whole or in part contains or is derived from the Program or any
100 | part thereof, to be licensed as a whole at no charge to all third
101 | parties under the terms of this License.
102 |
103 | c) If the modified program normally reads commands interactively
104 | when run, you must cause it, when started running for such
105 | interactive use in the most ordinary way, to print or display an
106 | announcement including an appropriate copyright notice and a
107 | notice that there is no warranty (or else, saying that you provide
108 | a warranty) and that users may redistribute the program under
109 | these conditions, and telling the user how to view a copy of this
110 | License. (Exception: if the Program itself is interactive but
111 | does not normally print such an announcement, your work based on
112 | the Program is not required to print an announcement.)
113 |
114 | These requirements apply to the modified work as a whole. If
115 | identifiable sections of that work are not derived from the Program,
116 | and can be reasonably considered independent and separate works in
117 | themselves, then this License, and its terms, do not apply to those
118 | sections when you distribute them as separate works. But when you
119 | distribute the same sections as part of a whole which is a work based
120 | on the Program, the distribution of the whole must be on the terms of
121 | this License, whose permissions for other licensees extend to the
122 | entire whole, and thus to each and every part regardless of who wrote it.
123 |
124 | Thus, it is not the intent of this section to claim rights or contest
125 | your rights to work written entirely by you; rather, the intent is to
126 | exercise the right to control the distribution of derivative or
127 | collective works based on the Program.
128 |
129 | In addition, mere aggregation of another work not based on the Program
130 | with the Program (or with a work based on the Program) on a volume of
131 | a storage or distribution medium does not bring the other work under
132 | the scope of this License.
133 |
134 | 3. You may copy and distribute the Program (or a work based on it,
135 | under Section 2) in object code or executable form under the terms of
136 | Sections 1 and 2 above provided that you also do one of the following:
137 |
138 | a) Accompany it with the complete corresponding machine-readable
139 | source code, which must be distributed under the terms of Sections
140 | 1 and 2 above on a medium customarily used for software interchange; or,
141 |
142 | b) Accompany it with a written offer, valid for at least three
143 | years, to give any third party, for a charge no more than your
144 | cost of physically performing source distribution, a complete
145 | machine-readable copy of the corresponding source code, to be
146 | distributed under the terms of Sections 1 and 2 above on a medium
147 | customarily used for software interchange; or,
148 |
149 | c) Accompany it with the information you received as to the offer
150 | to distribute corresponding source code. (This alternative is
151 | allowed only for noncommercial distribution and only if you
152 | received the program in object code or executable form with such
153 | an offer, in accord with Subsection b above.)
154 |
155 | The source code for a work means the preferred form of the work for
156 | making modifications to it. For an executable work, complete source
157 | code means all the source code for all modules it contains, plus any
158 | associated interface definition files, plus the scripts used to
159 | control compilation and installation of the executable. However, as a
160 | special exception, the source code distributed need not include
161 | anything that is normally distributed (in either source or binary
162 | form) with the major components (compiler, kernel, and so on) of the
163 | operating system on which the executable runs, unless that component
164 | itself accompanies the executable.
165 |
166 | If distribution of executable or object code is made by offering
167 | access to copy from a designated place, then offering equivalent
168 | access to copy the source code from the same place counts as
169 | distribution of the source code, even though third parties are not
170 | compelled to copy the source along with the object code.
171 |
172 | 4. You may not copy, modify, sublicense, or distribute the Program
173 | except as expressly provided under this License. Any attempt
174 | otherwise to copy, modify, sublicense or distribute the Program is
175 | void, and will automatically terminate your rights under this License.
176 | However, parties who have received copies, or rights, from you under
177 | this License will not have their licenses terminated so long as such
178 | parties remain in full compliance.
179 |
180 | 5. You are not required to accept this License, since you have not
181 | signed it. However, nothing else grants you permission to modify or
182 | distribute the Program or its derivative works. These actions are
183 | prohibited by law if you do not accept this License. Therefore, by
184 | modifying or distributing the Program (or any work based on the
185 | Program), you indicate your acceptance of this License to do so, and
186 | all its terms and conditions for copying, distributing or modifying
187 | the Program or works based on it.
188 |
189 | 6. Each time you redistribute the Program (or any work based on the
190 | Program), the recipient automatically receives a license from the
191 | original licensor to copy, distribute or modify the Program subject to
192 | these terms and conditions. You may not impose any further
193 | restrictions on the recipients' exercise of the rights granted herein.
194 | You are not responsible for enforcing compliance by third parties to
195 | this License.
196 |
197 | 7. If, as a consequence of a court judgment or allegation of patent
198 | infringement or for any other reason (not limited to patent issues),
199 | conditions are imposed on you (whether by court order, agreement or
200 | otherwise) that contradict the conditions of this License, they do not
201 | excuse you from the conditions of this License. If you cannot
202 | distribute so as to satisfy simultaneously your obligations under this
203 | License and any other pertinent obligations, then as a consequence you
204 | may not distribute the Program at all. For example, if a patent
205 | license would not permit royalty-free redistribution of the Program by
206 | all those who receive copies directly or indirectly through you, then
207 | the only way you could satisfy both it and this License would be to
208 | refrain entirely from distribution of the Program.
209 |
210 | If any portion of this section is held invalid or unenforceable under
211 | any particular circumstance, the balance of the section is intended to
212 | apply and the section as a whole is intended to apply in other
213 | circumstances.
214 |
215 | It is not the purpose of this section to induce you to infringe any
216 | patents or other property right claims or to contest validity of any
217 | such claims; this section has the sole purpose of protecting the
218 | integrity of the free software distribution system, which is
219 | implemented by public license practices. Many people have made
220 | generous contributions to the wide range of software distributed
221 | through that system in reliance on consistent application of that
222 | system; it is up to the author/donor to decide if he or she is willing
223 | to distribute software through any other system and a licensee cannot
224 | impose that choice.
225 |
226 | This section is intended to make thoroughly clear what is believed to
227 | be a consequence of the rest of this License.
228 |
229 | 8. If the distribution and/or use of the Program is restricted in
230 | certain countries either by patents or by copyrighted interfaces, the
231 | original copyright holder who places the Program under this License
232 | may add an explicit geographical distribution limitation excluding
233 | those countries, so that distribution is permitted only in or among
234 | countries not thus excluded. In such case, this License incorporates
235 | the limitation as if written in the body of this License.
236 |
237 | 9. The Free Software Foundation may publish revised and/or new versions
238 | of the General Public License from time to time. Such new versions will
239 | be similar in spirit to the present version, but may differ in detail to
240 | address new problems or concerns.
241 |
242 | Each version is given a distinguishing version number. If the Program
243 | specifies a version number of this License which applies to it and "any
244 | later version", you have the option of following the terms and conditions
245 | either of that version or of any later version published by the Free
246 | Software Foundation. If the Program does not specify a version number of
247 | this License, you may choose any version ever published by the Free Software
248 | Foundation.
249 |
250 | 10. If you wish to incorporate parts of the Program into other free
251 | programs whose distribution conditions are different, write to the author
252 | to ask for permission. For software which is copyrighted by the Free
253 | Software Foundation, write to the Free Software Foundation; we sometimes
254 | make exceptions for this. Our decision will be guided by the two goals
255 | of preserving the free status of all derivatives of our free software and
256 | of promoting the sharing and reuse of software generally.
257 |
258 | NO WARRANTY
259 |
260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268 | REPAIR OR CORRECTION.
269 |
270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278 | POSSIBILITY OF SUCH DAMAGES.
279 |
280 | END OF TERMS AND CONDITIONS
281 |
282 | How to Apply These Terms to Your New Programs
283 |
284 | If you develop a new program, and you want it to be of the greatest
285 | possible use to the public, the best way to achieve this is to make it
286 | free software which everyone can redistribute and change under these terms.
287 |
288 | To do so, attach the following notices to the program. It is safest
289 | to attach them to the start of each source file to most effectively
290 | convey the exclusion of warranty; and each file should have at least
291 | the "copyright" line and a pointer to where the full notice is found.
292 |
293 |
294 | Copyright (C)
295 |
296 | This program is free software; you can redistribute it and/or modify
297 | it under the terms of the GNU General Public License as published by
298 | the Free Software Foundation; either version 2 of the License, or
299 | (at your option) any later version.
300 |
301 | This program is distributed in the hope that it will be useful,
302 | but WITHOUT ANY WARRANTY; without even the implied warranty of
303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
304 | GNU General Public License for more details.
305 |
306 | You should have received a copy of the GNU General Public License along
307 | with this program; if not, write to the Free Software Foundation, Inc.,
308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
309 |
310 | Also add information on how to contact you by electronic and paper mail.
311 |
312 | If the program is interactive, make it output a short notice like this
313 | when it starts in an interactive mode:
314 |
315 | Gnomovision version 69, Copyright (C) year name of author
316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
317 | This is free software, and you are welcome to redistribute it
318 | under certain conditions; type `show c' for details.
319 |
320 | The hypothetical commands `show w' and `show c' should show the appropriate
321 | parts of the General Public License. Of course, the commands you use may
322 | be called something other than `show w' and `show c'; they could even be
323 | mouse-clicks or menu items--whatever suits your program.
324 |
325 | You should also get your employer (if you work as a programmer) or your
326 | school, if any, to sign a "copyright disclaimer" for the program, if
327 | necessary. Here is a sample; alter the names:
328 |
329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program
330 | `Gnomovision' (which makes passes at compilers) written by James Hacker.
331 |
332 | , 1 April 1989
333 | Ty Coon, President of Vice
334 |
335 | This General Public License does not permit incorporating your program into
336 | proprietary programs. If your program is a subroutine library, you may
337 | consider it more useful to permit linking proprietary applications with the
338 | library. If this is what you want to do, use the GNU Lesser General
339 | Public License instead of this License.
340 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/model_3DUnet.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from buildingblocks import Encoder, Decoder, FinalConv, DoubleConv, ExtResNetBlock, SingleConv
7 | from utils import create_feature_maps
8 |
9 |
10 | class UNet3D(nn.Module):
11 | """
12 | 3DUnet model from
13 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
14 | `.
15 |
16 | Args:
17 | in_channels (int): number of input channels
18 | out_channels (int): number of output segmentation masks;
19 | Note that that the of out_channels might correspond to either
20 | different semantic classes or to different binary segmentation mask.
21 | It's up to the user of the class to interpret the out_channels and
22 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
23 | or BCEWithLogitsLoss (two-class) respectively)
24 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
25 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
26 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
27 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
28 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
29 | layer_order (string): determines the order of layers
30 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
31 | See `SingleConv` for more info
32 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
33 | num_groups (int): number of groups for the GroupNorm
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=64, layer_order='cr', num_groups=8,
37 | **kwargs):
38 | super(UNet3D, self).__init__()
39 |
40 | if isinstance(f_maps, int):
41 | # use 4 levels in the encoder path as suggested in the paper
42 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4)
43 |
44 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
45 | # uses DoubleConv as a basic_module for the Encoder
46 | encoders = []
47 | for i, out_feature_num in enumerate(f_maps):
48 | if i == 0:
49 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
50 | conv_layer_order=layer_order, num_groups=num_groups)
51 | else:
52 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
53 | conv_layer_order=layer_order, num_groups=num_groups)
54 | encoders.append(encoder)
55 |
56 | self.encoders = nn.ModuleList(encoders)
57 |
58 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
59 | # uses DoubleConv as a basic_module for the Decoder
60 | decoders = []
61 | reversed_f_maps = list(reversed(f_maps))
62 | for i in range(len(reversed_f_maps) - 1):
63 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
64 | out_feature_num = reversed_f_maps[i + 1]
65 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
66 | conv_layer_order=layer_order, num_groups=num_groups)
67 | decoders.append(decoder)
68 |
69 | self.decoders = nn.ModuleList(decoders)
70 |
71 | # in the last layer a 1×1 convolution reduces the number of output
72 | # channels to the number of labels
73 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
74 |
75 | if final_sigmoid:
76 | self.final_activation = nn.Sigmoid()
77 | else:
78 | self.final_activation = nn.Softmax(dim=1)
79 |
80 | def forward(self, x):
81 | # encoder part
82 | encoders_features = []
83 | for encoder in self.encoders:
84 | x = encoder(x)
85 | # reverse the encoder outputs to be aligned with the decoder
86 | encoders_features.insert(0, x)
87 |
88 | # remove the last encoder's output from the list
89 | # !!remember: it's the 1st in the list
90 | encoders_features = encoders_features[1:]
91 |
92 | # decoder part
93 | for decoder, encoder_features in zip(self.decoders, encoders_features):
94 | # pass the output from the corresponding encoder and the output
95 | # of the previous decoder
96 | x = decoder(encoder_features, x)
97 |
98 | x = self.final_conv(x)
99 |
100 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs
101 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
102 | # if not self.training:
103 | # x = self.final_activation(x)
104 |
105 | return x
106 |
107 |
108 | class ResidualUNet3D(nn.Module):
109 | """
110 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
111 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead
112 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
113 |
114 | Args:
115 | in_channels (int): number of input channels
116 | out_channels (int): number of output segmentation masks;
117 | Note that that the of out_channels might correspond to either
118 | different semantic classes or to different binary segmentation mask.
119 | It's up to the user of the class to interpret the out_channels and
120 | use the proper loss criterion during training (i.e. NLLLoss (multi-class)
121 | or BCELoss (two-class) respectively)
122 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
123 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5
124 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
125 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
126 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
127 | conv_layer_order (string): determines the order of layers
128 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
129 | See `SingleConv` for more info
130 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
131 | num_groups (int): number of groups for the GroupNorm
132 | skip_final_activation (bool): if True, skips the final normalization layer (sigmoid/softmax) and returns the
133 | logits directly
134 | """
135 |
136 | def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=32, conv_layer_order='cge', num_groups=8,
137 | skip_final_activation=False, **kwargs):
138 | super(ResidualUNet3D, self).__init__()
139 |
140 | if isinstance(f_maps, int):
141 | # use 5 levels in the encoder path as suggested in the paper
142 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5)
143 |
144 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
145 | # uses ExtResNetBlock as a basic_module for the Encoder
146 | encoders = []
147 | for i, out_feature_num in enumerate(f_maps):
148 | if i == 0:
149 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock,
150 | conv_layer_order=conv_layer_order, num_groups=num_groups)
151 | else:
152 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock,
153 | conv_layer_order=conv_layer_order, num_groups=num_groups)
154 | encoders.append(encoder)
155 |
156 | self.encoders = nn.ModuleList(encoders)
157 |
158 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
159 | # uses ExtResNetBlock as a basic_module for the Decoder
160 | decoders = []
161 | reversed_f_maps = list(reversed(f_maps))
162 | for i in range(len(reversed_f_maps) - 1):
163 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock,
164 | conv_layer_order=conv_layer_order, num_groups=num_groups)
165 | decoders.append(decoder)
166 |
167 | self.decoders = nn.ModuleList(decoders)
168 |
169 | # in the last layer a 1×1 convolution reduces the number of output
170 | # channels to the number of labels
171 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
172 |
173 | if not skip_final_activation:
174 | if final_sigmoid:
175 | self.final_activation = nn.Sigmoid()
176 | else:
177 | self.final_activation = nn.Softmax(dim=1)
178 | else:
179 | self.final_activation = None
180 |
181 | def forward(self, x):
182 | # encoder part
183 | encoders_features = []
184 | for encoder in self.encoders:
185 | x = encoder(x)
186 | # reverse the encoder outputs to be aligned with the decoder
187 | encoders_features.insert(0, x)
188 |
189 | # remove the last encoder's output from the list
190 | # !!remember: it's the 1st in the list
191 | encoders_features = encoders_features[1:]
192 |
193 | # decoder part
194 | for decoder, encoder_features in zip(self.decoders, encoders_features):
195 | # pass the output from the corresponding encoder and the output
196 | # of the previous decoder
197 | x = decoder(encoder_features, x)
198 |
199 | x = self.final_conv(x)
200 |
201 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs
202 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
203 | if not self.training and self.final_activation is not None:
204 | x = self.final_activation(x)
205 |
206 | return x
207 |
208 |
209 | class Noise2NoiseUNet3D(nn.Module):
210 | """
211 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
212 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead
213 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
214 |
215 | Args:
216 | in_channels (int): number of input channels
217 | out_channels (int): number of output segmentation masks;
218 | Note that that the of out_channels might correspond to either
219 | different semantic classes or to different binary segmentation mask.
220 | It's up to the user of the class to interpret the out_channels and
221 | use the proper loss criterion during training (i.e. NLLLoss (multi-class)
222 | or BCELoss (two-class) respectively)
223 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
224 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5
225 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
226 | num_groups (int): number of groups for the GroupNorm
227 | """
228 |
229 | def __init__(self, in_channels, out_channels, f_maps=16, num_groups=8, **kwargs):
230 | super(Noise2NoiseUNet3D, self).__init__()
231 |
232 | # Use LeakyReLU activation everywhere except the last layer
233 | conv_layer_order = 'clg'
234 |
235 | if isinstance(f_maps, int):
236 | # use 5 levels in the encoder path as suggested in the paper
237 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5)
238 |
239 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
240 | # uses DoubleConv as a basic_module for the Encoder
241 | encoders = []
242 | for i, out_feature_num in enumerate(f_maps):
243 | if i == 0:
244 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
245 | conv_layer_order=conv_layer_order, num_groups=num_groups)
246 | else:
247 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
248 | conv_layer_order=conv_layer_order, num_groups=num_groups)
249 | encoders.append(encoder)
250 |
251 | self.encoders = nn.ModuleList(encoders)
252 |
253 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
254 | # uses DoubleConv as a basic_module for the Decoder
255 | decoders = []
256 | reversed_f_maps = list(reversed(f_maps))
257 | for i in range(len(reversed_f_maps) - 1):
258 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
259 | out_feature_num = reversed_f_maps[i + 1]
260 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
261 | conv_layer_order=conv_layer_order, num_groups=num_groups)
262 | decoders.append(decoder)
263 |
264 | self.decoders = nn.ModuleList(decoders)
265 |
266 | # 1x1x1 conv + simple ReLU in the final convolution
267 | self.final_conv = SingleConv(f_maps[0], out_channels, kernel_size=1, order='cr', padding=0)
268 |
269 | def forward(self, x):
270 | # encoder part
271 | encoders_features = []
272 | for encoder in self.encoders:
273 | x = encoder(x)
274 | # reverse the encoder outputs to be aligned with the decoder
275 | encoders_features.insert(0, x)
276 |
277 | # remove the last encoder's output from the list
278 | # !!remember: it's the 1st in the list
279 | encoders_features = encoders_features[1:]
280 |
281 | # decoder part
282 | for decoder, encoder_features in zip(self.decoders, encoders_features):
283 | # pass the output from the corresponding encoder and the output
284 | # of the previous decoder
285 | x = decoder(encoder_features, x)
286 |
287 | x = self.final_conv(x)
288 |
289 | return x
290 |
291 |
292 | def get_model(config):
293 | def _model_class(class_name):
294 | m = importlib.import_module('unet3d.model') #向a模块中导入c.py中的对象
295 | clazz = getattr(m, class_name) #getattr() 函数用于返回一个对象属性值。
296 | return clazz
297 |
298 | assert 'model' in config, 'Could not find model configuration'
299 | model_config = config['model']
300 | model_class = _model_class(model_config['name'])
301 | return model_class(**model_config)
302 |
303 |
304 | ###############################################Supervised Tags 3DUnet###################################################
305 |
306 | class TagsUNet3D(nn.Module):
307 | """
308 | Supervised tags 3DUnet
309 | Args:
310 | in_channels (int): number of input channels
311 | out_channels (int): number of output channels; since most often we're trying to learn
312 | 3D unit vectors we use 3 as a default value
313 | output_heads (int): number of output heads from the network, each head corresponds to different
314 | semantic tag/direction to be learned
315 | conv_layer_order (string): determines the order of layers
316 | in `DoubleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
317 | See `DoubleConv` for more info
318 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
319 | """
320 |
321 | def __init__(self, in_channels, out_channels=3, output_heads=1, conv_layer_order='crg', init_channel_number=32,
322 | **kwargs):
323 | super(TagsUNet3D, self).__init__()
324 |
325 | # number of groups for the GroupNorm
326 | num_groups = min(init_channel_number // 2, 32)
327 |
328 | # encoder path consist of 4 subsequent Encoder modules
329 | # the number of features maps is the same as in the paper
330 | self.encoders = nn.ModuleList([
331 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order=conv_layer_order,
332 | num_groups=num_groups),
333 | Encoder(init_channel_number, 2 * init_channel_number, conv_layer_order=conv_layer_order,
334 | num_groups=num_groups),
335 | Encoder(2 * init_channel_number, 4 * init_channel_number, conv_layer_order=conv_layer_order,
336 | num_groups=num_groups),
337 | Encoder(4 * init_channel_number, 8 * init_channel_number, conv_layer_order=conv_layer_order,
338 | num_groups=num_groups)
339 | ])
340 |
341 | self.decoders = nn.ModuleList([
342 | Decoder(4 * init_channel_number + 8 * init_channel_number, 4 * init_channel_number,
343 | conv_layer_order=conv_layer_order, num_groups=num_groups),
344 | Decoder(2 * init_channel_number + 4 * init_channel_number, 2 * init_channel_number,
345 | conv_layer_order=conv_layer_order, num_groups=num_groups),
346 | Decoder(init_channel_number + 2 * init_channel_number, init_channel_number,
347 | conv_layer_order=conv_layer_order, num_groups=num_groups)
348 | ])
349 |
350 | self.final_heads = nn.ModuleList(
351 | [FinalConv(init_channel_number, out_channels, num_groups=num_groups) for _ in
352 | range(output_heads)])
353 |
354 | def forward(self, x):
355 | # encoder part
356 | encoders_features = []
357 | for encoder in self.encoders:
358 | x = encoder(x)
359 | # reverse the encoder outputs to be aligned with the decoder
360 | encoders_features.insert(0, x)
361 |
362 | # remove the last encoder's output from the list
363 | # !!remember: it's the 1st in the list
364 | encoders_features = encoders_features[1:]
365 |
366 | # decoder part
367 | for decoder, encoder_features in zip(self.decoders, encoders_features):
368 | # pass the output from the corresponding encoder and the output
369 | # of the previous decoder
370 | x = decoder(encoder_features, x)
371 |
372 | # apply final layer per each output head
373 | tags = [final_head(x) for final_head in self.final_heads]
374 |
375 | # normalize directions with L2 norm
376 | return [tag / torch.norm(tag, p=2, dim=1).detach().clamp(min=1e-8) for tag in tags]
377 |
378 |
379 | ################################################Distance transform 3DUNet##############################################
380 | class DistanceTransformUNet3D(nn.Module):
381 | """
382 | Predict Distance Transform to the boundary signal based on the output from the Tags3DUnet. Fore training use either:
383 | 1. PixelWiseCrossEntropyLoss if the distance transform is quantized (classification)
384 | 2. MSELoss if the distance transform is continuous (regression)
385 | Args:
386 | in_channels (int): number of input channels
387 | out_channels (int): number of output segmentation masks;
388 | Note that that the of out_channels might correspond to either
389 | different semantic classes or to different binary segmentation mask.
390 | It's up to the user of the class to interpret the out_channels and
391 | use the proper loss criterion during training (i.e. NLLLoss (multi-class)
392 | or BCELoss (two-class) respectively)
393 | final_sigmoid (bool): 'sigmoid'/'softmax' whether element-wise nn.Sigmoid or nn.Softmax should be applied after
394 | the final 1x1 convolution
395 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
396 | """
397 |
398 | def __init__(self, in_channels, out_channels, final_sigmoid, init_channel_number=32, **kwargs):
399 | super(DistanceTransformUNet3D, self).__init__()
400 |
401 | # number of groups for the GroupNorm
402 | num_groups = min(init_channel_number // 2, 32)
403 |
404 | # encoder path consist of 4 subsequent Encoder modules
405 | # the number of features maps is the same as in the paper
406 | self.encoders = nn.ModuleList([
407 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order='crg',
408 | num_groups=num_groups),
409 | Encoder(init_channel_number, 2 * init_channel_number, pool_type='avg', conv_layer_order='crg',
410 | num_groups=num_groups)
411 | ])
412 |
413 | self.decoders = nn.ModuleList([
414 | Decoder(3 * init_channel_number, init_channel_number, conv_layer_order='crg', num_groups=num_groups)
415 | ])
416 |
417 | # in the last layer a 1×1 convolution reduces the number of output
418 | # channels to the number of labels
419 | self.final_conv = nn.Conv3d(init_channel_number, out_channels, 1)
420 |
421 | if final_sigmoid:
422 | self.final_activation = nn.Sigmoid()
423 | else:
424 | self.final_activation = nn.Softmax(dim=1)
425 |
426 | def forward(self, inputs):
427 | # allow multiple heads
428 | if isinstance(inputs, list) or isinstance(inputs, tuple):
429 | x = torch.cat(inputs, dim=1)
430 | else:
431 | x = inputs
432 |
433 | # encoder part
434 | encoders_features = []
435 | for encoder in self.encoders:
436 | x = encoder(x)
437 | # reverse the encoder outputs to be aligned with the decoder
438 | encoders_features.insert(0, x)
439 |
440 | # remove the last encoder's output from the list
441 | # !!remember: it's the 1st in the list
442 | encoders_features = encoders_features[1:]
443 |
444 | # decoder part
445 | for decoder, encoder_features in zip(self.decoders, encoders_features):
446 | # pass the output from the corresponding encoder and the output
447 | # of the previous decoder
448 | x = decoder(encoder_features, x)
449 |
450 | # apply final 1x1 convolution
451 | x = self.final_conv(x)
452 |
453 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs
454 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
455 | if not self.training:
456 | x = self.final_activation(x)
457 |
458 | return x
459 |
460 |
461 | class EndToEndDTUNet3D(nn.Module):
462 | def __init__(self, tags_in_channels, tags_out_channels, tags_output_heads, tags_init_channel_number,
463 | dt_in_channels, dt_out_channels, dt_final_sigmoid, dt_init_channel_number,
464 | tags_net_path=None, dt_net_path=None, **kwargs):
465 | super(EndToEndDTUNet3D, self).__init__()
466 |
467 | self.tags_net = TagsUNet3D(tags_in_channels, tags_out_channels, tags_output_heads,
468 | init_channel_number=tags_init_channel_number)
469 | if tags_net_path is not None:
470 | # load pre-trained TagsUNet3D
471 | self.tags_net = self._load_net(tags_net_path, self.tags_net)
472 |
473 | self.dt_net = DistanceTransformUNet3D(dt_in_channels, dt_out_channels, dt_final_sigmoid,
474 | init_channel_number=dt_init_channel_number)
475 | if dt_net_path is not None:
476 | # load pre-trained DistanceTransformUNet3D
477 | self.dt_net = self._load_net(dt_net_path, self.dt_net)
478 |
479 | @staticmethod
480 | def _load_net(checkpoint_path, model):
481 | state = torch.load(checkpoint_path)
482 | model.load_state_dict(state['model_state_dict'])
483 | return model
484 |
485 | def forward(self, x):
486 | x = self.tags_net(x)
487 | return self.dt_net(x)
488 |
--------------------------------------------------------------------------------
/DeepCAD_pytorch/data_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import os
4 | import tifffile as tiff
5 | import time
6 | import datetime
7 | import random
8 | from skimage import io
9 | import logging
10 | import math
11 |
12 |
13 | def shuffle_datasets(train_raw, train_GT, name_list):
14 | index_list = list(range(0, len(name_list)))
15 | # print('index_list -----> ',index_list)
16 | random.shuffle(index_list)
17 | random_index_list = index_list
18 | # print('index_list -----> ',index_list)
19 | new_name_list = list(range(0, len(name_list)))
20 | train_raw = np.array(train_raw)
21 | # print('train_raw shape -----> ',train_raw.shape)
22 | train_GT = np.array(train_GT)
23 | # print('train_GT shape -----> ',train_GT.shape)
24 | new_train_raw = train_raw
25 | new_train_GT = train_GT
26 | for i in range(0,len(random_index_list)):
27 | # print('i -----> ',i)
28 | new_train_raw[i,:,:,:] = train_raw[random_index_list[i],:,:,:]
29 | new_train_GT[i,:,:,:] = train_GT[random_index_list[i],:,:,:]
30 | new_name_list[i] = name_list[random_index_list[i]]
31 | # new_train_raw = np.expand_dims(new_train_raw, 4)
32 | # new_train_GT = np.expand_dims(new_train_GT, 4)
33 | return new_train_raw, new_train_GT, new_name_list
34 |
35 | def train_preprocess(args):
36 | img_h = args.img_h
37 | img_w = args.img_w
38 | img_s2 = args.img_s*2
39 | gap_h = args.gap_h
40 | gap_w = args.gap_w
41 | gap_s2 = args.gap_s*2
42 | im_folder = 'datasets//'+args.datasets_folder
43 |
44 | name_list = []
45 | train_raw = []
46 | train_GT = []
47 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
48 | print('im_name -----> ',im_name)
49 | im_dir = im_folder+'//'+im_name
50 | noise_im = tiff.imread(im_dir)
51 | print('noise_im shape -----> ',noise_im.shape)
52 | # print('noise_im max -----> ',noise_im.max())
53 | # print('noise_im min -----> ',noise_im.min())
54 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
55 |
56 | whole_w = noise_im.shape[2]
57 | whole_h = noise_im.shape[1]
58 | whole_s = noise_im.shape[0]
59 | print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
60 | print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
61 | print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
62 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
63 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
64 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
65 | init_h = gap_h*x
66 | end_h = gap_h*x + img_h
67 | init_w = gap_w*y
68 | end_w = gap_w*y + img_w
69 | init_s = gap_s2*z
70 | end_s = gap_s2*z + img_s2
71 | noise_patch1 = noise_im[init_s:end_s:2,init_h:end_h,init_w:end_w]
72 | noise_patch2 = noise_im[init_s+1:end_s:2,init_h:end_h,init_w:end_w]
73 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
74 | train_raw.append(noise_patch1.transpose(1,2,0))
75 | train_GT.append(noise_patch2.transpose(1,2,0))
76 | name_list.append(patch_name)
77 | return train_raw, train_GT, name_list, noise_im
78 |
79 | def test_preprocess(args):
80 | img_h = args.img_h
81 | img_w = args.img_w
82 | img_s2 = args.img_s
83 | gap_h = args.gap_h
84 | gap_w = args.gap_w
85 | gap_s2 = args.gap_s
86 | im_folder = 'datasets//'+args.datasets_folder
87 |
88 | name_list = []
89 | train_raw = []
90 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
91 | # print('im_name -----> ',im_name)
92 | im_dir = im_folder+'//'+im_name
93 | noise_im = tiff.imread(im_dir)
94 | # print('noise_im shape -----> ',noise_im.shape)
95 | # print('noise_im max -----> ',noise_im.max())
96 | # print('noise_im min -----> ',noise_im.min())
97 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
98 |
99 | whole_w = noise_im.shape[2]
100 | whole_h = noise_im.shape[1]
101 | whole_s = noise_im.shape[0]
102 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
103 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
104 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
105 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
106 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
107 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
108 | init_h = gap_h*x
109 | end_h = gap_h*x + img_h
110 | init_w = gap_w*y
111 | end_w = gap_w*y + img_w
112 | init_s = gap_s2*z
113 | end_s = gap_s2*z + img_s2
114 | noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
115 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
116 | train_raw.append(noise_patch1.transpose(1,2,0))
117 | name_list.append(patch_name)
118 | return train_raw, name_list, noise_im
119 |
120 | def test_preprocess_lessMemory (args):
121 | img_h = args.img_h
122 | img_w = args.img_w
123 | img_s2 = args.img_s
124 | gap_h = args.gap_h
125 | gap_w = args.gap_w
126 | gap_s2 = args.gap_s
127 | im_folder = 'datasets//'+args.datasets_folder
128 |
129 | name_list = []
130 | # train_raw = []
131 | coordinate_list={}
132 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
133 | # print('im_name -----> ',im_name)
134 | im_dir = im_folder+'//'+im_name
135 | noise_im = tiff.imread(im_dir)
136 | # print('noise_im shape -----> ',noise_im.shape)
137 | # print('noise_im max -----> ',noise_im.max())
138 | # print('noise_im min -----> ',noise_im.min())
139 | if noise_im.shape[0]>args.test_datasize:
140 | noise_im = noise_im[0:args.test_datasize,:,:]
141 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
142 |
143 | whole_w = noise_im.shape[2]
144 | whole_h = noise_im.shape[1]
145 | whole_s = noise_im.shape[0]
146 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
147 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
148 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
149 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
150 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
151 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
152 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
153 | init_h = gap_h*x
154 | end_h = gap_h*x + img_h
155 | init_w = gap_w*y
156 | end_w = gap_w*y + img_w
157 | init_s = gap_s2*z
158 | end_s = gap_s2*z + img_s2
159 | single_coordinate['init_h'] = init_h
160 | single_coordinate['end_h'] = end_h
161 | single_coordinate['init_w'] = init_w
162 | single_coordinate['end_w'] = end_w
163 | single_coordinate['init_s'] = init_s
164 | single_coordinate['end_s'] = end_s
165 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
166 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
167 | # train_raw.append(noise_patch1.transpose(1,2,0))
168 | name_list.append(patch_name)
169 | # print(' single_coordinate -----> ',single_coordinate)
170 | coordinate_list[patch_name] = single_coordinate
171 | return name_list, noise_im, coordinate_list
172 |
173 | def get_gap_s(args, img, stack_num):
174 | whole_w = img.shape[2]
175 | whole_h = img.shape[1]
176 | whole_s = img.shape[0]
177 | print('whole_w -----> ',whole_w)
178 | print('whole_h -----> ',whole_h)
179 | print('whole_s -----> ',whole_s)
180 | w_num = math.floor((whole_w-args.img_w)/args.gap_w)+1
181 | h_num = math.floor((whole_h-args.img_h)/args.gap_h)+1
182 | s_num = math.ceil(args.train_datasets_size/w_num/h_num/stack_num)
183 | print('w_num -----> ',w_num)
184 | print('h_num -----> ',h_num)
185 | print('s_num -----> ',s_num)
186 | gap_s = math.floor((whole_s-args.img_s*2)/(s_num-1))
187 | print('gap_s -----> ',gap_s)
188 | return gap_s
189 |
190 | def train_preprocess_lessMemory(args):
191 | img_h = args.img_h
192 | img_w = args.img_w
193 | img_s2 = args.img_s*2
194 | gap_h = args.gap_h
195 | gap_w = args.gap_w
196 | gap_s2 = args.gap_s*2
197 | im_folder = args.datasets_path+'//'+args.datasets_folder
198 |
199 | name_list = []
200 | # train_raw = []
201 | coordinate_list={}
202 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
203 | # print('im_name -----> ',im_name)
204 | im_dir = im_folder+'//'+im_name
205 | noise_im = tiff.imread(im_dir)
206 | if noise_im.shape[0]>args.select_img_num:
207 | noise_im = noise_im[0:args.select_img_num,:,:]
208 | gap_s2 = get_gap_s(args, noise_im)
209 | # print('noise_im shape -----> ',noise_im.shape)
210 | # print('noise_im max -----> ',noise_im.max())
211 | # print('noise_im min -----> ',noise_im.min())
212 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
213 |
214 | whole_w = noise_im.shape[2]
215 | whole_h = noise_im.shape[1]
216 | whole_s = noise_im.shape[0]
217 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
218 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
219 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
220 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
221 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
222 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
223 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
224 | init_h = gap_h*x
225 | end_h = gap_h*x + img_h
226 | init_w = gap_w*y
227 | end_w = gap_w*y + img_w
228 | init_s = gap_s2*z
229 | end_s = gap_s2*z + img_s2
230 | single_coordinate['init_h'] = init_h
231 | single_coordinate['end_h'] = end_h
232 | single_coordinate['init_w'] = init_w
233 | single_coordinate['end_w'] = end_w
234 | single_coordinate['init_s'] = init_s
235 | single_coordinate['end_s'] = end_s
236 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
237 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
238 | # train_raw.append(noise_patch1.transpose(1,2,0))
239 | name_list.append(patch_name)
240 | # print(' single_coordinate -----> ',single_coordinate)
241 | coordinate_list[patch_name] = single_coordinate
242 | return name_list, noise_im, coordinate_list
243 |
244 | def train_preprocess_lessMemoryMulStacks(args):
245 | img_h = args.img_h
246 | img_w = args.img_w
247 | img_s2 = args.img_s*2
248 | gap_h = args.gap_h
249 | gap_w = args.gap_w
250 | gap_s2 = args.gap_s*2
251 | im_folder = args.datasets_path+'//'+args.datasets_folder
252 |
253 | name_list = []
254 | # train_raw = []
255 | coordinate_list={}
256 |
257 | print('list(os.walk(im_folder, topdown=False)) -----> ',list(os.walk(im_folder, topdown=False)))
258 | stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1])
259 | print('stack_num -----> ',stack_num)
260 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
261 | # print('im_name -----> ',im_name)
262 | im_dir = im_folder+'//'+im_name
263 | noise_im = tiff.imread(im_dir)
264 | if noise_im.shape[0]>args.select_img_num:
265 | noise_im = noise_im[0:args.select_img_num,:,:]
266 | gap_s2 = get_gap_s(args, noise_im, stack_num)
267 | # print('noise_im shape -----> ',noise_im.shape)
268 | # print('noise_im max -----> ',noise_im.max())
269 | # print('noise_im min -----> ',noise_im.min())
270 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
271 |
272 | whole_w = noise_im.shape[2]
273 | whole_h = noise_im.shape[1]
274 | whole_s = noise_im.shape[0]
275 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
276 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
277 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
278 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
279 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
280 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
281 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
282 | init_h = gap_h*x
283 | end_h = gap_h*x + img_h
284 | init_w = gap_w*y
285 | end_w = gap_w*y + img_w
286 | init_s = gap_s2*z
287 | end_s = gap_s2*z + img_s2
288 | single_coordinate['init_h'] = init_h
289 | single_coordinate['end_h'] = end_h
290 | single_coordinate['init_w'] = init_w
291 | single_coordinate['end_w'] = end_w
292 | single_coordinate['init_s'] = init_s
293 | single_coordinate['end_s'] = end_s
294 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
295 | patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
296 | # train_raw.append(noise_patch1.transpose(1,2,0))
297 | name_list.append(patch_name)
298 | # print(' single_coordinate -----> ',single_coordinate)
299 | coordinate_list[patch_name] = single_coordinate
300 | return name_list, noise_im, coordinate_list
301 |
302 | def shuffle_datasets_lessMemory(name_list):
303 | index_list = list(range(0, len(name_list)))
304 | # print('index_list -----> ',index_list)
305 | random.shuffle(index_list)
306 | random_index_list = index_list
307 | # print('index_list -----> ',index_list)
308 | new_name_list = list(range(0, len(name_list)))
309 | for i in range(0,len(random_index_list)):
310 | new_name_list[i] = name_list[random_index_list[i]]
311 | return new_name_list
312 |
313 | def test_preprocess_lessMemoryPadding (args):
314 | img_h = args.img_h
315 | img_w = args.img_w
316 | img_s2 = args.img_s
317 | gap_h = args.gap_h
318 | gap_w = args.gap_w
319 | gap_s2 = args.gap_s
320 | im_folder = 'datasets//'+args.datasets_folder
321 |
322 | name_list = []
323 | # train_raw = []
324 | coordinate_list={}
325 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
326 | # print('im_name -----> ',im_name)
327 | im_dir = im_folder+'//'+im_name
328 | raw_noise_im = tiff.imread(im_dir)
329 | if raw_noise_im.shape[0]>args.test_datasize:
330 | raw_noise_im = raw_noise_im[0:args.test_datasize,:,:]
331 | raw_noise_im = (raw_noise_im-raw_noise_im.min()).astype(np.float32)/args.normalize_factor
332 |
333 | print('raw_noise_im shape -----> ',raw_noise_im.shape)
334 | noise_im_w = math.ceil((raw_noise_im.shape[2]-img_w)/gap_w)*gap_w+img_w
335 | noise_im_h = math.ceil((raw_noise_im.shape[1]-img_h)/gap_h)*gap_h+img_h
336 | noise_im_s = math.ceil((raw_noise_im.shape[0]-img_s2)/gap_s2)*gap_s2+img_s2
337 | noise_im = np.zeros([noise_im_s,noise_im_h,noise_im_w])
338 | noise_im[0:raw_noise_im.shape[0], 0:raw_noise_im.shape[1], 0:raw_noise_im.shape[2]]=raw_noise_im
339 | noise_im = noise_im.astype(np.float32)
340 | print('noise_im shape -----> ',noise_im.shape)
341 | # print('noise_im max -----> ',noise_im.max())
342 | # print('noise_im min -----> ',noise_im.min())
343 |
344 | whole_w = noise_im.shape[2]
345 | whole_h = noise_im.shape[1]
346 | whole_s = noise_im.shape[0]
347 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
348 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
349 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
350 | for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
351 | for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
352 | for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
353 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
354 | init_h = gap_h*x
355 | end_h = gap_h*x + img_h
356 | init_w = gap_w*y
357 | end_w = gap_w*y + img_w
358 | init_s = gap_s2*z
359 | end_s = gap_s2*z + img_s2
360 | single_coordinate['init_h'] = init_h
361 | single_coordinate['end_h'] = end_h
362 | single_coordinate['init_w'] = init_w
363 | single_coordinate['end_w'] = end_w
364 | single_coordinate['init_s'] = init_s
365 | single_coordinate['end_s'] = end_s
366 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
367 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
368 | # train_raw.append(noise_patch1.transpose(1,2,0))
369 | name_list.append(patch_name)
370 | # print(' single_coordinate -----> ',single_coordinate)
371 | coordinate_list[patch_name] = single_coordinate
372 | return name_list, noise_im, coordinate_list, raw_noise_im
373 |
374 | def test_preprocess_lessMemoryNoTail (args):
375 | img_h = args.img_h
376 | img_w = args.img_w
377 | img_s2 = args.img_s
378 | gap_h = args.gap_h
379 | gap_w = args.gap_w
380 | gap_s2 = args.gap_s
381 | cut_w = (img_w - gap_w)/2
382 | cut_h = (img_h - gap_h)/2
383 | cut_s = (img_s2 - gap_s2)/2
384 | im_folder = args.datasets_path+'//'+args.datasets_folder
385 |
386 | name_list = []
387 | # train_raw = []
388 | coordinate_list={}
389 | for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
390 | # print('im_name -----> ',im_name)
391 | im_dir = im_folder+'//'+im_name
392 | noise_im = tiff.imread(im_dir)
393 | # print('noise_im shape -----> ',noise_im.shape)
394 | # print('noise_im max -----> ',noise_im.max())
395 | # print('noise_im min -----> ',noise_im.min())
396 | if noise_im.shape[0]>args.test_datasize:
397 | noise_im = noise_im[0:args.test_datasize,:,:]
398 | noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
399 |
400 | whole_w = noise_im.shape[2]
401 | whole_h = noise_im.shape[1]
402 | whole_s = noise_im.shape[0]
403 |
404 | num_w = math.ceil((whole_w-img_w+gap_w)/gap_w)
405 | num_h = math.ceil((whole_h-img_h+gap_h)/gap_h)
406 | num_s = math.ceil((whole_s-img_s2+gap_s2)/gap_s2)
407 | # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
408 | # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
409 | # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
410 | for x in range(0,num_h):
411 | for y in range(0,num_w):
412 | for z in range(0,num_s):
413 | single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
414 | if x != (num_h-1):
415 | init_h = gap_h*x
416 | end_h = gap_h*x + img_h
417 | elif x == (num_h-1):
418 | init_h = whole_h - img_h
419 | end_h = whole_h
420 |
421 | if y != (num_w-1):
422 | init_w = gap_w*y
423 | end_w = gap_w*y + img_w
424 | elif y == (num_w-1):
425 | init_w = whole_w - img_w
426 | end_w = whole_w
427 |
428 | if z != (num_s-1):
429 | init_s = gap_s2*z
430 | end_s = gap_s2*z + img_s2
431 | elif z == (num_s-1):
432 | init_s = whole_s - img_s2
433 | end_s = whole_s
434 | single_coordinate['init_h'] = init_h
435 | single_coordinate['end_h'] = end_h
436 | single_coordinate['init_w'] = init_w
437 | single_coordinate['end_w'] = end_w
438 | single_coordinate['init_s'] = init_s
439 | single_coordinate['end_s'] = end_s
440 |
441 | if y == 0:
442 | single_coordinate['stack_start_w'] = y*gap_w
443 | single_coordinate['stack_end_w'] = y*gap_w+img_w-cut_w
444 | single_coordinate['patch_start_w'] = 0
445 | single_coordinate['patch_end_w'] = img_w-cut_w
446 | elif y == num_w-1:
447 | single_coordinate['stack_start_w'] = whole_w-img_w+cut_w
448 | single_coordinate['stack_end_w'] = whole_w
449 | single_coordinate['patch_start_w'] = cut_w
450 | single_coordinate['patch_end_w'] = img_w
451 | else:
452 | single_coordinate['stack_start_w'] = y*gap_w+cut_w
453 | single_coordinate['stack_end_w'] = y*gap_w+img_w-cut_w
454 | single_coordinate['patch_start_w'] = cut_w
455 | single_coordinate['patch_end_w'] = img_w-cut_w
456 |
457 | if x == 0:
458 | single_coordinate['stack_start_h'] = x*gap_h
459 | single_coordinate['stack_end_h'] = x*gap_h+img_h-cut_h
460 | single_coordinate['patch_start_h'] = 0
461 | single_coordinate['patch_end_h'] = img_h-cut_h
462 | elif x == num_h-1:
463 | single_coordinate['stack_start_h'] = whole_h-img_h+cut_h
464 | single_coordinate['stack_end_h'] = whole_h
465 | single_coordinate['patch_start_h'] = cut_h
466 | single_coordinate['patch_end_h'] = img_h
467 | else:
468 | single_coordinate['stack_start_h'] = x*gap_h+cut_h
469 | single_coordinate['stack_end_h'] = x*gap_h+img_h-cut_h
470 | single_coordinate['patch_start_h'] = cut_h
471 | single_coordinate['patch_end_h'] = img_h-cut_h
472 |
473 | if z == 0:
474 | single_coordinate['stack_start_s'] = z*gap_s2
475 | single_coordinate['stack_end_s'] = z*gap_s2+img_s2-cut_s
476 | single_coordinate['patch_start_s'] = 0
477 | single_coordinate['patch_end_s'] = img_s2-cut_s
478 | elif z == num_s-1:
479 | single_coordinate['stack_start_s'] = whole_s-img_s2+cut_s
480 | single_coordinate['stack_end_s'] = whole_s
481 | single_coordinate['patch_start_s'] = cut_s
482 | single_coordinate['patch_end_s'] = img_s2
483 | else:
484 | single_coordinate['stack_start_s'] = z*gap_s2+cut_s
485 | single_coordinate['stack_end_s'] = z*gap_s2+img_s2-cut_s
486 | single_coordinate['patch_start_s'] = cut_s
487 | single_coordinate['patch_end_s'] = img_s2-cut_s
488 |
489 | # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
490 | patch_name = args.datasets_folder+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
491 | # train_raw.append(noise_patch1.transpose(1,2,0))
492 | name_list.append(patch_name)
493 | # print(' single_coordinate -----> ',single_coordinate)
494 | coordinate_list[patch_name] = single_coordinate
495 | return name_list, noise_im, coordinate_list
496 |
497 |
498 | # stack_start_w ,stack_end_w ,patch_start_w ,patch_end_w ,
499 | # stack_start_h ,stack_end_h ,patch_start_h ,patch_end_h ,
500 | # stack_start_s ,stack_end_s ,patch_start_s ,patch_end_s
501 |
502 |
--------------------------------------------------------------------------------