├── .gitignore ├── LICENSE ├── MainWindow_3D_PP.py ├── README.md ├── bin ├── bash_command.md ├── generate_train_config.py ├── mk_images │ └── 20240520163736.png ├── preprocess.py ├── test_bash.py └── train_bash.py ├── configs ├── EMPIAR_10045_preprocess.py └── SHREC_2021_preprocess.py ├── dataset └── dataloader_DynamicLoad.py ├── main.py ├── model ├── __init__.py ├── model_loader.py └── residual_unet_att.py ├── options └── option.py ├── requirement.txt ├── test.py ├── train.py ├── tutorials ├── A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md ├── A_tutorial_of_particlePicking_on_SHREC2021_dataset.md └── images │ ├── EMPIAR_10045_GIF │ ├── Inference.gif │ ├── Preprocessing.gif │ ├── Training.gif │ └── Visualization.gif │ ├── Inference.gif │ ├── Preprocessing.gif │ ├── Training.gif │ └── Visualization.gif └── utils ├── __init__.py ├── colors.py ├── coordFormatConvert.py ├── coord_gen.py ├── coordconv_torch.py ├── coords2labels.py ├── coords_to_relion4.py ├── loss.py ├── metrics.py ├── misc.py ├── normalization.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.git 3 | *.idea 4 | bin/*.sh 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **DeepETpicker** 2 | 3 | A deep learning based open-source software with a friendly user interface to pick 3D particles rapidly and accurately from cryo-electron tomograms. With the advantages of weak labels, lightweight architecture and GPU-accelerated pooling operations, the cost of annotations and the time of computational inference are significantly reduced while the accuracy is greatly improved by applying a Gaussian-type mask and using a customized architecture design. 4 | 5 | [Guole Liu, Tongxin Niu, Mengxuan Qiu, Yun Zhu, Fei Sun, and Ge Yang, “DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning,” Nature Communications, vol. 15, no. 1, pp. 2090, 2024.](https://www.nature.com/articles/s41467-024-46041-0) 6 | 7 | 8 | **Note**: DeepETPicker is a Pytorch implementation. 9 | 10 | ## **Setup** 11 | 12 | ### **Prerequisites** 13 | 14 | - Linux (Ubuntu 18.04.5 LTS; CentOS) 15 | - NVIDIA GPU 16 | 17 | ### **Installation** 18 | 19 | #### **Option 1: Using conda** 20 | 21 | The following instructions assume that `pip` and `anaconda` or `miniconda` are available. In case you have a old deepetpicker environment installed, first remove the old one with: 22 | 23 | ```bash 24 | conda env remove --name deepetpicker 25 | ``` 26 | 27 | The first step is to crate a new conda virtual environment: 28 | 29 | ```bash 30 | conda create -n deepetpicker -c conda-forge python=3.8.3 -y 31 | ``` 32 | 33 | Activate the environment: 34 | 35 | ```bash 36 | conda activate deepetpicker 37 | ``` 38 | 39 | To download the codes, please do: 40 | ``` 41 | git clone https://github.com/cbmi-group/DeepETPicker 42 | cd DeepETPicker 43 | ``` 44 | 45 | Next, install a custom pytorch and relative packages needed by DeepETPicker: 46 | 47 | ```bash 48 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch -y 49 | 50 | pip install -r requirement.txt 51 | ``` 52 | 53 | - **if using GUI** 54 | 55 | To use GUI packages with Linux, you will need to install the following extended dependencies for Qt. 56 | 1. For `CentOS`, to install packages, please do: 57 | ```bash 58 | sudo yum install -y mesa-libGL libXext libSM libXrender fontconfig xcb-util-wm xcb-util-image xcb-util-keysyms xcb-util-renderutil libxkbcommon-x11 59 | ``` 60 | 61 | 2. For `Ubuntu`, to install packages, please do: 62 | ```bash 63 | sudo apt-get install -y libgl1-mesa-glx libglib2.0-dev libsm6 libxrender1 libfontconfig1 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-render-util0 libxcb-shape0 libxcb-xinerama0 libxcb-xkb1 libxkbcommon-x11-dev libdbus-1-3 64 | ``` 65 | 66 | 67 | To run the DeepETpicker, please do: 68 | 69 | ```bash 70 | conda activate deepetpicker 71 | python PATH_TO_DEEPETPICKER/main.py 72 | ``` 73 | 74 | Note: `PATH_TO_DEEPETPICKER` is the corresponding directory where the code located. 75 | - **Non GUI version** 76 | 77 | In addition to the `GUI version` of DeepETPicker, we also provide a `non-GUI version` of DeepETPicker for people who understand python and deep-learning. It consists of four processes, including `preprocessing`, `train config generation`, `training` and `testing`. A sample tutorial can be found in `.bin/bash_command.md`. 78 | 79 | #### **Option 2:Using docker** 80 | 81 | The following steps are required in order to run DeepETPicker: 82 | 1. Install [Docker](https://www.docker.com/) 83 | 84 | Note: docker engine version shuold be >= 19.03. The size of Docker mirror of Deepetpicker is 7.21 GB, please ensure that there is enough memory space. 85 | 86 | 2. Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) for GPU support. 87 | 88 | 3. Download Docker image of DeepETPicker. 89 | 90 | ```bash 91 | docker pull docker.io/lgl603/deepetpicker:latest 92 | ``` 93 | 94 | 4. Run the Docker image of DeepETPicker. 95 | 96 | ```bash 97 | docker run --gpus=all -itd \ 98 | --restart=always \ 99 | --shm-size=100G \ 100 | -e DISPLAY=unix$DISPLAY \ 101 | --name deepetpicker \ 102 | -p 50022:22 \ 103 | --mount type=bind,source='/host_path/to/data',target='/container_path/to/data' \ 104 | lgl603/deepetpicker:latest 105 | ``` 106 | 107 | - The option `--shm-size` is used to set the required size of shared momory of the docker containers. 108 | - The option `--mount` is used to mount a file or directory on the host machine into the Docker container, where `source='/host_path/to/data'` denotes denotes the data directory really existed in host machine. `target='/container_path/to/data'` is the data directory where the directory `'/host_path/to/data'` is mounted in the container. 109 | 110 | **Note: `'/host_path/to/data'` should be writable by running bash command `chmod -R 777 '/host_path/to/data'`. `'/host_path/to/data'` should be replaced by the data directory real existed in host machine. For convenience, `'/container_path/to/data'` can set the same as `'/host_path/to/data'`** 111 | image 112 | 113 | 114 | 1. The DeepETPicker can be used directly in this machine, and it also can be used by a machine in the same LAN. 115 | - Directly open DeepETPicker in this machine: 116 | ```bash 117 | ssh -X test@'ip_address' DeepETPicker 118 | # where the 'ip_address' of DeepETPicker container can be obtained as follows: 119 | docker inspect --format='{{.NetworkSettings.IPAddress}}' deepetpicker 120 | ``` 121 | image 122 | 123 | 124 | - Connect to this server remotely and open DeepETPicker software via a client machine: 125 | ```bash 126 | ssh -X -p 50022 test@ip DeepETPicker 127 | ``` 128 | Here `ip` is the IP address of the server machine,password is `password`. 129 | 130 | `Installation time`: the size of Docker mirror of Deepetpicker is 7.21 GB, and the installation time depends on your network speed. When the network speed is fast enough, it can be configured within a few minutes. 131 | 132 | ## **Particle picking tutorial** 133 | 134 | Detailed tutorials for two sample datasets of [SHREC2021](https://github.com/cbmi-group/DeepETPicker/blob/main/tutorials/A_tutorial_of_particlePicking_on_SHREC2021_dataset.md) and [EMPIAR-10045](https://github.com/cbmi-group/DeepETPicker/blob/main/tutorials/A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md) are provided. Main steps of DeepETPicker includeds preprocessing, traning of DeepETPicker, inference of DeepETPicker, and particle visualization. 135 | 136 | ### **Preprocessing** 137 | - Data preparation 138 | 139 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain: 140 | - tomograms(with extension .mrc or .rec) 141 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand.). 142 | 143 |
144 | 145 | Here, we provides two sample datasets of EMPIAR-10045 and SHREC_2021 for particle picking to enable you to learn the processing flow of DeepETPicker better and faster. The sample dataset can be download in one of two ways: 146 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi 147 | - Zeenodo Link: [https://zenodo.org/records/12512436](https://zenodo.org/records/12512436) 148 | 149 | - Data structure 150 | 151 | The data should be organized as follows: 152 | ``` 153 | ├── /base/path 154 | │   ├── raw_data 155 | │   │   ├── tomo1.coords 156 | │   │   └── tomo1.mrc 157 | │   │   └── tomo2.coords 158 | │   │   └── tomo2.mrc 159 | │   │   └── tomo3.mrc 160 | │   │   └── tomo4.mrc 161 | │   │   └── ... 162 | ``` 163 | 164 | For above data, `tomo1.mrc` and `tomo2.mrc` can be used as train/val dataset, since they all have coordinate files (matual annotation). If a tomogram has no matual annotation (such as `tomo3.mrc`), it only can be used as test dataset. 165 | 166 |
167 | 168 | 169 | - Input & Output 170 | 171 |
172 | 173 |
174 | 175 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows: 176 | - `input` 177 | - Dataset name: e.g. SHREC_2021_preprocess 178 | - Base path: path to base folder 179 | - Coords path: path to raw_data folder 180 | - Coords format: .csv, .coords or .txt 181 | - Tomogram path: path to raw_data folder 182 | - Tomogram format: .mrc or .rec 183 | - Number of classes: multiple classes of macromolecules also can be localized separately 184 | - `Output` 185 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter. 186 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma. 187 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page 188 | 189 | 190 | ### **Training of DeepETPicker** 191 | 192 |
193 | 194 | Note: Before `Training of DeepETPicker`, please do `Preprocessing` first to ensure that the basic parameters required for training are provided. 195 | 196 |
197 | 198 |
199 | 200 | In practice, default parameters can give you a good enough result. 201 | 202 | *Training parameter description:* 203 | 204 | - Dataset name: e.g. SHREC_2021_train 205 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}. 206 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids: 207 | - 0, 2, ...: different tomogram ids are separated with a comma. 208 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids. 209 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset. 210 | - Number of classes: particle classes you want to pick 211 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error. 212 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72. 213 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12. 214 | - Learning rate:  the step size at each iteration while moving toward a minimum of a loss function. 215 | - Max epoch: total training epochs. The default value 60 is usually sufficient. 216 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 217 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again. 218 | 219 | ### **Inference of DeepETPicker** 220 | 221 |
222 | 223 |
224 | 225 | In practice, default parameters can give you a good enough result. 226 | 227 | *Inference parameter description:* 228 | 229 | - Train Configs: path to the configuration file which has been saved in the `Training` step 230 | - Networks weights: path to the model which has be generated in the `Training` step 231 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where S = N - 2*Pad_size. 232 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 233 | 234 | *Coord format conversion* 235 | 236 | The predicted coordinates with extension `*.coords` has four columns: `class_id, x, y, z`. To facilitate users to perform the subsequent subtomogram averaging, format conversion of coordinate file is provided. 237 | 238 | - Coords path: the path of coordinates data predicted by well-trained DeepETPicker. 239 | - Output format: three formats can be converted, including `*.box` for EMAN2, `*.star` for RELION, `*.coords` for RELION. 240 | 241 | 242 | Note: After converting the coordinates to `*.coords`, one can get the coordinates with relion4 format using the script in `https://github.com/cbmi-group/DeepETPicker/blob/main/utils/coords_to_relion4.py` directly. 243 | 244 | 245 | ### **Particle visualization and mantual picking** 246 | 247 |
248 | 249 |
250 | 251 | - *Showing Tomogram* 252 | 253 | You can click the `Load tomogram` button on this page to load the tomogram. 254 | 255 | - *Showing Labels* 256 | 257 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI. 258 | 259 | - *Parameter Adjustment* 260 | 261 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided: 262 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects 263 | - Contrast: when choosing hist-equ, histogram equalization can be performed 264 | 265 | 266 | - *Mantual picking* 267 | 268 | After loading the tomogram and pressing ‘enable’, you can pick particles manually by double-click the left mouse button on the slices. If you want to delete an error labeling, just right-click the mouse. You can specify a different category id per class. Always remember to save the resuls when you finish. 269 | 270 | - *Position Slider* 271 | 272 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used. 273 | 274 | # Troubleshooting 275 | 276 | If you encounter any problems during installation or use of DeepETPicker, please contact us by email [guole.liu@ia.ac.cn](guole.liu@ia.ac.cn). We will help you as soon as possible. 277 | 278 | # Citation 279 | 280 | If you use this code for your research, please cite our paper [DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning](https://www.nature.com/articles/s41467-024-46041-0). 281 | 282 | ``` 283 | @article{DeepETPicker, 284 | title={DeepETPicker: Fast and accurate 3D particle picking for cryo-electron tomography using weakly supervised deep learning}, 285 | author={Guole Liu, Tongxin Niu, Mengxuan Qiu, Yun Zhu, Fei Sun, and Ge Yang}, 286 | journal={Nature Communications}, 287 | year={2024} 288 | } 289 | ``` 290 | -------------------------------------------------------------------------------- /bin/bash_command.md: -------------------------------------------------------------------------------- 1 | 2 | In addition to the GUI version of DeepETPicker, we also provide a non-GUI version of DeepETPicker for people who understand python and deep-learning. It consists of four processes, including `preprocessing`, `train config generation`, `training` and `testing`. 3 | 4 | Note: If you are not familiar with python and deep learning, the GUI version is recommended. 5 | 6 | Firstly, enter the bash command directory by 7 | ```bash 8 | cd PATH_TO_DEEPETPICKER/bin 9 | ``` 10 | 11 | where `PATH_TO_DEEPETPICKER` is the corresponding directory where the code located. 12 | 13 | ## Preprocessing 14 | 15 | ```bash 16 | python preprocess.py \ 17 | --pre_configs 'PATH_TO_Preprocess_CONFIG' 18 | ``` 19 | 20 | where `pre_configs` corresponds to the configuration file of preprocessing. We have provided a sample for EMPIAR-10045 dataset in `cofigs/EMPIAR_10045_preprocess.py`. The items of configuration file is the same as the generated file of `preprocessing` panel of GUI version. More details can be found in section `Preprocessing` of https://github.com/cbmi-group/DeepETPicker/tree/main. 21 | 22 | Note: `pre_configs` could also directly load the configuration file generated by `preprocessing` panel of GUI DeepETPicker. 23 | 24 | ## Generate configuration file for training and testing 25 | 26 | ```bash 27 | python generate_train_config.py \ 28 | --pre_configs 'PATH_TO_Preprocess_CONFIG' \ 29 | --dset_name 'Train_Config_Name' \ 30 | --cfg_save_path 'Save_Path_for_Config_Name' \ 31 | --train_set_ids '0' \ 32 | --val_set_ids '0' \ 33 | --batch_size 8 \ 34 | --block_size 72 \ 35 | --pad_size 12 \ 36 | --learning_rate 0.001 \ 37 | --max_epoch 60 \ 38 | --threshold 0.5 \ 39 | --gpu_id 0 1 40 | ``` 41 | 42 | where `dset_name` and `cfg_save_path` are the name and save path of training configuration file, respectively. Other parameters are the same as the input of `Training` panel of GUI version. More details can be found in section `Training of DeepETPicker` of https://github.com/cbmi-group/DeepETPicker/tree/main. 43 | 44 | ## Training 45 | 46 | ```bash 47 | python train_bash.py \ 48 | --train_configs 'Train_Config_Name' 49 | ``` 50 | 51 | where `train_configs` is the train configuration file generated by `generate_train_config.py`. 52 | 53 | Note: `train_configs` could also directly load the configuration file generated by `Training` panel of GUI DeepETPicker. 54 | 55 | The checkpoints are saved in the sub-folder `runs` of `base_path`. 56 | 57 | ## Testing 58 | 59 | ```bash 60 | python test_bash.py \ 61 | --train_configs 'Train_Config_Name' \ 62 | --checkpoints 'Checkpoint_Path' 63 | ``` 64 | 65 | The predicted results are saved in the sub-folder `runs` of `base_path`. 66 | 67 | ![](mk_images/20240520163736.png) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bin/generate_train_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | from os.path import dirname, abspath 5 | import importlib 6 | 7 | DeepETPickerHome = dirname(abspath(__file__)) 8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0] 9 | sys.path.append(DeepETPickerHome) 10 | sys.path.append(os.path.split(DeepETPickerHome)[0]) 11 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1]) 12 | 13 | if __name__ == "__main__": 14 | options = option.BaseOptions() 15 | args = options.gather_options() 16 | 17 | with open(args.pre_configs, 'r') as f: 18 | train_config = json.loads(''.join(f.readlines()).lstrip('pre_config=')) 19 | train_config['dset_name'] = args.dset_name 20 | train_config['coord_path'] = train_config['base_path'] + '/coords' 21 | train_config['tomo_path'] = train_config['base_path'] + '/data_std' 22 | train_config['label_name'] = train_config['label_type'] + f"{train_config['label_diameter']}" 23 | train_config['label_path'] = train_config['base_path'] + f"/{train_config['label_name']}" 24 | train_config['ocp_name'] = 'data_ocp' 25 | train_config['ocp_path'] = train_config['base_path'] + '/data_ocp' 26 | train_config['model_name'] = 'ResUNet' 27 | train_config['train_set_ids'] = args.train_set_ids 28 | train_config['val_set_ids'] = args.val_set_ids 29 | train_config['batch_size'] = args.batch_size 30 | train_config['patch_size'] = args.block_size 31 | train_config['padding_size'] = args.pad_size[0] 32 | train_config['lr'] = args.learning_rate 33 | train_config['max_epochs'] = args.max_epoch 34 | train_config['seg_thresh'] = args.threshold 35 | train_config['gpu_ids'] = ','.join([str(i) for i in args.gpu_id]) 36 | print(train_config['gpu_ids']) 37 | 38 | 39 | with open(f"{args.cfg_save_path}/{args.dset_name}.py", 'w') as f: 40 | f.write("train_configs=") 41 | json.dump(train_config, f, separators=(',\n' + ' ' * len('train_configs={'), ': ')) 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /bin/mk_images/20240520163736.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/bin/mk_images/20240520163736.png -------------------------------------------------------------------------------- /bin/preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os.path import dirname, abspath 4 | import importlib 5 | import json 6 | 7 | DeepETPickerHome = dirname(abspath(__file__)) 8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0] 9 | sys.path.append(DeepETPickerHome) 10 | sys.path.append(os.path.split(DeepETPickerHome)[0]) 11 | coords2labels = importlib.import_module(".utils.coords2labels", package=os.path.split(DeepETPickerHome)[1]) 12 | coord_gen = importlib.import_module(f".utils.coord_gen", package=os.path.split(DeepETPickerHome)[1]) 13 | norm = importlib.import_module(f".utils.normalization", package=os.path.split(DeepETPickerHome)[1]) 14 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1]) 15 | 16 | if __name__ == "__main__": 17 | options = option.BaseOptions() 18 | args = options.gather_options() 19 | 20 | with open(args.pre_configs, 'r') as f: 21 | pre_config = json.loads(''.join(f.readlines()).lstrip('pre_config=')) 22 | 23 | # initial coords 24 | coord_gen.coords_gen_show(args=(pre_config["coord_path"], 25 | pre_config["coord_format"], 26 | pre_config["base_path"], 27 | None, 28 | ) 29 | ) 30 | 31 | # normalization 32 | norm.norm_show(args=(pre_config["tomo_path"], 33 | pre_config["tomo_format"], 34 | pre_config["base_path"], 35 | pre_config["norm_type"], 36 | None, 37 | ) 38 | ) 39 | 40 | # generate labels based on coords 41 | coords2labels.label_gen_show(args=(pre_config["base_path"], 42 | pre_config["coord_path"], 43 | pre_config["coord_format"], 44 | pre_config["tomo_path"], 45 | pre_config["tomo_format"], 46 | pre_config["num_cls"], 47 | pre_config["label_type"], 48 | pre_config["label_diameter"], 49 | None, 50 | ) 51 | ) 52 | 53 | # generate occupancy abased on coords 54 | coords2labels.label_gen_show(args=(pre_config["base_path"], 55 | pre_config["coord_path"], 56 | pre_config["coord_format"], 57 | pre_config["tomo_path"], 58 | pre_config["tomo_format"], 59 | pre_config["num_cls"], 60 | 'data_ocp', 61 | pre_config["ocp_diameter"], 62 | None, 63 | ) 64 | ) 65 | -------------------------------------------------------------------------------- /bin/test_bash.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | import importlib 5 | from os.path import dirname, abspath 6 | import numpy as np 7 | DeepETPickerHome = dirname(abspath(__file__)) 8 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0] 9 | sys.path.append(DeepETPickerHome) 10 | sys.path.append(os.path.split(DeepETPickerHome)[0]) 11 | test = importlib.import_module(".test", package=os.path.split(DeepETPickerHome)[1]) 12 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1]) 13 | 14 | if __name__ == '__main__': 15 | options = option.BaseOptions() 16 | args = options.gather_options() 17 | 18 | # cofig 19 | with open(args.train_configs, 'r') as f: 20 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 21 | 22 | # parameters 23 | args.use_bg = True 24 | args.use_IP = True 25 | args.use_coord = True 26 | args.test_use_pad = True 27 | args.use_seg = True 28 | args.meanPool_NMS = True 29 | args.f_maps = [24, 48, 72, 108] 30 | args.num_classes = cfg['num_cls'] 31 | train_cls_num = cfg['num_cls'] 32 | if args.num_classes == 1: 33 | args.use_sigmoid = True 34 | args.use_softmax = False 35 | else: 36 | train_cls_num = train_cls_num + 1 37 | args.use_sigmoid = False 38 | args.use_softmax = True 39 | args.batch_size = cfg['batch_size'] 40 | args.block_size = cfg['patch_size'] 41 | args.val_batch_size = args.batch_size 42 | args.val_block_size = args.block_size 43 | args.pad_size = [cfg['padding_size']] 44 | args.learning_rate = cfg['lr'] 45 | args.max_epoch = cfg['max_epochs'] 46 | args.threshold = cfg['seg_thresh'] 47 | args.gpu_id = [int(i) for i in cfg['gpu_ids'].split(',')] 48 | args.test_mode = 'test_only' 49 | args.out_name = 'PredictedLabels' 50 | args.de_duplication = True 51 | args.de_dup_fmt = 'fmt4' 52 | args.mini_dist = sorted([int(i) // 2 + 1 for i in cfg['ocp_diameter'].split(',')])[0] 53 | args.data_split = [0, 1, 0, 1, 0, 1] 54 | args.configs = args.train_configs 55 | args.num_classes = train_cls_num 56 | 57 | # test_idxs 58 | dset_list = np.array( 59 | [i[:-(len(i.split('.')[-1]) + 1)] for i in os.listdir(cfg['tomo_path']) if cfg['tomo_format'] in i]) 60 | dset_num = dset_list.shape[0] 61 | num_name = np.concatenate([np.arange(dset_num).reshape(-1, 1), dset_list.reshape(-1, 1)], axis=1) 62 | np.savetxt(os.path.join(cfg['tomo_path'], 'num_name.csv'), 63 | num_name, 64 | delimiter='\t', 65 | fmt='%s', 66 | newline='\n') 67 | 68 | # tomo_list = [i for i in os.listdir(cfg[f"{cfg['base_path']}/data_std"]) if cfg['tomo_format'] in i] 69 | tomo_list = np.loadtxt(f"{cfg['base_path']}/data_std/num_name.csv", 70 | delimiter='\t', 71 | dtype=str) 72 | args.test_idxs = np.arange(len(tomo_list)) 73 | 74 | for k, v in sorted(vars(args).items()): 75 | print(k, '=', v) 76 | 77 | # Testing 78 | test.test_func(args, stdout=None) -------------------------------------------------------------------------------- /bin/train_bash.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | from os.path import dirname, abspath 5 | import importlib 6 | import numpy as np 7 | 8 | DeepETPickerHome = dirname(abspath(__file__)) 9 | DeepETPickerHome = os.path.split(DeepETPickerHome)[0] 10 | sys.path.append(DeepETPickerHome) 11 | sys.path.append(os.path.split(DeepETPickerHome)[0]) 12 | train = importlib.import_module(".train", package=os.path.split(DeepETPickerHome)[1]) 13 | option = importlib.import_module(f".options.option", package=os.path.split(DeepETPickerHome)[1]) 14 | 15 | 16 | if __name__ == '__main__': 17 | options = option.BaseOptions() 18 | args = options.gather_options() 19 | 20 | # cofig 21 | with open(args.train_configs, 'r') as f: 22 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 23 | 24 | # parameters 25 | args.use_bg = True 26 | args.use_IP = True 27 | args.use_coord = True 28 | args.test_use_pad = True 29 | args.meanPool_NMS = True 30 | args.f_maps = [24, 48, 72, 108] 31 | args.num_classes = cfg['num_cls'] 32 | train_cls_num = cfg['num_cls'] 33 | if args.num_classes == 1: 34 | args.use_sigmoid = True 35 | args.use_softmax = False 36 | else: 37 | train_cls_num = train_cls_num + 1 38 | args.use_sigmoid = False 39 | args.use_softmax = True 40 | args.batch_size = cfg['batch_size'] 41 | args.block_size = cfg['patch_size'] 42 | args.val_batch_size = args.batch_size 43 | args.val_block_size = args.block_size 44 | args.pad_size = [cfg['padding_size']] 45 | args.learning_rate = cfg['lr'] 46 | args.max_epoch = cfg['max_epochs'] 47 | args.threshold = cfg['seg_thresh'] 48 | args.gpu_id = [int(i) for i in cfg['gpu_ids'].split(',')] 49 | args.configs = args.train_configs 50 | args.test_mode = 'val' 51 | args.train_set_ids = cfg['train_set_ids'] 52 | args.val_set_ids = cfg['val_set_ids'] 53 | args.num_classes = train_cls_num 54 | 55 | train_list = [] 56 | for item in args.train_set_ids.split(','): 57 | if '-' in item: 58 | tmp = [int(i) for i in item.split('-')] 59 | train_list.extend(np.arange(tmp[0], tmp[1] + 1).tolist()) 60 | else: 61 | train_list.append(int(item)) 62 | 63 | val_list = [] 64 | for item in args.val_set_ids.split(','): 65 | if '-' in item: 66 | tmp = [int(i) for i in item.split('-')] 67 | val_list.extend(np.arange(tmp[0], tmp[1] + 1).tolist()) 68 | else: 69 | val_list.append(int(item)) 70 | val_first = len(train_list) if val_list[0] not in train_list else len(train_list) - 1 71 | args.data_split = [0, len(train_list), # train 72 | val_first, val_first + 1, # val 73 | val_first, val_first + 1] # test_val 74 | 75 | for k, v in sorted(vars(args).items()): 76 | print(k, '=', v) 77 | 78 | # Training 79 | train.train_func(args, stdout=None) -------------------------------------------------------------------------------- /configs/EMPIAR_10045_preprocess.py: -------------------------------------------------------------------------------- 1 | pre_config={"dset_name": "EMPIAR_10045_preprocess", 2 | "base_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045", 3 | "coord_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045/raw_data", 4 | "coord_format": ".coords", 5 | "tomo_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/EMPIAR_10045/raw_data", 6 | "tomo_format": ".mrc", 7 | "num_cls": 1, 8 | "label_type": "gaussian", 9 | "label_diameter": 11, 10 | "ocp_type": "sphere", 11 | "ocp_diameter": "23", 12 | "norm_type": "standardization"} -------------------------------------------------------------------------------- /configs/SHREC_2021_preprocess.py: -------------------------------------------------------------------------------- 1 | pre_config={"dset_name": "SHREC_2021_preprocess", 2 | "base_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021", 3 | "coord_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021/raw_data", 4 | "coord_format": ".coords", 5 | "tomo_path": "/mnt/data1/ET/DeepETPicker_test/SampleDatasets/SHREC_2021/raw_data", 6 | "tomo_format": ".mrc", 7 | "num_cls": 14, 8 | "label_type": "gaussian", 9 | "label_diameter": 9, 10 | "ocp_type": "sphere", 11 | "ocp_diameter": "7,7,7,7,7,9,9,9,11,11,13,13,13,17", 12 | "norm_type": "standardization"} -------------------------------------------------------------------------------- /dataset/dataloader_DynamicLoad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import numpy as np 4 | import mrcfile 5 | import pandas as pd 6 | import torch 7 | import warnings 8 | from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | class Dataset_ClsBased(data.Dataset): 13 | def __init__(self, 14 | mode='train', 15 | block_size=72, 16 | num_class=1, 17 | random_num=0, 18 | use_bg=True, 19 | data_split=[7, 7, 7], 20 | test_use_pad=False, 21 | pad_size=18, 22 | use_paf=False, 23 | cfg=None, 24 | args=None): 25 | 26 | self.args = args 27 | self.mode = mode 28 | # use_CL CL = cnotrastive learning 29 | self.use_CL = args.use_CL 30 | if args.use_CL: 31 | self.radius = 0 32 | else: 33 | self.radius = block_size // 5 34 | self.use_bg = use_bg 35 | self.use_paf = use_paf 36 | self.use_CL_DA = args.use_CL_DA 37 | self.use_bg_part = args.use_bg_part 38 | self.use_ice_part = args.use_ice_part 39 | self.Sel_Referance = args.Sel_Referance 40 | 41 | pad_size = pad_size[0] if isinstance(pad_size, list) else pad_size 42 | base_dir = cfg['base_path'] 43 | label_name = cfg['label_name'], 44 | coord_format = cfg['coord_format'] 45 | tomo_format = cfg['tomo_format'], 46 | norm_type = cfg['norm_type'] 47 | 48 | base_dir = base_dir[0] if isinstance(base_dir, tuple) else base_dir 49 | label_name = label_name[0] if isinstance(label_name, tuple) else label_name 50 | coord_format = coord_format[0] if isinstance(coord_format, tuple) else coord_format 51 | tomo_format = tomo_format[0] if isinstance(tomo_format, tuple) else tomo_format 52 | norm_type = norm_type[0] if isinstance(norm_type, tuple) else norm_type 53 | 54 | if 'label_path' not in cfg.keys(): 55 | label_path = os.path.join(base_dir, label_name) 56 | else: 57 | label_path = cfg['label_path'] 58 | label_path = label_path[0] if isinstance(label_path, tuple) else label_path 59 | 60 | if 'coord_path' not in cfg.keys(): 61 | coord_path = os.path.join(base_dir, 'coords') 62 | else: 63 | coord_path = cfg['coord_path'] 64 | coord_path = coord_path[0] if isinstance(coord_path, tuple) else coord_path 65 | 66 | if 'tomo_path' not in cfg.keys(): 67 | if norm_type == 'standardization': 68 | tomo_path = base_dir + '/data_std' 69 | elif norm_type == 'normalization': 70 | tomo_path = base_dir + '/data_norm' 71 | else: 72 | tomo_path = cfg['tomo_path'] 73 | tomo_path = tomo_path[0] if isinstance(tomo_path, tuple) else tomo_path 74 | 75 | ocp_name = cfg["ocp_name"] 76 | 77 | if 'ocp_path' not in cfg.keys(): 78 | ocp_path = os.path.join(base_dir, ocp_name) 79 | else: 80 | ocp_path = cfg['ocp_path'] 81 | ocp_path = ocp_path[0] if isinstance(ocp_path, tuple) else ocp_path 82 | 83 | print('*' * 100) 84 | print('num_name:', os.path.join(coord_path, 'num_name.csv')) 85 | print(f'base_path:{base_dir}') 86 | print(f"coord_path:{coord_path}") 87 | print(f"tomo_path:{tomo_path}") 88 | print(f"label_path:{label_path}") 89 | print(f"ocp_path:{ocp_path}") 90 | if self.use_paf: 91 | print(f"paf_path:{cfg['paf_path']}") 92 | print(f"label_name:{label_name}") 93 | print(f"coord_format:{coord_format}") 94 | print(f"tomo_format:{tomo_format}") 95 | print(f"norm_type:{norm_type}") 96 | print(f"ocp_name:{ocp_name}") 97 | print('*' * 100) 98 | 99 | if self.mode == 'test_only': 100 | num_name = pd.read_csv(os.path.join(tomo_path, 'num_name.csv'), sep='\t', header=None) 101 | else: 102 | num_name = pd.read_csv(os.path.join(coord_path, 'num_name.csv'), sep='\t', header=None) 103 | 104 | dir_names = num_name.iloc[:, 1].to_numpy().tolist() 105 | print(num_name) 106 | # print(dir_names) 107 | 108 | if self.mode == 'train': 109 | self.data_range = np.arange(data_split[0], data_split[1]) 110 | elif self.mode == 'val': 111 | self.data_range = np.arange(data_split[2], data_split[3]) 112 | else: # test or test_val or val_v1 113 | self.data_range = np.arange(data_split[4], data_split[5]) 114 | print(f"data_range:{self.data_range}") 115 | 116 | # print(f'data_range:{self.data_range}') 117 | self.shift = block_size // 2 # bigger than self.radius to cover full particle 118 | self.num_class = num_class 119 | 120 | self.ground_truth_volume = [] 121 | self.class_mask = [] 122 | self.location = [] 123 | self.origin = [] 124 | self.label = [] 125 | 126 | # inital Data Augmentation 127 | if self.use_bg and self.mode == 'train': 128 | patch_size = [block_size] * 3 129 | self.st = SpatialTransform_2( 130 | patch_size, [i // 2 for i in patch_size], 131 | do_elastic_deform=True, deformation_scale=(0, 0.05), 132 | do_rotation=True, 133 | angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), 134 | angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), 135 | angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), 136 | do_scale=True, scale=(0.95, 1.05), 137 | border_mode_data='constant', border_cval_data=0, 138 | border_mode_seg='constant', border_cval_seg=0, 139 | order_seg=0, order_data=3, 140 | random_crop=True, 141 | p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1 142 | ) 143 | 144 | self.mt = MirrorTransform(axes=(0, 1, 2)) 145 | # to avoid mrcfile warnings 146 | warnings.simplefilter('ignore') 147 | 148 | if self.mode == 'train' or self.mode == 'test_val' or self.mode == 'val': 149 | self.position = [ 150 | pd.read_csv(os.path.join(coord_path, dir_names[i] + coord_format), 151 | sep='\t', header=None).to_numpy() for i in self.data_range] 152 | 153 | # load Tomo 154 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' \ 155 | or self.mode == 'test_only': 156 | self.tomo_shape = [] 157 | for idx in self.data_range: 158 | if not args.input_cat: 159 | with mrcfile.open(os.path.join(tomo_path, dir_names[idx] + tomo_format), 160 | permissive=True) as gm: 161 | shape = gm.data.shape 162 | self.data_shape = gm.data.shape 163 | shape_pad = [i + 2 * pad_size for i in shape] 164 | try: 165 | temp = np.zeros(shape_pad).astype(np.float) 166 | except: 167 | temp = np.zeros(shape_pad).astype(np.float32) 168 | temp[pad_size:shape_pad[0] - pad_size, 169 | pad_size:shape_pad[1] - pad_size, 170 | pad_size:shape_pad[2] - pad_size] = gm.data 171 | self.origin.append(temp) 172 | self.tomo_shape.append(temp.shape) 173 | else: 174 | for idx, p_suffix in enumerate(args.input_cat_items): 175 | p_suffix = p_suffix.rstrip(',') 176 | p_suffix = '' if p_suffix == 'None' else p_suffix 177 | with mrcfile.open( 178 | os.path.join(tomo_path + p_suffix, dir_names[self.data_range[0]] + tomo_format), permissive=True) as tmp: 179 | if idx == 0: 180 | gm = np.array(tmp.data)[None, ...] 181 | else: 182 | gm = np.concatenate([gm, np.array(tmp.data)[None, ...]], axis=0) 183 | 184 | shape = gm.shape 185 | self.data_shape = gm.shape 186 | shape_pad = [shape[0]] 187 | shape_pad.extend([i + 2 * pad_size for i in shape[1:]]) 188 | try: 189 | temp = np.zeros(shape_pad).astype(np.float) 190 | except: 191 | temp = np.zeros(shape_pad).astype(np.float32) 192 | temp[:, pad_size:shape_pad[1] - pad_size, 193 | pad_size:shape_pad[2] - pad_size, 194 | pad_size:shape_pad[3] - pad_size] = gm 195 | self.origin.append(temp) 196 | self.tomo_shape.append(temp.shape) 197 | 198 | else: 199 | if not args.input_cat: 200 | print([os.path.join(tomo_path, dir_names[i] + tomo_format) for i in self.data_range]) 201 | self.origin = [mrcfile.open(os.path.join(tomo_path, dir_names[i] + tomo_format)) for i 202 | in self.data_range] 203 | else: 204 | for idx, p_suffix in enumerate(args.input_cat_items): 205 | p_suffix = p_suffix.rstrip(',') 206 | p_suffix = '' if p_suffix == 'None' else p_suffix 207 | with mrcfile.open(os.path.join(tomo_path + p_suffix, dir_names[self.data_range[0]] + tomo_format), permissive=True) as tmp: 208 | if idx == 0: 209 | self.origin = np.array(tmp.data)[None, ...] 210 | else: 211 | self.origin = np.concatenate([self.origin, np.array(tmp.data)[None, ...]], axis=0) 212 | 213 | # load Labels 214 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1': 215 | for idx in self.data_range: 216 | if os.path.exists(os.path.join(label_path, dir_names[idx] + tomo_format)): 217 | with mrcfile.open(os.path.join(label_path, dir_names[idx] + tomo_format), 218 | permissive=True) as cm: 219 | shape = cm.data.shape 220 | shape_pad = [i + 2 * pad_size if i > pad_size else i for i in shape] 221 | try: 222 | temp = np.zeros(shape_pad).astype(np.float) 223 | except: 224 | temp = np.zeros(shape_pad).astype(np.float32) 225 | 226 | if len(shape) == 3: 227 | temp[pad_size:shape_pad[-3] - pad_size, 228 | pad_size:shape_pad[-2] - pad_size, 229 | pad_size:shape_pad[-1] - pad_size] = cm.data 230 | elif len(shape) == 4: 231 | print(temp.shape) 232 | temp[:, pad_size:shape_pad[-3] - pad_size, 233 | pad_size:shape_pad[-2] - pad_size, 234 | pad_size:shape_pad[-1] - pad_size] = cm.data 235 | self.label.append(temp) 236 | elif self.mode == 'test_val' and args.use_cluster: 237 | self.label = [ 238 | np.zeros_like(self.origin[idx]) for idx, _ in enumerate(self.data_range)] 239 | elif self.mode == 'test_only': 240 | self.label = [ 241 | np.zeros_like(self.origin[idx]) for idx, _ in enumerate(self.data_range)] 242 | else: 243 | self.label = [ 244 | mrcfile.open(os.path.join(label_path, dir_names[idx] + tomo_format)) for idx 245 | in self.data_range] 246 | 247 | # load paf 248 | if self.use_paf: 249 | paf_path = cfg["paf_path"] 250 | paf_path = paf_path[0] if isinstance(paf_path, tuple) else paf_path 251 | 252 | self.paf_label = [] 253 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1': 254 | for idx in self.data_range: 255 | with mrcfile.open(os.path.join(paf_path, dir_names[idx] + tomo_format), 256 | permissive=True) as cm: 257 | shape = cm.data.shape 258 | shape_pad = [i + 2 * pad_size for i in shape] 259 | try: 260 | temp = np.zeros(shape_pad).astype(np.float) 261 | except: 262 | temp = np.zeros(shape_pad).astype(np.float32) 263 | temp[pad_size:shape_pad[0] - pad_size, 264 | pad_size:shape_pad[1] - pad_size, 265 | pad_size:shape_pad[2] - pad_size] = cm.data 266 | self.paf_label.append(temp) 267 | else: 268 | self.paf_label = [ 269 | mrcfile.open(os.path.join(paf_path, dir_names[idx] + tomo_format)) for idx 270 | in self.data_range] 271 | 272 | # Generate BlockData 273 | self.coords = [] 274 | self.data = [] 275 | if self.mode == 'train' or self.mode == 'val': 276 | for i in range(len(self.data_range)): 277 | for j, point1 in enumerate(self.position[i]): 278 | # if sel_train_num > 0 and j >= sel_train_num: 279 | # continue 280 | if args.Sel_Referance: 281 | if j in args.sel_train_num: 282 | self.coords.append([i, point1[-3], point1[-2], point1[-1]]) 283 | else: 284 | if point1[0] == 15: 285 | for _ in range(13): 286 | self.coords.append([i, point1[-3], point1[-2], point1[-1]]) 287 | else: 288 | self.coords.append([i, point1[-3], point1[-2], point1[-1]]) 289 | else: 290 | if test_use_pad: 291 | step_size = block_size - 2 * pad_size 292 | else: 293 | step_size = int(self.shift * 2) 294 | print(self.shift, step_size) 295 | for i in range(len(self.data_range)): 296 | shape = self.origin[i].shape[-3:] 297 | for j in range((shape[0] - 2 * pad_size) // step_size + ( 298 | 1 if (shape[0] - 2 * pad_size) % step_size > 0 else 0)): 299 | for k in range((shape[1] - 2 * pad_size) // step_size + ( 300 | 1 if (shape[1] - 2 * pad_size) % step_size > 0 else 0)): 301 | for l in range((shape[2] - 2 * pad_size) // step_size + ( 302 | 1 if (shape[2] - 2 * pad_size) % step_size > 0 else 0)): 303 | if j == (shape[0] - 2 * pad_size) // step_size + ( 304 | 1 if (shape[0] - 2 * pad_size) % step_size > 0 else 0) - 1: 305 | z = shape[0] - block_size // 2 306 | else: 307 | z = j * step_size + block_size // 2 308 | 309 | if k == (shape[1] - 2 * pad_size) // step_size + ( 310 | 1 if (shape[1] - 2 * pad_size) % step_size > 0 else 0) - 1: 311 | y = shape[1] - block_size // 2 312 | else: 313 | y = k * step_size + block_size // 2 314 | 315 | if l == (shape[2] - 2 * pad_size) // step_size + ( 316 | 1 if (shape[2] - 2 * pad_size) % step_size > 0 else 0) - 1: 317 | x = shape[2] - block_size // 2 318 | else: 319 | x = l * step_size + block_size // 2 320 | self.coords.append([i, x, y, z]) 321 | 322 | if len(self.origin[i].shape) == 4: 323 | img = self.origin[i][:, z - self.shift: z + self.shift, 324 | y - self.shift: y + self.shift, 325 | x - self.shift: x + self.shift] 326 | else: 327 | img = self.origin[i][z - self.shift: z + self.shift, 328 | y - self.shift: y + self.shift, 329 | x - self.shift: x + self.shift] 330 | 331 | if len(self.label[i].shape) == 4: 332 | lab = self.label[i][:, z - self.shift: z + self.shift, 333 | y - self.shift: y + self.shift, 334 | x - self.shift: x + self.shift] 335 | else: 336 | lab = self.label[i][z - self.shift: z + self.shift, 337 | y - self.shift: y + self.shift, 338 | x - self.shift: x + self.shift] 339 | 340 | if self.use_paf: 341 | paf = self.paf_label[i][z - self.shift: z + self.shift, 342 | y - self.shift: y + self.shift, 343 | x - self.shift: x + self.shift] 344 | self.data.append([img, lab, paf, [z, y, x]]) 345 | else: 346 | self.data.append([img, lab, [z, y, x]]) 347 | 348 | # add random samples 349 | if self.mode == 'train' and random_num > 0: 350 | print('random samples num:', random_num) 351 | for j in range(random_num): 352 | i = np.random.randint(len(self.data_range)) 353 | data_shape = self.origin[i].data.shape 354 | z = np.random.randint(self.shift + 1, data_shape[0] - self.shift) 355 | y = np.random.randint(self.shift + 1, data_shape[1] - self.shift) 356 | x = np.random.randint(self.shift + 1, data_shape[2] - self.shift) 357 | self.coords.append([i, x, y, z]) 358 | 359 | if self.mode == 'train': 360 | print("Training dataset contains {} samples".format((len(self.coords)))) 361 | if self.mode == 'val': 362 | print("Validation dataset contains {} samples".format((len(self.coords)))) 363 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' or self.mode == 'test_only': 364 | print("Test dataset contains {} samples".format((len(self.coords)))) 365 | self.test_len = len(self.coords) 366 | 367 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'test_only': 368 | self.dir_name = dir_names[self.data_range[0]] 369 | if self.mode == 'test' or self.mode == 'test_val': 370 | print(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format)) 371 | if os.path.exists(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format)): 372 | with mrcfile.open(os.path.join(ocp_path, dir_names[self.data_range[0]] + tomo_format), permissive=True) as f: 373 | self.occupancy_map = f.data 374 | elif self.mode == 'test_val' and args.use_cluster: 375 | self.occupancy_map = np.zeros_like(self.origin[0]) 376 | self.gt_coords = pd.read_csv(os.path.join(coord_path, "%s.coords" % dir_names[self.data_range[0]]), 377 | sep='\t', header=None).to_numpy() 378 | 379 | if args.use_bg_part and self.Sel_Referance: 380 | self.coords_bg = pd.read_csv(os.path.join(coord_path, dir_names[self.data_range[0]] + '_bg' + coord_format), 381 | sep='\t', header=None).to_numpy()[:len(self.coords)] 382 | if args.use_ice_part and self.Sel_Referance: 383 | self.coords_ice = pd.read_csv(os.path.join(coord_path, dir_names[self.data_range[0]] + '_ice' + coord_format), 384 | sep='\t', header=None).to_numpy()[:len(self.coords)] 385 | 386 | if self.mode == 'test_val': 387 | for i in range(len(self.data_range)): 388 | for j, point1 in enumerate(self.position[i]): 389 | x, y, z = point1[-3] + pad_size, point1[-2] + pad_size, point1[-1] + pad_size 390 | z_max, y_max, x_max = self.origin[i].data.shape 391 | x, y, z = self.__sample(np.array([x, y, z]), 392 | np.array([x_max, y_max, z_max])) 393 | img = self.origin[i][z - self.shift: z + self.shift, 394 | y - self.shift: y + self.shift, 395 | x - self.shift: x + self.shift] 396 | if len(self.label[i].shape) == 4: 397 | lab = self.label[i][:, z - self.shift: z + self.shift, 398 | y - self.shift: y + self.shift, 399 | x - self.shift: x + self.shift] 400 | else: 401 | lab = self.label[i][z - self.shift: z + self.shift, 402 | y - self.shift: y + self.shift, 403 | x - self.shift: x + self.shift] 404 | if self.use_paf: 405 | paf = self.paf_label[i][z - self.shift: z + self.shift, 406 | y - self.shift: y + self.shift, 407 | x - self.shift: x + self.shift] 408 | self.data.append([img, lab, paf, [z, y, x]]) 409 | else: 410 | self.data.append([img, lab, [z, y, x]]) 411 | if self.mode == 'test_val' and args.use_cluster: 412 | self.data = self.data[-len(self.position[0]):] 413 | 414 | def __getitem__(self, index): 415 | if self.mode == 'test' or self.mode == 'test_val' or self.mode == 'val_v1' or self.mode =='test_only': 416 | if self.use_paf: 417 | img, label, paf_label, position = self.data[index] 418 | else: 419 | img, label, position = self.data[index] 420 | 421 | else: 422 | idx, x, y, z = self.coords[index] 423 | z_max, y_max, x_max = self.origin[idx].data.shape 424 | 425 | point = self.__sample(np.array([x, y, z]), 426 | np.array([x_max, y_max, z_max])) 427 | 428 | if self.args.input_cat: 429 | img = self.origin[:, point[2] - self.shift:point[2] + self.shift, 430 | point[1] - self.shift:point[1] + self.shift, 431 | point[0] - self.shift:point[0] + self.shift] 432 | else: 433 | img = self.origin[idx].data[point[2] - self.shift:point[2] + self.shift, 434 | point[1] - self.shift:point[1] + self.shift, 435 | point[0] - self.shift:point[0] + self.shift] 436 | 437 | if len(self.label[idx].data.shape) == 4: 438 | label = self.label[idx].data[:, point[2] - self.shift:point[2] + self.shift, 439 | point[1] - self.shift:point[1] + self.shift, 440 | point[0] - self.shift:point[0] + self.shift] 441 | else: 442 | label = self.label[idx].data[point[2] - self.shift:point[2] + self.shift, 443 | point[1] - self.shift:point[1] + self.shift, 444 | point[0] - self.shift:point[0] + self.shift] 445 | position = [point[2], point[1], point[0]] 446 | if self.use_paf: 447 | paf_label = self.paf_label[idx].data[point[2] - self.shift:point[2] + self.shift, 448 | point[1] - self.shift:point[1] + self.shift, 449 | point[0] - self.shift:point[0] + self.shift] 450 | # print(img.shape, label.shape) 451 | img = np.array(img) 452 | try: 453 | label = np.array(label).astype(np.float) 454 | except: 455 | label = np.array(label).astype(np.float32) 456 | 457 | if self.num_class > 1 and len(label.shape) == 3: 458 | label = multiclass_label(label, 459 | num_classes=self.num_class, 460 | first_idx=1 if self.use_paf else 0) 461 | else: 462 | if self.mode == 'test' and label.shape != (self.shift * 2, self.shift * 2, self.shift * 2): 463 | label = np.zeros((1, self.shift * 2, self.shift * 2, self.shift * 2)) 464 | else: 465 | label = label.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 466 | 467 | if self.use_paf: 468 | try: 469 | paf_label = np.array(paf_label).astype(np.float).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 470 | except: 471 | paf_label = np.array(paf_label).astype(np.float32).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 472 | 473 | label = np.concatenate([label, paf_label], axis=0) 474 | 475 | if self.use_CL_DA: 476 | img = self.__DA_SelReference(img) 477 | 478 | # random 3D rotation 479 | if self.mode == 'train' and not self.use_CL: 480 | if self.use_bg: 481 | img_label = {'data': img.reshape(1, -1, self.shift * 2, self.shift * 2, self.shift * 2), 482 | 'seg': label.reshape(1, -1, self.shift * 2, self.shift * 2, self.shift * 2)} 483 | if torch.rand(1) < 0.5: 484 | img_label = self.st(**img_label) 485 | else: 486 | img_label = self.mt(**img_label) 487 | img = img_label['data'].reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 488 | label = img_label['seg'].reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 489 | else: 490 | img = np.array(img).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 491 | label = np.array(label).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 492 | # degree = np.random.randint(4, size=3) 493 | # img = self.__rotation3D(img, degree) 494 | # label = self.__rotation3D(label, degree) 495 | # img = np.array(img) 496 | # label = np.array(label) 497 | else: 498 | img = img.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 499 | label = label.reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 500 | 501 | img = torch.as_tensor(img).float() 502 | label = torch.as_tensor(label).float() 503 | 504 | if self.use_bg_part and self.Sel_Referance: 505 | idx, x, y, z = self.coords_bg[index] 506 | z_max, y_max, x_max = self.origin[0].data.shape 507 | 508 | # point = self.__sample(np.array([x, y, z]), 509 | # np.array([x_max, y_max, z_max])) 510 | point = [x, y, z] 511 | 512 | img_bg = self.origin[0].data[point[2] - self.shift:point[2] + self.shift, 513 | point[1] - self.shift:point[1] + self.shift, 514 | point[0] - self.shift:point[0] + self.shift] 515 | 516 | if self.use_CL_DA: 517 | img_bg = self.__DA_SelReference(np.array(img_bg)) 518 | 519 | img_bg = np.array(img_bg).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 520 | img_bg = torch.as_tensor(img_bg).float() 521 | 522 | if self.use_ice_part and self.Sel_Referance: 523 | idx, x, y, z = self.coords_ice[index] 524 | z_max, y_max, x_max = self.origin[0].data.shape 525 | 526 | # point = self.__sample(np.array([x, y, z]), 527 | # np.array([x_max, y_max, z_max])) 528 | point = [x, y, z] 529 | img_ice = self.origin[0].data[point[2] - self.shift:point[2] + self.shift, 530 | point[1] - self.shift:point[1] + self.shift, 531 | point[0] - self.shift:point[0] + self.shift] 532 | 533 | if self.use_CL_DA: 534 | img_ice = self.__DA_SelReference(np.array(img_ice)) 535 | 536 | img_ice = np.array(img_ice).reshape(-1, self.shift * 2, self.shift * 2, self.shift * 2) 537 | img_ice = torch.as_tensor(img_ice).float() 538 | 539 | if self.use_bg_part and self.Sel_Referance: 540 | if self.use_ice_part: 541 | return img, img_bg, img_ice 542 | else: 543 | return img, img_bg, position 544 | else: 545 | return img, label, position 546 | 547 | def __len__(self): 548 | return max(len(self.coords), len(self.data)) 549 | # return min(len(self.coords), len(self.data)) 550 | 551 | def __rotation3D(self, data, degree): 552 | data = np.rot90(data, degree[0], (0, 1)) 553 | data = np.rot90(data, degree[1], (1, 2)) 554 | data = np.rot90(data, degree[2], (0, 2)) 555 | return data 556 | 557 | def __DA_SelReference(self, data): 558 | D, H, W = data.shape 559 | out = data.reshape(1, D, H, W) 560 | out = np.concatenate([out, np.rot90(data, 1, (0, 2)).reshape(1, D, H, W)], axis=0) 561 | out = np.concatenate([out, np.rot90(data, 2, (0, 2)).reshape(1, D, H, W)], axis=0) 562 | out = np.concatenate([out, np.rot90(data, 3, (0, 2)).reshape(1, D, H, W)], axis=0) 563 | out = np.concatenate([out, np.rot90(data, 1, (1, 2)).reshape(1, D, H, W)], axis=0) 564 | out = np.concatenate([out, np.rot90(data, 3, (1, 2)).reshape(1, D, H, W)], axis=0) 565 | for axis_idx in range(out.shape[0]): 566 | data = out[axis_idx] 567 | out = np.concatenate([out, data[::-1, :, :].reshape(1, D, H, W)], axis=0) 568 | for idx in range(1, 4): 569 | out = np.concatenate([out, np.rot90(data, idx, (0, 1)).reshape(1, D, H, W)], axis=0) 570 | out = np.concatenate([out, np.rot90(data, idx, (0, 1)).reshape(1, D, H, W)[:, ::-1, :, :]], axis=0) 571 | return out 572 | 573 | def __DA_SelReference_inital(self, data): 574 | out = data 575 | for idx in range(1, 4): 576 | out = np.concatenate([out, np.rot90(data, idx, (0, 1))], axis=0) 577 | out = np.concatenate([out, np.rot90(data, idx, (1, 2))], axis=0) 578 | out = np.concatenate([out, np.rot90(data, idx, (0, 2))], axis=0) 579 | out = np.concatenate([out, data[::-1, :, :]], axis=0) 580 | out = np.concatenate([out, data[:, ::-1, :]], axis=0) 581 | out = np.concatenate([out, data[:, :, ::-1]], axis=0) 582 | return out 583 | 584 | 585 | def __sample(self, point, bound): 586 | # point: z, y, x 587 | new_point = point + np.random.randint(-self.radius, self.radius + 1, size=3) 588 | new_point[new_point < self.shift] = self.shift 589 | new_point[new_point + self.shift > bound] = bound[new_point + self.shift > bound] - self.shift 590 | return new_point 591 | 592 | 593 | # transform label to multichannel 594 | def multiclass_label(x, num_classes, first_idx=0): 595 | for i in range(first_idx, first_idx + num_classes): 596 | label_temp = x 597 | label_temp = np.where(label_temp == i, 1, 0) 598 | if i == first_idx: 599 | label_new = label_temp 600 | else: 601 | label_new = np.concatenate((label_new, label_temp)) 602 | 603 | return label_new 604 | 605 | 606 | if __name__ == '__main__': 607 | num_cls = 13 608 | dataloader = DataLoader(Dataset_ClsBased(mode='val', 609 | block_size=32, 610 | num_class=num_cls, 611 | random_num=0, 612 | use_bg=False, 613 | test_use_pad=False, pad_size=18, 614 | data_split=[6, 6, 6], 615 | base_dir="/ldap_shared/synology_shared/shrec_2020/shrec2020_new", 616 | label_name="sphere7", 617 | coord_format=".coords", 618 | tomo_format='.mrc', 619 | norm_type='normalization'), 620 | batch_size=1, 621 | num_workers=1, 622 | shuffle=False, 623 | pin_memory=False) 624 | import matplotlib.pyplot as plt 625 | 626 | rows = 2 627 | cols = 11 628 | plt.figure(1, figsize=(cols * 7, rows * 7)) 629 | for idx, (img, label, position) in enumerate(dataloader): 630 | if idx == 100: 631 | print(position) 632 | print(img.shape) 633 | print(img.max(), img.min()) 634 | print(label.max(), label.min()) 635 | print(label.shape) 636 | 637 | imgs = img[0, 0, 0:31:3, ...] 638 | labels = label[0, 0, 0:31:3, ...] 639 | z, h, w = imgs.shape 640 | for i in range(z): 641 | plt.subplot(rows, cols, i + 1) 642 | plt.imshow(imgs[i, ...], cmap=plt.cm.gray) 643 | plt.axis('off') 644 | 645 | plt.subplot(rows, cols, i + 1 + cols) 646 | plt.imshow(labels[i, ...], cmap=plt.cm.gray) 647 | plt.axis('off') 648 | 649 | plt.tight_layout() 650 | plt.savefig('temp.png') 651 | print(position) 652 | break 653 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.model_loader import * 2 | from model.residual_unet_att import * -------------------------------------------------------------------------------- /model/model_loader.py: -------------------------------------------------------------------------------- 1 | from model.residual_unet_att import ResidualUNet3D 2 | 3 | def get_model(args): 4 | 5 | if args.network == 'ResUNet': 6 | model = ResidualUNet3D(f_maps=args.f_maps, out_channels=args.num_classes, 7 | args=args, in_channels=args.in_channels, use_att=args.use_att, 8 | use_paf=args.use_paf, use_uncert=args.use_uncert) 9 | 10 | return model 11 | -------------------------------------------------------------------------------- /model/residual_unet_att.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | import torch 4 | from functools import partial 5 | import sys 6 | from torch.nn import Conv3d, Module, Linear, BatchNorm3d, ReLU 7 | from torch.nn.modules.utils import _pair, _triple 8 | 9 | sys.path.append("..") 10 | from utils.coordconv_torch import AddCoords 11 | 12 | # from SoftPool import SoftPool3d 13 | try: 14 | from model.sync_batchnorm import SynchronizedBatchNorm3d 15 | except: 16 | pass 17 | 18 | 19 | def normalization(planes, norm='bn'): 20 | if norm == 'bn': 21 | m = nn.BatchNorm3d(planes) 22 | elif norm == 'gn': 23 | m = nn.GroupNorm(4, planes) 24 | elif norm == 'in': 25 | m = nn.InstanceNorm3d(planes) 26 | elif norm == 'sync_bn': 27 | m = SynchronizedBatchNorm3d(planes) 28 | else: 29 | raise ValueError('normalization type {} is not supported'.format(norm)) 30 | return m 31 | 32 | 33 | # Residual 3D UNet 34 | class ResidualUNet3D(nn.Module): 35 | def __init__(self, f_maps=[32, 64, 128, 256], in_channels=1, out_channels=13, 36 | args=None, use_att=False, use_paf=None, use_uncert=None): 37 | super(ResidualUNet3D, self).__init__() 38 | if use_att: 39 | norm = BatchNorm3d 40 | else: 41 | norm = args.norm 42 | act = args.act 43 | use_lw = args.use_lw 44 | lw_kernel = args.lw_kernel 45 | 46 | self.use_aspp = args.use_aspp 47 | self.pif_sigmoid = args.pif_sigmoid 48 | self.paf_sigmoid = args.paf_sigmoid 49 | self.use_tanh = args.use_tanh 50 | self.use_IP = args.use_IP 51 | self.out_channels = out_channels 52 | if self.out_channels > 1: 53 | self.use_softmax = args.use_softmax 54 | else: 55 | self.use_sigmoid = args.use_sigmoid 56 | self.use_coord = args.use_coord 57 | 58 | self.use_softpool = args.use_softpool 59 | 60 | self.use_paf = use_paf 61 | self.use_uncert = use_uncert 62 | 63 | if self.use_softpool: 64 | # pool_layer = SoftPool3d 65 | pass 66 | else: 67 | pool_layer = nn.AvgPool3d 68 | 69 | if self.use_IP: 70 | pools = [] 71 | for _ in range(len(f_maps) - 1): 72 | pools.append(pool_layer(2)) 73 | self.pools = nn.ModuleList(pools) 74 | 75 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` 76 | encoders = [] 77 | for i, out_feature_num in enumerate(f_maps): 78 | if i == 0: 79 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, use_IP=False, 80 | use_coord=self.use_coord, 81 | pool_layer=pool_layer, norm=norm, act=act, use_att=use_att, 82 | use_lw=use_lw, lw_kernel=lw_kernel) 83 | else: 84 | # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations 85 | encoder = Encoder(f_maps[i - 1], out_feature_num, use_IP=self.use_IP, use_coord=self.use_coord, 86 | pool_layer=pool_layer, norm=norm, act=act, use_att=use_att, 87 | use_lw=use_lw, lw_kernel=lw_kernel) 88 | 89 | encoders.append(encoder) 90 | 91 | self.encoders = nn.ModuleList(encoders) 92 | 93 | # 使用aspp进一步提取特征 94 | if self.use_aspp: 95 | self.aspp = ASPP(in_channels=f_maps[-1], inter_channels=f_maps[-1], out_channels=f_maps[-1]) 96 | 97 | self.se_loss = args.use_se_loss 98 | if self.se_loss: 99 | self.avgpool = nn.AdaptiveAvgPool3d(1) 100 | self.fc1 = nn.Linear(f_maps[-1], f_maps[-1]) 101 | self.fc2 = nn.Linear(f_maps[-1], out_channels) 102 | 103 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 104 | decoders = [] 105 | reversed_f_maps = list(reversed(f_maps)) 106 | for i in range(len(reversed_f_maps) - 1): 107 | in_feature_num = reversed_f_maps[i] 108 | out_feature_num = reversed_f_maps[i + 1] 109 | # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv 110 | # currently strides with a constant stride: (2, 2, 2) 111 | decoder = Decoder(in_feature_num, out_feature_num, use_coord=self.use_coord, norm=norm, act=act, 112 | use_att=use_att, use_lw=use_lw, lw_kernel=lw_kernel) 113 | decoders.append(decoder) 114 | 115 | self.decoders = nn.ModuleList(decoders) 116 | 117 | # in the last layer a 1×1 convolution reduces the number of output 118 | # channels to the number of labels 119 | if args.final_double: 120 | self.final_conv = nn.Sequential( 121 | nn.Conv3d(f_maps[0], f_maps[0] // 2, kernel_size=1), 122 | nn.Conv3d(f_maps[0] // 2, out_channels, 1) 123 | ) 124 | if self.use_paf: 125 | self.paf_conv = nn.Sequential( 126 | nn.Conv3d(f_maps[0], f_maps[0] // 2, kernel_size=1), 127 | nn.Conv3d(f_maps[0] // 2, 1, 1) 128 | ) 129 | self.dropout = nn.Dropout3d 130 | else: 131 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 132 | if self.use_paf: 133 | self.paf_conv = nn.Conv3d(f_maps[0], 1, 1) 134 | self.dropout = nn.Dropout3d 135 | 136 | if self.use_paf: 137 | if self.use_uncert: 138 | self.logsigma = nn.Parameter(torch.FloatTensor([0.5] * 2)) 139 | else: 140 | self.logsigma = torch.FloatTensor([0.5] * 2) 141 | 142 | def forward(self, x): 143 | if self.use_IP: 144 | img_pyramid = [] 145 | img_d = x 146 | for pool in self.pools: 147 | img_d = pool(img_d) 148 | img_pyramid.append(img_d) 149 | 150 | encoders_features = [] 151 | for idx, encoder in enumerate(self.encoders): 152 | if self.use_IP and idx > 0: 153 | x = encoder(x, img_pyramid[idx - 1]) 154 | else: 155 | x = encoder(x) 156 | encoders_features.insert(0, x) 157 | 158 | if self.use_aspp: 159 | x = self.aspp(x) 160 | # remove last 161 | encoders_features = encoders_features[1:] 162 | 163 | if self.se_loss: 164 | se_out = self.avgpool(x) 165 | se_out = se_out.view(se_out.size(0), -1) 166 | se_out = self.fc1(se_out) 167 | se_out = self.fc2(se_out) 168 | 169 | for decoder, encoder_features in zip(self.decoders, encoders_features): 170 | x = decoder(encoder_features, x) 171 | 172 | out = self.final_conv(x) 173 | 174 | if self.out_channels > 1: 175 | if self.use_softmax: 176 | out = torch.softmax(out, dim=1) 177 | elif self.pif_sigmoid: 178 | out = torch.sigmoid(out) 179 | elif self.use_tanh: 180 | out = torch.tanh(out) 181 | else: 182 | if self.use_sigmoid: 183 | out = torch.sigmoid(out) 184 | elif self.use_tanh: 185 | out = torch.tanh(out) 186 | 187 | if self.use_paf: 188 | paf_out = self.paf_conv(x) 189 | if self.paf_sigmoid: 190 | paf_out = torch.sigmoid(paf_out) 191 | 192 | if self.se_loss: 193 | return [out, se_out] 194 | else: 195 | if self.use_paf: 196 | return [out, paf_out, self.logsigma] 197 | else: 198 | return out 199 | 200 | 201 | class Encoder(nn.Module): 202 | def __init__(self, in_channels, out_channels, apply_pooling=True, use_IP=False, use_coord=False, 203 | pool_layer=nn.MaxPool3d, norm='bn', act='relu', use_att=False, 204 | use_lw=False, lw_kernel=3, input_channels=1): 205 | super(Encoder, self).__init__() 206 | if apply_pooling: 207 | self.pooling = pool_layer(kernel_size=2) 208 | else: 209 | self.pooling = None 210 | 211 | self.use_IP = use_IP 212 | self.use_coord = use_coord 213 | inplaces = in_channels + input_channels if self.use_IP else in_channels 214 | inplaces = inplaces + 3 if self.use_coord else inplaces 215 | 216 | if use_att: 217 | self.basic_module = ExtResNetBlock_att(inplaces, out_channels, norm=norm, act=act) 218 | else: 219 | if use_lw: 220 | self.basic_module = ExtResNetBlock_lightWeight(inplaces, out_channels, lw_kernel=lw_kernel) 221 | else: 222 | self.basic_module = ExtResNetBlock(inplaces, out_channels, norm=norm, act=act) 223 | if self.use_coord: 224 | self.coord_conv = AddCoords(rank=3, with_r=False) 225 | 226 | def forward(self, x, scaled_img=None): 227 | if self.pooling is not None: 228 | x = self.pooling(x) 229 | if self.use_IP: 230 | x = torch.cat([x, scaled_img], dim=1) 231 | if self.use_coord: 232 | x = self.coord_conv(x) 233 | x = self.basic_module(x) 234 | return x 235 | 236 | 237 | class Decoder(nn.Module): 238 | def __init__(self, in_channels, out_channels, scale_factor=(2, 2, 2), mode='nearest', 239 | padding=1, use_coord=False, norm='bn', act='relu', use_att=False, 240 | use_lw=False, lw_kernel=3): 241 | super(Decoder, self).__init__() 242 | self.use_coord = use_coord 243 | if self.use_coord: 244 | self.coord_conv = AddCoords(rank=3, with_r=False) 245 | 246 | # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining 247 | self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, 248 | scale_factor=scale_factor, mode=mode) 249 | # sum joining 250 | self.joining = partial(self._joining, concat=False) 251 | # adapt the number of in_channels for the ExtResNetBlock 252 | in_channels = out_channels + 3 if self.use_coord else out_channels 253 | 254 | if use_att: 255 | self.basic_module = ExtResNetBlock_att(in_channels, out_channels, norm=norm, act=act) 256 | else: 257 | if use_lw: 258 | self.basic_module = ExtResNetBlock_lightWeight(in_channels, out_channels, lw_kernel=lw_kernel) 259 | else: 260 | self.basic_module = ExtResNetBlock(in_channels, out_channels, norm='bn', act=act) 261 | 262 | def forward(self, encoder_features, x, ReturnInput=False): 263 | x = self.upsampling(encoder_features, x) 264 | x = self.joining(encoder_features, x) 265 | if self.use_coord: 266 | x = self.coord_conv(x) 267 | if ReturnInput: 268 | x1 = self.basic_module(x) 269 | return x1, x 270 | x = self.basic_module(x) 271 | return x 272 | 273 | @staticmethod 274 | def _joining(encoder_features, x, concat): 275 | if concat: 276 | return torch.cat((encoder_features, x), dim=1) 277 | else: 278 | return encoder_features + x 279 | 280 | 281 | class ExtResNetBlock(nn.Module): 282 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'): 283 | super(ExtResNetBlock, self).__init__() 284 | # first convolution 285 | self.conv1 = SingleConv(in_channels, out_channels, norm=norm, act=act) 286 | # residual block 287 | self.conv2 = SingleConv(out_channels, out_channels, norm=norm, act=act) 288 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 289 | self.conv3 = SingleConv(out_channels, out_channels, norm=norm, act=act) 290 | self.non_linearity = nn.ELU(inplace=False) 291 | 292 | def forward(self, x): 293 | # apply first convolution and save the output as a residual 294 | out = self.conv1(x) 295 | residual = out 296 | # residual block 297 | out = self.conv2(out) 298 | out = self.conv3(out) 299 | 300 | out += residual 301 | out = self.non_linearity(out) 302 | return out 303 | 304 | 305 | class ExtResNetBlock_att(nn.Module): 306 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'): 307 | super(ExtResNetBlock_att, self).__init__() 308 | # first convolution 309 | self.conv1 = SingleConv(in_channels, out_channels, norm=norm, act=act) 310 | # residual block 311 | self.conv2 = SplAtConv3d(out_channels, out_channels // 2, norm_layer=norm) 312 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 313 | self.conv3 = SplAtConv3d(out_channels, out_channels // 2, norm_layer=norm) 314 | self.non_linearity = nn.ELU(inplace=False) 315 | 316 | def forward(self, x): 317 | # apply first convolution and save the output as a residual 318 | out = self.conv1(x) 319 | residual = out 320 | # residual block 321 | out = self.conv2(out) 322 | out = self.conv3(out) 323 | 324 | out += residual 325 | out = self.non_linearity(out) 326 | return out 327 | 328 | 329 | class SingleConv(nn.Sequential): 330 | def __init__(self, in_channels, out_channels, norm='bn', act='relu'): 331 | super(SingleConv, self).__init__() 332 | self.add_module('conv', nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)) 333 | self.add_module('batchnorm', normalization(out_channels, norm=norm)) 334 | if act == 'relu': 335 | self.add_module('relu', nn.ReLU(inplace=False)) 336 | elif act == 'lrelu': 337 | self.add_module('lrelu', nn.LeakyReLU(negative_slope=0.1, inplace=False)) 338 | elif act == 'elu': 339 | self.add_module('elu', nn.ELU(inplace=False)) 340 | elif act == 'gelu': 341 | self.add_module('elu', nn.GELU(inplace=False)) 342 | 343 | 344 | class ExtResNetBlock_lightWeight(nn.Module): 345 | def __init__(self, in_channels, out_channels, lw_kernel=3): 346 | super(ExtResNetBlock_lightWeight, self).__init__() 347 | # first convolution 348 | self.conv1 = SingleConv_lightWeight(in_channels, out_channels, lw_kernel=lw_kernel) 349 | # residual block 350 | self.conv2 = SingleConv_lightWeight(out_channels, out_channels, lw_kernel=lw_kernel) 351 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 352 | self.conv3 = SingleConv_lightWeight(out_channels, out_channels, lw_kernel=lw_kernel) 353 | self.non_linearity = nn.ELU(inplace=False) 354 | 355 | def forward(self, x): 356 | # apply first convolution and save the output as a residual 357 | out = self.conv1(x) 358 | residual = out 359 | # residual block 360 | out = self.conv2(out) 361 | out = self.conv3(out) 362 | 363 | out += residual 364 | out = self.non_linearity(out) 365 | return out 366 | 367 | 368 | class SingleConv_lightWeight(nn.Sequential): 369 | def __init__(self, in_channels, out_channels, lw_kernel=3, layer_scale_init_value=1e-6): 370 | super(SingleConv_lightWeight, self).__init__() 371 | 372 | self.dwconv = nn.Conv3d(in_channels, in_channels, kernel_size=lw_kernel, padding=lw_kernel//2, groups=in_channels) 373 | self.norm = nn.LayerNorm(in_channels, eps=1e-6) 374 | self.pwconv1 = nn.Linear(in_channels, 2 * in_channels) 375 | self.act = nn.GELU() 376 | self.pwconv2 = nn.Linear(2 * in_channels, out_channels) 377 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)), 378 | requires_grad=True) if layer_scale_init_value > 0 else None 379 | if in_channels != out_channels: 380 | self.skip = nn.Conv3d(in_channels, out_channels, 1) 381 | self.in_channels = in_channels 382 | self.out_channels = out_channels 383 | 384 | def forward(self, x): 385 | input = x 386 | x = self.dwconv(x) 387 | x = x.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) 388 | x = self.norm(x) 389 | x = self.pwconv1(x) 390 | x = self.act(x) 391 | x = self.pwconv2(x) 392 | if self.gamma is not None: 393 | x = self.gamma * x 394 | x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) 395 | 396 | x = x + (input if self.in_channels == self.out_channels else self.skip(input)) 397 | return x 398 | 399 | 400 | class Upsampling(nn.Module): 401 | def __init__(self, transposed_conv, in_channels=None, out_channels=None, scale_factor=(2, 2, 2), mode='nearest'): 402 | super(Upsampling, self).__init__() 403 | 404 | if transposed_conv: 405 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 406 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 407 | self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=scale_factor, padding=1) 408 | else: 409 | self.upsample = partial(self._interpolate, mode=mode) 410 | 411 | def forward(self, encoder_features, x): 412 | output_size = encoder_features.size()[2:] 413 | return self.upsample(x, output_size) 414 | 415 | @staticmethod 416 | def _interpolate(x, size, mode): 417 | return F.interpolate(x, size=size, mode=mode) 418 | 419 | 420 | class SplAtConv3d(Module): 421 | """Split-Attention Conv2d 422 | """ 423 | 424 | def __init__(self, in_channels, channels, kernel_size=3, stride=(1, 1, 1), padding=(1, 1, 1), 425 | dilation=(1, 1, 1), groups=1, bias=True, 426 | radix=2, reduction_factor=4, 427 | rectify=False, rectify_avg=False, norm_layer=BatchNorm3d, 428 | dropblock_prob=0.0, **kwargs): 429 | super(SplAtConv3d, self).__init__() 430 | padding = _triple(padding) 431 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 432 | self.rectify_avg = rectify_avg 433 | inter_channels = max(in_channels * radix // reduction_factor, 32) 434 | self.radix = radix 435 | self.cardinality = groups 436 | self.channels = channels 437 | self.dropblock_prob = dropblock_prob 438 | if self.rectify: 439 | pass 440 | # from rfconv import RFConv2d 441 | # self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 442 | # groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 443 | else: 444 | self.conv = Conv3d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 445 | groups=groups * radix, bias=bias, **kwargs) 446 | self.use_bn = norm_layer is not None 447 | if self.use_bn: 448 | self.bn0 = norm_layer(channels * radix) 449 | self.relu = ReLU(inplace=False) 450 | self.fc1 = Conv3d(channels, inter_channels, 1, groups=self.cardinality) 451 | if self.use_bn: 452 | self.bn1 = norm_layer(inter_channels) 453 | self.fc2 = Conv3d(inter_channels, channels * radix, 1, groups=self.cardinality) 454 | # if dropblock_prob > 0.0: 455 | # self.dropblock = DropBlock2D(dropblock_prob, 3) 456 | self.rsoftmax = rSoftMax(radix, groups) 457 | self.conv3 = Conv3d(channels, channels * radix, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False) 458 | self.bn3 = BatchNorm3d(channels * radix, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 459 | self.relu3 = ReLU(inplace=False) 460 | 461 | def forward(self, x): 462 | x = self.conv(x) 463 | if self.use_bn: 464 | x = self.bn0(x) 465 | # if self.dropblock_prob > 0.0: 466 | # x = self.dropblock(x) 467 | x = self.relu(x) 468 | 469 | batch, rchannel = x.shape[:2] 470 | if self.radix > 1: 471 | if torch.__version__ < '1.5': 472 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 473 | else: 474 | splited = torch.split(x, rchannel // self.radix, dim=1) 475 | gap = sum(splited) 476 | else: 477 | gap = x 478 | gap = F.adaptive_avg_pool3d(gap, 1) 479 | gap = self.fc1(gap) 480 | 481 | if self.use_bn: 482 | gap = self.bn1(gap) 483 | gap = self.relu(gap) 484 | 485 | atten = self.fc2(gap) 486 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1, 1) 487 | 488 | if self.radix > 1: 489 | if torch.__version__ < '1.5': 490 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 491 | else: 492 | attens = torch.split(atten, rchannel // self.radix, dim=1) 493 | out = sum([att * split for (att, split) in zip(attens, splited)]) 494 | else: 495 | out = atten * x 496 | 497 | out = self.relu3(self.bn3(self.conv3(out))) 498 | return out.contiguous() 499 | 500 | 501 | class rSoftMax(nn.Module): 502 | def __init__(self, radix, cardinality): 503 | super().__init__() 504 | self.radix = radix 505 | self.cardinality = cardinality 506 | 507 | def forward(self, x): 508 | batch = x.size(0) 509 | if self.radix > 1: 510 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 511 | x = F.softmax(x, dim=1) 512 | x = x.reshape(batch, -1) 513 | else: 514 | x = torch.sigmoid(x) 515 | return x 516 | 517 | 518 | if __name__ == "__main__": 519 | import argparse 520 | 521 | parser = argparse.ArgumentParser(description='Training for 3D U-Net models') 522 | parser.add_argument('--use_IP', type=bool, help='whether use image pyramid', default=False) 523 | parser.add_argument('--use_DS', type=bool, help='whether use deep supervision', default=False) 524 | parser.add_argument('--use_Res', type=bool, help='whether use residual connectivity', default=False) 525 | parser.add_argument('--use_bg', type=bool, help='whether use batch generator', default=False) 526 | parser.add_argument('--use_coord', type=bool, help='whether use coord conv', default=False) 527 | parser.add_argument('--use_softmax', type=bool, help='whether use softmax', default=False) 528 | parser.add_argument('--use_softpool', type=bool, help='whether use softpool', default=False) 529 | parser.add_argument('--use_aspp', type=bool, help='whether use aspp', default=False) 530 | parser.add_argument('--use_att', type=bool, help='whether use aspp', default=False) 531 | parser.add_argument('--use_se_loss', type=bool, help='whether use aspp', default=False) 532 | parser.add_argument('--pif_sigmoid', type=bool, help='whether use aspp', default=False) 533 | parser.add_argument('--paf_sigmoid', type=bool, help='whether use aspp', default=False) 534 | parser.add_argument('--final_double', type=bool, help='whether use aspp', default=False) 535 | parser.add_argument('--use_tanh', type=bool, help='whether use aspp', default=False) 536 | parser.add_argument('--norm', help='type of normalization', type=str, default='sync_bn', 537 | choices=['bn', 'gn', 'in', 'sync_bn']) 538 | parser.add_argument('--use_lw', type=bool, help='whether use lightweight', default=True) 539 | parser.add_argument('--lw_kernel', type=int, default=5) 540 | parser.add_argument('--act', help='type of activation function', type=str, default='relu', 541 | choices=['relu', 'lrelu', 'elu', 'gelu']) 542 | args = parser.parse_args() 543 | 544 | net = ResidualUNet3D(args=args, use_att=args.use_att, f_maps=[24, 48, 72, 108]) 545 | print(net) 546 | 547 | # conv = SplAtConv3d(64, 32, 3) 548 | # print(conv) 549 | data = torch.rand([2, 1, 56, 56, 56]) 550 | out = net(data) 551 | print(out.shape) 552 | -------------------------------------------------------------------------------- /options/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | 14 | 15 | class BaseOptions(): 16 | def __init__(self): 17 | self.parser = argparse.ArgumentParser(description='Parameters for 3D particle picking') 18 | 19 | # dataloader parameters 20 | self.parser.add_argument('--block_size', help='block size', type=int, default=72) 21 | self.parser.add_argument('--val_block_size', help='block size', type=int, default=0) 22 | self.parser.add_argument('--random_num', help='random number', type=int, default=0) 23 | self.parser.add_argument('--num_classes', help='number of classes', type=int, default=13) 24 | self.parser.add_argument('--use_bg', type=str2bool, help='whether use batch generator', default=False) 25 | self.parser.add_argument('--test_use_pad', type=str2bool, help='whether use coord conv', default=False) 26 | self.parser.add_argument('--pad_size', nargs='+', type=int, default=[12]) 27 | self.parser.add_argument('--data_split', nargs='+', type=int, default=[0, 1, 0, 1, 0, 1]) 28 | self.parser.add_argument('--configs', type=str, default='') 29 | self.parser.add_argument('--pre_configs', type=str, default='') 30 | self.parser.add_argument('--train_configs', type=str, default='') 31 | self.parser.add_argument('--ck_mode', type=str, default='') 32 | self.parser.add_argument('--val_configs', type=str, default='') 33 | self.parser.add_argument('--loader_type', type=str, default='dataloader_DynamicLoad', help="whether use DynamicLoad", 34 | # choices=["dataloader", "dataloader_DynamicLoad", 35 | # 'dataloader_DynamicLoad_CellSeg', 36 | # "dataloader_DynamicLoad_Semi"] 37 | ) 38 | self.parser.add_argument('--sel_train_num', nargs='+', type=int) 39 | self.parser.add_argument('--rand_min', type=int, default=0) 40 | self.parser.add_argument('--rand_max', type=int, default=99) 41 | self.parser.add_argument('--rand_scale', type=int, default=100) 42 | self.parser.add_argument('--input_cat', type=str2bool, help='whether use input cat', default=False) 43 | self.parser.add_argument('--input_cat_items', nargs='+', type=str, default='None') 44 | 45 | # model parameters 46 | self.parser.add_argument('--network', help='network type', type=str, default='ResUNet', 47 | # choices=['unet', 'UMC', 'ResUnet', 'DoubleUnet', 'MFNet', 'DMFNet', 'DMFNet_down3', 48 | # 'NestUnet', 'VoxResNet', 'HighRes3DNet', 'HRNetv1'] 49 | ) 50 | self.parser.add_argument('--in_channels', help='input channels of the network', type=int, default=1) 51 | self.parser.add_argument('--f_maps', nargs='+', type=int, help="Feature numbers of ResUnet") 52 | self.parser.add_argument('--use_LAM', type=str2bool, help='whether use LAM', default=False) 53 | self.parser.add_argument('--use_IP', type=str2bool, help='whether use image pyramid', default=False) 54 | self.parser.add_argument('--use_DS', type=str2bool, help='whether use deep supervision', default=False) 55 | self.parser.add_argument('--use_wDS', type=str2bool, help='whether use deep supervision', default=False) 56 | self.parser.add_argument('--use_Res', type=str2bool, help='whether use residual connectivity', default=False) 57 | self.parser.add_argument('--use_coord', type=str2bool, help='whether use coord conv', default=False) 58 | self.parser.add_argument('--use_softmax', type=str2bool, help='whether use softmax', default=False) 59 | self.parser.add_argument('--use_sigmoid', type=str2bool, help='whether use sigmoid', default=False) 60 | self.parser.add_argument('--use_tanh', type=str2bool, help='whether use tanh', default=False) 61 | self.parser.add_argument('--use_softpool', type=str2bool, help='whether use softpool', default=False) 62 | self.parser.add_argument('--use_aspp', type=str2bool, help='whether use aspp', default=False) 63 | self.parser.add_argument('--use_se_loss', type=str2bool, help='whether use SE loss', default=False) 64 | self.parser.add_argument('--use_att', type=str2bool, help='whether use aspp', default=False) 65 | self.parser.add_argument('--initial_channels', help='initial_channels of NestUnet', type=int, default=16) 66 | self.parser.add_argument('--mf_groups', help='number of groups', type=int, default=16) 67 | self.parser.add_argument('--norm', help='type of normalization', type=str, default='bn', 68 | choices=['bn', 'gn', 'in', 'sync_bn']) 69 | self.parser.add_argument('--act', help='type of activation function', type=str, default='relu', 70 | choices=['relu', 'lrelu', 'elu', 'gelu']) 71 | self.parser.add_argument('--use_wds', type=str2bool, help='whether use weighted deep supervision', 72 | default=False) 73 | self.parser.add_argument('--add_dropout_layer', type=str2bool, help='whether use dropout layer in HighResNet', 74 | default=False) 75 | self.parser.add_argument('--dimensions', type=int, default=3, help='Dimensions of HighResNet') 76 | self.parser.add_argument('--use_paf', type=str2bool, help='PostFusion_orit: whether use part affinity field', 77 | default=False) 78 | self.parser.add_argument('--paf_sigmoid', type=str2bool, help='whether use sigmoid for the branch of ' 79 | 'part affinity field', default=False) 80 | self.parser.add_argument('--pif_sigmoid', type=str2bool, help='whether use sigmoid for the branch of ' 81 | 'part intensity field', default=False) 82 | self.parser.add_argument('--final_double', type=str2bool, help='whether use sigmoid for the branch of ' 83 | 'part affinity field', default=False) 84 | self.parser.add_argument('--HRNet_c', type=int, default=12) 85 | self.parser.add_argument('--n_block', type=int, default=2) 86 | self.parser.add_argument('--reduce_ratio', type=int, default=1) 87 | self.parser.add_argument('--n_stages', nargs='+', type=int, default=[1, 1, 1, 1]) 88 | self.parser.add_argument('--use_uncert', type=str2bool, help='whether use uncert for loss weights', 89 | default=False) 90 | self.parser.add_argument('--Gau_num', type=int, default=2) 91 | self.parser.add_argument('--use_seg_gau', type=str2bool, help='whether use seg and gau', default=False) 92 | self.parser.add_argument('--gau_thresh', type=float, default=0.5) 93 | self.parser.add_argument('--use_lw', type=str2bool, help='whether use lightweight', default=False) 94 | self.parser.add_argument('--lw_kernel', type=int, default=3) 95 | 96 | # training hyper-parameters 97 | self.parser.add_argument('--learning_rate', type=float, default=5e-5) 98 | self.parser.add_argument('--batch_size', help='batch size', type=int, default=32) 99 | self.parser.add_argument('--val_batch_size', help='batch size', type=int, default=0) 100 | self.parser.add_argument('--max_epoch', help='number of epochs', type=int, default=100) 101 | self.parser.add_argument('--loss_func_seg', help='seg loss function type', type=str, default='Dice') 102 | self.parser.add_argument('--loss_func_dn', help='denoising loss function type', type=str, default='MSE') 103 | self.parser.add_argument('--loss_func_paf', help='paf loss function type', type=str, default='MSE') 104 | self.parser.add_argument('--pred2d_3d', type=str2bool, help='whether use LAM', default=False) 105 | self.parser.add_argument('--threshold', type=float, default=0.5, help="calculate seg_metrics") 106 | self.parser.add_argument('--others', help='others', type=str, default='') 107 | self.parser.add_argument('--paf_weight', type=int, default=1, help='Weight for Paf branch') 108 | self.parser.add_argument('--border_value', type=int, help='border width', default=0) 109 | self.parser.add_argument('--dset_name', type=str, help="the name of dataset") 110 | self.parser.add_argument('--train_mode', type=str, default='train', help='train mode') 111 | self.parser.add_argument('--gpu_id', nargs='+', type=int, default=[0, 1, 2, 3], help='gpu id') 112 | self.parser.add_argument('--prf1_alpha', type=float, default=3, help="calculate seg_metrics") 113 | self.parser.add_argument('--running', type=str2bool, help='whether use LAM', default=False) 114 | 115 | # Contrastive Learning hyper-parameters 116 | self.parser.add_argument('--checkpoints', type=str, help='Checkpoint directory', 117 | default=None) 118 | self.parser.add_argument('--checkpoints_version', type=str, help='Checkpoint directory', 119 | default=None) 120 | self.parser.add_argument('--cent_feats', type=str, help='Checkpoint directory', 121 | default=None) 122 | self.parser.add_argument('--particle_idx', type=int, default=70, help='Index of reference particle') 123 | self.parser.add_argument('--sel_particle_num', type=int, default=100, help='Index of reference particle') 124 | self.parser.add_argument('--iteration_idx', type=int, default=0, help='Iteration index') 125 | self.parser.add_argument('--cent_kernel', type=int, default=1, help='Iteration index') 126 | self.parser.add_argument('--Sel_Referance', type=str2bool, default=False, help='Select Reference Particle') 127 | self.parser.add_argument('--step1', type=str2bool, default=False, help='Select Reference Particle') 128 | self.parser.add_argument('--step2', type=str2bool, default=False, help='Select Reference Particle') 129 | self.parser.add_argument('--dir_name', type=str, help='Directory name', 130 | default=None) 131 | self.parser.add_argument('--stride', type=int, default=8, help='Select Reference Particle') 132 | self.parser.add_argument('--seg_tau', type=float, default=0.95, help='Segmentation threshold') 133 | self.parser.add_argument('--use_mask', type=str2bool, help='use mask to cal loss for SSL', 134 | default=False) 135 | self.parser.add_argument('--use_ema', type=str2bool, default=False, 136 | help='use EMA model') 137 | self.parser.add_argument('--use_bg_part', type=str2bool, default=False, 138 | help='use background particle') 139 | self.parser.add_argument('--use_ice_part', type=str2bool, default=False, 140 | help='use ice bg') 141 | self.parser.add_argument('--use_SimSeg_iteration', type=str2bool, default=False, 142 | help='use SimSeg_iteration') 143 | self.parser.add_argument('--ema_decay', default=0.999, type=float, 144 | help='EMA decay rate') 145 | self.parser.add_argument('--T', type=float, default=0.5, help='Segmentation threshold') 146 | self.parser.add_argument('--coord_path', type=str, help='Coordiate path name', 147 | default=None) 148 | 149 | # test_parameters 150 | self.parser.add_argument('--test_idxs', nargs='+', type=int, default=[0]) 151 | self.parser.add_argument('--save_pred', type=str2bool, help='whether use segmentation', default=False) 152 | self.parser.add_argument('--max_pxs', type=int, help='dilation pixel numbers', default=18) 153 | self.parser.add_argument('--de_duplication', type=str2bool, default=False, help='Whether use dilation') 154 | self.parser.add_argument('--test_mode', type=str, default='test_val', help='test mode') 155 | self.parser.add_argument('--paf_connect', type=str2bool, default=False, help='Whether use dilation') 156 | self.parser.add_argument('--Gau_nms', type=str2bool, default=False, help='Whether use gaussian NMS') 157 | self.parser.add_argument('--save_mrc', type=str2bool, default=False, help='Whether save .mrc file') 158 | self.parser.add_argument('--nms_kernel', type=int, help='kernel size for Gaussian NMS', default=3) 159 | self.parser.add_argument('--nms_topK', type=int, help='topK for Gaussian NMS', default=3) 160 | self.parser.add_argument('--pif_model', type=str, default='', help='pif model for paf-connect') 161 | self.parser.add_argument('--first_idx', type=int, help='first_idx', default=0) 162 | self.parser.add_argument('--use_CL', type=str2bool, default=False, help='Whether use Contrastive Learning') 163 | self.parser.add_argument('--use_cluster', type=str2bool, default=False, help='Whether use Contrastive Learning') 164 | self.parser.add_argument('--use_CL_DA', type=str2bool, default=False, 165 | help='Whether use DA for reference particle of Contrastive Learning') 166 | self.parser.add_argument('--CL_DA_fmt', type=str, default='mean', 167 | help='format of calculating similarity map under CL_DA') 168 | self.parser.add_argument('--ResearchTitle', type=str, default='None', 169 | help='format of calculating similarity map under CL_DA') 170 | self.parser.add_argument('--skip_4v94', type=bool, default=False, 171 | help='Whether to skip 4V94 evaluation or not. True in SHREC Cryo-ET 2021 results.') 172 | self.parser.add_argument('--skip_vesicles', type=bool, default=False, 173 | help='Whether to skip vesicles or not. True in SHREC Cryo-ET 2021 results.') 174 | self.parser.add_argument('--out_name', type=str, default='TestRes', 175 | help='file name for saving the predicted coordinates') 176 | 177 | self.parser.add_argument('--train_set_ids', type=str, default="0") 178 | self.parser.add_argument('--val_set_ids', type=str, default="0") 179 | self.parser.add_argument('--cfg_save_path', type=str, default=".") 180 | # optim parameters 181 | self.parser.add_argument('--optim', type=str, default='AdamW') 182 | self.parser.add_argument('--scheduler', type=str, default='OneCycleLR') 183 | self.parser.add_argument('--weight_decay', type=float, default=0.01, help="torch.optim: weight decay") 184 | self.parser.add_argument('--use_dilation', type=str2bool, default=False, help='Whether use dilation') 185 | self.parser.add_argument('--use_seg', type=str2bool, default=False, help='Whether use dilation') 186 | self.parser.add_argument('--use_eval', type=str2bool, default=False, help='Whether use dilation') 187 | 188 | # loss parameters 189 | self.parser.add_argument('--use_weight', type=str2bool, help='whether use different weights for cls losses', 190 | default=False) 191 | self.parser.add_argument('--NoBG', type=str2bool, 192 | help='whether calculate BG loss (the 0th dim of softmax outputs)', 193 | default=False) 194 | self.parser.add_argument('--pad_loss', type=str2bool, help='whether use padding loss', default=False) 195 | self.parser.add_argument('--alpha', type=float, default=0.7, help="Focal Tversky Loss: alpha * FP") 196 | self.parser.add_argument('--beta', type=float, default=0.3, help="Focal Tversky Loss: beta * FN") 197 | self.parser.add_argument('--gamma', type=float, default=0.75, help="Focal Tversky Loss: focal gamma") 198 | self.parser.add_argument('--eta', type=float, default=0.3, help="Dice_SE_Loss: weight of SE loss") 199 | self.parser.add_argument('--FL_a0', type=float, default=0.1, help="Soft_FL Loss: weight of a0") 200 | self.parser.add_argument('--FL_a1', type=float, default=0.9, help="Soft_FL Loss: weight of a1") 201 | 202 | # eval parameters 203 | self.parser.add_argument('--JudgeInDilation', type=str2bool, default=False) 204 | self.parser.add_argument('--save_FPsTPs', type=str2bool, default=False, 205 | help='Whether save the results of FP and TP') 206 | self.parser.add_argument('--de_dup_fmt', type=str, default='fmt4', help='de-duplication format') 207 | self.parser.add_argument('--eval_str', type=str, default='class', help='Whether use dilation') 208 | self.parser.add_argument('--min_vol', type=int, default=100, help='Minimum volume') 209 | self.parser.add_argument('--mini_dist', type=int, default=10, help='Minimum volume') 210 | self.parser.add_argument('--min_dist', type=int, default=10, help='Minimum volume') 211 | self.parser.add_argument('--eval_cls', type=str2bool, default=False, help='Minimum volume') 212 | self.parser.add_argument('--class_checkpoints', type=str, help='Checkpoint directory', 213 | default=None) 214 | self.parser.add_argument('--meanPool_NMS', type=str2bool, default=False, help='mean_pool NMS') 215 | self.parser.add_argument('--meanPool_kernel', type=int, default=5, help='mean_pool NMS') 216 | 217 | def gather_options(self): 218 | args = self.parser.parse_args() 219 | return args 220 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | pandas 3 | mrcfile 4 | scikit-image 5 | pillow==8.3.2 6 | pycm 7 | scikit-plot 8 | pyqtgraph==0.12.1 9 | PyQt5==5.15.4 10 | batchgenerators==0.21 11 | tinyaes 12 | pytorch_lightning==1.1.0 13 | opencv-python==4.2.0.34 14 | numpy==1.24.0 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | import torch 5 | from torch import nn 6 | import mrcfile 7 | import pytorch_lightning as pl 8 | from torch.utils.data import DataLoader 9 | import sys 10 | # from dataset.dataloader import Dataset_ClsBased 11 | import pandas as pd 12 | import importlib 13 | from glob import glob 14 | import matplotlib.pyplot as plt 15 | from pytorch_lightning import Trainer 16 | from options.option import BaseOptions 17 | from model.model_loader import get_model 18 | from utils.misc import combine, get_centroids, de_dup, cal_metrics_NMS_OneCls 19 | import json 20 | from dataset.dataloader_DynamicLoad import Dataset_ClsBased 21 | 22 | 23 | def test_func(args, stdout=None): 24 | if stdout is not None: 25 | save_stdout = sys.stdout 26 | save_stderr = sys.stderr 27 | sys.stdout = stdout 28 | sys.stderr = stdout 29 | for test_idx in args.test_idxs: 30 | model_name = args.checkpoints.split('/')[-4] + '_' + args.checkpoints.split('/')[-1].split('-')[0] 31 | # load config parameters 32 | if len(args.configs) > 0: 33 | with open(args.configs, 'r') as f: 34 | cfg = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 35 | 36 | start_time = time.time() 37 | args.data_split[-2] = test_idx 38 | args.data_split[-1] = test_idx + 1 39 | 40 | num_name = pd.read_csv(os.path.join(cfg["tomo_path"], 'num_name.csv'), sep='\t', header=None) 41 | dir_list = num_name.iloc[:, 1] 42 | dir_name = dir_list[args.data_split[-2]] 43 | print(dir_name) 44 | 45 | tomo_file = glob(cfg["tomo_path"] + "/*%s" % cfg["tomo_format"])[0] 46 | data_file = mrcfile.open(tomo_file, permissive=True) 47 | data_shape = data_file.data.shape 48 | print(data_shape) 49 | dataset = cfg["dset_name"] 50 | 51 | if args.use_seg: 52 | for pad_size in args.pad_size: 53 | 54 | class UNetTest(pl.LightningModule): 55 | def __init__(self): 56 | super(UNetTest, self).__init__() 57 | self.model = get_model(args) 58 | self.partical_volume = 4 / 3 * np.pi * (cfg["label_diameter"] / 2) ** 3 59 | self.num_classes = args.num_classes 60 | 61 | def forward(self, x): 62 | return self.model(x) 63 | 64 | def test_step(self, test_batch, batch_idx): 65 | with torch.no_grad(): 66 | img, label, index = test_batch 67 | index = torch.cat([i.view(1, -1) for i in index], dim=0).permute(1, 0) 68 | if args.use_paf: 69 | seg_output, paf_output, logsigma1 = self.forward(img) 70 | else: 71 | seg_output = self.forward(img) 72 | 73 | if args.test_use_pad: 74 | mp_num = int(sorted([int(i) for i in cfg["ocp_diameter"].split(',')])[-1] / (args.meanPool_kernel - 1) + 1) 75 | if args.num_classes > 1: 76 | return self._nms_v2(seg_output[:, 1:], kernel=args.meanPool_kernel, 77 | mp_num=mp_num, positions=index) 78 | else: 79 | return self._nms_v2(seg_output[:, :], kernel=args.meanPool_kernel, 80 | mp_num=mp_num, positions=index) 81 | 82 | def test_step_end(self, outputs): 83 | return outputs 84 | 85 | def test_epoch_end(self, epoch_output): 86 | with torch.no_grad(): 87 | if args.meanPool_NMS: 88 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy() 89 | print('coords_out:', coords_out.shape) 90 | if args.de_duplication: 91 | centroids = de_dup(coords_out, args) 92 | out_dir = '/'.join(args.checkpoints.split('/')[:-2]) + f'/{args.out_name}' 93 | os.makedirs(os.path.join(out_dir, 'Coords_withArea'), exist_ok=True) 94 | np.savetxt(os.path.join(out_dir, 'Coords_withArea', dir_name + '.coords'), 95 | centroids.astype(float), 96 | fmt='%s', 97 | delimiter='\t') 98 | 99 | coords = centroids[:, 0:4] 100 | os.makedirs(os.path.join(out_dir, 'Coords_All'), exist_ok=True) 101 | np.savetxt(os.path.join(out_dir, 'Coords_All', dir_name + '.coords'), 102 | coords.astype(int), 103 | fmt='%s', 104 | delimiter='\t') 105 | 106 | 107 | def test_dataloader(self): 108 | if args.test_mode == 'test': 109 | test_dataset = Dataset_ClsBased(mode='test', 110 | block_size=args.block_size, 111 | num_class=args.num_classes, 112 | random_num=args.random_num, 113 | use_bg=args.use_bg, 114 | data_split=args.data_split, 115 | test_use_pad=args.test_use_pad, 116 | pad_size=pad_size, 117 | cfg=cfg, 118 | args=args) 119 | test_dataloader = DataLoader(test_dataset, 120 | shuffle=False, 121 | batch_size=args.batch_size, 122 | num_workers=8 if args.batch_size >= 32 else 4, 123 | pin_memory=False) 124 | 125 | self.len_block = test_dataset.test_len 126 | self.data_shape = test_dataset.data_shape 127 | self.occupancy_map = test_dataset.occupancy_map 128 | self.gt_coords = test_dataset.gt_coords 129 | self.dir_name = test_dataset.dir_name 130 | return test_dataloader 131 | elif args.test_mode == 'test_only': 132 | test_dataset = Dataset_ClsBased(mode='test_only', 133 | block_size=args.block_size, 134 | num_class=args.num_classes, 135 | random_num=args.random_num, 136 | use_bg=args.use_bg, 137 | data_split=args.data_split, 138 | test_use_pad=args.test_use_pad, 139 | pad_size=pad_size, 140 | cfg=cfg, 141 | args=args) 142 | if args.batch_size <= 32: 143 | num_work = 4 144 | elif args.batch_size <= 64: 145 | num_work = 8 146 | elif args.batch_size <= 128: 147 | num_work = 8 148 | else: 149 | num_work = 16 150 | test_dataloader = DataLoader(test_dataset, 151 | shuffle=False, 152 | batch_size=args.batch_size, 153 | num_workers=num_work, 154 | pin_memory=False) 155 | self.len_block = test_dataset.test_len 156 | self.data_shape = test_dataset.data_shape 157 | self.dir_name = test_dataset.dir_name 158 | return test_dataloader 159 | 160 | def _nms_v2(self, pred, kernel=3, mp_num=5, positions=None): 161 | pred = torch.where(pred > 0.5, 1, 0) 162 | meanPool = nn.AvgPool3d(kernel, 1, kernel // 2).cuda() 163 | maxPool = nn.MaxPool3d(kernel, 1, kernel // 2).cuda() 164 | hmax = pred.clone().float() 165 | for _ in range(mp_num): 166 | hmax = meanPool(hmax) 167 | pred = hmax.clone() 168 | hmax = maxPool(hmax) 169 | keep = ((hmax == pred).float()) * ((pred > 0.1).float()) 170 | coords = keep.nonzero() # [N, 5] 171 | coords = coords[coords[:, 2] >= args.pad_size[0]] 172 | coords = coords[coords[:, 2] <= args.block_size - args.pad_size[0]] 173 | coords = coords[coords[:, 3] >= args.pad_size[0]] 174 | coords = coords[coords[:, 3] <= args.block_size - args.pad_size[0]] 175 | coords = coords[coords[:, 4] >= args.pad_size[0]] 176 | coords = coords[coords[:, 4] <= args.block_size - args.pad_size[0]] 177 | 178 | try: 179 | h_val = torch.cat( 180 | [hmax[item[0], item[1], item[2], item[3]:item[3] + 1, item[4]:item[4] + 1] for item in 181 | coords], dim=0) 182 | leftTop_coords = positions[coords[:, 0]] - (args.block_size // 2) - args.pad_size[0] 183 | coords[:, 2:5] = coords[:, 2:5] + leftTop_coords 184 | 185 | pred_final = torch.cat( 186 | [coords[:, 1:2] + 1, coords[:, 4:5], coords[:, 3:4], coords[:, 2:3], h_val], 187 | dim=1) 188 | 189 | return pred_final 190 | except: 191 | # print('haha') 192 | return torch.zeros([0, 5]).cuda() 193 | 194 | # load trained checkpoints to model 195 | model = UNetTest.load_from_checkpoint(args.checkpoints) 196 | 197 | # model = UNetTest().model 198 | model.eval() 199 | runner = Trainer(gpus=args.gpu_id, # 200 | accelerator='dp' 201 | ) 202 | os.makedirs(f'result/{dataset}/{model_name}/', exist_ok=True) 203 | 204 | runner.test(model=model) 205 | 206 | end_time = time.time() 207 | used_time = end_time - start_time 208 | save_path = '/'.join(args.checkpoints.split('/')[:-2]) + f'/{args.out_name}' 209 | os.makedirs(save_path, exist_ok=True) 210 | pad_size = args.pad_size[0] 211 | 212 | print('*' * 100) 213 | print('Testing Finished!') 214 | print('*' * 100) 215 | if stdout is not None: 216 | sys.stdout = save_stdout 217 | sys.stderr = save_stderr 218 | 219 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import sys 6 | from dataset.dataloader_DynamicLoad import Dataset_ClsBased 7 | from torch.utils.data import DataLoader 8 | from torchvision.utils import make_grid 9 | import pytorch_lightning as pl 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 12 | from pytorch_lightning import loggers 13 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 14 | from utils.loss import DiceLoss 15 | from utils.metrics import seg_metrics 16 | from utils.colors import COLORS 17 | from model.model_loader import get_model 18 | from utils.misc import combine, cal_metrics_NMS_OneCls, get_centroids, cal_metrics_MultiCls, combine_torch 19 | from sklearn.metrics import precision_recall_fscore_support 20 | import time 21 | import json 22 | 23 | if not sys.warnoptions: 24 | import warnings 25 | 26 | warnings.simplefilter("ignore") 27 | 28 | 29 | class UNetExperiment(pl.LightningModule): 30 | def __init__(self, args): 31 | if args.f_maps is None: 32 | args.f_maps = [32, 64, 128, 256] 33 | print(args.pad_size) 34 | 35 | if len(args.configs) > 0: 36 | with open(args.configs, 'r') as f: 37 | self.cfg = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 38 | else: 39 | self.cfg = {} 40 | if len(args.train_configs) > 0: 41 | with open(args.train_configs, 'r') as f: 42 | self.train_cfg = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 43 | else: 44 | self.train_cfg = self.cfg 45 | 46 | if len(args.val_configs) > 0: 47 | with open(args.val_configs, 'r') as f: 48 | self.val_config = json.loads(''.join(f.readlines()).lstrip('train_configs=')) 49 | else: 50 | self.val_cfg = self.cfg 51 | 52 | super(UNetExperiment, self).__init__() 53 | self.save_hyperparameters() 54 | self.model = get_model(args) 55 | print(self.model) 56 | 57 | if args.loss_func_seg == 'Dice': 58 | self.loss_function_seg = DiceLoss(args=args) 59 | 60 | if 'gaussian' in self.val_cfg["label_type"]: 61 | self.thresholds = np.linspace(0.15, 0.45, 7) 62 | elif 'sphere' in self.val_cfg["label_type"]: 63 | self.thresholds = np.linspace(0.2, 0.80, 13) 64 | self.partical_volume = 4 / 3 * np.pi * (self.val_cfg["label_diameter"] / 2) ** 3 65 | self.args = args 66 | 67 | def forward(self, x): 68 | return self.model(x) 69 | 70 | def training_step(self, train_batch, batch_idx): 71 | args = self.args 72 | img, label, index = train_batch 73 | img = img.to(torch.float32) 74 | seg_output = self.forward(img) 75 | if args.use_mask: 76 | mask = label.clone().detach() 77 | mask[mask > 0] = 1 78 | label[label < 255] = 0 79 | label[label > 0] = 1 80 | 81 | # update label and mask according to label-threshold 82 | label[seg_output > args.seg_tau] = 1 83 | mask[seg_output > args.seg_tau] = 1 84 | mask[seg_output < (1 - args.seg_tau)] = 1 85 | 86 | seg_output = seg_output * mask 87 | loss_seg = self.loss_function_seg(seg_output, label) 88 | self.log('train_loss', loss_seg, on_step=False, on_epoch=True) 89 | return loss_seg 90 | 91 | def validation_step(self, val_batch, batch_idx): 92 | args = self.args 93 | with torch.no_grad(): 94 | img, label, index = val_batch 95 | index = torch.cat([i.view(1, -1) for i in index], dim=0).permute(1, 0) 96 | img = img.to(torch.float32) 97 | self.seg_output = self.forward(img) 98 | 99 | if (batch_idx >= self.len_block // args.batch_size and args.test_mode == "test_val") or \ 100 | args.test_mode == "test" or args.test_mode == "val" or args.test_mode == "val_v1": 101 | loss_seg = self.loss_function_seg(self.seg_output, label) 102 | 103 | precision, recall, f1_score, iou = seg_metrics(self.seg_output, label, threshold=args.threshold) 104 | 105 | self.log('val_loss', loss_seg, on_step=False, on_epoch=True) 106 | self.log('val_precision', precision, on_step=False, on_epoch=True) 107 | self.log('val_recall', recall, on_step=False, on_epoch=True) 108 | self.log('val_f1', f1_score, on_step=False, on_epoch=True) 109 | self.log('val_iou', iou, on_step=False, on_epoch=True) 110 | 111 | # return loss_seg 112 | tensorboard = self.logger.experiment 113 | 114 | if (batch_idx == (self.len_block // args.batch_size + 1) and args.test_mode == 'test_val') or \ 115 | (batch_idx == 0 and args.test_mode == 'test') or \ 116 | (batch_idx == 0 and args.test_mode == 'val') or \ 117 | ( 118 | batch_idx == 0 and args.test_mode == 'val_v1'): # and True == False and self.current_epoch % 1 == 0 119 | img /= img.abs().max() # [-1,1] 120 | img = img * 0.5 + 0.5 # [0, 1] 121 | img_ = img[0, :, 0:(args.block_size - 1):5, :, :].permute(1, 0, 2, 3).repeat( 122 | (1, 3, 1, 1)) # sample0: [5, 3, y, x] 123 | 124 | label_ = label[0, :, 0:(args.block_size - 1):5, :, :] # sample0 [15, y, x] 125 | temp = torch.zeros( 126 | (len(np.arange(0, args.block_size - 1, 5)), args.block_size, args.block_size, 3)).float() 127 | # print(label.shape, temp.shape) 128 | for idx in np.arange(label_.shape[0]): 129 | temp[label_[idx] > 0.5] = torch.tensor( 130 | COLORS[(idx + 1) if (args.num_classes == 1 131 | or args.use_paf or 132 | label_.shape[0] == 1) else idx]).float() 133 | label__ = temp.permute(0, 3, 1, 2).contiguous().cuda() # [15, 3, y, x] 134 | 135 | seg_output_ = self.seg_output[0, :, 0:(args.block_size - 1):5, :, :] # sample0 [15, y, x] 136 | seg_threshes = [0.5, 0.3, 0.2, 0.15, 0.1, 0.05] 137 | seg_preds = [] 138 | for thresh in seg_threshes: 139 | temp = torch.zeros( 140 | (len(np.arange(0, args.block_size - 1, 5)), args.block_size, args.block_size, 3)).float() 141 | for idx in np.arange(seg_output_.shape[0]): 142 | temp[seg_output_[idx] > thresh] = torch.tensor( 143 | COLORS[(idx + 1) if (args.num_classes == 1 144 | or args.use_paf or 145 | seg_output_.shape[0] == 1) else idx]).float() 146 | seg_preds.append(temp.permute(0, 3, 1, 2).contiguous().cuda()) # [15, 3, y, x] 147 | 148 | seg_preds = torch.cat(seg_preds, dim=0) 149 | 150 | img_label_seg = torch.cat([img_, label__, seg_preds], dim=0) 151 | img_label_seg = make_grid(img_label_seg, (args.block_size - 1) // 5 + 1, padding=2, pad_value=120) 152 | 153 | tensorboard.add_image('img_label_seg', img_label_seg, self.current_epoch, dataformats="CHW") 154 | 155 | if args.num_classes > 1: 156 | return self._nms_v2(self.seg_output[:, 1:], kernel=args.meanPool_kernel, mp_num=6, positions=index) 157 | else: 158 | return self._nms_v2(self.seg_output[:, :], kernel=args.meanPool_kernel, mp_num=6, positions=index) 159 | 160 | def validation_step_end(self, outputs): 161 | args = self.args 162 | if 'test' in args.test_mode: 163 | return outputs 164 | 165 | def validation_epoch_end(self, epoch_output): 166 | args = self.args 167 | with torch.no_grad(): 168 | if 'test' in args.test_mode: 169 | if args.meanPool_NMS: 170 | if args.num_classes == 1: 171 | # coords_out: [N, 5] 172 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy() 173 | if coords_out.shape[0] > 50000: 174 | loc_p, loc_r, loc_f1, avg_dist = 1e-10, 1e-10, 1e-10, 100 175 | else: 176 | loc_p, loc_r, loc_f1, avg_dist = \ 177 | cal_metrics_NMS_OneCls(coords_out, 178 | self.gt_coords, 179 | self.occupancy_map, 180 | self.cfg, 181 | ) 182 | print("*" * 100) 183 | print(f"Precision:{loc_p}") 184 | print(f"Recall:{loc_r}") 185 | print(f"F1-score:{loc_f1}") 186 | print(f"Avg-dist:{avg_dist}") 187 | print("*" * 100) 188 | self.log('cls_precision', loc_p, on_step=False, on_epoch=True) 189 | self.log('cls_recall', loc_r, on_step=False, on_epoch=True) 190 | self.log('cls_f1', loc_f1, on_step=False, on_epoch=True) 191 | self.log('cls_dist', avg_dist, on_step=False, on_epoch=True) 192 | pr = (loc_p * (loc_r ** args.prf1_alpha)) / (loc_p + (loc_r ** args.prf1_alpha) + 1e-10) 193 | self.log(f'cls_pr_alpha{args.prf1_alpha:.1f}', pr, on_step=False, on_epoch=True) 194 | time.sleep(0.5) 195 | else: 196 | coords_out = torch.cat(epoch_output, dim=0).detach().cpu().numpy() 197 | loc_p, loc_r, loc_f1, loc_miss, avg_dist, gt_classes, pred_classes, self.num2pdb, cls_f1 = \ 198 | cal_metrics_MultiCls(coords_out, self.gt_coords, self.occupancy_map, self.cfg, args, 199 | args.pad_size, self.dir_name, self.partical_volume) 200 | self.log('cls_f1', cls_f1, on_step=False, on_epoch=True) 201 | 202 | def train_dataloader(self): 203 | args = self.args 204 | train_dataset = Dataset_ClsBased(mode=args.train_mode, 205 | block_size=args.block_size, 206 | num_class=args.num_classes, 207 | random_num=args.random_num, 208 | use_bg=args.use_bg, 209 | data_split=args.data_split, 210 | use_paf=args.use_paf, 211 | cfg=self.train_cfg, 212 | args=args) 213 | return DataLoader(train_dataset, 214 | batch_size=args.batch_size, 215 | num_workers=8 if args.batch_size >= 32 else 4, 216 | shuffle=True, 217 | pin_memory=False) 218 | 219 | def val_dataloader(self): 220 | args = self.args 221 | val_dataset = Dataset_ClsBased(mode=args.test_mode, 222 | block_size=args.val_block_size, 223 | num_class=args.num_classes, 224 | random_num=args.random_num, 225 | use_bg=args.use_bg, 226 | data_split=args.data_split, 227 | test_use_pad=args.test_use_pad, 228 | pad_size=args.pad_size, 229 | use_paf=args.use_paf, 230 | cfg=self.val_cfg, 231 | args=args) 232 | 233 | self.len_block = val_dataset.test_len 234 | if 'test' in args.test_mode: 235 | self.data_shape = val_dataset.data_shape 236 | self.occupancy_map = val_dataset.occupancy_map 237 | self.gt_coords = val_dataset.gt_coords 238 | self.dir_name = val_dataset.dir_name 239 | 240 | val_dataloader1 = DataLoader(val_dataset, 241 | batch_size=args.val_batch_size, 242 | num_workers=8 if args.batch_size >= 32 else 4, 243 | shuffle=False, 244 | pin_memory=False) 245 | return val_dataloader1 246 | 247 | def _nms_v2(self, pred, kernel=3, mp_num=5, positions=None): 248 | args = self.args 249 | pred = torch.where(pred > 0.5, 1, 0) 250 | meanPool = nn.AvgPool3d(kernel, 1, kernel // 2).cuda() 251 | maxPool = nn.MaxPool3d(kernel, 1, kernel // 2).cuda() 252 | hmax = pred.clone().float() 253 | for _ in range(mp_num): 254 | hmax = meanPool(hmax) 255 | pred = hmax.clone() 256 | hmax = maxPool(hmax) 257 | keep = ((hmax == pred).float()) * ((pred > 0.1).float()) 258 | coords = keep.nonzero() # [N, 5] 259 | if coords.shape[0] > 2000: 260 | return torch.zeros([1, 5]).cuda() 261 | coords = coords[coords[:, 2] >= args.pad_size] 262 | coords = coords[coords[:, 2] < args.block_size - args.pad_size] 263 | coords = coords[coords[:, 3] >= args.pad_size] 264 | coords = coords[coords[:, 3] < args.block_size - args.pad_size] 265 | coords = coords[coords[:, 4] >= args.pad_size] 266 | coords = coords[coords[:, 4] < args.block_size - args.pad_size] 267 | 268 | try: 269 | h_val = torch.cat( 270 | [hmax[item[0], item[1], item[2], item[3]:item[3] + 1, item[4]:item[4] + 1] for item in 271 | coords], dim=0) 272 | leftTop_coords = positions[coords[:, 0]] - (args.block_size // 2) - args.pad_size 273 | coords[:, 2:5] = coords[:, 2:5] + leftTop_coords 274 | 275 | pred_final = torch.cat( 276 | [coords[:, 1:2] + 1, coords[:, 4:5], coords[:, 3:4], coords[:, 2:3], h_val], 277 | dim=1) 278 | 279 | return pred_final 280 | except: 281 | return torch.zeros([0, 5]).cuda() 282 | 283 | def configure_optimizers(self): 284 | args = self.args 285 | if args.optim == 'SGD': 286 | optimizer = torch.optim.SGD(self.parameters(), 287 | lr=args.learning_rate, 288 | momentum=0.9, weight_decay=0.001 289 | ) 290 | elif args.optim == 'Adam': 291 | optimizer = torch.optim.Adam(self.parameters(), 292 | lr=args.learning_rate, 293 | betas=(0.9, 0.99) 294 | ) 295 | elif args.optim == 'AdamW': 296 | optimizer = torch.optim.AdamW(self.parameters(), 297 | lr=args.learning_rate, 298 | betas=(0.9, 0.99), 299 | weight_decay=args.weight_decay 300 | ) 301 | 302 | if args.scheduler == 'OneCycleLR': 303 | sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, 304 | max_lr=args.learning_rate, 305 | total_steps=args.max_epoch, 306 | pct_start=0.1, 307 | anneal_strategy='cos', 308 | div_factor=30, 309 | final_div_factor=100) 310 | lr_dict = { 311 | "scheduler": sched, 312 | "interval": "epoch", 313 | "frequency": 1 314 | } 315 | 316 | if args.scheduler is None: 317 | return [optimizer] 318 | else: 319 | return [optimizer], [lr_dict] 320 | 321 | 322 | def train_func(args, stdout=None): 323 | if stdout is not None: 324 | save_stdout = sys.stdout 325 | save_stderr = sys.stderr 326 | sys.stdout = stdout 327 | sys.stderr = stdout 328 | 329 | args.pad_size = args.pad_size[0] 330 | if 'test' in args.test_mode: 331 | checkpoint_callback = ModelCheckpoint(save_top_k=1, 332 | monitor=f'cls_pr_alpha{args.prf1_alpha:.1f}' if args.num_classes == 1 else 'cls_f1', 333 | mode='max') 334 | else: 335 | checkpoint_callback = ModelCheckpoint(save_top_k=1, 336 | monitor='val_loss', 337 | mode='min') 338 | 339 | model = UNetExperiment(args) 340 | logger_name = "{}_{}_BlockSize{}_{}Loss_MaxEpoch{}_bs{}_lr{}_IP{}_bg{}_coord{}_Softmax{}_{}_{}_TN{}".format( 341 | model.train_cfg["dset_name"], args.network, args.block_size, args.loss_func_seg, args.max_epoch, 342 | args.batch_size, 343 | args.learning_rate, 344 | int(args.use_IP), int(args.use_bg), int(args.use_coord), 345 | int(args.use_softmax), args.norm, args.others, args.sel_train_num) 346 | 347 | os.makedirs(f"{model.train_cfg['base_path']}/runs/{model.train_cfg['dset_name']}", exist_ok=True) 348 | tb_logger = loggers.TensorBoardLogger(f"{model.train_cfg['base_path']}/runs/{model.train_cfg['dset_name']}", 349 | name=logger_name) 350 | lr_monitor = LearningRateMonitor(logging_interval='step') 351 | 352 | runner = Trainer(min_epochs=min(50, args.max_epoch), 353 | max_epochs=args.max_epoch, 354 | logger=tb_logger, 355 | gpus=args.gpu_id, 356 | checkpoint_callback=checkpoint_callback, 357 | callbacks=[lr_monitor], 358 | accelerator='dp', 359 | precision=32, 360 | profiler=True, 361 | sync_batchnorm=True, 362 | resume_from_checkpoint=args.checkpoints) 363 | 364 | 365 | try: 366 | runner.fit(model) 367 | print('*' * 100) 368 | print('Training Finished') 369 | print(f'Training pid:{os.getpid()}') 370 | print('*' * 100) 371 | torch.cuda.empty_cache() 372 | if stdout is not None: 373 | sys.stderr = save_stderr 374 | sys.stdout = save_stdout 375 | return os.getpid() 376 | except: 377 | torch.cuda.empty_cache() 378 | if stdout is not None: 379 | stdout.flush() 380 | stdout.write('Training Exception!') 381 | sys.stderr = save_stderr 382 | sys.stdout = save_stdout 383 | return os.getpid() -------------------------------------------------------------------------------- /tutorials/A_tutorial_of_particlePicking_on_EMPIAR10045_dataset.md: -------------------------------------------------------------------------------- 1 | # **A tutorial of single-class particle picking on EMPIAR-10045 dataset** 2 | ## **Step 1. Preprocessing** 3 | 4 | - Data preparation 5 | 6 | 7 | The sample dataset of EMPIAR-10045 can be download in one of two ways: 8 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi 9 | - Microsoft onedrive Link: [https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew](https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew); verification code: cbmi 10 | 11 | - Data structure 12 | 13 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain: 14 | - tomograms(with extension .mrc or .rec) 15 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand). 16 | 17 |
18 | The data should be organized as follows: 19 | ``` 20 | ├── /base/path 21 | │   ├── raw_data 22 | │   │   ├── IS002_291013_005_iconmask2_norm_rot_cutZ.coords 23 | │   │   └── IS002_291013_005_iconmask2_norm_rot_cutZ.mrc 24 | │   │   └── IS002_291013_006_iconmask2_norm_rot_cutZ.mrc 25 | │   │   └── IS002_291013_007_iconmask2_norm_rot_cutZ.mrc 26 | │   │   └── IS002_291013_008_iconmask2_norm_rot_cutZ.mrc 27 | │   │   └── IS002_291013_009_iconmask2_norm_rot_cutZ.mrc 28 | │   │   └── IS002_291013_010_iconmask2_norm_rot_cutZ.mrc 29 | │   │   └── IS002_291013_011_iconmask2_norm_rot_cutZ.mrc 30 | ``` 31 | 32 | For above data, `IS002_291013_005_iconmask2_norm_rot_cutZ.mrc` can be used as train/val dataset, since they all have coordinate files (`IS002_291013_005_iconmask2_norm_rot_cutZ.coords`). If a tomogram (e.g. `IS002_291013_006_iconmask2_norm_rot_cutZ.mrc`) has no matual annotation (`IS002_291013_006_iconmask2_norm_rot_cutZ.coords`), it cannot be used as train/val datasets. 33 | 34 | - Input & Output 35 | 36 |
37 | 38 |
39 | 40 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows: 41 | - `input` 42 | - Dataset name: e.g. EMPIAR_10045_preprocess 43 | - Base path: path to base folder 44 | - Coords path: path to raw_data folder 45 | - Coords format: .csv, .coords or .txt 46 | - Tomogram path: path to raw_data folder 47 | - Tomogram format: .mrc or .rec 48 | - Number of classes: multiple classes of macromolecules also can be localized separately 49 | - `Output` 50 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter. 51 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma. 52 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page 53 | 54 | 55 | ### **Step 2. Training of DeepETPicker** 56 | 57 |
58 | 59 | Note: Before `Training of DeepETPicker`, please do `Step 1. Preprocessing` first to ensure that the basic parameters required for training are provided. 60 | 61 |
62 | 63 |
64 | 65 | In practice, default parameters can give you a good enough result. 66 | 67 | *Training parameter description:* 68 | 69 | - Dataset name: e.g. EMPIAR_10045_train 70 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}. 71 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids: 72 | - 0,2, ...: different tomogram ids are separated with a comma. 73 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids. 74 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset. 75 | - Number of classes: particle classes you want to pick 76 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error. 77 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72. 78 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12. 79 | - Learning rate:  the step size at each iteration while moving toward a minimum of a loss function. 80 | - Max epoch: total training epochs. The default value 60 is usually sufficient. 81 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 82 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again. 83 | 84 | ### **Step 3. Inference of DeepETPicker** 85 | 86 |
87 | 88 |
89 | 90 | In practice, default parameters can give you a good enough result. 91 | 92 | *Inference parameter description:* 93 | 94 | - Train Configs: path to the configuration file which has been saved in the `Training` step 95 | - Networks weights: path to the model which has be generated in the `Training` step 96 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where `S = N - 2*Pad_size`. 97 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 98 | 99 | ### **Step 4. Particle visualization and mantual picking** 100 | 101 |
102 | 103 |
104 | 105 | - *Showing Tomogram* 106 | 107 | You can click the `Load tomogram` button on this page to load the tomogram. 108 | 109 | - *Showing Labels* 110 | 111 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI. 112 | 113 | - *Parameter Adjustment* 114 | 115 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided: 116 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects 117 | - Contrast: when choosing hist-equ, histogram equalization can be performed 118 | 119 | 120 | - *Position Slider* 121 | 122 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used. 123 | 124 | 125 | -------------------------------------------------------------------------------- /tutorials/A_tutorial_of_particlePicking_on_SHREC2021_dataset.md: -------------------------------------------------------------------------------- 1 | # **A tutorial of multiple-class particle picking on SHREC2021 dataset** 2 | ## **Step 1. Preprocessing** 3 | 4 | - Data preparation 5 | 6 | 7 | The sample dataset of SHREC_2021 can be download in one of two ways: 8 | - Baidu Netdisk Link: [https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw](https://pan.baidu.com/s/1aijM4IgGSRMwBvBk5XbBmw ); verification code: cbmi 9 | - Microsoft onedrive Link: [https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew](https://1drv.ms/u/s!AmcdnIXL3Vo4hWf05lhsQWZWWSV3?e=dCWvew); verification code: cbmi 10 | 11 | - Data structure 12 | 13 | Before launching the graphical user interface, we recommend creating a single folder to save inputs and outputs of DeepETpicker. Inside this base folder you should make a subfolder to store raw data. This raw_data folder should contain: 14 | - tomograms(with extension .mrc or .rec) 15 | - coordinates file with the same name as tomograms except for extension. (with extension *.csv, *.coords or *.txt. Generally, *.coords is recoommand). 16 | 17 |
18 | 19 | The complete data should be organized as follows: 20 | ``` 21 | ├── /base/path 22 | │   ├── raw_data 23 | │   │   ├── model_0.coords 24 | │   │   └── model_0.mrc 25 | │   │   ├── model_1.coords 26 | │   │   └── model_1.mrc 27 | │   │   └── model_2.coords 28 | │   │   └── model_2.mrc 29 | │   │   ├── model_3.coords 30 | │   │   └── model_3.mrc 31 | │   │   └── model_4.coords 32 | │   │   └── model_4.mrc 33 | │   │   ├── model_5.coords 34 | │   │   └── model_5.mrc 35 | │   │   └── model_6.coords 36 | │   │   └── model_6.mrc 37 | │   │   ├── model_7.coords 38 | │   │   └── model_7.mrc 39 | │   │   └── model_8.coords 40 | │   │   └── model_8.mrc 41 | │   │   └── model_9.mrc 42 | ``` 43 | 44 | For above data, `model_0.mrc` to `model_8.mrc` can be used as train/val dataset, since they all have coordinate files (`model_0.coords` to `model_8.coords`). If a tomogram (e.g. `model_9.mrc`) has no matual annotation (`model_9.coords`), it cannot be used as train/val datasets. 45 | 46 | 47 | - Input & Output 48 | 49 |
50 | 51 |
52 | 53 | Launch the graphical user interface of DeepETPicker. On the `Preprocessing` page, please set some key parameters as follows: 54 | - `input` 55 | - Dataset name: e.g. SHREC_2021_preprocess 56 | - Base path: path to base folder 57 | - Coords path: path to raw_data folder 58 | - Coords format: .csv, .coords or .txt 59 | - Tomogram path: path to raw_data folder 60 | - Tomogram format: .mrc or .rec 61 | - Number of classes: multiple classes of macromolecules also can be localized separately 62 | - `Output` 63 | - Label diameter(in voxels): the diameter of generated weak label, which is usually smaller than the average diameter of the particles. Empirically, you can set it as large as possible but should be smaller than the real diameter. 64 | - Ocp diameter(in voxels): the real diameter of the particles. Empirically, in order to obtain good selection results, we recommend that the particle size is adjusted to the range of 20~30 by binning operation. For particles of multi-classes, their diameters should be separated with a comma. 65 | - Configs: if you click 'Save configs', it would be the path to the file which contains all the parameters filled in this page 66 | 67 | 68 | ### **Step 2. Training of DeepETPicker** 69 | 70 |
71 | 72 | Note: Before `Training of DeepETPicker`, please do `Step 1. Preprocessing` first to ensure that the basic parameters required for training are provided. 73 | 74 |
75 | 76 |
77 | 78 | In practice, default parameters can give you a good enough result. 79 | 80 | *Training parameter description:* 81 | 82 | - Dataset name: e.g. SHREC_2021_train 83 | - Model name: the name of segmentation model. 84 | - Dataset list: get the list of train/val tomograms. The first column denotes particle number, the second column denotes tomogram name, the third column denotes tomogram ids. If you have n tomograms, the ids will be {0, 1, 2, ..., n-1}. 85 | - Train dataset ids: tomograms used for training. You can click `Dataset list` to obain the dataset ids firstly. One or multiple tomograms can be used as training tomograms. But make sure that the `traning dataset ids` are selected from {0, 1, 2, ..., n-1}, where n is the total number of tomograms obtained from `Dataset list`. Here, we provides two ways to set dataset ids: 86 | - 0,2, ...: different tomogram ids are separated with a comma. 87 | - 0-m: where the ids of {0, 1, 2, ..., m-1} will be selected. Note: this way only can be used for tomograms with continuous ids. 88 | - Val dataset ids: tomograms used for validation. You can click `Dataset list` to obain the dataset ids firstly. Note: only one tomogram can be selected as val dataset. 89 | - Number of classes: particle classes you want to pick 90 | - Batch size: a number of samples processed before the model is updated. It is determined by your GPU memory, reducing this parameter might be helpful if you encounter out of memory error. 91 | - Patch size: the sizes of subtomogram. It needs to be a multiple of 8. It is recommended that this value is not less than 64, and the default value is 72. 92 | - Padding size: a hyperparameter of overlap-tile strategy. Usually, it can be from 6 to 12, and the default value is 12. 93 | - Learning rate:  the step size at each iteration while moving toward a minimum of a loss function. 94 | - Max epoch: total training epochs. The default value 60 is usually sufficient. 95 | - GPU id: the GPUs used for training, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 96 | - Save Configs: save the configs listed above. The saved configurations contains all the parameters filled in this page, which can be directly loaded via *`Load configs`* next time instead of filling them again. 97 | 98 | ### **Step 3. Inference of DeepETPicker** 99 | 100 |
101 | 102 |
103 | 104 | In practice, default parameters can give you a good enough result. 105 | 106 | *Inference parameter description:* 107 | 108 | - Train Configs: path to the configuration file which has been saved in the `Training` step 109 | - Networks weights: path to the model which has be generated in the `Training` step 110 | - Patch size & Pad_size: tomogram is scanned with a specific stride S and a patch size of N in this stage, where `S = N - 2*Pad_size`. 111 | - GPU id: the GPUs used for inference, e.g. 0,1,2,3 denotes using GPUs of 0-4. You can run the following command to get the information of available GPUs: nvidia-smi. 112 | 113 | 114 | 115 | 116 | ### ** Step 4. Particle visualization and mantual picking** 117 | 118 |
119 | 120 |
121 | 122 | - *Showing Tomogram* 123 | 124 | You can click the `Load tomogram` button on this page to load the tomogram. 125 | 126 | - *Showing Labels* 127 | 128 | After loading the coordinates file by clicking `Load labels`, you can click `Show result` to visualize the label. The label's diameter and width also can be tuned on the GUI. 129 | 130 | - *Parameter Adjustment* 131 | 132 | In order to increase the visualization of particles, Gaussian filtering and histogram equalization are provided: 133 | - Filter: when choosing Gaussian, a Gaussian filter can be applied to the displayed tomogram, kernel_size and sigma(standard deviation) can be tuned to adjust the visual effects 134 | - Contrast: when choosing hist-equ, histogram equalization can be performed 135 | 136 | 137 | - *Position Slider* 138 | 139 | You can scan through the volume in x, y and z directions by changing their values. For z-axis scanning, shortcut keys of Up/Down arrow can be used. 140 | 141 | 142 | -------------------------------------------------------------------------------- /tutorials/images/EMPIAR_10045_GIF/Inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Inference.gif -------------------------------------------------------------------------------- /tutorials/images/EMPIAR_10045_GIF/Preprocessing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Preprocessing.gif -------------------------------------------------------------------------------- /tutorials/images/EMPIAR_10045_GIF/Training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Training.gif -------------------------------------------------------------------------------- /tutorials/images/EMPIAR_10045_GIF/Visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/EMPIAR_10045_GIF/Visualization.gif -------------------------------------------------------------------------------- /tutorials/images/Inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Inference.gif -------------------------------------------------------------------------------- /tutorials/images/Preprocessing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Preprocessing.gif -------------------------------------------------------------------------------- /tutorials/images/Training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Training.gif -------------------------------------------------------------------------------- /tutorials/images/Visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/DeepETPicker/0b3f8cb298128805a23cfd2909f81010006d2390/tutorials/images/Visualization.gif -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.colors import * 2 | from utils.coord_gen import * 3 | from utils.coords2labels import * 4 | from utils.loss import * 5 | from utils.metrics import * 6 | from utils.misc import * 7 | from utils.normalization import * 8 | from utils.coordconv_torch import * -------------------------------------------------------------------------------- /utils/colors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | 5 | COLORS = ( 6 | (255, 255, 255), # 0 7 | (244, 67, 54), # 1 8 | (233, 30, 99), # 2 9 | (156, 39, 176), # 3 10 | (103, 58, 183), # 4 11 | (63, 81, 181), # 5 12 | (33, 150, 243), # 6 13 | (3, 169, 244), # 7 14 | (0, 188, 212), # 8 15 | (0, 150, 136), # 9 16 | (76, 175, 80), # 10 17 | (139, 195, 74), # 11 18 | (205, 220, 57), # 12 19 | (255, 235, 59), # 13 20 | (255, 193, 7), # 14 21 | (255, 152, 0), # 15 22 | (255, 87, 34), 23 | (121, 85, 72), 24 | (158, 158, 158), 25 | (96, 125, 139)) 26 | 27 | 28 | def plot_legend(COLORS): 29 | plt.figure(figsize=(6, 0.8), dpi=100) 30 | for idx in np.arange(1, 13): 31 | plt.subplot(1, 12, idx) 32 | data = np.array(COLORS[idx] * 30000).reshape(100, 300, 3) 33 | plt.imshow(data) 34 | plt.axis('off') 35 | plt.title('%d' % idx, fontdict={'size': 15, 'weight': 'bold'}) 36 | plt.tight_layout() 37 | plt.savefig('temp.png') 38 | 39 | img = cv2.imread('temp.png') 40 | img_sum = np.sum(np.array(img), axis=2) 41 | img[img_sum == 765] = [0, 0, 0] 42 | img[img_sum == 0] = [255, 255, 255] 43 | cv2.imwrite('temp.png', img) -------------------------------------------------------------------------------- /utils/coordFormatConvert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def coords2star(data, save_path): 4 | """Relion star""" 5 | string = """ 6 | data_ 7 | 8 | loop_ 9 | _rlnCoordinateX #1 10 | _rlnCoordinateY #2 11 | _rlnCoordinateZ #3""" 12 | 13 | data = np.round(data, 1).astype(str) 14 | 15 | with open(save_path, 'w') as f: 16 | f.writelines(string + '\n') 17 | for item in data: 18 | line = ' '.join(item) 19 | f.write(line + '\n') 20 | 21 | 22 | def coords2box(data, save_path): 23 | """EMAN2 box""" 24 | with open(save_path, 'w') as f: 25 | for item in data: 26 | line = f"{item[0]:.1f}\t{item[1]:.1f}\t{item[2]:.0f}" 27 | f.write(line + '\n') 28 | 29 | def coords2coords(data, save_path): 30 | """EMAN2 box""" 31 | with open(save_path, 'w') as f: 32 | for item in data: 33 | line = f"{item[0]:.0f}\t{item[1]:.0f}\t{item[2]:.0f}" 34 | f.write(line + '\n') 35 | -------------------------------------------------------------------------------- /utils/coord_gen.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import glob 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | 8 | warnings.simplefilter('ignore') 9 | base_dir = "/ldap_shared/synology_shared/IBP_ribosome/liver_NewConstruct/Pick/reconstruction_2400" 10 | 11 | def coords_gen(coord_path, coord_format, base_dir): 12 | os.makedirs(os.path.join(base_dir, 'coords'), exist_ok=True) 13 | 14 | coords_list = [i.split('/')[-1] for i in glob.glob(coord_path + f'/*{coord_format}')] 15 | coords_list = sorted(coords_list) 16 | 17 | num_all = [] 18 | dir_names = [] 19 | for dir in coords_list: 20 | data = [] 21 | with open(os.path.join(coord_path, dir), 'r') as f: 22 | for idx, item in enumerate(f): 23 | data.append(item.rstrip('\n').split()) 24 | try: 25 | data = np.array(data).astype(np.float).astype(int) 26 | except: 27 | data = np.array(data).astype(np.float32).astype(int) 28 | 29 | np.savetxt(os.path.join(base_dir, "coords", dir), 30 | data, delimiter='\t', newline="\n", fmt="%s") 31 | num_all.append(data.shape[0]) 32 | 33 | dir_name = dir[:-len(coord_format)] 34 | dir_names.append(dir_name) 35 | 36 | str_ = "|" 37 | for i in np.arange(0, len(num_all), 1): 38 | tmp = np.array(num_all)[:i+1].sum() 39 | print("0 to %d:" % (i+1), tmp) 40 | str_ += "%d|" % tmp 41 | # print(str_) 42 | # print(list(enumerate(num_all))) 43 | 44 | # gen num_name.csv 45 | num_name = np.array([num_all]).transpose() 46 | try: 47 | df = pd.DataFrame(num_name).astype(np.float) 48 | except: 49 | df = pd.DataFrame(num_name).astype(np.float32) 50 | df['dir_names'] = np.array([dir_names]).transpose() 51 | df['idx'] = np.arange(len(dir_names)).reshape(-1, 1) 52 | df.to_csv(os.path.join(base_dir, "coords", "num_name.csv"), sep='\t', header=False, index=False) 53 | 54 | 55 | def coords_gen_show(args): 56 | coord_path, coord_format, base_dir, stdout = args 57 | if stdout is not None: 58 | save_stdout = sys.stdout 59 | save_stderr = sys.stderr 60 | sys.stdout = stdout 61 | sys.stderr = stdout 62 | 63 | try: 64 | coords_gen(coord_path, coord_format, base_dir) 65 | print('Coord generation finished!') 66 | print('*' * 100) 67 | except: 68 | stdout.flush() 69 | stdout.write('Coordinates Generation Exception!') 70 | print('*' * 100) 71 | return 0 72 | 73 | if stdout is not None: 74 | sys.stderr = save_stderr 75 | sys.stdout = save_stdout -------------------------------------------------------------------------------- /utils/coordconv_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.modules.conv as conv 4 | 5 | 6 | class AddCoords(nn.Module): 7 | def __init__(self, rank, with_r=False): 8 | super(AddCoords, self).__init__() 9 | self.rank = rank 10 | self.with_r = with_r 11 | 12 | def forward(self, input_tensor): 13 | """ 14 | :param input_tensor: shape (N, C_in, H, W) 15 | :return: 16 | """ 17 | if self.rank == 1: 18 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 19 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device) 20 | xx_channel = x_range.expand([batch_size_shape, 1, -1]) 21 | 22 | out = torch.cat([input_tensor, xx_channel], dim=1) 23 | 24 | if self.with_r: 25 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 26 | out = torch.cat([out, rr], dim=1) 27 | 28 | elif self.rank == 2: 29 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape 30 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device) 31 | y_range = torch.linspace(-1, 1, dim_y, device=input_tensor.device) 32 | yy_channel, xx_channel = torch.meshgrid(y_range, x_range) 33 | yy_channel = yy_channel.expand([batch_size_shape, 1, -1, -1]) 34 | xx_channel = xx_channel.expand([batch_size_shape, 1, -1, -1]) 35 | 36 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 40 | out = torch.cat([out, rr], dim=1) 41 | 42 | elif self.rank == 3: 43 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape 44 | x_range = torch.linspace(-1, 1, dim_x, device=input_tensor.device) 45 | y_range = torch.linspace(-1, 1, dim_y, device=input_tensor.device) 46 | z_range = torch.linspace(-1, 1, dim_z, device=input_tensor.device) 47 | zz_channel, yy_channel, xx_channel = torch.meshgrid(z_range, y_range, x_range) 48 | zz_channel = zz_channel.expand([batch_size_shape, 1, -1, -1, -1]) 49 | yy_channel = yy_channel.expand([batch_size_shape, 1, -1, -1, -1]) 50 | xx_channel = xx_channel.expand([batch_size_shape, 1, -1, -1, -1]) 51 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) 52 | 53 | if self.with_r: 54 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + 55 | torch.pow(yy_channel - 0.5, 2) + 56 | torch.pow(zz_channel - 0.5, 2)) 57 | out = torch.cat([out, rr], dim=1) 58 | else: 59 | raise NotImplementedError 60 | 61 | return out 62 | 63 | 64 | class CoordConv1d(conv.Conv1d): 65 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 66 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 67 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size, 68 | stride, padding, dilation, groups, bias) 69 | self.rank = 1 70 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 71 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels, 72 | kernel_size, stride, padding, dilation, groups, bias) 73 | 74 | def forward(self, input_tensor): 75 | """ 76 | input_tensor_shape: (N, C_in,H,W) 77 | output_tensor_shape: N,C_out,H_out,W_out) 78 | :return: CoordConv2d Result 79 | """ 80 | out = self.addcoords(input_tensor) 81 | out = self.conv(out) 82 | 83 | return out 84 | 85 | 86 | class CoordConv2d(conv.Conv2d): 87 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 88 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 89 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size, 90 | stride, padding, dilation, groups, bias) 91 | self.rank = 2 92 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 93 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, 94 | kernel_size, stride, padding, dilation, groups, bias) 95 | 96 | def forward(self, input_tensor): 97 | """ 98 | input_tensor_shape: (N, C_in,H,W) 99 | output_tensor_shape: N,C_out,H_out,W_out) 100 | :return: CoordConv2d Result 101 | """ 102 | out = self.addcoords(input_tensor) 103 | out = self.conv(out) 104 | 105 | return out 106 | 107 | 108 | class CoordConv3d(conv.Conv3d): 109 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 110 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 111 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size, 112 | stride, padding, dilation, groups, bias) 113 | self.rank = 3 114 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 115 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels, 116 | kernel_size, stride, padding, dilation, groups, bias) 117 | 118 | def forward(self, input_tensor): 119 | """ 120 | input_tensor_shape: (N, C_in,H,W) 121 | output_tensor_shape: N,C_out,H_out,W_out) 122 | :return: CoordConv2d Result 123 | """ 124 | out = self.addcoords(input_tensor) 125 | out = self.conv(out) 126 | 127 | return out -------------------------------------------------------------------------------- /utils/coords2labels.py: -------------------------------------------------------------------------------- 1 | import mrcfile 2 | from multiprocessing import Pool 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | from glob import glob 7 | import sys 8 | import traceback 9 | 10 | def gaussian3D(shape, sigma=1): 11 | l, m, n = [(ss - 1.) / 2. for ss in shape] 12 | z, y, x = np.ogrid[-l:l + 1, -m:m + 1, -n:n + 1] 13 | sigma = (sigma - 1.) / 2. 14 | h = np.exp(-(x * x + y * y + z * z) / (2 * sigma * sigma)) 15 | # h[h < np.finfo(float).eps * h.max()] = 0 16 | return h 17 | 18 | 19 | class Coord_to_Label(): 20 | def __init__(self, base_path, coord_path, coord_format, tomo_path, tomo_format, 21 | num_cls, label_type, label_diameter): 22 | 23 | self.base_path = base_path 24 | self.coord_path = coord_path 25 | self.coord_format = coord_format 26 | self.tomo_path = tomo_path 27 | self.tomo_format = tomo_format 28 | self.num_cls = num_cls 29 | self.label_type = label_type 30 | if not isinstance(label_diameter, int): 31 | self.label_diameter = [int(i) for i in label_diameter.split(',')] 32 | else: 33 | self.label_diameter = [label_diameter] 34 | 35 | if 'ocp' in self.label_type.lower(): 36 | self.label_path = os.path.join(self.base_path, self.label_type) 37 | else: 38 | self.label_path = os.path.join(self.base_path, 39 | self.label_type + str(self.label_diameter[0])) 40 | os.makedirs(self.label_path, exist_ok=True) 41 | 42 | self.dir_list = [i[:-len(self.coord_format)] for i in os.listdir(self.coord_path) if self.coord_format in i] 43 | self.names = [i + self.tomo_format for i in self.dir_list] # if self.tomo_format not in i else i 44 | 45 | def single_handle(self, i): 46 | self.tomo_file = f"{self.tomo_path}/{self.names[i]}" 47 | data_file = mrcfile.open(self.tomo_file, permissive=True) 48 | # print(os.path.join(self.label_path, self.names[i])) 49 | label_file = mrcfile.new(os.path.join(self.label_path, self.names[i]), 50 | overwrite=True) 51 | 52 | label_positions = pd.read_csv(os.path.join(self.base_path, 'coords', '%s.coords' % self.dir_list[i]), sep='\t', 53 | header=None).to_numpy() 54 | 55 | # template = np.fromfunction(lambda i, j, k: (i - r) * (i - r) + (j - r) * (j - r) + (k - r) * (k - r) <= r * r, 56 | # (2 * r + 1, 2 * r + 1, 2 * r + 1), dtype=int).astype(int) 57 | 58 | z_max, y_max, x_max = data_file.data.shape 59 | try: 60 | label_data = np.zeros(data_file.data.shape, dtype=np.float) 61 | except: 62 | label_data = np.zeros(data_file.data.shape, dtype=np.float32) 63 | 64 | for pos_idx, a_pos in enumerate(label_positions): 65 | if self.num_cls == 1 and len(a_pos) == 3: 66 | x, y, z = a_pos 67 | cls_idx_ = 1 68 | else: 69 | cls_idx_, x, y, z = a_pos 70 | 71 | if 'data_ocp' in self.label_type.lower(): 72 | dim = int(self.label_diameter[cls_idx_ - 1]) 73 | else: 74 | dim = int(self.label_diameter[0]) 75 | radius = int(dim / 2) 76 | r = radius 77 | 78 | template = gaussian3D((dim, dim, dim), dim) 79 | 80 | cls_idx = pos_idx+1 if 'data_ocp' in self.label_type else cls_idx_ 81 | # print(self.label_type, dim, cls_idx) 82 | z_start = 0 if z - r < 0 else z - r 83 | z_end = z_max if z + r + 1 > z_max else z + r + 1 84 | y_start = 0 if y - r < 0 else y - r 85 | y_end = y_max if y + r + 1 > y_max else y + r + 1 86 | x_start = 0 if x - r < 0 else x - r 87 | x_end = x_max if x + r + 1 > x_max else x + r + 1 88 | 89 | t_z_start = r - z if z - r < 0 else 0 90 | t_z_end = (r + z_max - z) if z + r + 1 > z_max else 2 * r + 1 91 | t_y_start = r - y if y - r < 0 else 0 92 | t_y_end = (r + y_max - y) if y + r + 1 > y_max else 2 * r + 1 93 | t_x_start = r - x if x - r < 0 else 0 94 | t_x_end = (r + x_max - x) if x + r + 1 > x_max else 2 * r + 1 95 | 96 | # print(z_start, z_end, y_start, y_end, x_start, x_end) 97 | # check border 98 | # print(label_data.shape) 99 | # print(z_start, z_end, y_start, y_end, x_start, x_end) 100 | tmp1 = label_data[z_start:z_end, y_start:y_end, x_start:x_end] 101 | tmp2 = template[t_z_start:t_z_end, t_y_start:t_y_end, t_x_start:t_x_end] 102 | 103 | larger_index = tmp1 < tmp2 104 | tmp1[larger_index] = tmp2[larger_index] 105 | 106 | if 'cubic' in self.label_type.lower(): 107 | tg = 0.223 # exp(-1.5) 108 | elif 'sphere' in self.label_type.lower() or 'ocp' in self.label_type.lower(): 109 | tg = 0.60653 # exp(-0.5) 110 | else: 111 | tg = 0.367879 # exp(-1) 112 | tmp1[tmp1 <= tg] = 0 113 | tmp1 = np.where(tmp1 > 0, cls_idx, 0) 114 | label_data[z_start:z_end, y_start:y_end, x_start:x_end] = tmp1 115 | 116 | label_file.set_data(label_data) 117 | 118 | data_file.close() 119 | label_file.close() 120 | # print('work %s done' % i) 121 | # return 'work %s done' % i 122 | 123 | def gen_labels(self): 124 | if len(self.dir_list) == 1: 125 | self.single_handle(0) 126 | else: 127 | with Pool(len(self.dir_list)) as p: 128 | p.map(self.single_handle, np.arange(len(self.dir_list)).tolist()) 129 | 130 | 131 | def label_gen_show(args): 132 | base_path, coord_path, coord_format, tomo_path, tomo_format, \ 133 | num_cls, label_type, label_diameter, stdout = args 134 | if stdout is not None: 135 | save_stdout = sys.stdout 136 | save_stderr = sys.stderr 137 | sys.stdout = stdout 138 | sys.stderr = stdout 139 | 140 | try: 141 | label_gen = Coord_to_Label(base_path, 142 | coord_path, 143 | coord_format, 144 | tomo_path, 145 | tomo_format, 146 | num_cls, 147 | label_type, 148 | label_diameter) 149 | label_gen.gen_labels() 150 | if 'ocp' not in label_type: 151 | print('Label generation finished!') 152 | print('*' * 100) 153 | else: 154 | print('Occupancy generation finished!') 155 | print('*' * 100) 156 | except Exception as ex: 157 | term = 'Occupancy' if 'ocp' in label_type else 'Label' 158 | if stdout is not None: 159 | stdout.flush() 160 | stdout.write(f"{ex}") 161 | stdout.write(f'{term} Generation Exception!') 162 | print('*' * 100) 163 | else: 164 | traceback.print_exc() 165 | #print(f"{ex}") 166 | print(f'{term} Generation Exception!') 167 | print('*' * 100) 168 | return 0 169 | if stdout is not None: 170 | sys.stderr = save_stderr 171 | sys.stdout = save_stdout 172 | 173 | 174 | class Coord_to_Label_v1(): 175 | def __init__(self, tomo_file, coord_file, num_cls, label_diameter, label_type): 176 | 177 | self.tomo_file = tomo_file 178 | self.coord_file = coord_file 179 | self.num_cls = num_cls 180 | self.label_type = label_type 181 | self.label_diameter = label_diameter 182 | 183 | def gen_labels(self): 184 | if '.coords' in self.coord_file or '.txt' in self.coord_file: 185 | data_file = mrcfile.open(self.tomo_file, permissive=True) 186 | 187 | label_positions = pd.read_csv(self.coord_file, sep='\t', header=None).to_numpy() 188 | if self.label_type == 'Coords': 189 | return label_positions 190 | 191 | dim = int(self.label_diameter) 192 | radius = int(dim / 2) 193 | r = radius 194 | 195 | template = gaussian3D((dim, dim, dim), dim) 196 | 197 | z_max, y_max, x_max = data_file.data.shape 198 | try: 199 | label_data = np.zeros(data_file.data.shape, dtype=np.float) 200 | except: 201 | label_data = np.zeros(data_file.data.shape, dtype=np.float32) 202 | 203 | for pos_idx, a_pos in enumerate(label_positions): 204 | if self.num_cls == 1 and len(a_pos) == 3: 205 | x, y, z = a_pos 206 | cls_idx = 1 207 | else: 208 | cls_idx, x, y, z = a_pos 209 | cls_idx = pos_idx+1 if 'ocp' in self.label_type else cls_idx 210 | z_start = 0 if z - r < 0 else z - r 211 | z_end = z_max if z + r + 1 > z_max else z + r + 1 212 | y_start = 0 if y - r < 0 else y - r 213 | y_end = y_max if y + r + 1 > y_max else y + r + 1 214 | x_start = 0 if x - r < 0 else x - r 215 | x_end = x_max if x + r + 1 > x_max else x + r + 1 216 | 217 | t_z_start = r - z if z - r < 0 else 0 218 | t_z_end = (r + z_max - z) if z + r + 1 > z_max else 2 * r + 1 219 | t_y_start = r - y if y - r < 0 else 0 220 | t_y_end = (r + y_max - y) if y + r + 1 > y_max else 2 * r + 1 221 | t_x_start = r - x if x - r < 0 else 0 222 | t_x_end = (r + x_max - x) if x + r + 1 > x_max else 2 * r + 1 223 | 224 | # print(z_start, z_end, y_start, y_end, x_start, x_end) 225 | # check border 226 | tmp1 = label_data[z_start:z_end, y_start:y_end, x_start:x_end] 227 | tmp2 = template[t_z_start:t_z_end, t_y_start:t_y_end, t_x_start:t_x_end] 228 | 229 | larger_index = tmp1 < tmp2 230 | tmp1[larger_index] = tmp2[larger_index] 231 | tmp1[tmp1 <= 0.36788] = 0 232 | 233 | tmp1 = np.where(tmp1 > 0, cls_idx, 0) 234 | 235 | label_data[z_start:z_end, y_start:y_end, x_start:x_end] = tmp1 236 | 237 | data_file.close() 238 | return label_data 239 | elif '.mrc' in self.tomo_file or '.rec' in self.tomo_file: 240 | label_data = mrcfile.open(self.coord_file, permissive=True) 241 | return label_data.data 242 | 243 | if __name__ == "__main__": 244 | import sys 245 | sys.path.append("..") 246 | from configs.c2l_10045_New_bin8_mask3 import pre_config 247 | from utils.coord_gen import coords_gen 248 | from utils.normalization import InputNorm 249 | 250 | # 初始化坐标文件为整数 251 | coords_gen(pre_config["coord_path"], 252 | pre_config["base_path"]) 253 | 254 | # 归一化 255 | pre_norm = InputNorm(pre_config["tomo_path"], 256 | pre_config["tomo_format"], 257 | pre_config["base_path"], 258 | pre_config["norm_type"]) 259 | pre_norm.handle_parallel() 260 | 261 | # 根据coords产生labels 262 | c2l = Coord_to_Label(base_path=pre_config["base_path"], 263 | coord_path=pre_config["coord_path"], 264 | coord_format=pre_config["coord_format"], 265 | tomo_path=pre_config["tomo_path"], 266 | tomo_format=pre_config["tomo_format"], 267 | num_cls=pre_config["num_cls"], 268 | label_type=pre_config["label_type"], 269 | label_diameter=pre_config["label_diameter"], 270 | ) 271 | c2l.gen_labels() 272 | 273 | # 根据coords产生ocps 274 | if pre_config["label_diameter"] !=pre_config["ocp_diameter"]: 275 | c2l = Coord_to_Label(base_path=pre_config["base_path"], 276 | coord_path=pre_config["coord_path"], 277 | coord_format=pre_config["coord_format"], 278 | tomo_path=pre_config["tomo_path"], 279 | tomo_format=pre_config["tomo_format"], 280 | num_cls=pre_config["num_cls"], 281 | label_type=pre_config["ocp_type"], 282 | label_diameter=pre_config["ocp_diameter"], 283 | ) 284 | c2l.gen_labels() 285 | -------------------------------------------------------------------------------- /utils/coords_to_relion4.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | import pandas as pd 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | 7 | def c2w(input_dir="coord_path/Coords_All", 8 | ouput_dir="./output", 9 | ouput_name="coords2relion4.star"): 10 | input_dir = Path(input_dir) 11 | ouput_dir = Path(ouput_dir) 12 | ouput_dir.mkdir(exist_ok=True, parents=True) 13 | ouput_path = ouput_dir / ouput_name 14 | coords_paths = sorted(list(input_dir.glob("*.txt")) + list(input_dir.glob("*.coords"))) 15 | 16 | dfs = [] 17 | for coords_path in coords_paths: 18 | coords_data = np.loadtxt(coords_path) 19 | XYZ = coords_data[:, -3:] 20 | 21 | TomoName = np.array([str(coords_path).split('/')[-1].split('.')[0]] * coords_data.shape[0], dtype=str).reshape(-1, 1) 22 | TomoParticleId = np.arange(1, coords_data.shape[0]+1).reshape(-1, 1) 23 | originXYZang = np.zeros_like(XYZ) 24 | angle = np.zeros_like(XYZ) 25 | if coords_data.shape[1] == 4: 26 | ClassNumber = coords_data[:, 0].reshape(-1, 1) 27 | else: 28 | ClassNumber = np.ones_like(TomoParticleId) 29 | randomsubset = np.array([1, 2] * (coords_data.shape[0]+1 // 2)).reshape(-1, 1)[:coords_data.shape[0]] 30 | df = pd.DataFrame(np.c_[TomoName, TomoParticleId, XYZ.astype(np.int32), originXYZang, angle, ClassNumber.astype(np.int32), randomsubset]) 31 | dfs.append(df) 32 | dfs = pd.concat(dfs, axis=0) 33 | 34 | with StringIO() as buffer: 35 | dfs.to_csv(buffer, sep="\t", index=False, header=None) 36 | lines = buffer.getvalue() 37 | 38 | with open(ouput_path, "w") as ofile: 39 | ofile.write("relion4" + "\n") 40 | ofile.write("data_particles" + "\n") 41 | ofile.write("loop_" + "\n") 42 | ofile.write("_rlnTomoName #1" + "\n") 43 | ofile.write("_rlnTomoParticleId #2" + "\n") 44 | ofile.write("_rlnCoordinateX #3" + "\n") 45 | ofile.write("_rlnCoordinateY #4" + "\n") 46 | ofile.write("_rlnCoordinateZ #5" + "\n") 47 | ofile.write("_rlnOriginXAngst #6" + "\n") 48 | ofile.write("_rlnOriginYAngst #7" + "\n") 49 | ofile.write("_rlnOriginZAngst #8" + "\n") 50 | ofile.write("_rlnAngleRot #9" + "\n") 51 | ofile.write("_rlnAngleTilt #10" + "\n") 52 | ofile.write("_rlnAnglePsi #11" + "\n") 53 | ofile.write("_rlnClassNumber #12" + "\n") 54 | ofile.write("_rlnRandomSubset #13" +"\n") 55 | ofile.write(lines) 56 | print(f"Save: {ouput_path}") 57 | 58 | 59 | if __name__ == "__main__": 60 | c2w() 61 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | 4 | def flatten(tensor): 5 | """Flattens a given tensor such that the channel axis is first. 6 | The shapes are transformed as follows: 7 | (N, C, D, H, W) -> (C, N * D * H * W) 8 | """ 9 | # number of channels 10 | C = tensor.size(1) 11 | # new axis order 12 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 13 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 14 | transposed = tensor.permute(axis_order) 15 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 16 | return transposed.contiguous().view(C, -1) 17 | 18 | 19 | # Dice loss 20 | class DiceLoss(nn.Module): 21 | def __init__(self, smooth=1, args=None): 22 | super(DiceLoss, self).__init__() 23 | self.smooth = smooth 24 | self.use_softmax = args.use_softmax 25 | self.use_sigmoid = args.use_sigmoid 26 | 27 | 28 | def forward(self, outputs, targets): 29 | # flatten label and prediction tensors 30 | outputs = flatten(outputs) 31 | targets = flatten(targets) 32 | 33 | intersection = (outputs * targets).sum(-1) 34 | dice = (2. * intersection + self.smooth) / (outputs.sum(-1) + targets.sum(-1) + self.smooth) 35 | return 1 - dice.mean() 36 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def flatten(tensor): 6 | """Flattens a given tensor such that the channel axis is first. 7 | The shapes are transformed as follows: 8 | (N, C, D, H, W) -> (C, N * D * H * W) 9 | """ 10 | # number of channels 11 | C = tensor.size(1) 12 | # new axis order 13 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 14 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 15 | transposed = tensor.permute(axis_order) 16 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 17 | return transposed.contiguous().view(C, -1) 18 | 19 | 20 | # segmentation metrics 21 | def seg_metrics(y_pred, y_true, smooth=1e-7, isTrain=True, threshold=0.5, use_sigmoid=False): 22 | #comment out if your model contains a sigmoid or equivalent activation layer 23 | if use_sigmoid: 24 | y_pred = F.sigmoid(y_pred) 25 | y_pred = torch.where(y_pred < threshold, torch.zeros(1).cuda(), torch.ones(1).cuda()) 26 | 27 | #flatten label and prediction tensors 28 | y_pred = flatten(y_pred) 29 | y_true = flatten(y_true) 30 | 31 | tp = (y_true * y_pred).sum(-1) 32 | fp = ((1 - y_true) * y_pred).sum(-1) 33 | fn = (y_true * (1 - y_pred)).sum(-1) 34 | 35 | precision = (tp + smooth) / (tp + fp + smooth) 36 | recall = (tp + smooth) / (tp + fn + smooth) 37 | iou = (tp + smooth) / (tp + fn + fp + smooth) 38 | f1 = 2 * (precision*recall) / (precision + recall + smooth) 39 | 40 | mean_precision = precision.mean() 41 | mean_recall = recall.mean() 42 | mean_iou = iou.mean() 43 | mean_f1 = f1.mean() 44 | 45 | # for training, ouput mean metrics 46 | if isTrain: 47 | return mean_precision.item(), mean_recall.item(), mean_iou.item(), mean_f1.item() 48 | # for testing, output metrics array by class with threshold 49 | else: 50 | return precision.detach().cpu().numpy(), recall.detach().cpu().numpy(), iou.detach().cpu().numpy(), f1.detach().cpu().numpy() 51 | 52 | 53 | def seg_metrics_2d(y_pred, y_true, smooth=1e-7, isTrain=True, threshold=0.5, use_sigmoid=False): 54 | # comment out if your model contains a sigmoid or equivalent activation layer 55 | if use_sigmoid: 56 | y_pred = F.sigmoid(y_pred) 57 | y_pred = torch.where(y_pred < threshold, torch.zeros(1).cuda(), torch.ones(1).cuda()) 58 | 59 | # flatten label and prediction tensors 60 | y_pred = y_pred.flatten() 61 | y_true = y_true.flatten() 62 | 63 | tp = (y_true * y_pred).sum(-1) 64 | fp = ((1 - y_true) * y_pred).sum(-1) 65 | fn = (y_true * (1 - y_pred)).sum(-1) 66 | 67 | precision = (tp + smooth) / (tp + fp + smooth) 68 | recall = (tp + smooth) / (tp + fn + smooth) 69 | iou = (tp + smooth) / (tp + fn + fp + smooth) 70 | f1 = 2 * (precision * recall) / (precision + recall + smooth) 71 | 72 | # for training, ouput mean metrics 73 | if isTrain: 74 | return precision.item(), recall.item(), iou.item(), f1.item() 75 | # for testing, output metrics array by class with threshold 76 | else: 77 | return precision.detach().cpu().numpy(), recall.detach().cpu().numpy(), iou.detach().cpu().numpy(), f1.detach().cpu().numpy() 78 | -------------------------------------------------------------------------------- /utils/normalization.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import mrcfile 3 | import numpy as np 4 | import warnings 5 | import os 6 | import glob 7 | import sys 8 | 9 | warnings.simplefilter('ignore') 10 | class InputNorm(): 11 | def __init__(self, tomo_path, tomo_format, base_dir, norm_type): 12 | self.tomo_path = tomo_path 13 | self.tomo_format = tomo_format 14 | self.base_dir = base_dir 15 | self.norm_type = norm_type 16 | 17 | 18 | if self.norm_type == 'standardization': 19 | self.save_dir = os.path.join(self.base_dir, 'data_std') 20 | elif self.norm_type == 'normalization': 21 | self.save_dir = os.path.join(self.base_dir, 'data_norm') 22 | os.makedirs(self.save_dir, exist_ok=True) 23 | 24 | self.dir_list = [i.split('/')[-1] for i in glob.glob(self.tomo_path + '/*%s' % self.tomo_format)] 25 | print(self.dir_list) 26 | 27 | def single_handle(self, i): 28 | dir_name = self.dir_list[i] 29 | with mrcfile.open(os.path.join(self.tomo_path, dir_name), 30 | permissive=True) as gm: 31 | try: 32 | data = np.array(gm.data).astype(np.float) 33 | except: 34 | data = np.array(gm.data).astype(np.float32) 35 | # print(data.shape) 36 | if self.norm_type == 'standardization': 37 | data -= data.mean() 38 | data /= data.std() 39 | elif self.norm_type == 'normalization': 40 | data -= data.min() 41 | data /= (data.max() - data.min()) 42 | 43 | reconstruction_norm = mrcfile.new( 44 | os.path.join(self.save_dir, dir_name), overwrite=True) 45 | try: 46 | reconstruction_norm.set_data(data.astype(np.float32)) 47 | except: 48 | reconstruction_norm.set_data(data.astype(np.float)) 49 | 50 | reconstruction_norm.close() 51 | print('%d/%d finished.' % (i + 1, len(self.dir_list))) 52 | 53 | def handle_parallel(self): 54 | with Pool(len(self.dir_list)) as p: 55 | p.map(self.single_handle, np.arange(len(self.dir_list)).tolist()) 56 | 57 | 58 | def norm_show(args): 59 | tomo_path, tomo_format, base_dir, norm_type, stdout = args 60 | if stdout is not None: 61 | save_stdout = sys.stdout 62 | save_stderr = sys.stderr 63 | sys.stdout = stdout 64 | sys.stderr = stdout 65 | 66 | pre_norm = InputNorm(tomo_path, tomo_format, base_dir, norm_type) 67 | pre_norm.handle_parallel() 68 | print('Standardization finished!') 69 | print('*' * 100) 70 | """ 71 | try: 72 | pre_norm = InputNorm(tomo_path, tomo_format, base_dir, norm_type) 73 | pre_norm.handle_parallel() 74 | print('Standardization finished!') 75 | print('*' * 100) 76 | except: 77 | stdout.flush() 78 | stdout.write('Normalization Exception!') 79 | return 0 80 | """ 81 | if stdout is not None: 82 | sys.stderr = save_stderr 83 | sys.stdout = save_stdout -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import cv2 4 | import math 5 | import matplotlib.pyplot as plt 6 | from PyQt5.QtCore import QThread, QObject, pyqtSignal 7 | import threading, inspect 8 | import ctypes 9 | 10 | def rescale(arr): 11 | arr_min = arr.min() 12 | arr_max = arr.max() 13 | return (arr - arr_min) / (arr_max - arr_min) 14 | 15 | 16 | def hist_equ(tomo, min_val=0.01, max_val=0.99): 17 | fre_num, xx, _ = plt.hist(tomo.flatten(), bins=256, cumulative=True) 18 | fre_num /= fre_num[-1] 19 | for idx, x in enumerate(fre_num): 20 | if x > min_val: 21 | min_idx = int(idx) 22 | break 23 | 24 | for idx, x in enumerate(fre_num[::-1]): 25 | if x < max_val: 26 | max_idx = int(len(fre_num) - idx) 27 | break 28 | tomo = np.clip(tomo, xx[min_idx], xx[max_idx]) 29 | # tomo = np.clip(tomo, 52, 150) 30 | # print(xx[min_idx], xx[max_idx]) 31 | return tomo 32 | 33 | 34 | def gauss_filter(kernel_size=3, sigma=1): 35 | max_idx = kernel_size // 2 36 | idx = np.linspace(-max_idx, max_idx, kernel_size) 37 | Y, X = np.meshgrid(idx, idx) 38 | gauss_filter = np.exp(-(X ** 2 + Y ** 2) / (2 * sigma ** 2)) 39 | gauss_filter /= np.sum(np.sum(gauss_filter)) 40 | return gauss_filter 41 | 42 | 43 | 44 | def stretch(tomo): 45 | tomo = (tomo - np.min(tomo)) / (np.max(tomo) - np.min(tomo)) * 255 46 | return np.array(tomo).astype(np.uint8) 47 | 48 | 49 | class myThread(threading.Thread): 50 | def __init__(self, thread_id, func, args, emit_str): 51 | threading.Thread.__init__(self) 52 | self.threadID = thread_id 53 | self.func = func 54 | self.args = args 55 | self.emit_str = emit_str 56 | self.n = 1 57 | 58 | def run(self): 59 | while self.n: 60 | self.pid_num = self.func(self.args, self.emit_str) 61 | self.n -= 1 62 | 63 | def get_n(self): 64 | return self.n 65 | 66 | 67 | def make_video(tomo, save_path, fps, size): 68 | if 'mp4' in save_path: 69 | fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') 70 | elif 'avi' in save_path: 71 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 72 | 73 | videowriter = cv2.VideoWriter(save_path, 74 | fourcc, 75 | fps, 76 | size) 77 | for i in range(tomo.shape[0]): 78 | img = tomo[i] 79 | videowriter.write(img) 80 | 81 | 82 | class Concur(threading.Thread): 83 | """ 84 | 停止thread 85 | """ 86 | def __init__(self, job, args, stdout): 87 | super(Concur, self).__init__() 88 | self.__flag = threading.Event() 89 | self.__flag.set() 90 | self.__running = threading.Event() 91 | self.__running.set() 92 | self.job = job 93 | self.args = args 94 | self.stdout = stdout 95 | 96 | def run(self): 97 | while self.__running.isSet(): 98 | self.__flag.wait() 99 | try: 100 | self.job(self.args, self.stdout) 101 | self.pause() 102 | except Exception as e: 103 | self.stop() 104 | print(e) 105 | 106 | def pause(self): 107 | self.__flag.clear() 108 | 109 | def resume(self): 110 | self.__flag.set() 111 | 112 | def stop(self): 113 | self.__flag.set() 114 | self.__running.clear() 115 | 116 | 117 | def _async_raise(tid, exctype): 118 | """ 119 | stop thread 120 | """ 121 | """raises the exception, performs cleanup if needed""" 122 | tid = ctypes.c_long(tid) 123 | if not inspect.isclass(exctype): 124 | exctype = type(exctype) 125 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) 126 | if res == 0: 127 | raise ValueError("invalid thread id") 128 | elif res != 1: 129 | # """if it returns a number greater than one, you're in trouble, 130 | # and you should call it again with exc=NULL to revert the effect""" 131 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) 132 | raise SystemError("PyThreadState_SetAsyncExc failed") 133 | 134 | 135 | def stop_thread(thread): 136 | _async_raise(thread.ident, SystemExit) 137 | 138 | 139 | class EmittingStr(QObject): 140 | textWritten = pyqtSignal(str) 141 | 142 | def __init__(self): 143 | super(EmittingStr, self).__init__() 144 | 145 | def write(self, text): 146 | try: 147 | if len(str(text)) >= 2: 148 | self.textWritten.emit(str(text)) 149 | except: 150 | pass 151 | 152 | def flush(self, text=None): 153 | pass 154 | 155 | 156 | class ThreadShowInfo(QThread): 157 | def __init__(self, func, args): 158 | super(ThreadShowInfo, self).__init__() 159 | self.func = func 160 | self.args = args 161 | 162 | def run(self): 163 | self.func(self.args) 164 | 165 | 166 | def add_transparency(img, label, factor, color, thresh): 167 | img = np.array(255.0 * rescale(img), dtype=np.uint8) 168 | 169 | alpha_channel = np.ones(img.shape, dtype=img.dtype) * 255 170 | 171 | img_BGRA = cv2.merge((img, img, img, alpha_channel)) 172 | 173 | img = img[:, :, np.newaxis] 174 | img1 = img.repeat([3], axis=2) 175 | img1[label.astype(int) == 1] = (255, 0, 0) 176 | img1[label.astype(int) == 2] = (0, 255, 0) 177 | img1[label.astype(int) == 3] = (0, 0, 255) 178 | img1[label.astype(int) == 4] = (0, 255, 255) 179 | # img1[label > factor] = color 180 | c_b, c_g, c_r = cv2.split(img1) 181 | mask = np.where(label > factor, 1, 0) 182 | img1_alpha = np.array(mask * 255 * factor, dtype=np.uint8) 183 | img1_alpha[img1_alpha == 0] = 255 184 | img1_BGRA = cv2.merge((c_b, c_g, c_r, img1_alpha)) 185 | 186 | out = cv2.addWeighted(img_BGRA, 1 - factor, img1_BGRA, factor, 0) 187 | 188 | return np.array(out) 189 | 190 | 191 | def annotate_particle(img, coords, diameter, zz, idx, circle_width, color): 192 | """Plot circle centered at the particle coordinates directly on the tomogram.""" 193 | img = 255.0 * rescale(img) 194 | img = img[:, :, np.newaxis] 195 | rgb_uint8 = img.repeat([3], axis=2).astype(np.uint8) 196 | 197 | if idx == 0: 198 | columns = ['x', 'y', 'z'] 199 | elif idx == 1: 200 | columns = ['z', 'y', 'x'] 201 | else: 202 | columns = ['x', 'z', 'y'] 203 | 204 | df = pd.DataFrame(data=coords, 205 | columns=columns) 206 | color = np.array(color) 207 | print(color.shape) 208 | df["R"] = color[:, 0] 209 | df["G"] = color[:, 1] 210 | df["B"] = color[:, 2] 211 | 212 | r = diameter / 2 213 | coords_xy = df[df['z'] >= (zz - r)] 214 | coords_xy = coords_xy[df['z'] <= (zz + r)] 215 | rr2 = r ** 2 - (zz - coords_xy['z']) ** 2 216 | rr = [math.sqrt(i) for i in rr2] 217 | for x, y, rad, r, g, b in zip(coords_xy['x'], coords_xy['y'], rr, coords_xy['R'], coords_xy['G'], coords_xy['B']): 218 | cv2.circle(rgb_uint8, (int(y), int(x)), int(rad), (r, g, b), circle_width) 219 | return rgb_uint8 --------------------------------------------------------------------------------