├── .gitignore
├── LICENSE
├── MainWindow_3D_PP.py
├── README.md
├── bin
├── bash_command.md
├── generate_train_config.py
├── mk_images
│ └── 20240520163736.png
├── preprocess.py
├── test_bash.py
└── train_bash.py
├── configs
├── EMPIAR_10045_preprocess.py
└── SHREC_2021_preprocess.py
├── dataset
└── dataloader_DynamicLoad.py
├── main.py
├── model
├── __init__.py
├── model_loader.py
└── residual_unet_att.py
├── options
└── option.py
├── requirement.txt
├── test.py
├── train.py
├── tutorials
├── A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md
├── A_tutorial_of_particlePicking_on_SHREC2021_dataset.md
└── images
│ ├── EMPIAR_10045_GIF
│ ├── Inference.gif
│ ├── Preprocessing.gif
│ ├── Training.gif
│ └── Visualization.gif
│ ├── Inference.gif
│ ├── Preprocessing.gif
│ ├── Training.gif
│ └── Visualization.gif
└── utils
├── __init__.py
├── colors.py
├── coordFormatConvert.py
├── coord_gen.py
├── coordconv_torch.py
├── coords2labels.py
├── coords_to_relion4.py
├── loss.py
├── metrics.py
├── misc.py
├── normalization.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_Store
2 | *.git
3 | *.idea
4 | bin/*.sh
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **DeepETpicker**
2 |
3 | A deep learning based open-source software with a friendly user interface to pick 3D particles rapidly and accurately from cryo-electron tomograms. With the advantages of weak labels, lightweight architecture and GPU-accelerated pooling operations, the cost of annotations and the time of computational inference are significantly reduced while the accuracy is greatly improved by applying a Gaussian-type mask and using a customized architecture design.
4 |
5 | [Guole Liu, Tongxin Niu, Mengxuan Qiu, Yun Zhu, Fei Sun, and Ge Yang, “DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning,” Nature Communications, vol. 15, no. 1, pp. 2090, 2024.](https://www.nature.com/articles/s41467-024-46041-0)
6 |
7 |
8 | **Note**: DeepETPicker is a Pytorch implementation.
9 |
10 | ## **Setup**
11 |
12 | ### **Prerequisites**
13 |
14 | - Linux (Ubuntu 18.04.5 LTS; CentOS)
15 | - NVIDIA GPU
16 |
17 | ### **Installation**
18 |
19 | #### **Option 1: Using conda**
20 |
21 | The following instructions assume that `pip` and `anaconda` or `miniconda` are available. In case you have a old deepetpicker environment installed, first remove the old one with:
22 |
23 | ```bash
24 | conda env remove --name deepetpicker
25 | ```
26 |
27 | The first step is to crate a new conda virtual environment:
28 |
29 | ```bash
30 | conda create -n deepetpicker -c conda-forge python=3.8.3 -y
31 | ```
32 |
33 | Activate the environment:
34 |
35 | ```bash
36 | conda activate deepetpicker
37 | ```
38 |
39 | To download the codes, please do:
40 | ```
41 | git clone https://github.com/cbmi-group/DeepETPicker
42 | cd DeepETPicker
43 | ```
44 |
45 | Next, install a custom pytorch and relative packages needed by DeepETPicker:
46 |
47 | ```bash
48 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch -y
49 |
50 | pip install -r requirement.txt
51 | ```
52 |
53 | - **if using GUI**
54 |
55 | To use GUI packages with Linux, you will need to install the following extended dependencies for Qt.
56 | 1. For `CentOS`, to install packages, please do:
57 | ```bash
58 | sudo yum install -y mesa-libGL libXext libSM libXrender fontconfig xcb-util-wm xcb-util-image xcb-util-keysyms xcb-util-renderutil libxkbcommon-x11
59 | ```
60 |
61 | 2. For `Ubuntu`, to install packages, please do:
62 | ```bash
63 | sudo apt-get install -y libgl1-mesa-glx libglib2.0-dev libsm6 libxrender1 libfontconfig1 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-render-util0 libxcb-shape0 libxcb-xinerama0 libxcb-xkb1 libxkbcommon-x11-dev libdbus-1-3
64 | ```
65 |
66 |
67 | To run the DeepETpicker, please do:
68 |
69 | ```bash
70 | conda activate deepetpicker
71 | python PATH_TO_DEEPETPICKER/main.py
72 | ```
73 |
74 | Note: `PATH_TO_DEEPETPICKER` is the corresponding directory where the code located.
75 | - **Non GUI version**
76 |
77 | In addition to the `GUI version` of DeepETPicker, we also provide a `non-GUI version` of DeepETPicker for people who understand python and deep-learning. It consists of four processes, including `preprocessing`, `train config generation`, `training` and `testing`. A sample tutorial can be found in `.bin/bash_command.md`.
78 |
79 | #### **Option 2:Using docker**
80 |
81 | The following steps are required in order to run DeepETPicker:
82 | 1. Install [Docker](https://www.docker.com/)
83 |
84 | Note: docker engine version shuold be >= 19.03. The size of Docker mirror of Deepetpicker is 7.21 GB, please ensure that there is enough memory space.
85 |
86 | 2. Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) for GPU support.
87 |
88 | 3. Download Docker image of DeepETPicker.
89 |
90 | ```bash
91 | docker pull docker.io/lgl603/deepetpicker:latest
92 | ```
93 |
94 | 4. Run the Docker image of DeepETPicker.
95 |
96 | ```bash
97 | docker run --gpus=all -itd \
98 | --restart=always \
99 | --shm-size=100G \
100 | -e DISPLAY=unix$DISPLAY \
101 | --name deepetpicker \
102 | -p 50022:22 \
103 | --mount type=bind,source='/host_path/to/data',target='/container_path/to/data' \
104 | lgl603/deepetpicker:latest
105 | ```
106 |
107 | - The option `--shm-size` is used to set the required size of shared momory of the docker containers.
108 | - The option `--mount` is used to mount a file or directory on the host machine into the Docker container, where `source='/host_path/to/data'` denotes denotes the data directory really existed in host machine. `target='/container_path/to/data'` is the data directory where the directory `'/host_path/to/data'` is mounted in the container.
109 |
110 | **Note: `'/host_path/to/data'` should be writable by running bash command `chmod -R 777 '/host_path/to/data'`. `'/host_path/to/data'` should be replaced by the data directory real existed in host machine. For convenience, `'/container_path/to/data'` can set the same as `'/host_path/to/data'`**
111 |
112 |
113 |
114 | 1. The DeepETPicker can be used directly in this machine, and it also can be used by a machine in the same LAN.
115 | - Directly open DeepETPicker in this machine:
116 | ```bash
117 | ssh -X test@'ip_address' DeepETPicker
118 | # where the 'ip_address' of DeepETPicker container can be obtained as follows:
119 | docker inspect --format='{{.NetworkSettings.IPAddress}}' deepetpicker
120 | ```
121 |
122 |
123 |
124 | - Connect to this server remotely and open DeepETPicker software via a client machine:
125 | ```bash
126 | ssh -X -p 50022 test@ip DeepETPicker
127 | ```
128 | Here `ip` is the IP address of the server machine,password is `password`.
129 |
130 | `Installation time`: the size of Docker mirror of Deepetpicker is 7.21 GB, and the installation time depends on your network speed. When the network speed is fast enough, it can be configured within a few minutes.
131 |
132 | ## **Particle picking tutorial**
133 |
134 | Detailed tutorials for two sample datasets of [SHREC2021](https://github.com/cbmi-group/DeepETPicker/blob/main/tutorials/A_tutorial_of_particlePicking_on_SHREC2021_dataset.md) and [EMPIAR-10045](https://github.com/cbmi-group/DeepETPicker/blob/main/tutorials/A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md) are provided. Main steps of DeepETPicker includeds preprocessing, traning of DeepETPicker, inference of DeepETPicker, and particle visualization.
135 |
136 | ### **Preprocessing**
137 | - Data preparation
138 |
139 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain:
140 | - tomograms(with extension .mrc or .rec)
141 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand.).
142 |
143 |
144 |
145 | Here, we provides two sample datasets of EMPIAR-10045 and SHREC_2021 for particle picking to enable you to learn the processing flow of DeepETPicker better and faster. The sample dataset can be download in one of two ways:
146 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi
147 | - Zeenodo Link: [https://zenodo.org/records/12512436](https://zenodo.org/records/12512436)
148 |
149 | - Data structure
150 |
151 | The data should be organized as follows:
152 | ```
153 | ├── /base/path
154 | │ ├── raw_data
155 | │ │ ├── tomo1.coords
156 | │ │ └── tomo1.mrc
157 | │ │ └── tomo2.coords
158 | │ │ └── tomo2.mrc
159 | │ │ └── tomo3.mrc
160 | │ │ └── tomo4.mrc
161 | │ │ └── ...
162 | ```
163 |
164 | For above data, `tomo1.mrc` and `tomo2.mrc` can be used as train/val dataset, since they all have coordinate files (matual annotation). If a tomogram has no matual annotation (such as `tomo3.mrc`), it only can be used as test dataset.
165 |
166 |
167 |
168 |
169 | - Input & Output
170 |
171 |
172 |

173 |
174 |
175 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows:
176 | - `input`
177 | - Dataset name: e.g. SHREC_2021_preprocess
178 | - Base path: path to base folder
179 | - Coords path: path to raw_data folder
180 | - Coords format: .csv, .coords or .txt
181 | - Tomogram path: path to raw_data folder
182 | - Tomogram format: .mrc or .rec
183 | - Number of classes: multiple classes of macromolecules also can be localized separately
184 | - `Output`
185 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter.
186 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma.
187 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page
188 |
189 |
190 | ### **Training of DeepETPicker**
191 |
192 |
193 |
194 | Note: Before `Training of DeepETPicker`, please do `Preprocessing` first to ensure that the basic parameters required for training are provided.
195 |
196 |
197 |

198 |
199 |
200 | In practice, default parameters can give you a good enough result.
201 |
202 | *Training parameter description:*
203 |
204 | - Dataset name: e.g. SHREC_2021_train
205 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}.
206 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids:
207 | - 0, 2, ...: different tomogram ids are separated with a comma.
208 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids.
209 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset.
210 | - Number of classes: particle classes you want to pick
211 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error.
212 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72.
213 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12.
214 | - Learning rate: the step size at each iteration while moving toward a minimum of a loss function.
215 | - Max epoch: total training epochs. The default value 60 is usually sufficient.
216 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
217 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again.
218 |
219 | ### **Inference of DeepETPicker**
220 |
221 |
222 |

223 |
224 |
225 | In practice, default parameters can give you a good enough result.
226 |
227 | *Inference parameter description:*
228 |
229 | - Train Configs: path to the configuration file which has been saved in the `Training` step
230 | - Networks weights: path to the model which has be generated in the `Training` step
231 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where S = N - 2*Pad_size.
232 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
233 |
234 | *Coord format conversion*
235 |
236 | The predicted coordinates with extension `*.coords` has four columns: `class_id, x, y, z`. To facilitate users to perform the subsequent subtomogram averaging, format conversion of coordinate file is provided.
237 |
238 | - Coords path: the path of coordinates data predicted by well-trained DeepETPicker.
239 | - Output format: three formats can be converted, including `*.box` for EMAN2, `*.star` for RELION, `*.coords` for RELION.
240 |
241 |
242 | Note: After converting the coordinates to `*.coords`, one can get the coordinates with relion4 format using the script in `https://github.com/cbmi-group/DeepETPicker/blob/main/utils/coords_to_relion4.py` directly.
243 |
244 |
245 | ### **Particle visualization and mantual picking**
246 |
247 |
248 |

249 |
250 |
251 | - *Showing Tomogram*
252 |
253 | You can click the `Load tomogram` button on this page to load the tomogram.
254 |
255 | - *Showing Labels*
256 |
257 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI.
258 |
259 | - *Parameter Adjustment*
260 |
261 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided:
262 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects
263 | - Contrast: when choosing hist-equ, histogram equalization can be performed
264 |
265 |
266 | - *Mantual picking*
267 |
268 | After loading the tomogram and pressing ‘enable’, you can pick particles manually by double-click the left mouse button on the slices. If you want to delete an error labeling, just right-click the mouse. You can specify a different category id per class. Always remember to save the resuls when you finish.
269 |
270 | - *Position Slider*
271 |
272 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used.
273 |
274 | # Troubleshooting
275 |
276 | If you encounter any problems during installation or use of DeepETPicker, please contact us by email [guole.liu@ia.ac.cn](guole.liu@ia.ac.cn). We will help you as soon as possible.
277 |
278 | # Citation
279 |
280 | If you use this code for your research, please cite our paper [DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning](https://www.nature.com/articles/s41467-024-46041-0).
281 |
282 | ```
283 | @article{DeepETPicker,
284 | title={DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning},
285 | author={Guole Liu, Tongxin Niu, Mengxuan Qiu, Yun Zhu, Fei Sun, and Ge Yang},
286 | journal={Nature Communications},
287 | year={2024}
288 | }
289 | ```
290 |
--------------------------------------------------------------------------------
/bin/bash_command.md:
--------------------------------------------------------------------------------
1 |
2 | In addition to the GUI version of DeepETPicker, we also provide a non-GUI version of DeepETPicker for people who understand python and deep-learning. It consists of four processes, including `preprocessing`, `train config generation`, `training` and `testing`.
3 |
4 | Note: If you are not familiar with python and deep learning, the GUI version is recommended.
5 |
6 | Firstly, enter the bash command directory by
7 | ```bash
8 | cd PATH_TO_DEEPETPICKER/bin
9 | ```
10 |
11 | where `PATH_TO_DEEPETPICKER` is the corresponding directory where the code located.
12 |
13 | ## Preprocessing
14 |
15 | ```bash
16 | python preprocess.py \
17 | --pre_configs 'PATH_TO_Preprocess_CONFIG'
18 | ```
19 |
20 | where `pre_configs` corresponds to the configuration file of preprocessing. We have provided a sample for EMPIAR-10045 dataset in `cofigs/EMPIAR_10045_preprocess.py`. The items of configuration file is the same as the generated file of `preprocessing` panel of GUI version. More details can be found in section `Preprocessing` of https://github.com/cbmi-group/DeepETPicker/tree/main.
21 |
22 | Note: `pre_configs` could also directly load the configuration file generated by `preprocessing` panel of GUI DeepETPicker.
23 |
24 | ## Generate configuration file for training and testing
25 |
26 | ```bash
27 | python generate_train_config.py \
28 | --pre_configs 'PATH_TO_Preprocess_CONFIG' \
29 | --dset_name 'Train_Config_Name' \
30 | --cfg_save_path 'Save_Path_for_Config_Name' \
31 | --train_set_ids '0' \
32 | --val_set_ids '0' \
33 | --batch_size 8 \
34 | --block_size 72 \
35 | --pad_size 12 \
36 | --learning_rate 0.001 \
37 | --max_epoch 60 \
38 | --threshold 0.5 \
39 | --gpu_id 0 1
40 | ```
41 |
42 | where `dset_name` and `cfg_save_path` are the name and save path of training configuration file, respectively. Other parameters are the same as the input of `Training` panel of GUI version. More details can be found in section `Training of DeepETPicker` of https://github.com/cbmi-group/DeepETPicker/tree/main.
43 |
44 | ## Training
45 |
46 | ```bash
47 | python train_bash.py \
48 | --train_configs 'Train_Config_Name'
49 | ```
50 |
51 | where `train_configs` is the train configuration file generated by `generate_train_config.py`.
52 |
53 | Note: `train_configs` could also directly load the configuration file generated by `Training` panel of GUI DeepETPicker.
54 |
55 | The checkpoints are saved in the sub-folder `runs` of `base_path`.
56 |
57 | ## Testing
58 |
59 | ```bash
60 | python test_bash.py \
61 | --train_configs 'Train_Config_Name' \
62 | --checkpoints 'Checkpoint_Path'
63 | ```
64 |
65 | The predicted results are saved in the sub-folder `runs` of `base_path`.
66 |
67 | 
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/bin/generate_train_config.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import json
4 | from os.path import dirname, abspath
5 | import importlib
6 |
7 | DeepETPickerHome = dirname(abspath(__file__))
8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0]
9 | sys.path.append(DeepETPickerHome)
10 | sys.path.append(os.path.split(DeepETPickerHome)[0])
11 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1])
12 |
13 | if __name__ == "__main__":
14 | options = option.BaseOptions()
15 | args = options.gather_options()
16 |
17 | with open(args.pre_configs, 'r') as f:
18 | train_config = json.loads(''.join(f.readlines()).lstrip('pre_config='))
19 | train_config['dset_name'] = args.dset_name
20 | train_config['coord_path'] = train_config['base_path'] + '/coords'
21 | train_config['tomo_path'] = train_config['base_path'] + '/data_std'
22 | train_config['label_name'] = train_config['label_type'] + f"{train_config['label_diameter']}"
23 | train_config['label_path'] = train_config['base_path'] + f"/{train_config['label_name']}"
24 | train_config['ocp_name'] = 'data_ocp'
25 | train_config['ocp_path'] = train_config['base_path'] + '/data_ocp'
26 | train_config['model_name'] = 'ResUNet'
27 | train_config['train_set_ids'] = args.train_set_ids
28 | train_config['val_set_ids'] = args.val_set_ids
29 | train_config['batch_size'] = args.batch_size
30 | train_config['patch_size'] = args.block_size
31 | train_config['padding_size'] = args.pad_size[0]
32 | train_config['lr'] = args.learning_rate
33 | train_config['max_epochs'] = args.max_epoch
34 | train_config['seg_thresh'] = args.threshold
35 | train_config['gpu_ids'] = ','.join([str(i) for i in args.gpu_id])
36 | print(train_config['gpu_ids'])
37 |
38 |
39 | with open(f"{args.cfg_save_path}/{args.dset_name}.py", 'w') as f:
40 | f.write("train_configs=")
41 | json.dump(train_config, f, separators=(',\n' + ' ' * len('train_configs={'), ': '))
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/bin/mk_images/20240520163736.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/bin/mk_images/20240520163736.png
--------------------------------------------------------------------------------
/bin/preprocess.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from os.path import dirname, abspath
4 | import importlib
5 | import json
6 |
7 | DeepETPickerHome = dirname(abspath(__file__))
8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0]
9 | sys.path.append(DeepETPickerHome)
10 | sys.path.append(os.path.split(DeepETPickerHome)[0])
11 | coords2labels = importlib.import_module(".utils.coords2labels", package=os.path.split(DeepETPickerHome)[1])
12 | coord_gen = importlib.import_module(f".utils.coord_gen", package=os.path.split(DeepETPickerHome)[1])
13 | norm = importlib.import_module(f".utils.normalization", package=os.path.split(DeepETPickerHome)[1])
14 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1])
15 |
16 | if __name__ == "__main__":
17 | options = option.BaseOptions()
18 | args = options.gather_options()
19 |
20 | with open(args.pre_configs, 'r') as f:
21 | pre_config = json.loads(''.join(f.readlines()).lstrip('pre_config='))
22 |
23 | # initial coords
24 | coord_gen.coords_gen_show(args=(pre_config["coord_path"],
25 | pre_config["coord_format"],
26 | pre_config["base_path"],
27 | None,
28 | )
29 | )
30 |
31 | # normalization
32 | norm.norm_show(args=(pre_config["tomo_path"],
33 | pre_config["tomo_format"],
34 | pre_config["base_path"],
35 | pre_config["norm_type"],
36 | None,
37 | )
38 | )
39 |
40 | # generate labels based on coords
41 | coords2labels.label_gen_show(args=(pre_config["base_path"],
42 | pre_config["coord_path"],
43 | pre_config["coord_format"],
44 | pre_config["tomo_path"],
45 | pre_config["tomo_format"],
46 | pre_config["num_cls"],
47 | pre_config["label_type"],
48 | pre_config["label_diameter"],
49 | None,
50 | )
51 | )
52 |
53 | # generate occupancy abased on coords
54 | coords2labels.label_gen_show(args=(pre_config["base_path"],
55 | pre_config["coord_path"],
56 | pre_config["coord_format"],
57 | pre_config["tomo_path"],
58 | pre_config["tomo_format"],
59 | pre_config["num_cls"],
60 | 'data_ocp',
61 | pre_config["ocp_diameter"],
62 | None,
63 | )
64 | )
65 |
--------------------------------------------------------------------------------
/bin/test_bash.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import os
4 | import importlib
5 | from os.path import dirname, abspath
6 | import numpy as np
7 | DeepETPickerHome = dirname(abspath(__file__))
8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0]
9 | sys.path.append(DeepETPickerHome)
10 | sys.path.append(os.path.split(DeepETPickerHome)[0])
11 | test = importlib.import_module(".test", package=os.path.split(DeepETPickerHome)[1])
12 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1])
13 |
14 | if __name__ == '__main__':
15 | options = option.BaseOptions()
16 | args = options.gather_options()
17 |
18 | # cofig
19 | with open(args.train_configs, 'r') as f:
20 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs='))
21 |
22 | # parameters
23 | args.use_bg = True
24 | args.use_IP = True
25 | args.use_coord = True
26 | args.test_use_pad = True
27 | args.use_seg = True
28 | args.meanPool_NMS = True
29 | args.f_maps = [24, 48, 72, 108]
30 | args.num_classes = cfg['num_cls']
31 | train_cls_num = cfg['num_cls']
32 | if args.num_classes == 1:
33 | args.use_sigmoid = True
34 | args.use_softmax = False
35 | else:
36 | train_cls_num = train_cls_num + 1
37 | args.use_sigmoid = False
38 | args.use_softmax = True
39 | args.batch_size = cfg['batch_size']
40 | args.block_size = cfg['patch_size']
41 | args.val_batch_size = args.batch_size
42 | args.val_block_size = args.block_size
43 | args.pad_size = [cfg['padding_size']]
44 | args.learning_rate = cfg['lr']
45 | args.max_epoch = cfg['max_epochs']
46 | args.threshold = cfg['seg_thresh']
47 | args.gpu_id = [int(i) for i in cfg['gpu_ids'].split(',')]
48 | args.test_mode = 'test_only'
49 | args.out_name = 'PredictedLabels'
50 | args.de_duplication = True
51 | args.de_dup_fmt = 'fmt4'
52 | args.mini_dist = sorted([int(i) // 2 + 1 for i in cfg['ocp_diameter'].split(',')])[0]
53 | args.data_split = [0, 1, 0, 1, 0, 1]
54 | args.configs = args.train_configs
55 | args.num_classes = train_cls_num
56 |
57 | # test_idxs
58 | dset_list = np.array(
59 | [i[:-(len(i.split('.')[-1]) + 1)] for i in os.listdir(cfg['tomo_path']) if cfg['tomo_format'] in i])
60 | dset_num = dset_list.shape[0]
61 | num_name = np.concatenate([np.arange(dset_num).reshape(-1, 1), dset_list.reshape(-1, 1)], axis=1)
62 | np.savetxt(os.path.join(cfg['tomo_path'], 'num_name.csv'),
63 | num_name,
64 | delimiter='\t',
65 | fmt='%s',
66 | newline='\n')
67 |
68 | # tomo_list = [i for i in os.listdir(cfg[f"{cfg['base_path']}/data_std"]) if cfg['tomo_format'] in i]
69 | tomo_list = np.loadtxt(f"{cfg['base_path']}/data_std/num_name.csv",
70 | delimiter='\t',
71 | dtype=str)
72 | args.test_idxs = np.arange(len(tomo_list))
73 |
74 | for k, v in sorted(vars(args).items()):
75 | print(k, '=', v)
76 |
77 | # Testing
78 | test.test_func(args, stdout=None)
--------------------------------------------------------------------------------
/bin/train_bash.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import os
4 | from os.path import dirname, abspath
5 | import importlib
6 | import numpy as np
7 |
8 | DeepETPickerHome = dirname(abspath(__file__))
9 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0]
10 | sys.path.append(DeepETPickerHome)
11 | sys.path.append(os.path.split(DeepETPickerHome)[0])
12 | train = importlib.import_module(".train", package=os.path.split(DeepETPickerHome)[1])
13 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1])
14 |
15 |
16 | if __name__ == '__main__':
17 | options = option.BaseOptions()
18 | args = options.gather_options()
19 |
20 | # cofig
21 | with open(args.train_configs, 'r') as f:
22 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs='))
23 |
24 | # parameters
25 | args.use_bg = True
26 | args.use_IP = True
27 | args.use_coord = True
28 | args.test_use_pad = True
29 | args.meanPool_NMS = True
30 | args.f_maps = [24, 48, 72, 108]
31 | args.num_classes = cfg['num_cls']
32 | train_cls_num = cfg['num_cls']
33 | if args.num_classes == 1:
34 | args.use_sigmoid = True
35 | args.use_softmax = False
36 | else:
37 | train_cls_num = train_cls_num + 1
38 | args.use_sigmoid = False
39 | args.use_softmax = True
40 | args.batch_size = cfg['batch_size']
41 | args.block_size = cfg['patch_size']
42 | args.val_batch_size = args.batch_size
43 | args.val_block_size = args.block_size
44 | args.pad_size = [cfg['padding_size']]
45 | args.learning_rate = cfg['lr']
46 | args.max_epoch = cfg['max_epochs']
47 | args.threshold = cfg['seg_thresh']
48 | args.gpu_id = [int(i) for i in cfg['gpu_ids'].split(',')]
49 | args.configs = args.train_configs
50 | args.test_mode = 'val'
51 | args.train_set_ids = cfg['train_set_ids']
52 | args.val_set_ids = cfg['val_set_ids']
53 | args.num_classes = train_cls_num
54 |
55 | train_list = []
56 | for item in args.train_set_ids.split(','):
57 | if '-' in item:
58 | tmp = [int(i) for i in item.split('-')]
59 | train_list.extend(np.arange(tmp[0], tmp[1] + 1).tolist())
60 | else:
61 | train_list.append(int(item))
62 |
63 | val_list = []
64 | for item in args.val_set_ids.split(','):
65 | if '-' in item:
66 | tmp = [int(i) for i in item.split('-')]
67 | val_list.extend(np.arange(tmp[0], tmp[1] + 1).tolist())
68 | else:
69 | val_list.append(int(item))
70 | val_first = len(train_list) if val_list[0] not in train_list else len(train_list) - 1
71 | args.data_split = [0, len(train_list), # train
72 | val_first, val_first + 1, # val
73 | val_first, val_first + 1] # test_val
74 |
75 | for k, v in sorted(vars(args).items()):
76 | print(k, '=', v)
77 |
78 | # Training
79 | train.train_func(args, stdout=None)
--------------------------------------------------------------------------------
/configs/EMPIAR_10045_preprocess.py:
--------------------------------------------------------------------------------
1 | pre_config={"dset_name": "EMPIAR_10045_preprocess",
2 | "base_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045",
3 | "coord_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045/raw_data",
4 | "coord_format": ".coords",
5 | "tomo_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045/raw_data",
6 | "tomo_format": ".mrc",
7 | "num_cls": 1,
8 | "label_type": "gaussian",
9 | "label_diameter": 11,
10 | "ocp_type": "sphere",
11 | "ocp_diameter": "23",
12 | "norm_type": "standardization"}
--------------------------------------------------------------------------------
/configs/SHREC_2021_preprocess.py:
--------------------------------------------------------------------------------
1 | pre_config={"dset_name": "SHREC_2021_preprocess",
2 | "base_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021",
3 | "coord_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021/raw_data",
4 | "coord_format": ".coords",
5 | "tomo_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021/raw_data",
6 | "tomo_format": ".mrc",
7 | "num_cls": 14,
8 | "label_type": "gaussian",
9 | "label_diameter": 9,
10 | "ocp_type": "sphere",
11 | "ocp_diameter": "7,7,7,7,7,9,9,9,11,11,13,13,13,17",
12 | "norm_type": "standardization"}
--------------------------------------------------------------------------------
/dataset/dataloader_DynamicLoad.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | import numpy as np
4 | import mrcfile
5 | import pandas as pd
6 | import torch
7 | import warnings
8 | from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform
9 | from torch.utils.data import DataLoader
10 |
11 |
12 | class Dataset_ClsBased(data.Dataset):
13 | def __init__(self,
14 | mode='train',
15 | block_size=72,
16 | num_class=1,
17 | random_num=0,
18 | use_bg=True,
19 | data_split=[7, 7, 7],
20 | test_use_pad=False,
21 | pad_size=18,
22 | use_paf=False,
23 | cfg=None,
24 | args=None):
25 |
26 | self.args = args
27 | self.mode = mode
28 | # use_CL CL = cnotrastive learning
29 | self.use_CL = args.use_CL
30 | if args.use_CL:
31 | self.radius = 0
32 | else:
33 | self.radius = block_size // 5
34 | self.use_bg = use_bg
35 | self.use_paf = use_paf
36 | self.use_CL_DA = args.use_CL_DA
37 | self.use_bg_part = args.use_bg_part
38 | self.use_ice_part = args.use_ice_part
39 | self.Sel_Referance = args.Sel_Referance
40 |
41 | pad_size = pad_size[0] if isinstance(pad_size, list) else pad_size
42 | base_dir = cfg['base_path']
43 | label_name = cfg['label_name'],
44 | coord_format = cfg['coord_format']
45 | tomo_format = cfg['tomo_format'],
46 | norm_type = cfg['norm_type']
47 |
48 | base_dir = base_dir[0] if isinstance(base_dir, tuple) else base_dir
49 | label_name = label_name[0] if isinstance(label_name, tuple) else label_name
50 | coord_format = coord_format[0] if isinstance(coord_format, tuple) else coord_format
51 | tomo_format = tomo_format[0] if isinstance(tomo_format, tuple) else tomo_format
52 | norm_type = norm_type[0] if isinstance(norm_type, tuple) else norm_type
53 |
54 | if 'label_path' not in cfg.keys():
55 | label_path = os.path.join(base_dir, label_name)
56 | else:
57 | label_path = cfg['label_path']
58 | label_path = label_path[0] if isinstance(label_path, tuple) else label_path
59 |
60 | if 'coord_path' not in cfg.keys():
61 | coord_path = os.path.join(base_dir, 'coords')
62 | else:
63 | coord_path = cfg['coord_path']
64 | coord_path = coord_path[0] if isinstance(coord_path, tuple) else coord_path
65 |
66 | if 'tomo_path' not in cfg.keys():
67 | if norm_type == 'standardization':
68 | tomo_path = base_dir + '/data_std'
69 | elif norm_type == 'normalization':
70 | tomo_path = base_dir + '/data_norm'
71 | else:
72 | tomo_path = cfg['tomo_path']
73 | tomo_path = tomo_path[0] if isinstance(tomo_path, tuple) else tomo_path
74 |
75 | ocp_name = cfg["ocp_name"]
76 |
77 | if 'ocp_path' not in cfg.keys():
78 | ocp_path = os.path.join(base_dir, ocp_name)
79 | else:
80 | ocp_path = cfg['ocp_path']
81 | ocp_path = ocp_path[0] if isinstance(ocp_path, tuple) else ocp_path
82 |
83 | print('*' * 100)
84 | print('num_name:', os.path.join(coord_path, 'num_name.csv'))
85 | print(f'base_path:{base_dir}')
86 | print(f"coord_path:{coord_path}")
87 | print(f"tomo_path:{tomo_path}")
88 | print(f"label_path:{label_path}")
89 | print(f"ocp_path:{ocp_path}")
90 | if self.use_paf:
91 | print(f"paf_path:{cfg['paf_path']}")
92 | print(f"label_name:{label_name}")
93 | print(f"coord_format:{coord_format}")
94 | print(f"tomo_format:{tomo_format}")
95 | print(f"norm_type:{norm_type}")
96 | print(f"ocp_name:{ocp_name}")
97 | print('*' * 100)
98 |
99 | if self.mode == 'test_only':
100 | num_name = pd.read_csv(os.path.join(tomo_path, 'num_name.csv'), sep='\t', header=None)
101 | else:
102 | num_name = pd.read_csv(os.path.join(coord_path, 'num_name.csv'), sep='\t', header=None)
103 |
104 | dir_names = num_name.iloc[:, 1].to_numpy().tolist()
105 | print(num_name)
106 | # print(dir_names)
107 |
108 | if self.mode == 'train':
109 | self.data_range = np.arange(data_split[0], data_split[1])
110 | elif self.mode == 'val':
111 | self.data_range = np.arange(data_split[2], data_split[3])
112 | else: # test or test_val or val_v1
113 | self.data_range = np.arange(data_split[4], data_split[5])
114 | print(f"data_range:{self.data_range}")
115 |
116 | # print(f'data_range:{self.data_range}')
117 | self.shift = block_size // 2 # bigger than self.radius to cover full particle
118 | self.num_class = num_class
119 |
120 | self.ground_truth_volume = []
121 | self.class_mask = []
122 | self.location = []
123 | self.origin = []
124 | self.label = []
125 |
126 | # inital Data Augmentation
127 | if self.use_bg and self.mode == 'train':
128 | patch_size = [block_size] * 3
129 | self.st = SpatialTransform_2(
130 | patch_size, [i // 2 for i in patch_size],
131 | do_elastic_deform=True, deformation_scale=(0, 0.05),
132 | do_rotation=True,
133 | angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
134 | angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
135 | angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
136 | do_scale=True, scale=(0.95, 1.05),
137 | border_mode_data='constant', border_cval_data=0,
138 | border_mode_seg='constant', border_cval_seg=0,
139 | order_seg=0, order_data=3,
140 | random_crop=True,
141 | p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1
142 | )
143 |
144 | self.mt = MirrorTransform(axes=(0, 1, 2))
145 | # to avoid mrcfile warnings
146 | warnings.simplefilter('ignore')
147 |
148 | if self.mode == 'train' or self.mode == 'test_val' or self.mode == 'val':
149 | self.position = [
150 | pd.read_csv(os.path.join(coord_path, dir_names[i] + coord_format),
151 | sep='\t', header=None).to_numpy() for i in self.data_range]
152 |
153 | # load Tomo
154 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' \
155 | or self.mode == 'test_only':
156 | self.tomo_shape = []
157 | for idx in self.data_range:
158 | if not args.input_cat:
159 | with mrcfile.open(os.path.join(tomo_path, dir_names[idx] + tomo_format),
160 | permissive=True) as gm:
161 | shape = gm.data.shape
162 | self.data_shape = gm.data.shape
163 | shape_pad = [i + 2 * pad_size for i in shape]
164 | try:
165 | temp = np.zeros(shape_pad).astype(np.float)
166 | except:
167 | temp = np.zeros(shape_pad).astype(np.float32)
168 | temp[pad_size:shape_pad[0] - pad_size,
169 | pad_size:shape_pad[1] - pad_size,
170 | pad_size:shape_pad[2] - pad_size] = gm.data
171 | self.origin.append(temp)
172 | self.tomo_shape.append(temp.shape)
173 | else:
174 | for idx, p_suffix in enumerate(args.input_cat_items):
175 | p_suffix = p_suffix.rstrip(',')
176 | p_suffix = '' if p_suffix == 'None' else p_suffix
177 | with mrcfile.open(
178 | os.path.join(tomo_path + p_suffix, dir_names[self.data_range[0]] + tomo_format), permissive=True) as tmp:
179 | if idx == 0:
180 | gm = np.array(tmp.data)[None, ...]
181 | else:
182 | gm = np.concatenate([gm, np.array(tmp.data)[None, ...]], axis=0)
183 |
184 | shape = gm.shape
185 | self.data_shape = gm.shape
186 | shape_pad = [shape[0]]
187 | shape_pad.extend([i + 2 * pad_size for i in shape[1:]])
188 | try:
189 | temp = np.zeros(shape_pad).astype(np.float)
190 | except:
191 | temp = np.zeros(shape_pad).astype(np.float32)
192 | temp[:, pad_size:shape_pad[1] - pad_size,
193 | pad_size:shape_pad[2] - pad_size,
194 | pad_size:shape_pad[3] - pad_size] = gm
195 | self.origin.append(temp)
196 | self.tomo_shape.append(temp.shape)
197 |
198 | else:
199 | if not args.input_cat:
200 | print([os.path.join(tomo_path, dir_names[i] + tomo_format) for i in self.data_range])
201 | self.origin = [mrcfile.open(os.path.join(tomo_path, dir_names[i] + tomo_format)) for i
202 | in self.data_range]
203 | else:
204 | for idx, p_suffix in enumerate(args.input_cat_items):
205 | p_suffix = p_suffix.rstrip(',')
206 | p_suffix = '' if p_suffix == 'None' else p_suffix
207 | with mrcfile.open(os.path.join(tomo_path + p_suffix, dir_names[self.data_range[0]] + tomo_format), permissive=True) as tmp:
208 | if idx == 0:
209 | self.origin = np.array(tmp.data)[None, ...]
210 | else:
211 | self.origin = np.concatenate([self.origin, np.array(tmp.data)[None, ...]], axis=0)
212 |
213 | # load Labels
214 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1':
215 | for idx in self.data_range:
216 | if os.path.exists(os.path.join(label_path, dir_names[idx] + tomo_format)):
217 | with mrcfile.open(os.path.join(label_path, dir_names[idx] + tomo_format),
218 | permissive=True) as cm:
219 | shape = cm.data.shape
220 | shape_pad = [i + 2 * pad_size if i > pad_size else i for i in shape]
221 | try:
222 | temp = np.zeros(shape_pad).astype(np.float)
223 | except:
224 | temp = np.zeros(shape_pad).astype(np.float32)
225 |
226 | if len(shape) == 3:
227 | temp[pad_size:shape_pad[-3] - pad_size,
228 | pad_size:shape_pad[-2] - pad_size,
229 | pad_size:shape_pad[-1] - pad_size] = cm.data
230 | elif len(shape) == 4:
231 | print(temp.shape)
232 | temp[:, pad_size:shape_pad[-3] - pad_size,
233 | pad_size:shape_pad[-2] - pad_size,
234 | pad_size:shape_pad[-1] - pad_size] = cm.data
235 | self.label.append(temp)
236 | elif self.mode == 'test_val' and args.use_cluster:
237 | self.label = [
238 | np.zeros_like(self.origin[idx]) for idx, _ in enumerate(self.data_range)]
239 | elif self.mode == 'test_only':
240 | self.label = [
241 | np.zeros_like(self.origin[idx]) for idx, _ in enumerate(self.data_range)]
242 | else:
243 | self.label = [
244 | mrcfile.open(os.path.join(label_path, dir_names[idx] + tomo_format)) for idx
245 | in self.data_range]
246 |
247 | # load paf
248 | if self.use_paf:
249 | paf_path = cfg["paf_path"]
250 | paf_path = paf_path[0] if isinstance(paf_path, tuple) else paf_path
251 |
252 | self.paf_label = []
253 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1':
254 | for idx in self.data_range:
255 | with mrcfile.open(os.path.join(paf_path, dir_names[idx] + tomo_format),
256 | permissive=True) as cm:
257 | shape = cm.data.shape
258 | shape_pad = [i + 2 * pad_size for i in shape]
259 | try:
260 | temp = np.zeros(shape_pad).astype(np.float)
261 | except:
262 | temp = np.zeros(shape_pad).astype(np.float32)
263 | temp[pad_size:shape_pad[0] - pad_size,
264 | pad_size:shape_pad[1] - pad_size,
265 | pad_size:shape_pad[2] - pad_size] = cm.data
266 | self.paf_label.append(temp)
267 | else:
268 | self.paf_label = [
269 | mrcfile.open(os.path.join(paf_path, dir_names[idx] + tomo_format)) for idx
270 | in self.data_range]
271 |
272 | # Generate BlockData
273 | self.coords = []
274 | self.data = []
275 | if self.mode == 'train' or self.mode == 'val':
276 | for i in range(len(self.data_range)):
277 | for j, point1 in enumerate(self.position[i]):
278 | # if sel_train_num > 0 and j >= sel_train_num:
279 | # continue
280 | if args.Sel_Referance:
281 | if j in args.sel_train_num:
282 | self.coords.append([i, point1[-3], point1[-2], point1[-1]])
283 | else:
284 | if point1[0] == 15:
285 | for _ in range(13):
286 | self.coords.append([i, point1[-3], point1[-2], point1[-1]])
287 | else:
288 | self.coords.append([i, point1[-3], point1[-2], point1[-1]])
289 | else:
290 | if test_use_pad:
291 | step_size = block_size - 2 * pad_size
292 | else:
293 | step_size = int(self.shift * 2)
294 | print(self.shift, step_size)
295 | for i in range(len(self.data_range)):
296 | shape = self.origin[i].shape[-3:]
297 | for j in range((shape[0] - 2 * pad_size) // step_size + (
298 | 1 if (shape[0] - 2 * pad_size) % step_size > 0 else 0)):
299 | for k in range((shape[1] - 2 * pad_size) // step_size + (
300 | 1 if (shape[1] - 2 * pad_size) % step_size > 0 else 0)):
301 | for l in range((shape[2] - 2 * pad_size) // step_size + (
302 | 1 if (shape[2] - 2 * pad_size) % step_size > 0 else 0)):
303 | if j == (shape[0] - 2 * pad_size) // step_size + (
304 | 1 if (shape[0] - 2 * pad_size) % step_size > 0 else 0) - 1:
305 | z = shape[0] - block_size // 2
306 | else:
307 | z = j * step_size + block_size // 2
308 |
309 | if k == (shape[1] - 2 * pad_size) // step_size + (
310 | 1 if (shape[1] - 2 * pad_size) % step_size > 0 else 0) - 1:
311 | y = shape[1] - block_size // 2
312 | else:
313 | y = k * step_size + block_size // 2
314 |
315 | if l == (shape[2] - 2 * pad_size) // step_size + (
316 | 1 if (shape[2] - 2 * pad_size) % step_size > 0 else 0) - 1:
317 | x = shape[2] - block_size // 2
318 | else:
319 | x = l * step_size + block_size // 2
320 | self.coords.append([i, x, y, z])
321 |
322 | if len(self.origin[i].shape) == 4:
323 | img = self.origin[i][:, z - self.shift: z + self.shift,
324 | y - self.shift: y + self.shift,
325 | x - self.shift: x + self.shift]
326 | else:
327 | img = self.origin[i][z - self.shift: z + self.shift,
328 | y - self.shift: y + self.shift,
329 | x - self.shift: x + self.shift]
330 |
331 | if len(self.label[i].shape) == 4:
332 | lab = self.label[i][:, z - self.shift: z + self.shift,
333 | y - self.shift: y + self.shift,
334 | x - self.shift: x + self.shift]
335 | else:
336 | lab = self.label[i][z - self.shift: z + self.shift,
337 | y - self.shift: y + self.shift,
338 | x - self.shift: x + self.shift]
339 |
340 | if self.use_paf:
341 | paf = self.paf_label[i][z - self.shift: z + self.shift,
342 | y - self.shift: y + self.shift,
343 | x - self.shift: x + self.shift]
344 | self.data.append([img, lab, paf, [z, y, x]])
345 | else:
346 | self.data.append([img, lab, [z, y, x]])
347 |
348 | # add random samples
349 | if self.mode == 'train' and random_num > 0:
350 | print('random samples num:', random_num)
351 | for j in range(random_num):
352 | i = np.random.randint(len(self.data_range))
353 | data_shape = self.origin[i].data.shape
354 | z = np.random.randint(self.shift + 1, data_shape[0] - self.shift)
355 | y = np.random.randint(self.shift + 1, data_shape[1] - self.shift)
356 | x = np.random.randint(self.shift + 1, data_shape[2] - self.shift)
357 | self.coords.append([i, x, y, z])
358 |
359 | if self.mode == 'train':
360 | print("Training dataset contains {} samples".format((len(self.coords))))
361 | if self.mode == 'val':
362 | print("Validation dataset contains {} samples".format((len(self.coords))))
363 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' or self.mode == 'test_only':
364 | print("Test dataset contains {} samples".format((len(self.coords))))
365 | self.test_len = len(self.coords)
366 |
367 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'test_only':
368 | self.dir_name = dir_names[self.data_range[0]]
369 | if self.mode == 'test' or self.mode == 'test_val':
370 | print(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format))
371 | if os.path.exists(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format)):
372 | with mrcfile.open(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format), permissive=True) as f:
373 | self.occupancy_map = f.data
374 | elif self.mode == 'test_val' and args.use_cluster:
375 | self.occupancy_map = np.zeros_like(self.origin[0])
376 | self.gt_coords = pd.read_csv(os.path.join(coord_path, "%s.coords" % dir_names[self.data_range[0]]),
377 | sep='\t', header=None).to_numpy()
378 |
379 | if args.use_bg_part and self.Sel_Referance:
380 | self.coords_bg = pd.read_csv(os.path.join(coord_path, dir_names[self.data_range[0]] + '_bg' + coord_format),
381 | sep='\t', header=None).to_numpy()[:len(self.coords)]
382 | if args.use_ice_part and self.Sel_Referance:
383 | self.coords_ice = pd.read_csv(os.path.join(coord_path, dir_names[self.data_range[0]] + '_ice' + coord_format),
384 | sep='\t', header=None).to_numpy()[:len(self.coords)]
385 |
386 | if self.mode == 'test_val':
387 | for i in range(len(self.data_range)):
388 | for j, point1 in enumerate(self.position[i]):
389 | x, y, z = point1[-3] + pad_size, point1[-2] + pad_size, point1[-1] + pad_size
390 | z_max, y_max, x_max = self.origin[i].data.shape
391 | x, y, z = self.__sample(np.array([x, y, z]),
392 | np.array([x_max, y_max, z_max]))
393 | img = self.origin[i][z - self.shift: z + self.shift,
394 | y - self.shift: y + self.shift,
395 | x - self.shift: x + self.shift]
396 | if len(self.label[i].shape) == 4:
397 | lab = self.label[i][:, z - self.shift: z + self.shift,
398 | y - self.shift: y + self.shift,
399 | x - self.shift: x + self.shift]
400 | else:
401 | lab = self.label[i][z - self.shift: z + self.shift,
402 | y - self.shift: y + self.shift,
403 | x - self.shift: x + self.shift]
404 | if self.use_paf:
405 | paf = self.paf_label[i][z - self.shift: z + self.shift,
406 | y - self.shift: y + self.shift,
407 | x - self.shift: x + self.shift]
408 | self.data.append([img, lab, paf, [z, y, x]])
409 | else:
410 | self.data.append([img, lab, [z, y, x]])
411 | if self.mode == 'test_val' and args.use_cluster:
412 | self.data = self.data[-len(self.position[0]):]
413 |
414 | def __getitem__(self, index):
415 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' or self.mode =='test_only':
416 | if self.use_paf:
417 | img, label, paf_label, position = self.data[index]
418 | else:
419 | img, label, position = self.data[index]
420 |
421 | else:
422 | idx, x, y, z = self.coords[index]
423 | z_max, y_max, x_max = self.origin[idx].data.shape
424 |
425 | point = self.__sample(np.array([x, y, z]),
426 | np.array([x_max, y_max, z_max]))
427 |
428 | if self.args.input_cat:
429 | img = self.origin[:, point[2] - self.shift:point[2] + self.shift,
430 | point[1] - self.shift:point[1] + self.shift,
431 | point[0] - self.shift:point[0] + self.shift]
432 | else:
433 | img = self.origin[idx].data[point[2] - self.shift:point[2] + self.shift,
434 | point[1] - self.shift:point[1] + self.shift,
435 | point[0] - self.shift:point[0] + self.shift]
436 |
437 | if len(self.label[idx].data.shape) == 4:
438 | label = self.label[idx].data[:, point[2] - self.shift:point[2] + self.shift,
439 | point[1] - self.shift:point[1] + self.shift,
440 | point[0] - self.shift:point[0] + self.shift]
441 | else:
442 | label = self.label[idx].data[point[2] - self.shift:point[2] + self.shift,
443 | point[1] - self.shift:point[1] + self.shift,
444 | point[0] - self.shift:point[0] + self.shift]
445 | position = [point[2], point[1], point[0]]
446 | if self.use_paf:
447 | paf_label = self.paf_label[idx].data[point[2] - self.shift:point[2] + self.shift,
448 | point[1] - self.shift:point[1] + self.shift,
449 | point[0] - self.shift:point[0] + self.shift]
450 | # print(img.shape, label.shape)
451 | img = np.array(img)
452 | try:
453 | label = np.array(label).astype(np.float)
454 | except:
455 | label = np.array(label).astype(np.float32)
456 |
457 | if self.num_class > 1 and len(label.shape) == 3:
458 | label = multiclass_label(label,
459 | num_classes=self.num_class,
460 | first_idx=1 if self.use_paf else 0)
461 | else:
462 | if self.mode == 'test' and label.shape != (self.shift * 2, self.shift * 2, self.shift * 2):
463 | label = np.zeros((1, self.shift * 2, self.shift * 2, self.shift * 2))
464 | else:
465 | label = label.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
466 |
467 | if self.use_paf:
468 | try:
469 | paf_label = np.array(paf_label).astype(np.float).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
470 | except:
471 | paf_label = np.array(paf_label).astype(np.float32).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
472 |
473 | label = np.concatenate([label, paf_label], axis=0)
474 |
475 | if self.use_CL_DA:
476 | img = self.__DA_SelReference(img)
477 |
478 | # random 3D rotation
479 | if self.mode == 'train' and not self.use_CL:
480 | if self.use_bg:
481 | img_label = {'data': img.reshape(1, -1, self.shift * 2, self.shift * 2, self.shift * 2),
482 | 'seg': label.reshape(1, -1, self.shift * 2, self.shift * 2, self.shift * 2)}
483 | if torch.rand(1) < 0.5:
484 | img_label = self.st(**img_label)
485 | else:
486 | img_label = self.mt(**img_label)
487 | img = img_label['data'].reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
488 | label = img_label['seg'].reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
489 | else:
490 | img = np.array(img).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
491 | label = np.array(label).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
492 | # degree = np.random.randint(4, size=3)
493 | # img = self.__rotation3D(img, degree)
494 | # label = self.__rotation3D(label, degree)
495 | # img = np.array(img)
496 | # label = np.array(label)
497 | else:
498 | img = img.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
499 | label = label.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
500 |
501 | img = torch.as_tensor(img).float()
502 | label = torch.as_tensor(label).float()
503 |
504 | if self.use_bg_part and self.Sel_Referance:
505 | idx, x, y, z = self.coords_bg[index]
506 | z_max, y_max, x_max = self.origin[0].data.shape
507 |
508 | # point = self.__sample(np.array([x, y, z]),
509 | # np.array([x_max, y_max, z_max]))
510 | point = [x, y, z]
511 |
512 | img_bg = self.origin[0].data[point[2] - self.shift:point[2] + self.shift,
513 | point[1] - self.shift:point[1] + self.shift,
514 | point[0] - self.shift:point[0] + self.shift]
515 |
516 | if self.use_CL_DA:
517 | img_bg = self.__DA_SelReference(np.array(img_bg))
518 |
519 | img_bg = np.array(img_bg).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
520 | img_bg = torch.as_tensor(img_bg).float()
521 |
522 | if self.use_ice_part and self.Sel_Referance:
523 | idx, x, y, z = self.coords_ice[index]
524 | z_max, y_max, x_max = self.origin[0].data.shape
525 |
526 | # point = self.__sample(np.array([x, y, z]),
527 | # np.array([x_max, y_max, z_max]))
528 | point = [x, y, z]
529 | img_ice = self.origin[0].data[point[2] - self.shift:point[2] + self.shift,
530 | point[1] - self.shift:point[1] + self.shift,
531 | point[0] - self.shift:point[0] + self.shift]
532 |
533 | if self.use_CL_DA:
534 | img_ice = self.__DA_SelReference(np.array(img_ice))
535 |
536 | img_ice = np.array(img_ice).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2)
537 | img_ice = torch.as_tensor(img_ice).float()
538 |
539 | if self.use_bg_part and self.Sel_Referance:
540 | if self.use_ice_part:
541 | return img, img_bg, img_ice
542 | else:
543 | return img, img_bg, position
544 | else:
545 | return img, label, position
546 |
547 | def __len__(self):
548 | return max(len(self.coords), len(self.data))
549 | # return min(len(self.coords), len(self.data))
550 |
551 | def __rotation3D(self, data, degree):
552 | data = np.rot90(data, degree[0], (0, 1))
553 | data = np.rot90(data, degree[1], (1, 2))
554 | data = np.rot90(data, degree[2], (0, 2))
555 | return data
556 |
557 | def __DA_SelReference(self, data):
558 | D, H, W = data.shape
559 | out = data.reshape(1, D, H, W)
560 | out = np.concatenate([out, np.rot90(data, 1, (0, 2)).reshape(1, D, H, W)], axis=0)
561 | out = np.concatenate([out, np.rot90(data, 2, (0, 2)).reshape(1, D, H, W)], axis=0)
562 | out = np.concatenate([out, np.rot90(data, 3, (0, 2)).reshape(1, D, H, W)], axis=0)
563 | out = np.concatenate([out, np.rot90(data, 1, (1, 2)).reshape(1, D, H, W)], axis=0)
564 | out = np.concatenate([out, np.rot90(data, 3, (1, 2)).reshape(1, D, H, W)], axis=0)
565 | for axis_idx in range(out.shape[0]):
566 | data = out[axis_idx]
567 | out = np.concatenate([out, data[::-1, :, :].reshape(1, D, H, W)], axis=0)
568 | for idx in range(1, 4):
569 | out = np.concatenate([out, np.rot90(data, idx, (0, 1)).reshape(1, D, H, W)], axis=0)
570 | out = np.concatenate([out, np.rot90(data, idx, (0, 1)).reshape(1, D, H, W)[:, ::-1, :, :]], axis=0)
571 | return out
572 |
573 | def __DA_SelReference_inital(self, data):
574 | out = data
575 | for idx in range(1, 4):
576 | out = np.concatenate([out, np.rot90(data, idx, (0, 1))], axis=0)
577 | out = np.concatenate([out, np.rot90(data, idx, (1, 2))], axis=0)
578 | out = np.concatenate([out, np.rot90(data, idx, (0, 2))], axis=0)
579 | out = np.concatenate([out, data[::-1, :, :]], axis=0)
580 | out = np.concatenate([out, data[:, ::-1, :]], axis=0)
581 | out = np.concatenate([out, data[:, :, ::-1]], axis=0)
582 | return out
583 |
584 |
585 | def __sample(self, point, bound):
586 | # point: z, y, x
587 | new_point = point + np.random.randint(-self.radius, self.radius + 1, size=3)
588 | new_point[new_point < self.shift] = self.shift
589 | new_point[new_point + self.shift > bound] = bound[new_point + self.shift > bound] - self.shift
590 | return new_point
591 |
592 |
593 | # transform label to multichannel
594 | def multiclass_label(x, num_classes, first_idx=0):
595 | for i in range(first_idx, first_idx + num_classes):
596 | label_temp = x
597 | label_temp = np.where(label_temp == i, 1, 0)
598 | if i == first_idx:
599 | label_new = label_temp
600 | else:
601 | label_new = np.concatenate((label_new, label_temp))
602 |
603 | return label_new
604 |
605 |
606 | if __name__ == '__main__':
607 | num_cls = 13
608 | dataloader = DataLoader(Dataset_ClsBased(mode='val',
609 | block_size=32,
610 | num_class=num_cls,
611 | random_num=0,
612 | use_bg=False,
613 | test_use_pad=False, pad_size=18,
614 | data_split=[6, 6, 6],
615 | base_dir="/ldap_shared/synology_shared/shrec_2020/shrec2020_new",
616 | label_name="sphere7",
617 | coord_format=".coords",
618 | tomo_format='.mrc',
619 | norm_type='normalization'),
620 | batch_size=1,
621 | num_workers=1,
622 | shuffle=False,
623 | pin_memory=False)
624 | import matplotlib.pyplot as plt
625 |
626 | rows = 2
627 | cols = 11
628 | plt.figure(1, figsize=(cols * 7, rows * 7))
629 | for idx, (img, label, position) in enumerate(dataloader):
630 | if idx == 100:
631 | print(position)
632 | print(img.shape)
633 | print(img.max(), img.min())
634 | print(label.max(), label.min())
635 | print(label.shape)
636 |
637 | imgs = img[0, 0, 0:31:3, ...]
638 | labels = label[0, 0, 0:31:3, ...]
639 | z, h, w = imgs.shape
640 | for i in range(z):
641 | plt.subplot(rows, cols, i + 1)
642 | plt.imshow(imgs[i, ...], cmap=plt.cm.gray)
643 | plt.axis('off')
644 |
645 | plt.subplot(rows, cols, i + 1 + cols)
646 | plt.imshow(labels[i, ...], cmap=plt.cm.gray)
647 | plt.axis('off')
648 |
649 | plt.tight_layout()
650 | plt.savefig('temp.png')
651 | print(position)
652 | break
653 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from model.model_loader import *
2 | from model.residual_unet_att import *
--------------------------------------------------------------------------------
/model/model_loader.py:
--------------------------------------------------------------------------------
1 | from model.residual_unet_att import ResidualUNet3D
2 |
3 | def get_model(args):
4 |
5 | if args.network == 'ResUNet':
6 | model = ResidualUNet3D(f_maps=args.f_maps, out_channels=args.num_classes,
7 | args=args, in_channels=args.in_channels, use_att=args.use_att,
8 | use_paf=args.use_paf, use_uncert=args.use_uncert)
9 |
10 | return model
11 |
--------------------------------------------------------------------------------
/model/residual_unet_att.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 | import torch
4 | from functools import partial
5 | import sys
6 | from torch.nn import Conv3d, Module, Linear, BatchNorm3d, ReLU
7 | from torch.nn.modules.utils import _pair, _triple
8 |
9 | sys.path.append("..")
10 | from utils.coordconv_torch import AddCoords
11 |
12 | # from SoftPool import SoftPool3d
13 | try:
14 | from model.sync_batchnorm import SynchronizedBatchNorm3d
15 | except:
16 | pass
17 |
18 |
19 | def normalization(planes, norm='bn'):
20 | if norm == 'bn':
21 | m = nn.BatchNorm3d(planes)
22 | elif norm == 'gn':
23 | m = nn.GroupNorm(4, planes)
24 | elif norm == 'in':
25 | m = nn.InstanceNorm3d(planes)
26 | elif norm == 'sync_bn':
27 | m = SynchronizedBatchNorm3d(planes)
28 | else:
29 | raise ValueError('normalization type {} is not supported'.format(norm))
30 | return m
31 |
32 |
33 | # Residual 3D UNet
34 | class ResidualUNet3D(nn.Module):
35 | def __init__(self, f_maps=[32, 64, 128, 256], in_channels=1, out_channels=13,
36 | args=None, use_att=False, use_paf=None, use_uncert=None):
37 | super(ResidualUNet3D, self).__init__()
38 | if use_att:
39 | norm = BatchNorm3d
40 | else:
41 | norm = args.norm
42 | act = args.act
43 | use_lw = args.use_lw
44 | lw_kernel = args.lw_kernel
45 |
46 | self.use_aspp = args.use_aspp
47 | self.pif_sigmoid = args.pif_sigmoid
48 | self.paf_sigmoid = args.paf_sigmoid
49 | self.use_tanh = args.use_tanh
50 | self.use_IP = args.use_IP
51 | self.out_channels = out_channels
52 | if self.out_channels > 1:
53 | self.use_softmax = args.use_softmax
54 | else:
55 | self.use_sigmoid = args.use_sigmoid
56 | self.use_coord = args.use_coord
57 |
58 | self.use_softpool = args.use_softpool
59 |
60 | self.use_paf = use_paf
61 | self.use_uncert = use_uncert
62 |
63 | if self.use_softpool:
64 | # pool_layer = SoftPool3d
65 | pass
66 | else:
67 | pool_layer = nn.AvgPool3d
68 |
69 | if self.use_IP:
70 | pools = []
71 | for _ in range(len(f_maps) - 1):
72 | pools.append(pool_layer(2))
73 | self.pools = nn.ModuleList(pools)
74 |
75 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
76 | encoders = []
77 | for i, out_feature_num in enumerate(f_maps):
78 | if i == 0:
79 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, use_IP=False,
80 | use_coord=self.use_coord,
81 | pool_layer=pool_layer, norm=norm, act=act, use_att=use_att,
82 | use_lw=use_lw, lw_kernel=lw_kernel)
83 | else:
84 | # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
85 | encoder = Encoder(f_maps[i - 1], out_feature_num, use_IP=self.use_IP, use_coord=self.use_coord,
86 | pool_layer=pool_layer, norm=norm, act=act, use_att=use_att,
87 | use_lw=use_lw, lw_kernel=lw_kernel)
88 |
89 | encoders.append(encoder)
90 |
91 | self.encoders = nn.ModuleList(encoders)
92 |
93 | # 使用aspp进一步提取特征
94 | if self.use_aspp:
95 | self.aspp = ASPP(in_channels=f_maps[-1], inter_channels=f_maps[-1], out_channels=f_maps[-1])
96 |
97 | self.se_loss = args.use_se_loss
98 | if self.se_loss:
99 | self.avgpool = nn.AdaptiveAvgPool3d(1)
100 | self.fc1 = nn.Linear(f_maps[-1], f_maps[-1])
101 | self.fc2 = nn.Linear(f_maps[-1], out_channels)
102 |
103 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
104 | decoders = []
105 | reversed_f_maps = list(reversed(f_maps))
106 | for i in range(len(reversed_f_maps) - 1):
107 | in_feature_num = reversed_f_maps[i]
108 | out_feature_num = reversed_f_maps[i + 1]
109 | # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
110 | # currently strides with a constant stride: (2, 2, 2)
111 | decoder = Decoder(in_feature_num, out_feature_num, use_coord=self.use_coord, norm=norm, act=act,
112 | use_att=use_att, use_lw=use_lw, lw_kernel=lw_kernel)
113 | decoders.append(decoder)
114 |
115 | self.decoders = nn.ModuleList(decoders)
116 |
117 | # in the last layer a 1×1 convolution reduces the number of output
118 | # channels to the number of labels
119 | if args.final_double:
120 | self.final_conv = nn.Sequential(
121 | nn.Conv3d(f_maps[0], f_maps[0] // 2, kernel_size=1),
122 | nn.Conv3d(f_maps[0] // 2, out_channels, 1)
123 | )
124 | if self.use_paf:
125 | self.paf_conv = nn.Sequential(
126 | nn.Conv3d(f_maps[0], f_maps[0] // 2, kernel_size=1),
127 | nn.Conv3d(f_maps[0] // 2, 1, 1)
128 | )
129 | self.dropout = nn.Dropout3d
130 | else:
131 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
132 | if self.use_paf:
133 | self.paf_conv = nn.Conv3d(f_maps[0], 1, 1)
134 | self.dropout = nn.Dropout3d
135 |
136 | if self.use_paf:
137 | if self.use_uncert:
138 | self.logsigma = nn.Parameter(torch.FloatTensor([0.5] * 2))
139 | else:
140 | self.logsigma = torch.FloatTensor([0.5] * 2)
141 |
142 | def forward(self, x):
143 | if self.use_IP:
144 | img_pyramid = []
145 | img_d = x
146 | for pool in self.pools:
147 | img_d = pool(img_d)
148 | img_pyramid.append(img_d)
149 |
150 | encoders_features = []
151 | for idx, encoder in enumerate(self.encoders):
152 | if self.use_IP and idx > 0:
153 | x = encoder(x, img_pyramid[idx - 1])
154 | else:
155 | x = encoder(x)
156 | encoders_features.insert(0, x)
157 |
158 | if self.use_aspp:
159 | x = self.aspp(x)
160 | # remove last
161 | encoders_features = encoders_features[1:]
162 |
163 | if self.se_loss:
164 | se_out = self.avgpool(x)
165 | se_out = se_out.view(se_out.size(0), -1)
166 | se_out = self.fc1(se_out)
167 | se_out = self.fc2(se_out)
168 |
169 | for decoder, encoder_features in zip(self.decoders, encoders_features):
170 | x = decoder(encoder_features, x)
171 |
172 | out = self.final_conv(x)
173 |
174 | if self.out_channels > 1:
175 | if self.use_softmax:
176 | out = torch.softmax(out, dim=1)
177 | elif self.pif_sigmoid:
178 | out = torch.sigmoid(out)
179 | elif self.use_tanh:
180 | out = torch.tanh(out)
181 | else:
182 | if self.use_sigmoid:
183 | out = torch.sigmoid(out)
184 | elif self.use_tanh:
185 | out = torch.tanh(out)
186 |
187 | if self.use_paf:
188 | paf_out = self.paf_conv(x)
189 | if self.paf_sigmoid:
190 | paf_out = torch.sigmoid(paf_out)
191 |
192 | if self.se_loss:
193 | return [out, se_out]
194 | else:
195 | if self.use_paf:
196 | return [out, paf_out, self.logsigma]
197 | else:
198 | return out
199 |
200 |
201 | class Encoder(nn.Module):
202 | def __init__(self, in_channels, out_channels, apply_pooling=True, use_IP=False, use_coord=False,
203 | pool_layer=nn.MaxPool3d, norm='bn', act='relu', use_att=False,
204 | use_lw=False, lw_kernel=3, input_channels=1):
205 | super(Encoder, self).__init__()
206 | if apply_pooling:
207 | self.pooling = pool_layer(kernel_size=2)
208 | else:
209 | self.pooling = None
210 |
211 | self.use_IP = use_IP
212 | self.use_coord = use_coord
213 | inplaces = in_channels + input_channels if self.use_IP else in_channels
214 | inplaces = inplaces + 3 if self.use_coord else inplaces
215 |
216 | if use_att:
217 | self.basic_module = ExtResNetBlock_att(inplaces, out_channels, norm=norm, act=act)
218 | else:
219 | if use_lw:
220 | self.basic_module = ExtResNetBlock_lightWeight(inplaces, out_channels, lw_kernel=lw_kernel)
221 | else:
222 | self.basic_module = ExtResNetBlock(inplaces, out_channels, norm=norm, act=act)
223 | if self.use_coord:
224 | self.coord_conv = AddCoords(rank=3, with_r=False)
225 |
226 | def forward(self, x, scaled_img=None):
227 | if self.pooling is not None:
228 | x = self.pooling(x)
229 | if self.use_IP:
230 | x = torch.cat([x, scaled_img], dim=1)
231 | if self.use_coord:
232 | x = self.coord_conv(x)
233 | x = self.basic_module(x)
234 | return x
235 |
236 |
237 | class Decoder(nn.Module):
238 | def __init__(self, in_channels, out_channels, scale_factor=(2, 2, 2), mode='nearest',
239 | padding=1, use_coord=False, norm='bn', act='relu', use_att=False,
240 | use_lw=False, lw_kernel=3):
241 | super(Decoder, self).__init__()
242 | self.use_coord = use_coord
243 | if self.use_coord:
244 | self.coord_conv = AddCoords(rank=3, with_r=False)
245 |
246 | # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
247 | self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
248 | scale_factor=scale_factor, mode=mode)
249 | # sum joining
250 | self.joining = partial(self._joining, concat=False)
251 | # adapt the number of in_channels for the ExtResNetBlock
252 | in_channels = out_channels + 3 if self.use_coord else out_channels
253 |
254 | if use_att:
255 | self.basic_module = ExtResNetBlock_att(in_channels, out_channels, norm=norm, act=act)
256 | else:
257 | if use_lw:
258 | self.basic_module = ExtResNetBlock_lightWeight(in_channels, out_channels, lw_kernel=lw_kernel)
259 | else:
260 | self.basic_module = ExtResNetBlock(in_channels, out_channels, norm='bn', act=act)
261 |
262 | def forward(self, encoder_features, x, ReturnInput=False):
263 | x = self.upsampling(encoder_features, x)
264 | x = self.joining(encoder_features, x)
265 | if self.use_coord:
266 | x = self.coord_conv(x)
267 | if ReturnInput:
268 | x1 = self.basic_module(x)
269 | return x1, x
270 | x = self.basic_module(x)
271 | return x
272 |
273 | @staticmethod
274 | def _joining(encoder_features, x, concat):
275 | if concat:
276 | return torch.cat((encoder_features, x), dim=1)
277 | else:
278 | return encoder_features + x
279 |
280 |
281 | class ExtResNetBlock(nn.Module):
282 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'):
283 | super(ExtResNetBlock, self).__init__()
284 | # first convolution
285 | self.conv1 = SingleConv(in_channels, out_channels, norm=norm, act=act)
286 | # residual block
287 | self.conv2 = SingleConv(out_channels, out_channels, norm=norm, act=act)
288 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
289 | self.conv3 = SingleConv(out_channels, out_channels, norm=norm, act=act)
290 | self.non_linearity = nn.ELU(inplace=False)
291 |
292 | def forward(self, x):
293 | # apply first convolution and save the output as a residual
294 | out = self.conv1(x)
295 | residual = out
296 | # residual block
297 | out = self.conv2(out)
298 | out = self.conv3(out)
299 |
300 | out += residual
301 | out = self.non_linearity(out)
302 | return out
303 |
304 |
305 | class ExtResNetBlock_att(nn.Module):
306 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'):
307 | super(ExtResNetBlock_att, self).__init__()
308 | # first convolution
309 | self.conv1 = SingleConv(in_channels, out_channels, norm=norm, act=act)
310 | # residual block
311 | self.conv2 = SplAtConv3d(out_channels, out_channels // 2, norm_layer=norm)
312 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
313 | self.conv3 = SplAtConv3d(out_channels, out_channels // 2, norm_layer=norm)
314 | self.non_linearity = nn.ELU(inplace=False)
315 |
316 | def forward(self, x):
317 | # apply first convolution and save the output as a residual
318 | out = self.conv1(x)
319 | residual = out
320 | # residual block
321 | out = self.conv2(out)
322 | out = self.conv3(out)
323 |
324 | out += residual
325 | out = self.non_linearity(out)
326 | return out
327 |
328 |
329 | class SingleConv(nn.Sequential):
330 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'):
331 | super(SingleConv, self).__init__()
332 | self.add_module('conv', nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1))
333 | self.add_module('batchnorm', normalization(out_channels, norm=norm))
334 | if act == 'relu':
335 | self.add_module('relu', nn.ReLU(inplace=False))
336 | elif act == 'lrelu':
337 | self.add_module('lrelu', nn.LeakyReLU(negative_slope=0.1, inplace=False))
338 | elif act == 'elu':
339 | self.add_module('elu', nn.ELU(inplace=False))
340 | elif act == 'gelu':
341 | self.add_module('elu', nn.GELU(inplace=False))
342 |
343 |
344 | class ExtResNetBlock_lightWeight(nn.Module):
345 | def __init__(self, in_channels, out_channels, lw_kernel=3):
346 | super(ExtResNetBlock_lightWeight, self).__init__()
347 | # first convolution
348 | self.conv1 = SingleConv_lightWeight(in_channels, out_channels, lw_kernel=lw_kernel)
349 | # residual block
350 | self.conv2 = SingleConv_lightWeight(out_channels, out_channels, lw_kernel=lw_kernel)
351 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
352 | self.conv3 = SingleConv_lightWeight(out_channels, out_channels, lw_kernel=lw_kernel)
353 | self.non_linearity = nn.ELU(inplace=False)
354 |
355 | def forward(self, x):
356 | # apply first convolution and save the output as a residual
357 | out = self.conv1(x)
358 | residual = out
359 | # residual block
360 | out = self.conv2(out)
361 | out = self.conv3(out)
362 |
363 | out += residual
364 | out = self.non_linearity(out)
365 | return out
366 |
367 |
368 | class SingleConv_lightWeight(nn.Sequential):
369 | def __init__(self, in_channels, out_channels, lw_kernel=3, layer_scale_init_value=1e-6):
370 | super(SingleConv_lightWeight, self).__init__()
371 |
372 | self.dwconv = nn.Conv3d(in_channels, in_channels, kernel_size=lw_kernel, padding=lw_kernel//2, groups=in_channels)
373 | self.norm = nn.LayerNorm(in_channels, eps=1e-6)
374 | self.pwconv1 = nn.Linear(in_channels, 2 * in_channels)
375 | self.act = nn.GELU()
376 | self.pwconv2 = nn.Linear(2 * in_channels, out_channels)
377 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)),
378 | requires_grad=True) if layer_scale_init_value > 0 else None
379 | if in_channels != out_channels:
380 | self.skip = nn.Conv3d(in_channels, out_channels, 1)
381 | self.in_channels = in_channels
382 | self.out_channels = out_channels
383 |
384 | def forward(self, x):
385 | input = x
386 | x = self.dwconv(x)
387 | x = x.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C)
388 | x = self.norm(x)
389 | x = self.pwconv1(x)
390 | x = self.act(x)
391 | x = self.pwconv2(x)
392 | if self.gamma is not None:
393 | x = self.gamma * x
394 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
395 |
396 | x = x + (input if self.in_channels == self.out_channels else self.skip(input))
397 | return x
398 |
399 |
400 | class Upsampling(nn.Module):
401 | def __init__(self, transposed_conv, in_channels=None, out_channels=None, scale_factor=(2, 2, 2), mode='nearest'):
402 | super(Upsampling, self).__init__()
403 |
404 | if transposed_conv:
405 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder
406 | # (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
407 | self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=scale_factor, padding=1)
408 | else:
409 | self.upsample = partial(self._interpolate, mode=mode)
410 |
411 | def forward(self, encoder_features, x):
412 | output_size = encoder_features.size()[2:]
413 | return self.upsample(x, output_size)
414 |
415 | @staticmethod
416 | def _interpolate(x, size, mode):
417 | return F.interpolate(x, size=size, mode=mode)
418 |
419 |
420 | class SplAtConv3d(Module):
421 | """Split-Attention Conv2d
422 | """
423 |
424 | def __init__(self, in_channels, channels, kernel_size=3, stride=(1, 1, 1), padding=(1, 1, 1),
425 | dilation=(1, 1, 1), groups=1, bias=True,
426 | radix=2, reduction_factor=4,
427 | rectify=False, rectify_avg=False, norm_layer=BatchNorm3d,
428 | dropblock_prob=0.0, **kwargs):
429 | super(SplAtConv3d, self).__init__()
430 | padding = _triple(padding)
431 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
432 | self.rectify_avg = rectify_avg
433 | inter_channels = max(in_channels * radix // reduction_factor, 32)
434 | self.radix = radix
435 | self.cardinality = groups
436 | self.channels = channels
437 | self.dropblock_prob = dropblock_prob
438 | if self.rectify:
439 | pass
440 | # from rfconv import RFConv2d
441 | # self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
442 | # groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
443 | else:
444 | self.conv = Conv3d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
445 | groups=groups * radix, bias=bias, **kwargs)
446 | self.use_bn = norm_layer is not None
447 | if self.use_bn:
448 | self.bn0 = norm_layer(channels * radix)
449 | self.relu = ReLU(inplace=False)
450 | self.fc1 = Conv3d(channels, inter_channels, 1, groups=self.cardinality)
451 | if self.use_bn:
452 | self.bn1 = norm_layer(inter_channels)
453 | self.fc2 = Conv3d(inter_channels, channels * radix, 1, groups=self.cardinality)
454 | # if dropblock_prob > 0.0:
455 | # self.dropblock = DropBlock2D(dropblock_prob, 3)
456 | self.rsoftmax = rSoftMax(radix, groups)
457 | self.conv3 = Conv3d(channels, channels * radix, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
458 | self.bn3 = BatchNorm3d(channels * radix, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
459 | self.relu3 = ReLU(inplace=False)
460 |
461 | def forward(self, x):
462 | x = self.conv(x)
463 | if self.use_bn:
464 | x = self.bn0(x)
465 | # if self.dropblock_prob > 0.0:
466 | # x = self.dropblock(x)
467 | x = self.relu(x)
468 |
469 | batch, rchannel = x.shape[:2]
470 | if self.radix > 1:
471 | if torch.__version__ < '1.5':
472 | splited = torch.split(x, int(rchannel // self.radix), dim=1)
473 | else:
474 | splited = torch.split(x, rchannel // self.radix, dim=1)
475 | gap = sum(splited)
476 | else:
477 | gap = x
478 | gap = F.adaptive_avg_pool3d(gap, 1)
479 | gap = self.fc1(gap)
480 |
481 | if self.use_bn:
482 | gap = self.bn1(gap)
483 | gap = self.relu(gap)
484 |
485 | atten = self.fc2(gap)
486 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1, 1)
487 |
488 | if self.radix > 1:
489 | if torch.__version__ < '1.5':
490 | attens = torch.split(atten, int(rchannel // self.radix), dim=1)
491 | else:
492 | attens = torch.split(atten, rchannel // self.radix, dim=1)
493 | out = sum([att * split for (att, split) in zip(attens, splited)])
494 | else:
495 | out = atten * x
496 |
497 | out = self.relu3(self.bn3(self.conv3(out)))
498 | return out.contiguous()
499 |
500 |
501 | class rSoftMax(nn.Module):
502 | def __init__(self, radix, cardinality):
503 | super().__init__()
504 | self.radix = radix
505 | self.cardinality = cardinality
506 |
507 | def forward(self, x):
508 | batch = x.size(0)
509 | if self.radix > 1:
510 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
511 | x = F.softmax(x, dim=1)
512 | x = x.reshape(batch, -1)
513 | else:
514 | x = torch.sigmoid(x)
515 | return x
516 |
517 |
518 | if __name__ == "__main__":
519 | import argparse
520 |
521 | parser = argparse.ArgumentParser(description='Training for 3D U-Net models')
522 | parser.add_argument('--use_IP', type=bool, help='whether use image pyramid', default=False)
523 | parser.add_argument('--use_DS', type=bool, help='whether use deep supervision', default=False)
524 | parser.add_argument('--use_Res', type=bool, help='whether use residual connectivity', default=False)
525 | parser.add_argument('--use_bg', type=bool, help='whether use batch generator', default=False)
526 | parser.add_argument('--use_coord', type=bool, help='whether use coord conv', default=False)
527 | parser.add_argument('--use_softmax', type=bool, help='whether use softmax', default=False)
528 | parser.add_argument('--use_softpool', type=bool, help='whether use softpool', default=False)
529 | parser.add_argument('--use_aspp', type=bool, help='whether use aspp', default=False)
530 | parser.add_argument('--use_att', type=bool, help='whether use aspp', default=False)
531 | parser.add_argument('--use_se_loss', type=bool, help='whether use aspp', default=False)
532 | parser.add_argument('--pif_sigmoid', type=bool, help='whether use aspp', default=False)
533 | parser.add_argument('--paf_sigmoid', type=bool, help='whether use aspp', default=False)
534 | parser.add_argument('--final_double', type=bool, help='whether use aspp', default=False)
535 | parser.add_argument('--use_tanh', type=bool, help='whether use aspp', default=False)
536 | parser.add_argument('--norm', help='type of normalization', type=str, default='sync_bn',
537 | choices=['bn', 'gn', 'in', 'sync_bn'])
538 | parser.add_argument('--use_lw', type=bool, help='whether use lightweight', default=True)
539 | parser.add_argument('--lw_kernel', type=int, default=5)
540 | parser.add_argument('--act', help='type of activation function', type=str, default='relu',
541 | choices=['relu', 'lrelu', 'elu', 'gelu'])
542 | args = parser.parse_args()
543 |
544 | net = ResidualUNet3D(args=args, use_att=args.use_att, f_maps=[24, 48, 72, 108])
545 | print(net)
546 |
547 | # conv = SplAtConv3d(64, 32, 3)
548 | # print(conv)
549 | data = torch.rand([2, 1, 56, 56, 56])
550 | out = net(data)
551 | print(out.shape)
552 |
--------------------------------------------------------------------------------
/options/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def str2bool(v):
5 | if isinstance(v, bool):
6 | return v
7 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
8 | return True
9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
10 | return False
11 | else:
12 | raise argparse.ArgumentTypeError('Boolean value expected.')
13 |
14 |
15 | class BaseOptions():
16 | def __init__(self):
17 | self.parser = argparse.ArgumentParser(description='Parameters for 3D particle picking')
18 |
19 | # dataloader parameters
20 | self.parser.add_argument('--block_size', help='block size', type=int, default=72)
21 | self.parser.add_argument('--val_block_size', help='block size', type=int, default=0)
22 | self.parser.add_argument('--random_num', help='random number', type=int, default=0)
23 | self.parser.add_argument('--num_classes', help='number of classes', type=int, default=13)
24 | self.parser.add_argument('--use_bg', type=str2bool, help='whether use batch generator', default=False)
25 | self.parser.add_argument('--test_use_pad', type=str2bool, help='whether use coord conv', default=False)
26 | self.parser.add_argument('--pad_size', nargs='+', type=int, default=[12])
27 | self.parser.add_argument('--data_split', nargs='+', type=int, default=[0, 1, 0, 1, 0, 1])
28 | self.parser.add_argument('--configs', type=str, default='')
29 | self.parser.add_argument('--pre_configs', type=str, default='')
30 | self.parser.add_argument('--train_configs', type=str, default='')
31 | self.parser.add_argument('--ck_mode', type=str, default='')
32 | self.parser.add_argument('--val_configs', type=str, default='')
33 | self.parser.add_argument('--loader_type', type=str, default='dataloader_DynamicLoad', help="whether use DynamicLoad",
34 | # choices=["dataloader", "dataloader_DynamicLoad",
35 | # 'dataloader_DynamicLoad_CellSeg',
36 | # "dataloader_DynamicLoad_Semi"]
37 | )
38 | self.parser.add_argument('--sel_train_num', nargs='+', type=int)
39 | self.parser.add_argument('--rand_min', type=int, default=0)
40 | self.parser.add_argument('--rand_max', type=int, default=99)
41 | self.parser.add_argument('--rand_scale', type=int, default=100)
42 | self.parser.add_argument('--input_cat', type=str2bool, help='whether use input cat', default=False)
43 | self.parser.add_argument('--input_cat_items', nargs='+', type=str, default='None')
44 |
45 | # model parameters
46 | self.parser.add_argument('--network', help='network type', type=str, default='ResUNet',
47 | # choices=['unet', 'UMC', 'ResUnet', 'DoubleUnet', 'MFNet', 'DMFNet', 'DMFNet_down3',
48 | # 'NestUnet', 'VoxResNet', 'HighRes3DNet', 'HRNetv1']
49 | )
50 | self.parser.add_argument('--in_channels', help='input channels of the network', type=int, default=1)
51 | self.parser.add_argument('--f_maps', nargs='+', type=int, help="Feature numbers of ResUnet")
52 | self.parser.add_argument('--use_LAM', type=str2bool, help='whether use LAM', default=False)
53 | self.parser.add_argument('--use_IP', type=str2bool, help='whether use image pyramid', default=False)
54 | self.parser.add_argument('--use_DS', type=str2bool, help='whether use deep supervision', default=False)
55 | self.parser.add_argument('--use_wDS', type=str2bool, help='whether use deep supervision', default=False)
56 | self.parser.add_argument('--use_Res', type=str2bool, help='whether use residual connectivity', default=False)
57 | self.parser.add_argument('--use_coord', type=str2bool, help='whether use coord conv', default=False)
58 | self.parser.add_argument('--use_softmax', type=str2bool, help='whether use softmax', default=False)
59 | self.parser.add_argument('--use_sigmoid', type=str2bool, help='whether use sigmoid', default=False)
60 | self.parser.add_argument('--use_tanh', type=str2bool, help='whether use tanh', default=False)
61 | self.parser.add_argument('--use_softpool', type=str2bool, help='whether use softpool', default=False)
62 | self.parser.add_argument('--use_aspp', type=str2bool, help='whether use aspp', default=False)
63 | self.parser.add_argument('--use_se_loss', type=str2bool, help='whether use SE loss', default=False)
64 | self.parser.add_argument('--use_att', type=str2bool, help='whether use aspp', default=False)
65 | self.parser.add_argument('--initial_channels', help='initial_channels of NestUnet', type=int, default=16)
66 | self.parser.add_argument('--mf_groups', help='number of groups', type=int, default=16)
67 | self.parser.add_argument('--norm', help='type of normalization', type=str, default='bn',
68 | choices=['bn', 'gn', 'in', 'sync_bn'])
69 | self.parser.add_argument('--act', help='type of activation function', type=str, default='relu',
70 | choices=['relu', 'lrelu', 'elu', 'gelu'])
71 | self.parser.add_argument('--use_wds', type=str2bool, help='whether use weighted deep supervision',
72 | default=False)
73 | self.parser.add_argument('--add_dropout_layer', type=str2bool, help='whether use dropout layer in HighResNet',
74 | default=False)
75 | self.parser.add_argument('--dimensions', type=int, default=3, help='Dimensions of HighResNet')
76 | self.parser.add_argument('--use_paf', type=str2bool, help='PostFusion_orit: whether use part affinity field',
77 | default=False)
78 | self.parser.add_argument('--paf_sigmoid', type=str2bool, help='whether use sigmoid for the branch of '
79 | 'part affinity field', default=False)
80 | self.parser.add_argument('--pif_sigmoid', type=str2bool, help='whether use sigmoid for the branch of '
81 | 'part intensity field', default=False)
82 | self.parser.add_argument('--final_double', type=str2bool, help='whether use sigmoid for the branch of '
83 | 'part affinity field', default=False)
84 | self.parser.add_argument('--HRNet_c', type=int, default=12)
85 | self.parser.add_argument('--n_block', type=int, default=2)
86 | self.parser.add_argument('--reduce_ratio', type=int, default=1)
87 | self.parser.add_argument('--n_stages', nargs='+', type=int, default=[1, 1, 1, 1])
88 | self.parser.add_argument('--use_uncert', type=str2bool, help='whether use uncert for loss weights',
89 | default=False)
90 | self.parser.add_argument('--Gau_num', type=int, default=2)
91 | self.parser.add_argument('--use_seg_gau', type=str2bool, help='whether use seg and gau', default=False)
92 | self.parser.add_argument('--gau_thresh', type=float, default=0.5)
93 | self.parser.add_argument('--use_lw', type=str2bool, help='whether use lightweight', default=False)
94 | self.parser.add_argument('--lw_kernel', type=int, default=3)
95 |
96 | # training hyper-parameters
97 | self.parser.add_argument('--learning_rate', type=float, default=5e-5)
98 | self.parser.add_argument('--batch_size', help='batch size', type=int, default=32)
99 | self.parser.add_argument('--val_batch_size', help='batch size', type=int, default=0)
100 | self.parser.add_argument('--max_epoch', help='number of epochs', type=int, default=100)
101 | self.parser.add_argument('--loss_func_seg', help='seg loss function type', type=str, default='Dice')
102 | self.parser.add_argument('--loss_func_dn', help='denoising loss function type', type=str, default='MSE')
103 | self.parser.add_argument('--loss_func_paf', help='paf loss function type', type=str, default='MSE')
104 | self.parser.add_argument('--pred2d_3d', type=str2bool, help='whether use LAM', default=False)
105 | self.parser.add_argument('--threshold', type=float, default=0.5, help="calculate seg_metrics")
106 | self.parser.add_argument('--others', help='others', type=str, default='')
107 | self.parser.add_argument('--paf_weight', type=int, default=1, help='Weight for Paf branch')
108 | self.parser.add_argument('--border_value', type=int, help='border width', default=0)
109 | self.parser.add_argument('--dset_name', type=str, help="the name of dataset")
110 | self.parser.add_argument('--train_mode', type=str, default='train', help='train mode')
111 | self.parser.add_argument('--gpu_id', nargs='+', type=int, default=[0, 1, 2, 3], help='gpu id')
112 | self.parser.add_argument('--prf1_alpha', type=float, default=3, help="calculate seg_metrics")
113 | self.parser.add_argument('--running', type=str2bool, help='whether use LAM', default=False)
114 |
115 | # Contrastive Learning hyper-parameters
116 | self.parser.add_argument('--checkpoints', type=str, help='Checkpoint directory',
117 | default=None)
118 | self.parser.add_argument('--checkpoints_version', type=str, help='Checkpoint directory',
119 | default=None)
120 | self.parser.add_argument('--cent_feats', type=str, help='Checkpoint directory',
121 | default=None)
122 | self.parser.add_argument('--particle_idx', type=int, default=70, help='Index of reference particle')
123 | self.parser.add_argument('--sel_particle_num', type=int, default=100, help='Index of reference particle')
124 | self.parser.add_argument('--iteration_idx', type=int, default=0, help='Iteration index')
125 | self.parser.add_argument('--cent_kernel', type=int, default=1, help='Iteration index')
126 | self.parser.add_argument('--Sel_Referance', type=str2bool, default=False, help='Select Reference Particle')
127 | self.parser.add_argument('--step1', type=str2bool, default=False, help='Select Reference Particle')
128 | self.parser.add_argument('--step2', type=str2bool, default=False, help='Select Reference Particle')
129 | self.parser.add_argument('--dir_name', type=str, help='Directory name',
130 | default=None)
131 | self.parser.add_argument('--stride', type=int, default=8, help='Select Reference Particle')
132 | self.parser.add_argument('--seg_tau', type=float, default=0.95, help='Segmentation threshold')
133 | self.parser.add_argument('--use_mask', type=str2bool, help='use mask to cal loss for SSL',
134 | default=False)
135 | self.parser.add_argument('--use_ema', type=str2bool, default=False,
136 | help='use EMA model')
137 | self.parser.add_argument('--use_bg_part', type=str2bool, default=False,
138 | help='use background particle')
139 | self.parser.add_argument('--use_ice_part', type=str2bool, default=False,
140 | help='use ice bg')
141 | self.parser.add_argument('--use_SimSeg_iteration', type=str2bool, default=False,
142 | help='use SimSeg_iteration')
143 | self.parser.add_argument('--ema_decay', default=0.999, type=float,
144 | help='EMA decay rate')
145 | self.parser.add_argument('--T', type=float, default=0.5, help='Segmentation threshold')
146 | self.parser.add_argument('--coord_path', type=str, help='Coordiate path name',
147 | default=None)
148 |
149 | # test_parameters
150 | self.parser.add_argument('--test_idxs', nargs='+', type=int, default=[0])
151 | self.parser.add_argument('--save_pred', type=str2bool, help='whether use segmentation', default=False)
152 | self.parser.add_argument('--max_pxs', type=int, help='dilation pixel numbers', default=18)
153 | self.parser.add_argument('--de_duplication', type=str2bool, default=False, help='Whether use dilation')
154 | self.parser.add_argument('--test_mode', type=str, default='test_val', help='test mode')
155 | self.parser.add_argument('--paf_connect', type=str2bool, default=False, help='Whether use dilation')
156 | self.parser.add_argument('--Gau_nms', type=str2bool, default=False, help='Whether use gaussian NMS')
157 | self.parser.add_argument('--save_mrc', type=str2bool, default=False, help='Whether save .mrc file')
158 | self.parser.add_argument('--nms_kernel', type=int, help='kernel size for Gaussian NMS', default=3)
159 | self.parser.add_argument('--nms_topK', type=int, help='topK for Gaussian NMS', default=3)
160 | self.parser.add_argument('--pif_model', type=str, default='', help='pif model for paf-connect')
161 | self.parser.add_argument('--first_idx', type=int, help='first_idx', default=0)
162 | self.parser.add_argument('--use_CL', type=str2bool, default=False, help='Whether use Contrastive Learning')
163 | self.parser.add_argument('--use_cluster', type=str2bool, default=False, help='Whether use Contrastive Learning')
164 | self.parser.add_argument('--use_CL_DA', type=str2bool, default=False,
165 | help='Whether use DA for reference particle of Contrastive Learning')
166 | self.parser.add_argument('--CL_DA_fmt', type=str, default='mean',
167 | help='format of calculating similarity map under CL_DA')
168 | self.parser.add_argument('--ResearchTitle', type=str, default='None',
169 | help='format of calculating similarity map under CL_DA')
170 | self.parser.add_argument('--skip_4v94', type=bool, default=False,
171 | help='Whether to skip 4V94 evaluation or not. True in SHREC Cryo-ET 2021 results.')
172 | self.parser.add_argument('--skip_vesicles', type=bool, default=False,
173 | help='Whether to skip vesicles or not. True in SHREC Cryo-ET 2021 results.')
174 | self.parser.add_argument('--out_name', type=str, default='TestRes',
175 | help='file name for saving the predicted coordinates')
176 |
177 | self.parser.add_argument('--train_set_ids', type=str, default="0")
178 | self.parser.add_argument('--val_set_ids', type=str, default="0")
179 | self.parser.add_argument('--cfg_save_path', type=str, default=".")
180 | # optim parameters
181 | self.parser.add_argument('--optim', type=str, default='AdamW')
182 | self.parser.add_argument('--scheduler', type=str, default='OneCycleLR')
183 | self.parser.add_argument('--weight_decay', type=float, default=0.01, help="torch.optim: weight decay")
184 | self.parser.add_argument('--use_dilation', type=str2bool, default=False, help='Whether use dilation')
185 | self.parser.add_argument('--use_seg', type=str2bool, default=False, help='Whether use dilation')
186 | self.parser.add_argument('--use_eval', type=str2bool, default=False, help='Whether use dilation')
187 |
188 | # loss parameters
189 | self.parser.add_argument('--use_weight', type=str2bool, help='whether use different weights for cls losses',
190 | default=False)
191 | self.parser.add_argument('--NoBG', type=str2bool,
192 | help='whether calculate BG loss (the 0th dim of softmax outputs)',
193 | default=False)
194 | self.parser.add_argument('--pad_loss', type=str2bool, help='whether use padding loss', default=False)
195 | self.parser.add_argument('--alpha', type=float, default=0.7, help="Focal Tversky Loss: alpha * FP")
196 | self.parser.add_argument('--beta', type=float, default=0.3, help="Focal Tversky Loss: beta * FN")
197 | self.parser.add_argument('--gamma', type=float, default=0.75, help="Focal Tversky Loss: focal gamma")
198 | self.parser.add_argument('--eta', type=float, default=0.3, help="Dice_SE_Loss: weight of SE loss")
199 | self.parser.add_argument('--FL_a0', type=float, default=0.1, help="Soft_FL Loss: weight of a0")
200 | self.parser.add_argument('--FL_a1', type=float, default=0.9, help="Soft_FL Loss: weight of a1")
201 |
202 | # eval parameters
203 | self.parser.add_argument('--JudgeInDilation', type=str2bool, default=False)
204 | self.parser.add_argument('--save_FPsTPs', type=str2bool, default=False,
205 | help='Whether save the results of FP and TP')
206 | self.parser.add_argument('--de_dup_fmt', type=str, default='fmt4', help='de-duplication format')
207 | self.parser.add_argument('--eval_str', type=str, default='class', help='Whether use dilation')
208 | self.parser.add_argument('--min_vol', type=int, default=100, help='Minimum volume')
209 | self.parser.add_argument('--mini_dist', type=int, default=10, help='Minimum volume')
210 | self.parser.add_argument('--min_dist', type=int, default=10, help='Minimum volume')
211 | self.parser.add_argument('--eval_cls', type=str2bool, default=False, help='Minimum volume')
212 | self.parser.add_argument('--class_checkpoints', type=str, help='Checkpoint directory',
213 | default=None)
214 | self.parser.add_argument('--meanPool_NMS', type=str2bool, default=False, help='mean_pool NMS')
215 | self.parser.add_argument('--meanPool_kernel', type=int, default=5, help='mean_pool NMS')
216 |
217 | def gather_options(self):
218 | args = self.parser.parse_args()
219 | return args
220 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | pandas
3 | mrcfile
4 | scikit-image
5 | pillow==8.3.2
6 | pycm
7 | scikit-plot
8 | pyqtgraph==0.12.1
9 | PyQt5==5.15.4
10 | batchgenerators==0.21
11 | tinyaes
12 | pytorch_lightning==1.1.0
13 | opencv-python==4.2.0.34
14 | numpy==1.24.0
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | import torch
5 | from torch import nn
6 | import mrcfile
7 | import pytorch_lightning as pl
8 | from torch.utils.data import DataLoader
9 | import sys
10 | # from dataset.dataloader import Dataset_ClsBased
11 | import pandas as pd
12 | import importlib
13 | from glob import glob
14 | import matplotlib.pyplot as plt
15 | from pytorch_lightning import Trainer
16 | from options.option import BaseOptions
17 | from model.model_loader import get_model
18 | from utils.misc import combine, get_centroids, de_dup, cal_metrics_NMS_OneCls
19 | import json
20 | from dataset.dataloader_DynamicLoad import Dataset_ClsBased
21 |
22 |
23 | def test_func(args, stdout=None):
24 | if stdout is not None:
25 | save_stdout = sys.stdout
26 | save_stderr = sys.stderr
27 | sys.stdout = stdout
28 | sys.stderr = stdout
29 | for test_idx in args.test_idxs:
30 | model_name = args.checkpoints.split('/')[-4] + '_' + args.checkpoints.split('/')[-1].split('-')[0]
31 | # load config parameters
32 | if len(args.configs) > 0:
33 | with open(args.configs, 'r') as f:
34 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs='))
35 |
36 | start_time = time.time()
37 | args.data_split[-2] = test_idx
38 | args.data_split[-1] = test_idx + 1
39 |
40 | num_name = pd.read_csv(os.path.join(cfg["tomo_path"], 'num_name.csv'), sep='\t', header=None)
41 | dir_list = num_name.iloc[:, 1]
42 | dir_name = dir_list[args.data_split[-2]]
43 | print(dir_name)
44 |
45 | tomo_file = glob(cfg["tomo_path"] + "/*%s" % cfg["tomo_format"])[0]
46 | data_file = mrcfile.open(tomo_file, permissive=True)
47 | data_shape = data_file.data.shape
48 | print(data_shape)
49 | dataset = cfg["dset_name"]
50 |
51 | if args.use_seg:
52 | for pad_size in args.pad_size:
53 |
54 | class UNetTest(pl.LightningModule):
55 | def __init__(self):
56 | super(UNetTest, self).__init__()
57 | self.model = get_model(args)
58 | self.partical_volume = 4 / 3 * np.pi * (cfg["label_diameter"] / 2) ** 3
59 | self.num_classes = args.num_classes
60 |
61 | def forward(self, x):
62 | return self.model(x)
63 |
64 | def test_step(self, test_batch, batch_idx):
65 | with torch.no_grad():
66 | img, label, index = test_batch
67 | index = torch.cat([i.view(1, -1) for i in index], dim=0).permute(1, 0)
68 | if args.use_paf:
69 | seg_output, paf_output, logsigma1 = self.forward(img)
70 | else:
71 | seg_output = self.forward(img)
72 |
73 | if args.test_use_pad:
74 | mp_num = int(sorted([int(i) for i in cfg["ocp_diameter"].split(',')])[-1] / (args.meanPool_kernel - 1) + 1)
75 | if args.num_classes > 1:
76 | return self._nms_v2(seg_output[:, 1:], kernel=args.meanPool_kernel,
77 | mp_num=mp_num, positions=index)
78 | else:
79 | return self._nms_v2(seg_output[:, :], kernel=args.meanPool_kernel,
80 | mp_num=mp_num, positions=index)
81 |
82 | def test_step_end(self, outputs):
83 | return outputs
84 |
85 | def test_epoch_end(self, epoch_output):
86 | with torch.no_grad():
87 | if args.meanPool_NMS:
88 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy()
89 | print('coords_out:', coords_out.shape)
90 | if args.de_duplication:
91 | centroids = de_dup(coords_out, args)
92 | out_dir = '/'.join(args.checkpoints.split('/')[:-2]) + f'/{args.out_name}'
93 | os.makedirs(os.path.join(out_dir, 'Coords_withArea'), exist_ok=True)
94 | np.savetxt(os.path.join(out_dir, 'Coords_withArea', dir_name + '.coords'),
95 | centroids.astype(float),
96 | fmt='%s',
97 | delimiter='\t')
98 |
99 | coords = centroids[:, 0:4]
100 | os.makedirs(os.path.join(out_dir, 'Coords_All'), exist_ok=True)
101 | np.savetxt(os.path.join(out_dir, 'Coords_All', dir_name + '.coords'),
102 | coords.astype(int),
103 | fmt='%s',
104 | delimiter='\t')
105 |
106 |
107 | def test_dataloader(self):
108 | if args.test_mode == 'test':
109 | test_dataset = Dataset_ClsBased(mode='test',
110 | block_size=args.block_size,
111 | num_class=args.num_classes,
112 | random_num=args.random_num,
113 | use_bg=args.use_bg,
114 | data_split=args.data_split,
115 | test_use_pad=args.test_use_pad,
116 | pad_size=pad_size,
117 | cfg=cfg,
118 | args=args)
119 | test_dataloader = DataLoader(test_dataset,
120 | shuffle=False,
121 | batch_size=args.batch_size,
122 | num_workers=8 if args.batch_size >= 32 else 4,
123 | pin_memory=False)
124 |
125 | self.len_block = test_dataset.test_len
126 | self.data_shape = test_dataset.data_shape
127 | self.occupancy_map = test_dataset.occupancy_map
128 | self.gt_coords = test_dataset.gt_coords
129 | self.dir_name = test_dataset.dir_name
130 | return test_dataloader
131 | elif args.test_mode == 'test_only':
132 | test_dataset = Dataset_ClsBased(mode='test_only',
133 | block_size=args.block_size,
134 | num_class=args.num_classes,
135 | random_num=args.random_num,
136 | use_bg=args.use_bg,
137 | data_split=args.data_split,
138 | test_use_pad=args.test_use_pad,
139 | pad_size=pad_size,
140 | cfg=cfg,
141 | args=args)
142 | if args.batch_size <= 32:
143 | num_work = 4
144 | elif args.batch_size <= 64:
145 | num_work = 8
146 | elif args.batch_size <= 128:
147 | num_work = 8
148 | else:
149 | num_work = 16
150 | test_dataloader = DataLoader(test_dataset,
151 | shuffle=False,
152 | batch_size=args.batch_size,
153 | num_workers=num_work,
154 | pin_memory=False)
155 | self.len_block = test_dataset.test_len
156 | self.data_shape = test_dataset.data_shape
157 | self.dir_name = test_dataset.dir_name
158 | return test_dataloader
159 |
160 | def _nms_v2(self, pred, kernel=3, mp_num=5, positions=None):
161 | pred = torch.where(pred > 0.5, 1, 0)
162 | meanPool = nn.AvgPool3d(kernel, 1, kernel // 2).cuda()
163 | maxPool = nn.MaxPool3d(kernel, 1, kernel // 2).cuda()
164 | hmax = pred.clone().float()
165 | for _ in range(mp_num):
166 | hmax = meanPool(hmax)
167 | pred = hmax.clone()
168 | hmax = maxPool(hmax)
169 | keep = ((hmax == pred).float()) * ((pred > 0.1).float())
170 | coords = keep.nonzero() # [N, 5]
171 | coords = coords[coords[:, 2] >= args.pad_size[0]]
172 | coords = coords[coords[:, 2] <= args.block_size - args.pad_size[0]]
173 | coords = coords[coords[:, 3] >= args.pad_size[0]]
174 | coords = coords[coords[:, 3] <= args.block_size - args.pad_size[0]]
175 | coords = coords[coords[:, 4] >= args.pad_size[0]]
176 | coords = coords[coords[:, 4] <= args.block_size - args.pad_size[0]]
177 |
178 | try:
179 | h_val = torch.cat(
180 | [hmax[item[0], item[1], item[2], item[3]:item[3] + 1, item[4]:item[4] + 1] for item in
181 | coords], dim=0)
182 | leftTop_coords = positions[coords[:, 0]] - (args.block_size // 2) - args.pad_size[0]
183 | coords[:, 2:5] = coords[:, 2:5] + leftTop_coords
184 |
185 | pred_final = torch.cat(
186 | [coords[:, 1:2] + 1, coords[:, 4:5], coords[:, 3:4], coords[:, 2:3], h_val],
187 | dim=1)
188 |
189 | return pred_final
190 | except:
191 | # print('haha')
192 | return torch.zeros([0, 5]).cuda()
193 |
194 | # load trained checkpoints to model
195 | model = UNetTest.load_from_checkpoint(args.checkpoints)
196 |
197 | # model = UNetTest().model
198 | model.eval()
199 | runner = Trainer(gpus=args.gpu_id, #
200 | accelerator='dp'
201 | )
202 | os.makedirs(f'result/{dataset}/{model_name}/', exist_ok=True)
203 |
204 | runner.test(model=model)
205 |
206 | end_time = time.time()
207 | used_time = end_time - start_time
208 | save_path = '/'.join(args.checkpoints.split('/')[:-2]) + f'/{args.out_name}'
209 | os.makedirs(save_path, exist_ok=True)
210 | pad_size = args.pad_size[0]
211 |
212 | print('*' * 100)
213 | print('Testing Finished!')
214 | print('*' * 100)
215 | if stdout is not None:
216 | sys.stdout = save_stdout
217 | sys.stderr = save_stderr
218 |
219 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import os
5 | import sys
6 | from dataset.dataloader_DynamicLoad import Dataset_ClsBased
7 | from torch.utils.data import DataLoader
8 | from torchvision.utils import make_grid
9 | import pytorch_lightning as pl
10 | from pytorch_lightning import Trainer
11 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
12 | from pytorch_lightning import loggers
13 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
14 | from utils.loss import DiceLoss
15 | from utils.metrics import seg_metrics
16 | from utils.colors import COLORS
17 | from model.model_loader import get_model
18 | from utils.misc import combine, cal_metrics_NMS_OneCls, get_centroids, cal_metrics_MultiCls, combine_torch
19 | from sklearn.metrics import precision_recall_fscore_support
20 | import time
21 | import json
22 |
23 | if not sys.warnoptions:
24 | import warnings
25 |
26 | warnings.simplefilter("ignore")
27 |
28 |
29 | class UNetExperiment(pl.LightningModule):
30 | def __init__(self, args):
31 | if args.f_maps is None:
32 | args.f_maps = [32, 64, 128, 256]
33 | print(args.pad_size)
34 |
35 | if len(args.configs) > 0:
36 | with open(args.configs, 'r') as f:
37 | self.cfg = json.loads(''.join(f.readlines()).lstrip('train_configs='))
38 | else:
39 | self.cfg = {}
40 | if len(args.train_configs) > 0:
41 | with open(args.train_configs, 'r') as f:
42 | self.train_cfg = json.loads(''.join(f.readlines()).lstrip('train_configs='))
43 | else:
44 | self.train_cfg = self.cfg
45 |
46 | if len(args.val_configs) > 0:
47 | with open(args.val_configs, 'r') as f:
48 | self.val_config = json.loads(''.join(f.readlines()).lstrip('train_configs='))
49 | else:
50 | self.val_cfg = self.cfg
51 |
52 | super(UNetExperiment, self).__init__()
53 | self.save_hyperparameters()
54 | self.model = get_model(args)
55 | print(self.model)
56 |
57 | if args.loss_func_seg == 'Dice':
58 | self.loss_function_seg = DiceLoss(args=args)
59 |
60 | if 'gaussian' in self.val_cfg["label_type"]:
61 | self.thresholds = np.linspace(0.15, 0.45, 7)
62 | elif 'sphere' in self.val_cfg["label_type"]:
63 | self.thresholds = np.linspace(0.2, 0.80, 13)
64 | self.partical_volume = 4 / 3 * np.pi * (self.val_cfg["label_diameter"] / 2) ** 3
65 | self.args = args
66 |
67 | def forward(self, x):
68 | return self.model(x)
69 |
70 | def training_step(self, train_batch, batch_idx):
71 | args = self.args
72 | img, label, index = train_batch
73 | img = img.to(torch.float32)
74 | seg_output = self.forward(img)
75 | if args.use_mask:
76 | mask = label.clone().detach()
77 | mask[mask > 0] = 1
78 | label[label < 255] = 0
79 | label[label > 0] = 1
80 |
81 | # update label and mask according to label-threshold
82 | label[seg_output > args.seg_tau] = 1
83 | mask[seg_output > args.seg_tau] = 1
84 | mask[seg_output < (1 - args.seg_tau)] = 1
85 |
86 | seg_output = seg_output * mask
87 | loss_seg = self.loss_function_seg(seg_output, label)
88 | self.log('train_loss', loss_seg, on_step=False, on_epoch=True)
89 | return loss_seg
90 |
91 | def validation_step(self, val_batch, batch_idx):
92 | args = self.args
93 | with torch.no_grad():
94 | img, label, index = val_batch
95 | index = torch.cat([i.view(1, -1) for i in index], dim=0).permute(1, 0)
96 | img = img.to(torch.float32)
97 | self.seg_output = self.forward(img)
98 |
99 | if (batch_idx >= self.len_block // args.batch_size and args.test_mode == "test_val") or \
100 | args.test_mode == "test" or args.test_mode == "val" or args.test_mode == "val_v1":
101 | loss_seg = self.loss_function_seg(self.seg_output, label)
102 |
103 | precision, recall, f1_score, iou = seg_metrics(self.seg_output, label, threshold=args.threshold)
104 |
105 | self.log('val_loss', loss_seg, on_step=False, on_epoch=True)
106 | self.log('val_precision', precision, on_step=False, on_epoch=True)
107 | self.log('val_recall', recall, on_step=False, on_epoch=True)
108 | self.log('val_f1', f1_score, on_step=False, on_epoch=True)
109 | self.log('val_iou', iou, on_step=False, on_epoch=True)
110 |
111 | # return loss_seg
112 | tensorboard = self.logger.experiment
113 |
114 | if (batch_idx == (self.len_block // args.batch_size + 1) and args.test_mode == 'test_val') or \
115 | (batch_idx == 0 and args.test_mode == 'test') or \
116 | (batch_idx == 0 and args.test_mode == 'val') or \
117 | (
118 | batch_idx == 0 and args.test_mode == 'val_v1'): # and True == False and self.current_epoch % 1 == 0
119 | img /= img.abs().max() # [-1,1]
120 | img = img * 0.5 + 0.5 # [0, 1]
121 | img_ = img[0, :, 0:(args.block_size - 1):5, :, :].permute(1, 0, 2, 3).repeat(
122 | (1, 3, 1, 1)) # sample0: [5, 3, y, x]
123 |
124 | label_ = label[0, :, 0:(args.block_size - 1):5, :, :] # sample0 [15, y, x]
125 | temp = torch.zeros(
126 | (len(np.arange(0, args.block_size - 1, 5)), args.block_size, args.block_size, 3)).float()
127 | # print(label.shape, temp.shape)
128 | for idx in np.arange(label_.shape[0]):
129 | temp[label_[idx] > 0.5] = torch.tensor(
130 | COLORS[(idx + 1) if (args.num_classes == 1
131 | or args.use_paf or
132 | label_.shape[0] == 1) else idx]).float()
133 | label__ = temp.permute(0, 3, 1, 2).contiguous().cuda() # [15, 3, y, x]
134 |
135 | seg_output_ = self.seg_output[0, :, 0:(args.block_size - 1):5, :, :] # sample0 [15, y, x]
136 | seg_threshes = [0.5, 0.3, 0.2, 0.15, 0.1, 0.05]
137 | seg_preds = []
138 | for thresh in seg_threshes:
139 | temp = torch.zeros(
140 | (len(np.arange(0, args.block_size - 1, 5)), args.block_size, args.block_size, 3)).float()
141 | for idx in np.arange(seg_output_.shape[0]):
142 | temp[seg_output_[idx] > thresh] = torch.tensor(
143 | COLORS[(idx + 1) if (args.num_classes == 1
144 | or args.use_paf or
145 | seg_output_.shape[0] == 1) else idx]).float()
146 | seg_preds.append(temp.permute(0, 3, 1, 2).contiguous().cuda()) # [15, 3, y, x]
147 |
148 | seg_preds = torch.cat(seg_preds, dim=0)
149 |
150 | img_label_seg = torch.cat([img_, label__, seg_preds], dim=0)
151 | img_label_seg = make_grid(img_label_seg, (args.block_size - 1) // 5 + 1, padding=2, pad_value=120)
152 |
153 | tensorboard.add_image('img_label_seg', img_label_seg, self.current_epoch, dataformats="CHW")
154 |
155 | if args.num_classes > 1:
156 | return self._nms_v2(self.seg_output[:, 1:], kernel=args.meanPool_kernel, mp_num=6, positions=index)
157 | else:
158 | return self._nms_v2(self.seg_output[:, :], kernel=args.meanPool_kernel, mp_num=6, positions=index)
159 |
160 | def validation_step_end(self, outputs):
161 | args = self.args
162 | if 'test' in args.test_mode:
163 | return outputs
164 |
165 | def validation_epoch_end(self, epoch_output):
166 | args = self.args
167 | with torch.no_grad():
168 | if 'test' in args.test_mode:
169 | if args.meanPool_NMS:
170 | if args.num_classes == 1:
171 | # coords_out: [N, 5]
172 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy()
173 | if coords_out.shape[0] > 50000:
174 | loc_p, loc_r, loc_f1, avg_dist = 1e-10, 1e-10, 1e-10, 100
175 | else:
176 | loc_p, loc_r, loc_f1, avg_dist = \
177 | cal_metrics_NMS_OneCls(coords_out,
178 | self.gt_coords,
179 | self.occupancy_map,
180 | self.cfg,
181 | )
182 | print("*" * 100)
183 | print(f"Precision:{loc_p}")
184 | print(f"Recall:{loc_r}")
185 | print(f"F1-score:{loc_f1}")
186 | print(f"Avg-dist:{avg_dist}")
187 | print("*" * 100)
188 | self.log('cls_precision', loc_p, on_step=False, on_epoch=True)
189 | self.log('cls_recall', loc_r, on_step=False, on_epoch=True)
190 | self.log('cls_f1', loc_f1, on_step=False, on_epoch=True)
191 | self.log('cls_dist', avg_dist, on_step=False, on_epoch=True)
192 | pr = (loc_p * (loc_r ** args.prf1_alpha)) / (loc_p + (loc_r ** args.prf1_alpha) + 1e-10)
193 | self.log(f'cls_pr_alpha{args.prf1_alpha:.1f}', pr, on_step=False, on_epoch=True)
194 | time.sleep(0.5)
195 | else:
196 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy()
197 | loc_p, loc_r, loc_f1, loc_miss, avg_dist, gt_classes, pred_classes, self.num2pdb, cls_f1 = \
198 | cal_metrics_MultiCls(coords_out, self.gt_coords, self.occupancy_map, self.cfg, args,
199 | args.pad_size, self.dir_name, self.partical_volume)
200 | self.log('cls_f1', cls_f1, on_step=False, on_epoch=True)
201 |
202 | def train_dataloader(self):
203 | args = self.args
204 | train_dataset = Dataset_ClsBased(mode=args.train_mode,
205 | block_size=args.block_size,
206 | num_class=args.num_classes,
207 | random_num=args.random_num,
208 | use_bg=args.use_bg,
209 | data_split=args.data_split,
210 | use_paf=args.use_paf,
211 | cfg=self.train_cfg,
212 | args=args)
213 | return DataLoader(train_dataset,
214 | batch_size=args.batch_size,
215 | num_workers=8 if args.batch_size >= 32 else 4,
216 | shuffle=True,
217 | pin_memory=False)
218 |
219 | def val_dataloader(self):
220 | args = self.args
221 | val_dataset = Dataset_ClsBased(mode=args.test_mode,
222 | block_size=args.val_block_size,
223 | num_class=args.num_classes,
224 | random_num=args.random_num,
225 | use_bg=args.use_bg,
226 | data_split=args.data_split,
227 | test_use_pad=args.test_use_pad,
228 | pad_size=args.pad_size,
229 | use_paf=args.use_paf,
230 | cfg=self.val_cfg,
231 | args=args)
232 |
233 | self.len_block = val_dataset.test_len
234 | if 'test' in args.test_mode:
235 | self.data_shape = val_dataset.data_shape
236 | self.occupancy_map = val_dataset.occupancy_map
237 | self.gt_coords = val_dataset.gt_coords
238 | self.dir_name = val_dataset.dir_name
239 |
240 | val_dataloader1 = DataLoader(val_dataset,
241 | batch_size=args.val_batch_size,
242 | num_workers=8 if args.batch_size >= 32 else 4,
243 | shuffle=False,
244 | pin_memory=False)
245 | return val_dataloader1
246 |
247 | def _nms_v2(self, pred, kernel=3, mp_num=5, positions=None):
248 | args = self.args
249 | pred = torch.where(pred > 0.5, 1, 0)
250 | meanPool = nn.AvgPool3d(kernel, 1, kernel // 2).cuda()
251 | maxPool = nn.MaxPool3d(kernel, 1, kernel // 2).cuda()
252 | hmax = pred.clone().float()
253 | for _ in range(mp_num):
254 | hmax = meanPool(hmax)
255 | pred = hmax.clone()
256 | hmax = maxPool(hmax)
257 | keep = ((hmax == pred).float()) * ((pred > 0.1).float())
258 | coords = keep.nonzero() # [N, 5]
259 | if coords.shape[0] > 2000:
260 | return torch.zeros([1, 5]).cuda()
261 | coords = coords[coords[:, 2] >= args.pad_size]
262 | coords = coords[coords[:, 2] < args.block_size - args.pad_size]
263 | coords = coords[coords[:, 3] >= args.pad_size]
264 | coords = coords[coords[:, 3] < args.block_size - args.pad_size]
265 | coords = coords[coords[:, 4] >= args.pad_size]
266 | coords = coords[coords[:, 4] < args.block_size - args.pad_size]
267 |
268 | try:
269 | h_val = torch.cat(
270 | [hmax[item[0], item[1], item[2], item[3]:item[3] + 1, item[4]:item[4] + 1] for item in
271 | coords], dim=0)
272 | leftTop_coords = positions[coords[:, 0]] - (args.block_size // 2) - args.pad_size
273 | coords[:, 2:5] = coords[:, 2:5] + leftTop_coords
274 |
275 | pred_final = torch.cat(
276 | [coords[:, 1:2] + 1, coords[:, 4:5], coords[:, 3:4], coords[:, 2:3], h_val],
277 | dim=1)
278 |
279 | return pred_final
280 | except:
281 | return torch.zeros([0, 5]).cuda()
282 |
283 | def configure_optimizers(self):
284 | args = self.args
285 | if args.optim == 'SGD':
286 | optimizer = torch.optim.SGD(self.parameters(),
287 | lr=args.learning_rate,
288 | momentum=0.9, weight_decay=0.001
289 | )
290 | elif args.optim == 'Adam':
291 | optimizer = torch.optim.Adam(self.parameters(),
292 | lr=args.learning_rate,
293 | betas=(0.9, 0.99)
294 | )
295 | elif args.optim == 'AdamW':
296 | optimizer = torch.optim.AdamW(self.parameters(),
297 | lr=args.learning_rate,
298 | betas=(0.9, 0.99),
299 | weight_decay=args.weight_decay
300 | )
301 |
302 | if args.scheduler == 'OneCycleLR':
303 | sched = torch.optim.lr_scheduler.OneCycleLR(optimizer,
304 | max_lr=args.learning_rate,
305 | total_steps=args.max_epoch,
306 | pct_start=0.1,
307 | anneal_strategy='cos',
308 | div_factor=30,
309 | final_div_factor=100)
310 | lr_dict = {
311 | "scheduler": sched,
312 | "interval": "epoch",
313 | "frequency": 1
314 | }
315 |
316 | if args.scheduler is None:
317 | return [optimizer]
318 | else:
319 | return [optimizer], [lr_dict]
320 |
321 |
322 | def train_func(args, stdout=None):
323 | if stdout is not None:
324 | save_stdout = sys.stdout
325 | save_stderr = sys.stderr
326 | sys.stdout = stdout
327 | sys.stderr = stdout
328 |
329 | args.pad_size = args.pad_size[0]
330 | if 'test' in args.test_mode:
331 | checkpoint_callback = ModelCheckpoint(save_top_k=1,
332 | monitor=f'cls_pr_alpha{args.prf1_alpha:.1f}' if args.num_classes == 1 else 'cls_f1',
333 | mode='max')
334 | else:
335 | checkpoint_callback = ModelCheckpoint(save_top_k=1,
336 | monitor='val_loss',
337 | mode='min')
338 |
339 | model = UNetExperiment(args)
340 | logger_name = "{}_{}_BlockSize{}_{}Loss_MaxEpoch{}_bs{}_lr{}_IP{}_bg{}_coord{}_Softmax{}_{}_{}_TN{}".format(
341 | model.train_cfg["dset_name"], args.network, args.block_size, args.loss_func_seg, args.max_epoch,
342 | args.batch_size,
343 | args.learning_rate,
344 | int(args.use_IP), int(args.use_bg), int(args.use_coord),
345 | int(args.use_softmax), args.norm, args.others, args.sel_train_num)
346 |
347 | os.makedirs(f"{model.train_cfg['base_path']}/runs/{model.train_cfg['dset_name']}", exist_ok=True)
348 | tb_logger = loggers.TensorBoardLogger(f"{model.train_cfg['base_path']}/runs/{model.train_cfg['dset_name']}",
349 | name=logger_name)
350 | lr_monitor = LearningRateMonitor(logging_interval='step')
351 |
352 | runner = Trainer(min_epochs=min(50, args.max_epoch),
353 | max_epochs=args.max_epoch,
354 | logger=tb_logger,
355 | gpus=args.gpu_id,
356 | checkpoint_callback=checkpoint_callback,
357 | callbacks=[lr_monitor],
358 | accelerator='dp',
359 | precision=32,
360 | profiler=True,
361 | sync_batchnorm=True,
362 | resume_from_checkpoint=args.checkpoints)
363 |
364 |
365 | try:
366 | runner.fit(model)
367 | print('*' * 100)
368 | print('Training Finished')
369 | print(f'Training pid:{os.getpid()}')
370 | print('*' * 100)
371 | torch.cuda.empty_cache()
372 | if stdout is not None:
373 | sys.stderr = save_stderr
374 | sys.stdout = save_stdout
375 | return os.getpid()
376 | except:
377 | torch.cuda.empty_cache()
378 | if stdout is not None:
379 | stdout.flush()
380 | stdout.write('Training Exception!')
381 | sys.stderr = save_stderr
382 | sys.stdout = save_stdout
383 | return os.getpid()
--------------------------------------------------------------------------------
/tutorials/A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md:
--------------------------------------------------------------------------------
1 | # **A tutorial of single-class particle picking on EMPIAR-10045 dataset**
2 | ## **Step 1. Preprocessing**
3 |
4 | - Data preparation
5 |
6 |
7 | The sample dataset of EMPIAR-10045 can be download in one of two ways:
8 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi
9 | - Microsoft onedrive Link: [https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew](https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew); verification code: cbmi
10 |
11 | - Data structure
12 |
13 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain:
14 | - tomograms(with extension .mrc or .rec)
15 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand).
16 |
17 |
18 | The data should be organized as follows:
19 | ```
20 | ├── /base/path
21 | │ ├── raw_data
22 | │ │ ├── IS002_291013_005_iconmask2_norm_rot_cutZ.coords
23 | │ │ └── IS002_291013_005_iconmask2_norm_rot_cutZ.mrc
24 | │ │ └── IS002_291013_006_iconmask2_norm_rot_cutZ.mrc
25 | │ │ └── IS002_291013_007_iconmask2_norm_rot_cutZ.mrc
26 | │ │ └── IS002_291013_008_iconmask2_norm_rot_cutZ.mrc
27 | │ │ └── IS002_291013_009_iconmask2_norm_rot_cutZ.mrc
28 | │ │ └── IS002_291013_010_iconmask2_norm_rot_cutZ.mrc
29 | │ │ └── IS002_291013_011_iconmask2_norm_rot_cutZ.mrc
30 | ```
31 |
32 | For above data, `IS002_291013_005_iconmask2_norm_rot_cutZ.mrc` can be used as train/val dataset, since they all have coordinate files (`IS002_291013_005_iconmask2_norm_rot_cutZ.coords`). If a tomogram (e.g. `IS002_291013_006_iconmask2_norm_rot_cutZ.mrc`) has no matual annotation (`IS002_291013_006_iconmask2_norm_rot_cutZ.coords`), it cannot be used as train/val datasets.
33 |
34 | - Input & Output
35 |
36 |
37 |

38 |
39 |
40 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows:
41 | - `input`
42 | - Dataset name: e.g. EMPIAR_10045_preprocess
43 | - Base path: path to base folder
44 | - Coords path: path to raw_data folder
45 | - Coords format: .csv, .coords or .txt
46 | - Tomogram path: path to raw_data folder
47 | - Tomogram format: .mrc or .rec
48 | - Number of classes: multiple classes of macromolecules also can be localized separately
49 | - `Output`
50 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter.
51 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma.
52 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page
53 |
54 |
55 | ### **Step 2. Training of DeepETPicker**
56 |
57 |
58 |
59 | Note: Before `Training of DeepETPicker`, please do `Step 1. Preprocessing` first to ensure that the basic parameters required for training are provided.
60 |
61 |
62 |

63 |
64 |
65 | In practice, default parameters can give you a good enough result.
66 |
67 | *Training parameter description:*
68 |
69 | - Dataset name: e.g. EMPIAR_10045_train
70 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}.
71 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids:
72 | - 0,2, ...: different tomogram ids are separated with a comma.
73 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids.
74 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset.
75 | - Number of classes: particle classes you want to pick
76 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error.
77 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72.
78 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12.
79 | - Learning rate: the step size at each iteration while moving toward a minimum of a loss function.
80 | - Max epoch: total training epochs. The default value 60 is usually sufficient.
81 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
82 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again.
83 |
84 | ### **Step 3. Inference of DeepETPicker**
85 |
86 |
87 |

88 |
89 |
90 | In practice, default parameters can give you a good enough result.
91 |
92 | *Inference parameter description:*
93 |
94 | - Train Configs: path to the configuration file which has been saved in the `Training` step
95 | - Networks weights: path to the model which has be generated in the `Training` step
96 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where `S = N - 2*Pad_size`.
97 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
98 |
99 | ### **Step 4. Particle visualization and mantual picking**
100 |
101 |
102 |

103 |
104 |
105 | - *Showing Tomogram*
106 |
107 | You can click the `Load tomogram` button on this page to load the tomogram.
108 |
109 | - *Showing Labels*
110 |
111 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI.
112 |
113 | - *Parameter Adjustment*
114 |
115 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided:
116 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects
117 | - Contrast: when choosing hist-equ, histogram equalization can be performed
118 |
119 |
120 | - *Position Slider*
121 |
122 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used.
123 |
124 |
125 |
--------------------------------------------------------------------------------
/tutorials/A_tutorial_of_particlePicking_on_SHREC2021_dataset.md:
--------------------------------------------------------------------------------
1 | # **A tutorial of multiple-class particle picking on SHREC2021 dataset**
2 | ## **Step 1. Preprocessing**
3 |
4 | - Data preparation
5 |
6 |
7 | The sample dataset of SHREC_2021 can be download in one of two ways:
8 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi
9 | - Microsoft onedrive Link: [https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew](https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew); verification code: cbmi
10 |
11 | - Data structure
12 |
13 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain:
14 | - tomograms(with extension .mrc or .rec)
15 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand).
16 |
17 |
18 |
19 | The complete data should be organized as follows:
20 | ```
21 | ├── /base/path
22 | │ ├── raw_data
23 | │ │ ├── model_0.coords
24 | │ │ └── model_0.mrc
25 | │ │ ├── model_1.coords
26 | │ │ └── model_1.mrc
27 | │ │ └── model_2.coords
28 | │ │ └── model_2.mrc
29 | │ │ ├── model_3.coords
30 | │ │ └── model_3.mrc
31 | │ │ └── model_4.coords
32 | │ │ └── model_4.mrc
33 | │ │ ├── model_5.coords
34 | │ │ └── model_5.mrc
35 | │ │ └── model_6.coords
36 | │ │ └── model_6.mrc
37 | │ │ ├── model_7.coords
38 | │ │ └── model_7.mrc
39 | │ │ └── model_8.coords
40 | │ │ └── model_8.mrc
41 | │ │ └── model_9.mrc
42 | ```
43 |
44 | For above data, `model_0.mrc` to `model_8.mrc` can be used as train/val dataset, since they all have coordinate files (`model_0.coords` to `model_8.coords`). If a tomogram (e.g. `model_9.mrc`) has no matual annotation (`model_9.coords`), it cannot be used as train/val datasets.
45 |
46 |
47 | - Input & Output
48 |
49 |
50 |

51 |
52 |
53 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows:
54 | - `input`
55 | - Dataset name: e.g. SHREC_2021_preprocess
56 | - Base path: path to base folder
57 | - Coords path: path to raw_data folder
58 | - Coords format: .csv, .coords or .txt
59 | - Tomogram path: path to raw_data folder
60 | - Tomogram format: .mrc or .rec
61 | - Number of classes: multiple classes of macromolecules also can be localized separately
62 | - `Output`
63 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter.
64 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma.
65 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page
66 |
67 |
68 | ### **Step 2. Training of DeepETPicker**
69 |
70 |
71 |
72 | Note: Before `Training of DeepETPicker`, please do `Step 1. Preprocessing` first to ensure that the basic parameters required for training are provided.
73 |
74 |
75 |

76 |
77 |
78 | In practice, default parameters can give you a good enough result.
79 |
80 | *Training parameter description:*
81 |
82 | - Dataset name: e.g. SHREC_2021_train
83 | - Model name: the name of segmentation model.
84 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}.
85 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids:
86 | - 0,2, ...: different tomogram ids are separated with a comma.
87 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids.
88 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset.
89 | - Number of classes: particle classes you want to pick
90 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error.
91 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72.
92 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12.
93 | - Learning rate: the step size at each iteration while moving toward a minimum of a loss function.
94 | - Max epoch: total training epochs. The default value 60 is usually sufficient.
95 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
96 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again.
97 |
98 | ### **Step 3. Inference of DeepETPicker**
99 |
100 |
101 |

102 |
103 |
104 | In practice, default parameters can give you a good enough result.
105 |
106 | *Inference parameter description:*
107 |
108 | - Train Configs: path to the configuration file which has been saved in the `Training` step
109 | - Networks weights: path to the model which has be generated in the `Training` step
110 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where `S = N - 2*Pad_size`.
111 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi.
112 |
113 |
114 |
115 |
116 | ### ** Step 4. Particle visualization and mantual picking**
117 |
118 |
119 |

120 |
121 |
122 | - *Showing Tomogram*
123 |
124 | You can click the `Load tomogram` button on this page to load the tomogram.
125 |
126 | - *Showing Labels*
127 |
128 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI.
129 |
130 | - *Parameter Adjustment*
131 |
132 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided:
133 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects
134 | - Contrast: when choosing hist-equ, histogram equalization can be performed
135 |
136 |
137 | - *Position Slider*
138 |
139 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used.
140 |
141 |
142 |
--------------------------------------------------------------------------------
/tutorials/images/EMPIAR_10045_GIF/Inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Inference.gif
--------------------------------------------------------------------------------
/tutorials/images/EMPIAR_10045_GIF/Preprocessing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Preprocessing.gif
--------------------------------------------------------------------------------
/tutorials/images/EMPIAR_10045_GIF/Training.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Training.gif
--------------------------------------------------------------------------------
/tutorials/images/EMPIAR_10045_GIF/Visualization.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Visualization.gif
--------------------------------------------------------------------------------
/tutorials/images/Inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Inference.gif
--------------------------------------------------------------------------------
/tutorials/images/Preprocessing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Preprocessing.gif
--------------------------------------------------------------------------------
/tutorials/images/Training.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Training.gif
--------------------------------------------------------------------------------
/tutorials/images/Visualization.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Visualization.gif
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.colors import *
2 | from utils.coord_gen import *
3 | from utils.coords2labels import *
4 | from utils.loss import *
5 | from utils.metrics import *
6 | from utils.misc import *
7 | from utils.normalization import *
8 | from utils.coordconv_torch import *
--------------------------------------------------------------------------------
/utils/colors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import cv2
4 |
5 | COLORS = (
6 | (255, 255, 255), # 0
7 | (244, 67, 54), # 1
8 | (233, 30, 99), # 2
9 | (156, 39, 176), # 3
10 | (103, 58, 183), # 4
11 | (63, 81, 181), # 5
12 | (33, 150, 243), # 6
13 | (3, 169, 244), # 7
14 | (0, 188, 212), # 8
15 | (0, 150, 136), # 9
16 | (76, 175, 80), # 10
17 | (139, 195, 74), # 11
18 | (205, 220, 57), # 12
19 | (255, 235, 59), # 13
20 | (255, 193, 7), # 14
21 | (255, 152, 0), # 15
22 | (255, 87, 34),
23 | (121, 85, 72),
24 | (158, 158, 158),
25 | (96, 125, 139))
26 |
27 |
28 | def plot_legend(COLORS):
29 | plt.figure(figsize=(6, 0.8), dpi=100)
30 | for idx in np.arange(1, 13):
31 | plt.subplot(1, 12, idx)
32 | data = np.array(COLORS[idx] * 30000).reshape(100, 300, 3)
33 | plt.imshow(data)
34 | plt.axis('off')
35 | plt.title('%d' % idx, fontdict={'size': 15, 'weight': 'bold'})
36 | plt.tight_layout()
37 | plt.savefig('temp.png')
38 |
39 | img = cv2.imread('temp.png')
40 | img_sum = np.sum(np.array(img), axis=2)
41 | img[img_sum == 765] = [0, 0, 0]
42 | img[img_sum == 0] = [255, 255, 255]
43 | cv2.imwrite('temp.png', img)
--------------------------------------------------------------------------------
/utils/coordFormatConvert.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def coords2star(data, save_path):
4 | """Relion star"""
5 | string = """
6 | data_
7 |
8 | loop_
9 | _rlnCoordinateX #1
10 | _rlnCoordinateY #2
11 | _rlnCoordinateZ #3"""
12 |
13 | data = np.round(data, 1).astype(str)
14 |
15 | with open(save_path, 'w') as f:
16 | f.writelines(string + '\n')
17 | for item in data:
18 | line = ' '.join(item)
19 | f.write(line + '\n')
20 |
21 |
22 | def coords2box(data, save_path):
23 | """EMAN2 box"""
24 | with open(save_path, 'w') as f:
25 | for item in data:
26 | line = f"{item[0]:.1f}\t{item[1]:.1f}\t{item[2]:.0f}"
27 | f.write(line + '\n')
28 |
29 | def coords2coords(data, save_path):
30 | """EMAN2 box"""
31 | with open(save_path, 'w') as f:
32 | for item in data:
33 | line = f"{item[0]:.0f}\t{item[1]:.0f}\t{item[2]:.0f}"
34 | f.write(line + '\n')
35 |
--------------------------------------------------------------------------------
/utils/coord_gen.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import os
3 | import glob
4 | import numpy as np
5 | import pandas as pd
6 | import sys
7 |
8 | warnings.simplefilter('ignore')
9 | base_dir = "/ldap_shared/synology_shared/IBP_ribosome/liver_NewConstruct/Pick/reconstruction_2400"
10 |
11 | def coords_gen(coord_path, coord_format, base_dir):
12 | os.makedirs(os.path.join(base_dir, 'coords'), exist_ok=True)
13 |
14 | coords_list = [i.split('/')[-1] for i in glob.glob(coord_path + f'/*{coord_format}')]
15 | coords_list = sorted(coords_list)
16 |
17 | num_all = []
18 | dir_names = []
19 | for dir in coords_list:
20 | data = []
21 | with open(os.path.join(coord_path, dir), 'r') as f:
22 | for idx, item in enumerate(f):
23 | data.append(item.rstrip('\n').split())
24 | try:
25 | data = np.array(data).astype(np.float).astype(int)
26 | except:
27 | data = np.array(data).astype(np.float32).astype(int)
28 |
29 | np.savetxt(os.path.join(base_dir, "coords", dir),
30 | data, delimiter='\t', newline="\n", fmt="%s")
31 | num_all.append(data.shape[0])
32 |
33 | dir_name = dir[:-len(coord_format)]
34 | dir_names.append(dir_name)
35 |
36 | str_ = "|"
37 | for i in np.arange(0, len(num_all), 1):
38 | tmp = np.array(num_all)[:i+1].sum()
39 | print("0 to %d:" % (i+1), tmp)
40 | str_ += "%d|" % tmp
41 | # print(str_)
42 | # print(list(enumerate(num_all)))
43 |
44 | # gen num_name.csv
45 | num_name = np.array([num_all]).transpose()
46 | try:
47 | df = pd.DataFrame(num_name).astype(np.float)
48 | except:
49 | df = pd.DataFrame(num_name).astype(np.float32)
50 | df['dir_names'] = np.array([dir_names]).transpose()
51 | df['idx'] = np.arange(len(dir_names)).reshape(-1, 1)
52 | df.to_csv(os.path.join(base_dir, "coords", "num_name.csv"), sep='\t', header=False, index=False)
53 |
54 |
55 | def coords_gen_show(args):
56 | coord_path, coord_format, base_dir, stdout = args
57 | if stdout is not None:
58 | save_stdout = sys.stdout
59 | save_stderr = sys.stderr
60 | sys.stdout = stdout
61 | sys.stderr = stdout
62 |
63 | try:
64 | coords_gen(coord_path, coord_format, base_dir)
65 | print('Coord generation finished!')
66 | print('*' * 100)
67 | except:
68 | stdout.flush()
69 | stdout.write('Coordinates Generation Exception!')
70 | print('*' * 100)
71 | return 0
72 |
73 | if stdout is not None:
74 | sys.stderr = save_stderr
75 | sys.stdout = save_stdout
--------------------------------------------------------------------------------
/utils/coordconv_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.modules.conv as conv
4 |
5 |
6 | class AddCoords(nn.Module):
7 | def __init__(self, rank, with_r=False):
8 | super(AddCoords, self).__init__()
9 | self.rank = rank
10 | self.with_r = with_r
11 |
12 | def forward(self, input_tensor):
13 | """
14 | :param input_tensor: shape (N, C_in, H, W)
15 | :return:
16 | """
17 | if self.rank == 1:
18 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape
19 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device)
20 | xx_channel = x_range.expand([batch_size_shape, 1, -1])
21 |
22 | out = torch.cat([input_tensor, xx_channel], dim=1)
23 |
24 | if self.with_r:
25 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
26 | out = torch.cat([out, rr], dim=1)
27 |
28 | elif self.rank == 2:
29 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape
30 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device)
31 | y_range = torch.linspace(-1, 1, dim_y, device=input_tensor.device)
32 | yy_channel, xx_channel = torch.meshgrid(y_range, x_range)
33 | yy_channel = yy_channel.expand([batch_size_shape, 1, -1, -1])
34 | xx_channel = xx_channel.expand([batch_size_shape, 1, -1, -1])
35 |
36 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
40 | out = torch.cat([out, rr], dim=1)
41 |
42 | elif self.rank == 3:
43 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape
44 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device)
45 | y_range = torch.linspace(-1, 1, dim_y, device=input_tensor.device)
46 | z_range = torch.linspace(-1, 1, dim_z, device=input_tensor.device)
47 | zz_channel, yy_channel, xx_channel = torch.meshgrid(z_range, y_range, x_range)
48 | zz_channel = zz_channel.expand([batch_size_shape, 1, -1, -1, -1])
49 | yy_channel = yy_channel.expand([batch_size_shape, 1, -1, -1, -1])
50 | xx_channel = xx_channel.expand([batch_size_shape, 1, -1, -1, -1])
51 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1)
52 |
53 | if self.with_r:
54 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) +
55 | torch.pow(yy_channel - 0.5, 2) +
56 | torch.pow(zz_channel - 0.5, 2))
57 | out = torch.cat([out, rr], dim=1)
58 | else:
59 | raise NotImplementedError
60 |
61 | return out
62 |
63 |
64 | class CoordConv1d(conv.Conv1d):
65 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
66 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
67 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size,
68 | stride, padding, dilation, groups, bias)
69 | self.rank = 1
70 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
71 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels,
72 | kernel_size, stride, padding, dilation, groups, bias)
73 |
74 | def forward(self, input_tensor):
75 | """
76 | input_tensor_shape: (N, C_in,H,W)
77 | output_tensor_shape: N,C_out,H_out,W_out)
78 | :return: CoordConv2d Result
79 | """
80 | out = self.addcoords(input_tensor)
81 | out = self.conv(out)
82 |
83 | return out
84 |
85 |
86 | class CoordConv2d(conv.Conv2d):
87 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
88 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
89 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size,
90 | stride, padding, dilation, groups, bias)
91 | self.rank = 2
92 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
93 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels,
94 | kernel_size, stride, padding, dilation, groups, bias)
95 |
96 | def forward(self, input_tensor):
97 | """
98 | input_tensor_shape: (N, C_in,H,W)
99 | output_tensor_shape: N,C_out,H_out,W_out)
100 | :return: CoordConv2d Result
101 | """
102 | out = self.addcoords(input_tensor)
103 | out = self.conv(out)
104 |
105 | return out
106 |
107 |
108 | class CoordConv3d(conv.Conv3d):
109 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
110 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
111 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size,
112 | stride, padding, dilation, groups, bias)
113 | self.rank = 3
114 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
115 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels,
116 | kernel_size, stride, padding, dilation, groups, bias)
117 |
118 | def forward(self, input_tensor):
119 | """
120 | input_tensor_shape: (N, C_in,H,W)
121 | output_tensor_shape: N,C_out,H_out,W_out)
122 | :return: CoordConv2d Result
123 | """
124 | out = self.addcoords(input_tensor)
125 | out = self.conv(out)
126 |
127 | return out
--------------------------------------------------------------------------------
/utils/coords2labels.py:
--------------------------------------------------------------------------------
1 | import mrcfile
2 | from multiprocessing import Pool
3 | import pandas as pd
4 | import os
5 | import numpy as np
6 | from glob import glob
7 | import sys
8 | import traceback
9 |
10 | def gaussian3D(shape, sigma=1):
11 | l, m, n = [(ss - 1.) / 2. for ss in shape]
12 | z, y, x = np.ogrid[-l:l + 1, -m:m + 1, -n:n + 1]
13 | sigma = (sigma - 1.) / 2.
14 | h = np.exp(-(x * x + y * y + z * z) / (2 * sigma * sigma))
15 | # h[h < np.finfo(float).eps * h.max()] = 0
16 | return h
17 |
18 |
19 | class Coord_to_Label():
20 | def __init__(self, base_path, coord_path, coord_format, tomo_path, tomo_format,
21 | num_cls, label_type, label_diameter):
22 |
23 | self.base_path = base_path
24 | self.coord_path = coord_path
25 | self.coord_format = coord_format
26 | self.tomo_path = tomo_path
27 | self.tomo_format = tomo_format
28 | self.num_cls = num_cls
29 | self.label_type = label_type
30 | if not isinstance(label_diameter, int):
31 | self.label_diameter = [int(i) for i in label_diameter.split(',')]
32 | else:
33 | self.label_diameter = [label_diameter]
34 |
35 | if 'ocp' in self.label_type.lower():
36 | self.label_path = os.path.join(self.base_path, self.label_type)
37 | else:
38 | self.label_path = os.path.join(self.base_path,
39 | self.label_type + str(self.label_diameter[0]))
40 | os.makedirs(self.label_path, exist_ok=True)
41 |
42 | self.dir_list = [i[:-len(self.coord_format)] for i in os.listdir(self.coord_path) if self.coord_format in i]
43 | self.names = [i + self.tomo_format for i in self.dir_list] # if self.tomo_format not in i else i
44 |
45 | def single_handle(self, i):
46 | self.tomo_file = f"{self.tomo_path}/{self.names[i]}"
47 | data_file = mrcfile.open(self.tomo_file, permissive=True)
48 | # print(os.path.join(self.label_path, self.names[i]))
49 | label_file = mrcfile.new(os.path.join(self.label_path, self.names[i]),
50 | overwrite=True)
51 |
52 | label_positions = pd.read_csv(os.path.join(self.base_path, 'coords', '%s.coords' % self.dir_list[i]), sep='\t',
53 | header=None).to_numpy()
54 |
55 | # template = np.fromfunction(lambda i, j, k: (i - r) * (i - r) + (j - r) * (j - r) + (k - r) * (k - r) <= r * r,
56 | # (2 * r + 1, 2 * r + 1, 2 * r + 1), dtype=int).astype(int)
57 |
58 | z_max, y_max, x_max = data_file.data.shape
59 | try:
60 | label_data = np.zeros(data_file.data.shape, dtype=np.float)
61 | except:
62 | label_data = np.zeros(data_file.data.shape, dtype=np.float32)
63 |
64 | for pos_idx, a_pos in enumerate(label_positions):
65 | if self.num_cls == 1 and len(a_pos) == 3:
66 | x, y, z = a_pos
67 | cls_idx_ = 1
68 | else:
69 | cls_idx_, x, y, z = a_pos
70 |
71 | if 'data_ocp' in self.label_type.lower():
72 | dim = int(self.label_diameter[cls_idx_ - 1])
73 | else:
74 | dim = int(self.label_diameter[0])
75 | radius = int(dim / 2)
76 | r = radius
77 |
78 | template = gaussian3D((dim, dim, dim), dim)
79 |
80 | cls_idx = pos_idx+1 if 'data_ocp' in self.label_type else cls_idx_
81 | # print(self.label_type, dim, cls_idx)
82 | z_start = 0 if z - r < 0 else z - r
83 | z_end = z_max if z + r + 1 > z_max else z + r + 1
84 | y_start = 0 if y - r < 0 else y - r
85 | y_end = y_max if y + r + 1 > y_max else y + r + 1
86 | x_start = 0 if x - r < 0 else x - r
87 | x_end = x_max if x + r + 1 > x_max else x + r + 1
88 |
89 | t_z_start = r - z if z - r < 0 else 0
90 | t_z_end = (r + z_max - z) if z + r + 1 > z_max else 2 * r + 1
91 | t_y_start = r - y if y - r < 0 else 0
92 | t_y_end = (r + y_max - y) if y + r + 1 > y_max else 2 * r + 1
93 | t_x_start = r - x if x - r < 0 else 0
94 | t_x_end = (r + x_max - x) if x + r + 1 > x_max else 2 * r + 1
95 |
96 | # print(z_start, z_end, y_start, y_end, x_start, x_end)
97 | # check border
98 | # print(label_data.shape)
99 | # print(z_start, z_end, y_start, y_end, x_start, x_end)
100 | tmp1 = label_data[z_start:z_end, y_start:y_end, x_start:x_end]
101 | tmp2 = template[t_z_start:t_z_end, t_y_start:t_y_end, t_x_start:t_x_end]
102 |
103 | larger_index = tmp1 < tmp2
104 | tmp1[larger_index] = tmp2[larger_index]
105 |
106 | if 'cubic' in self.label_type.lower():
107 | tg = 0.223 # exp(-1.5)
108 | elif 'sphere' in self.label_type.lower() or 'ocp' in self.label_type.lower():
109 | tg = 0.60653 # exp(-0.5)
110 | else:
111 | tg = 0.367879 # exp(-1)
112 | tmp1[tmp1 <= tg] = 0
113 | tmp1 = np.where(tmp1 > 0, cls_idx, 0)
114 | label_data[z_start:z_end, y_start:y_end, x_start:x_end] = tmp1
115 |
116 | label_file.set_data(label_data)
117 |
118 | data_file.close()
119 | label_file.close()
120 | # print('work %s done' % i)
121 | # return 'work %s done' % i
122 |
123 | def gen_labels(self):
124 | if len(self.dir_list) == 1:
125 | self.single_handle(0)
126 | else:
127 | with Pool(len(self.dir_list)) as p:
128 | p.map(self.single_handle, np.arange(len(self.dir_list)).tolist())
129 |
130 |
131 | def label_gen_show(args):
132 | base_path, coord_path, coord_format, tomo_path, tomo_format, \
133 | num_cls, label_type, label_diameter, stdout = args
134 | if stdout is not None:
135 | save_stdout = sys.stdout
136 | save_stderr = sys.stderr
137 | sys.stdout = stdout
138 | sys.stderr = stdout
139 |
140 | try:
141 | label_gen = Coord_to_Label(base_path,
142 | coord_path,
143 | coord_format,
144 | tomo_path,
145 | tomo_format,
146 | num_cls,
147 | label_type,
148 | label_diameter)
149 | label_gen.gen_labels()
150 | if 'ocp' not in label_type:
151 | print('Label generation finished!')
152 | print('*' * 100)
153 | else:
154 | print('Occupancy generation finished!')
155 | print('*' * 100)
156 | except Exception as ex:
157 | term = 'Occupancy' if 'ocp' in label_type else 'Label'
158 | if stdout is not None:
159 | stdout.flush()
160 | stdout.write(f"{ex}")
161 | stdout.write(f'{term} Generation Exception!')
162 | print('*' * 100)
163 | else:
164 | traceback.print_exc()
165 | #print(f"{ex}")
166 | print(f'{term} Generation Exception!')
167 | print('*' * 100)
168 | return 0
169 | if stdout is not None:
170 | sys.stderr = save_stderr
171 | sys.stdout = save_stdout
172 |
173 |
174 | class Coord_to_Label_v1():
175 | def __init__(self, tomo_file, coord_file, num_cls, label_diameter, label_type):
176 |
177 | self.tomo_file = tomo_file
178 | self.coord_file = coord_file
179 | self.num_cls = num_cls
180 | self.label_type = label_type
181 | self.label_diameter = label_diameter
182 |
183 | def gen_labels(self):
184 | if '.coords' in self.coord_file or '.txt' in self.coord_file:
185 | data_file = mrcfile.open(self.tomo_file, permissive=True)
186 |
187 | label_positions = pd.read_csv(self.coord_file, sep='\t', header=None).to_numpy()
188 | if self.label_type == 'Coords':
189 | return label_positions
190 |
191 | dim = int(self.label_diameter)
192 | radius = int(dim / 2)
193 | r = radius
194 |
195 | template = gaussian3D((dim, dim, dim), dim)
196 |
197 | z_max, y_max, x_max = data_file.data.shape
198 | try:
199 | label_data = np.zeros(data_file.data.shape, dtype=np.float)
200 | except:
201 | label_data = np.zeros(data_file.data.shape, dtype=np.float32)
202 |
203 | for pos_idx, a_pos in enumerate(label_positions):
204 | if self.num_cls == 1 and len(a_pos) == 3:
205 | x, y, z = a_pos
206 | cls_idx = 1
207 | else:
208 | cls_idx, x, y, z = a_pos
209 | cls_idx = pos_idx+1 if 'ocp' in self.label_type else cls_idx
210 | z_start = 0 if z - r < 0 else z - r
211 | z_end = z_max if z + r + 1 > z_max else z + r + 1
212 | y_start = 0 if y - r < 0 else y - r
213 | y_end = y_max if y + r + 1 > y_max else y + r + 1
214 | x_start = 0 if x - r < 0 else x - r
215 | x_end = x_max if x + r + 1 > x_max else x + r + 1
216 |
217 | t_z_start = r - z if z - r < 0 else 0
218 | t_z_end = (r + z_max - z) if z + r + 1 > z_max else 2 * r + 1
219 | t_y_start = r - y if y - r < 0 else 0
220 | t_y_end = (r + y_max - y) if y + r + 1 > y_max else 2 * r + 1
221 | t_x_start = r - x if x - r < 0 else 0
222 | t_x_end = (r + x_max - x) if x + r + 1 > x_max else 2 * r + 1
223 |
224 | # print(z_start, z_end, y_start, y_end, x_start, x_end)
225 | # check border
226 | tmp1 = label_data[z_start:z_end, y_start:y_end, x_start:x_end]
227 | tmp2 = template[t_z_start:t_z_end, t_y_start:t_y_end, t_x_start:t_x_end]
228 |
229 | larger_index = tmp1 < tmp2
230 | tmp1[larger_index] = tmp2[larger_index]
231 | tmp1[tmp1 <= 0.36788] = 0
232 |
233 | tmp1 = np.where(tmp1 > 0, cls_idx, 0)
234 |
235 | label_data[z_start:z_end, y_start:y_end, x_start:x_end] = tmp1
236 |
237 | data_file.close()
238 | return label_data
239 | elif '.mrc' in self.tomo_file or '.rec' in self.tomo_file:
240 | label_data = mrcfile.open(self.coord_file, permissive=True)
241 | return label_data.data
242 |
243 | if __name__ == "__main__":
244 | import sys
245 | sys.path.append("..")
246 | from configs.c2l_10045_New_bin8_mask3 import pre_config
247 | from utils.coord_gen import coords_gen
248 | from utils.normalization import InputNorm
249 |
250 | # 初始化坐标文件为整数
251 | coords_gen(pre_config["coord_path"],
252 | pre_config["base_path"])
253 |
254 | # 归一化
255 | pre_norm = InputNorm(pre_config["tomo_path"],
256 | pre_config["tomo_format"],
257 | pre_config["base_path"],
258 | pre_config["norm_type"])
259 | pre_norm.handle_parallel()
260 |
261 | # 根据coords产生labels
262 | c2l = Coord_to_Label(base_path=pre_config["base_path"],
263 | coord_path=pre_config["coord_path"],
264 | coord_format=pre_config["coord_format"],
265 | tomo_path=pre_config["tomo_path"],
266 | tomo_format=pre_config["tomo_format"],
267 | num_cls=pre_config["num_cls"],
268 | label_type=pre_config["label_type"],
269 | label_diameter=pre_config["label_diameter"],
270 | )
271 | c2l.gen_labels()
272 |
273 | # 根据coords产生ocps
274 | if pre_config["label_diameter"] !=pre_config["ocp_diameter"]:
275 | c2l = Coord_to_Label(base_path=pre_config["base_path"],
276 | coord_path=pre_config["coord_path"],
277 | coord_format=pre_config["coord_format"],
278 | tomo_path=pre_config["tomo_path"],
279 | tomo_format=pre_config["tomo_format"],
280 | num_cls=pre_config["num_cls"],
281 | label_type=pre_config["ocp_type"],
282 | label_diameter=pre_config["ocp_diameter"],
283 | )
284 | c2l.gen_labels()
285 |
--------------------------------------------------------------------------------
/utils/coords_to_relion4.py:
--------------------------------------------------------------------------------
1 | from io import StringIO
2 | import pandas as pd
3 | import numpy as np
4 | from pathlib import Path
5 |
6 |
7 | def c2w(input_dir="coord_path/Coords_All",
8 | ouput_dir="./output",
9 | ouput_name="coords2relion4.star"):
10 | input_dir = Path(input_dir)
11 | ouput_dir = Path(ouput_dir)
12 | ouput_dir.mkdir(exist_ok=True, parents=True)
13 | ouput_path = ouput_dir / ouput_name
14 | coords_paths = sorted(list(input_dir.glob("*.txt")) + list(input_dir.glob("*.coords")))
15 |
16 | dfs = []
17 | for coords_path in coords_paths:
18 | coords_data = np.loadtxt(coords_path)
19 | XYZ = coords_data[:, -3:]
20 |
21 | TomoName = np.array([str(coords_path).split('/')[-1].split('.')[0]] * coords_data.shape[0], dtype=str).reshape(-1, 1)
22 | TomoParticleId = np.arange(1, coords_data.shape[0]+1).reshape(-1, 1)
23 | originXYZang = np.zeros_like(XYZ)
24 | angle = np.zeros_like(XYZ)
25 | if coords_data.shape[1] == 4:
26 | ClassNumber = coords_data[:, 0].reshape(-1, 1)
27 | else:
28 | ClassNumber = np.ones_like(TomoParticleId)
29 | randomsubset = np.array([1, 2] * (coords_data.shape[0]+1 // 2)).reshape(-1, 1)[:coords_data.shape[0]]
30 | df = pd.DataFrame(np.c_[TomoName, TomoParticleId, XYZ.astype(np.int32), originXYZang, angle, ClassNumber.astype(np.int32), randomsubset])
31 | dfs.append(df)
32 | dfs = pd.concat(dfs, axis=0)
33 |
34 | with StringIO() as buffer:
35 | dfs.to_csv(buffer, sep="\t", index=False, header=None)
36 | lines = buffer.getvalue()
37 |
38 | with open(ouput_path, "w") as ofile:
39 | ofile.write("relion4" + "\n")
40 | ofile.write("data_particles" + "\n")
41 | ofile.write("loop_" + "\n")
42 | ofile.write("_rlnTomoName #1" + "\n")
43 | ofile.write("_rlnTomoParticleId #2" + "\n")
44 | ofile.write("_rlnCoordinateX #3" + "\n")
45 | ofile.write("_rlnCoordinateY #4" + "\n")
46 | ofile.write("_rlnCoordinateZ #5" + "\n")
47 | ofile.write("_rlnOriginXAngst #6" + "\n")
48 | ofile.write("_rlnOriginYAngst #7" + "\n")
49 | ofile.write("_rlnOriginZAngst #8" + "\n")
50 | ofile.write("_rlnAngleRot #9" + "\n")
51 | ofile.write("_rlnAngleTilt #10" + "\n")
52 | ofile.write("_rlnAnglePsi #11" + "\n")
53 | ofile.write("_rlnClassNumber #12" + "\n")
54 | ofile.write("_rlnRandomSubset #13" +"\n")
55 | ofile.write(lines)
56 | print(f"Save: {ouput_path}")
57 |
58 |
59 | if __name__ == "__main__":
60 | c2w()
61 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 |
3 |
4 | def flatten(tensor):
5 | """Flattens a given tensor such that the channel axis is first.
6 | The shapes are transformed as follows:
7 | (N, C, D, H, W) -> (C, N * D * H * W)
8 | """
9 | # number of channels
10 | C = tensor.size(1)
11 | # new axis order
12 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
13 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
14 | transposed = tensor.permute(axis_order)
15 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
16 | return transposed.contiguous().view(C, -1)
17 |
18 |
19 | # Dice loss
20 | class DiceLoss(nn.Module):
21 | def __init__(self, smooth=1, args=None):
22 | super(DiceLoss, self).__init__()
23 | self.smooth = smooth
24 | self.use_softmax = args.use_softmax
25 | self.use_sigmoid = args.use_sigmoid
26 |
27 |
28 | def forward(self, outputs, targets):
29 | # flatten label and prediction tensors
30 | outputs = flatten(outputs)
31 | targets = flatten(targets)
32 |
33 | intersection = (outputs * targets).sum(-1)
34 | dice = (2. * intersection + self.smooth) / (outputs.sum(-1) + targets.sum(-1) + self.smooth)
35 | return 1 - dice.mean()
36 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def flatten(tensor):
6 | """Flattens a given tensor such that the channel axis is first.
7 | The shapes are transformed as follows:
8 | (N, C, D, H, W) -> (C, N * D * H * W)
9 | """
10 | # number of channels
11 | C = tensor.size(1)
12 | # new axis order
13 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
14 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
15 | transposed = tensor.permute(axis_order)
16 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
17 | return transposed.contiguous().view(C, -1)
18 |
19 |
20 | # segmentation metrics
21 | def seg_metrics(y_pred, y_true, smooth=1e-7, isTrain=True, threshold=0.5, use_sigmoid=False):
22 | #comment out if your model contains a sigmoid or equivalent activation layer
23 | if use_sigmoid:
24 | y_pred = F.sigmoid(y_pred)
25 | y_pred = torch.where(y_pred < threshold, torch.zeros(1).cuda(), torch.ones(1).cuda())
26 |
27 | #flatten label and prediction tensors
28 | y_pred = flatten(y_pred)
29 | y_true = flatten(y_true)
30 |
31 | tp = (y_true * y_pred).sum(-1)
32 | fp = ((1 - y_true) * y_pred).sum(-1)
33 | fn = (y_true * (1 - y_pred)).sum(-1)
34 |
35 | precision = (tp + smooth) / (tp + fp + smooth)
36 | recall = (tp + smooth) / (tp + fn + smooth)
37 | iou = (tp + smooth) / (tp + fn + fp + smooth)
38 | f1 = 2 * (precision*recall) / (precision + recall + smooth)
39 |
40 | mean_precision = precision.mean()
41 | mean_recall = recall.mean()
42 | mean_iou = iou.mean()
43 | mean_f1 = f1.mean()
44 |
45 | # for training, ouput mean metrics
46 | if isTrain:
47 | return mean_precision.item(), mean_recall.item(), mean_iou.item(), mean_f1.item()
48 | # for testing, output metrics array by class with threshold
49 | else:
50 | return precision.detach().cpu().numpy(), recall.detach().cpu().numpy(), iou.detach().cpu().numpy(), f1.detach().cpu().numpy()
51 |
52 |
53 | def seg_metrics_2d(y_pred, y_true, smooth=1e-7, isTrain=True, threshold=0.5, use_sigmoid=False):
54 | # comment out if your model contains a sigmoid or equivalent activation layer
55 | if use_sigmoid:
56 | y_pred = F.sigmoid(y_pred)
57 | y_pred = torch.where(y_pred < threshold, torch.zeros(1).cuda(), torch.ones(1).cuda())
58 |
59 | # flatten label and prediction tensors
60 | y_pred = y_pred.flatten()
61 | y_true = y_true.flatten()
62 |
63 | tp = (y_true * y_pred).sum(-1)
64 | fp = ((1 - y_true) * y_pred).sum(-1)
65 | fn = (y_true * (1 - y_pred)).sum(-1)
66 |
67 | precision = (tp + smooth) / (tp + fp + smooth)
68 | recall = (tp + smooth) / (tp + fn + smooth)
69 | iou = (tp + smooth) / (tp + fn + fp + smooth)
70 | f1 = 2 * (precision * recall) / (precision + recall + smooth)
71 |
72 | # for training, ouput mean metrics
73 | if isTrain:
74 | return precision.item(), recall.item(), iou.item(), f1.item()
75 | # for testing, output metrics array by class with threshold
76 | else:
77 | return precision.detach().cpu().numpy(), recall.detach().cpu().numpy(), iou.detach().cpu().numpy(), f1.detach().cpu().numpy()
78 |
--------------------------------------------------------------------------------
/utils/normalization.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import mrcfile
3 | import numpy as np
4 | import warnings
5 | import os
6 | import glob
7 | import sys
8 |
9 | warnings.simplefilter('ignore')
10 | class InputNorm():
11 | def __init__(self, tomo_path, tomo_format, base_dir, norm_type):
12 | self.tomo_path = tomo_path
13 | self.tomo_format = tomo_format
14 | self.base_dir = base_dir
15 | self.norm_type = norm_type
16 |
17 |
18 | if self.norm_type == 'standardization':
19 | self.save_dir = os.path.join(self.base_dir, 'data_std')
20 | elif self.norm_type == 'normalization':
21 | self.save_dir = os.path.join(self.base_dir, 'data_norm')
22 | os.makedirs(self.save_dir, exist_ok=True)
23 |
24 | self.dir_list = [i.split('/')[-1] for i in glob.glob(self.tomo_path + '/*%s' % self.tomo_format)]
25 | print(self.dir_list)
26 |
27 | def single_handle(self, i):
28 | dir_name = self.dir_list[i]
29 | with mrcfile.open(os.path.join(self.tomo_path, dir_name),
30 | permissive=True) as gm:
31 | try:
32 | data = np.array(gm.data).astype(np.float)
33 | except:
34 | data = np.array(gm.data).astype(np.float32)
35 | # print(data.shape)
36 | if self.norm_type == 'standardization':
37 | data -= data.mean()
38 | data /= data.std()
39 | elif self.norm_type == 'normalization':
40 | data -= data.min()
41 | data /= (data.max() - data.min())
42 |
43 | reconstruction_norm = mrcfile.new(
44 | os.path.join(self.save_dir, dir_name), overwrite=True)
45 | try:
46 | reconstruction_norm.set_data(data.astype(np.float32))
47 | except:
48 | reconstruction_norm.set_data(data.astype(np.float))
49 |
50 | reconstruction_norm.close()
51 | print('%d/%d finished.' % (i + 1, len(self.dir_list)))
52 |
53 | def handle_parallel(self):
54 | with Pool(len(self.dir_list)) as p:
55 | p.map(self.single_handle, np.arange(len(self.dir_list)).tolist())
56 |
57 |
58 | def norm_show(args):
59 | tomo_path, tomo_format, base_dir, norm_type, stdout = args
60 | if stdout is not None:
61 | save_stdout = sys.stdout
62 | save_stderr = sys.stderr
63 | sys.stdout = stdout
64 | sys.stderr = stdout
65 |
66 | pre_norm = InputNorm(tomo_path, tomo_format, base_dir, norm_type)
67 | pre_norm.handle_parallel()
68 | print('Standardization finished!')
69 | print('*' * 100)
70 | """
71 | try:
72 | pre_norm = InputNorm(tomo_path, tomo_format, base_dir, norm_type)
73 | pre_norm.handle_parallel()
74 | print('Standardization finished!')
75 | print('*' * 100)
76 | except:
77 | stdout.flush()
78 | stdout.write('Normalization Exception!')
79 | return 0
80 | """
81 | if stdout is not None:
82 | sys.stderr = save_stderr
83 | sys.stdout = save_stdout
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import cv2
4 | import math
5 | import matplotlib.pyplot as plt
6 | from PyQt5.QtCore import QThread, QObject, pyqtSignal
7 | import threading, inspect
8 | import ctypes
9 |
10 | def rescale(arr):
11 | arr_min = arr.min()
12 | arr_max = arr.max()
13 | return (arr - arr_min) / (arr_max - arr_min)
14 |
15 |
16 | def hist_equ(tomo, min_val=0.01, max_val=0.99):
17 | fre_num, xx, _ = plt.hist(tomo.flatten(), bins=256, cumulative=True)
18 | fre_num /= fre_num[-1]
19 | for idx, x in enumerate(fre_num):
20 | if x > min_val:
21 | min_idx = int(idx)
22 | break
23 |
24 | for idx, x in enumerate(fre_num[::-1]):
25 | if x < max_val:
26 | max_idx = int(len(fre_num) - idx)
27 | break
28 | tomo = np.clip(tomo, xx[min_idx], xx[max_idx])
29 | # tomo = np.clip(tomo, 52, 150)
30 | # print(xx[min_idx], xx[max_idx])
31 | return tomo
32 |
33 |
34 | def gauss_filter(kernel_size=3, sigma=1):
35 | max_idx = kernel_size // 2
36 | idx = np.linspace(-max_idx, max_idx, kernel_size)
37 | Y, X = np.meshgrid(idx, idx)
38 | gauss_filter = np.exp(-(X ** 2 + Y ** 2) / (2 * sigma ** 2))
39 | gauss_filter /= np.sum(np.sum(gauss_filter))
40 | return gauss_filter
41 |
42 |
43 |
44 | def stretch(tomo):
45 | tomo = (tomo - np.min(tomo)) / (np.max(tomo) - np.min(tomo)) * 255
46 | return np.array(tomo).astype(np.uint8)
47 |
48 |
49 | class myThread(threading.Thread):
50 | def __init__(self, thread_id, func, args, emit_str):
51 | threading.Thread.__init__(self)
52 | self.threadID = thread_id
53 | self.func = func
54 | self.args = args
55 | self.emit_str = emit_str
56 | self.n = 1
57 |
58 | def run(self):
59 | while self.n:
60 | self.pid_num = self.func(self.args, self.emit_str)
61 | self.n -= 1
62 |
63 | def get_n(self):
64 | return self.n
65 |
66 |
67 | def make_video(tomo, save_path, fps, size):
68 | if 'mp4' in save_path:
69 | fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
70 | elif 'avi' in save_path:
71 | fourcc = cv2.VideoWriter_fourcc(*'MJPG')
72 |
73 | videowriter = cv2.VideoWriter(save_path,
74 | fourcc,
75 | fps,
76 | size)
77 | for i in range(tomo.shape[0]):
78 | img = tomo[i]
79 | videowriter.write(img)
80 |
81 |
82 | class Concur(threading.Thread):
83 | """
84 | 停止thread
85 | """
86 | def __init__(self, job, args, stdout):
87 | super(Concur, self).__init__()
88 | self.__flag = threading.Event()
89 | self.__flag.set()
90 | self.__running = threading.Event()
91 | self.__running.set()
92 | self.job = job
93 | self.args = args
94 | self.stdout = stdout
95 |
96 | def run(self):
97 | while self.__running.isSet():
98 | self.__flag.wait()
99 | try:
100 | self.job(self.args, self.stdout)
101 | self.pause()
102 | except Exception as e:
103 | self.stop()
104 | print(e)
105 |
106 | def pause(self):
107 | self.__flag.clear()
108 |
109 | def resume(self):
110 | self.__flag.set()
111 |
112 | def stop(self):
113 | self.__flag.set()
114 | self.__running.clear()
115 |
116 |
117 | def _async_raise(tid, exctype):
118 | """
119 | stop thread
120 | """
121 | """raises the exception, performs cleanup if needed"""
122 | tid = ctypes.c_long(tid)
123 | if not inspect.isclass(exctype):
124 | exctype = type(exctype)
125 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
126 | if res == 0:
127 | raise ValueError("invalid thread id")
128 | elif res != 1:
129 | # """if it returns a number greater than one, you're in trouble,
130 | # and you should call it again with exc=NULL to revert the effect"""
131 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
132 | raise SystemError("PyThreadState_SetAsyncExc failed")
133 |
134 |
135 | def stop_thread(thread):
136 | _async_raise(thread.ident, SystemExit)
137 |
138 |
139 | class EmittingStr(QObject):
140 | textWritten = pyqtSignal(str)
141 |
142 | def __init__(self):
143 | super(EmittingStr, self).__init__()
144 |
145 | def write(self, text):
146 | try:
147 | if len(str(text)) >= 2:
148 | self.textWritten.emit(str(text))
149 | except:
150 | pass
151 |
152 | def flush(self, text=None):
153 | pass
154 |
155 |
156 | class ThreadShowInfo(QThread):
157 | def __init__(self, func, args):
158 | super(ThreadShowInfo, self).__init__()
159 | self.func = func
160 | self.args = args
161 |
162 | def run(self):
163 | self.func(self.args)
164 |
165 |
166 | def add_transparency(img, label, factor, color, thresh):
167 | img = np.array(255.0 * rescale(img), dtype=np.uint8)
168 |
169 | alpha_channel = np.ones(img.shape, dtype=img.dtype) * 255
170 |
171 | img_BGRA = cv2.merge((img, img, img, alpha_channel))
172 |
173 | img = img[:, :, np.newaxis]
174 | img1 = img.repeat([3], axis=2)
175 | img1[label.astype(int) == 1] = (255, 0, 0)
176 | img1[label.astype(int) == 2] = (0, 255, 0)
177 | img1[label.astype(int) == 3] = (0, 0, 255)
178 | img1[label.astype(int) == 4] = (0, 255, 255)
179 | # img1[label > factor] = color
180 | c_b, c_g, c_r = cv2.split(img1)
181 | mask = np.where(label > factor, 1, 0)
182 | img1_alpha = np.array(mask * 255 * factor, dtype=np.uint8)
183 | img1_alpha[img1_alpha == 0] = 255
184 | img1_BGRA = cv2.merge((c_b, c_g, c_r, img1_alpha))
185 |
186 | out = cv2.addWeighted(img_BGRA, 1 - factor, img1_BGRA, factor, 0)
187 |
188 | return np.array(out)
189 |
190 |
191 | def annotate_particle(img, coords, diameter, zz, idx, circle_width, color):
192 | """Plot circle centered at the particle coordinates directly on the tomogram."""
193 | img = 255.0 * rescale(img)
194 | img = img[:, :, np.newaxis]
195 | rgb_uint8 = img.repeat([3], axis=2).astype(np.uint8)
196 |
197 | if idx == 0:
198 | columns = ['x', 'y', 'z']
199 | elif idx == 1:
200 | columns = ['z', 'y', 'x']
201 | else:
202 | columns = ['x', 'z', 'y']
203 |
204 | df = pd.DataFrame(data=coords,
205 | columns=columns)
206 | color = np.array(color)
207 | print(color.shape)
208 | df["R"] = color[:, 0]
209 | df["G"] = color[:, 1]
210 | df["B"] = color[:, 2]
211 |
212 | r = diameter / 2
213 | coords_xy = df[df['z'] >= (zz - r)]
214 | coords_xy = coords_xy[df['z'] <= (zz + r)]
215 | rr2 = r ** 2 - (zz - coords_xy['z']) ** 2
216 | rr = [math.sqrt(i) for i in rr2]
217 | for x, y, rad, r, g, b in zip(coords_xy['x'], coords_xy['y'], rr, coords_xy['R'], coords_xy['G'], coords_xy['B']):
218 | cv2.circle(rgb_uint8, (int(y), int(x)), int(rad), (r, g, b), circle_width)
219 | return rgb_uint8
--------------------------------------------------------------------------------