├── .gitignore ├── LICENSE ├── README.md ├── burden ├── pred_burden.py └── validate_burden.py ├── patches ├── __init__.py ├── cal_patch_mean_std.py └── gen_patches.py ├── preprocess ├── check_mask.py ├── check_property.py ├── locate_tissue.py └── unzip_slides.py ├── requirements.txt ├── seg ├── __init__.py ├── dataload.py ├── loss.py ├── pred_test_slide.py ├── segnet │ ├── __init__.py │ ├── pspnet │ │ ├── __init__.py │ │ ├── caffe_pb2.py │ │ ├── layers.py │ │ └── pspnet.py │ └── unet.py ├── train_seg.py ├── train_seg.sh └── utils.py └── viable_whole_burden.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | data/ 3 | 4 | # Files 5 | data 6 | .remote-sync.json 7 | 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Pingjun Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [PAIP2019: Liver Cancer Segmentation](https://paip2019.grand-challenge.org/) 2 | 3 | To use the code, the user needs to pre-install a few packages. 4 | ``` 5 | $ sudo apt-get install openslide-tools 6 | $ sudo apt-get install libgeos-dev 7 | $ pip install -r requirements.txt 8 | ``` 9 | 10 | ## Preprocessing: 11 | ### 1. Download slides and unzip 12 | Download all 50 zipped slides and two csv files, put them inside `./data/SourceData`, unzip them by running 13 | ``` 14 | $ cd preprocess 15 | $ python unzip_slides.py 16 | ``` 17 | All slides would be unzipped into `./data/LiverImages`. 18 | 19 | ### 2. Check the segmentation masks 20 | Visualizing the `whole` and `viable` mask of a slide can give the user an intuitive feel on how the tumor looks. Run the following code to generate the side-by-side view of the masks with the corresponding slide image. 21 | ``` 22 | $ python check_mask.py 23 | ``` 24 | Moreover, [`tissueloc`](https://github.com/PingjunChen/tissueloc) provides the algorithm to locate the boundary of real tissues in the slide. Running 25 | ``` 26 | $ python locate_tissue.py 27 | ``` 28 | The located tissue results may help in slide-level prediction stage. Both mask comparison and tissue localization results are saved in the `./data/Visualization` directory. 29 | 30 | ### 3. Check the viable tumor burden 31 | If you want to check the provided `viable tumor burden` with the calculated result from provided masks, run the following code: 32 | ``` 33 | $ cd ../burden 34 | $ python validate_burden.py 35 | ``` 36 | 37 | ## Patch-based Slide Segmentation: 38 | ### 1. Patch sample generation 39 | Current we use two splitting manners, half-overlap and self-overlap. ``half-overlap`` would have more overlap between neighboring patches, thus obtains more patches. ``self-overlap`` usually have small overlap between neighboring patches, the number of generated patches is much smaller than ``half-overlap``. On both viable and whole tumor types, we would remove those patches that are entirely in non-tissue regions. 40 | 41 | **viable tumor patch splitting:**: We combine half-overlap with self-overlap. However, in half-overlap splitting, we control patches by its mask foreground. 42 | 43 | **whole tumor patch splitting:** We combine half-overlap with self-overlap with no control on mask foreground ratio. Compared with the viable tumor, whole tumor patch generation can obtain patches, and the ratio of background would be higher, thus avoiding false-positive in whole tumor prediction. 44 | 45 | Use the following commands to generate patches for `viable` and `whole` by setting the parameter `tumor_type`. 46 | ``` 47 | $ cd patches 48 | $ python gen_patches.py 49 | ``` 50 | 51 | ### 2. Segmentation model training 52 | #### 2.1 Model selection 53 | We explore UNet and PSPNet on liver patch segmentation. Experimental results show that PSPNet achieves superior performance. 54 | #### 2.2 Optimizer 55 | We compare SGD with initial learning rate 1.0e-2 and Adam with initial learning rate 1.0e-3. On both PSPNet and UNet, SGD presents superior performance. We train the segmentation model for 50 epochs and decay the learning rate with epoch-wise down-stepping until 0.0. 56 | #### 2.3 Loss function 57 | Binary cross-entropy (BCE) and dice loss are combined as the overall loss. BCE-0.1 achieves the most promising results. 58 | #### 2.4 Patch normalization 59 | We compare applying patch normalization and no patch normalization. The prediction performance on validation patches shows that applying no normalization performs a little bit better. Besides, without normalization is also more convenient to implement. 60 | 61 | The patch training can be run as follows, we train `viable` and `whole` with the same settings: 62 | ``` 63 | $ cd seg 64 | $ python train_seg.py 65 | ``` 66 | 67 | The Caffe pretrained PSPNet needs to be downloaded from [Here](https://drive.google.com/open?id=0BzaU285cX7TCT1M3TmNfNjlUeEU) and put it in `seg/segnet/pspnet/`. 68 | 69 | 70 | ### 3. Slide tumor prediction 71 | The slide-level segmentation is also conducted in a patch-wise manner. To be specific, we first split whole slide images into patches, and then we predict each patch. At last, we merge all patches' predictions to generate the final tumor segmentation result. 72 | 73 | Here the main issue is how to split the whole slide image. To make the slide-level segmentation to be more robust, we adopt a stride-wise patch splitting method and set the stride to be small (64 used). When the stride is small, each pixel would lie in more patches and thus would be predicted more times. As we would average the predictions to get the final prediction, each pixel's segmentation prediction would be more robust if it is predicted more times in multiple different contexts. However, the time cost would linearly increase with the number of patches. In the current application, we take the segmentation accuracy as the priority. 74 | 75 | Before predicting on test slides, we copy the best-performed model and paste it to `BestModel` folder for both `viable` and `whole`, then run 76 | ``` 77 | $ cd seg 78 | $ python pred_test_slide.py 79 | ``` 80 | 81 | After `viable` and `whole` tumor regions are predicted, we calculate the tumor burden with 82 | ``` 83 | $ cd burden 84 | $ python pred_burden.py 85 | ``` 86 | 87 | ## Acknowledgements 88 | - [shahabty/PSPNet-Pytorch](https://github.com/shahabty/PSPNet-Pytorch) 89 | - [hszhao/PSPNet](https://github.com/hszhao/PSPNet) 90 | - [usuyama/pytorch-unet](https://github.com/usuyama/pytorch-unet) 91 | -------------------------------------------------------------------------------- /burden/pred_burden.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | import argparse 8 | import numpy as np 9 | from skimage import io 10 | from pydaily import format 11 | 12 | def set_args(): 13 | parser = argparse.ArgumentParser(description = 'Liver Burden Prediction') 14 | 15 | parser.add_argument("--model_name", type=str, default="PSP") 16 | parser.add_argument("--result_dir", type=str, default="../data/TestResults") 17 | parser.add_argument("--seed", type=int, default=1234) 18 | 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == "__main__": 24 | args = set_args() 25 | viable_pred_dir = os.path.join(args.result_dir, "viable") 26 | whole_pred_dir = os.path.join(args.result_dir, "whole") 27 | if not os.path.exists(viable_pred_dir): 28 | raise Exception("No viable tumor prediction yet") 29 | if not os.path.exists(whole_pred_dir): 30 | raise Exception("No whole tumor prediciton yet") 31 | 32 | viable_pred_list = [ele for ele in os.listdir(viable_pred_dir) if "viable" in ele] 33 | whole_pred_list = [ele for ele in os.listdir(whole_pred_dir) if "whole" in ele] 34 | if len(viable_pred_list) != len(whole_pred_list): 35 | print("The number of viable and whole is not equal.") 36 | 37 | 38 | slide_list, ratio_list = [], [] 39 | id_list = [ele[7:10] for ele in viable_pred_list] 40 | id_list.sort() 41 | for id in id_list: 42 | viable_pred_name = "01_01_0" + id + "_viable.tif" 43 | whole_pred_name = "01_01_0" + id + "_whole.tif" 44 | viable_pred_path = os.path.join(viable_pred_dir, viable_pred_name) 45 | whole_pred_path = os.path.join(whole_pred_dir, whole_pred_name) 46 | viable_pred_img = io.imread(viable_pred_path) / 255 47 | whole_pred_img = io.imread(whole_pred_path) / 255 48 | pred_burden = np.sum(viable_pred_img) * 100.0 / (np.sum(whole_pred_img) + 1.0e-8) 49 | print("{} {:.3f}".format(id, pred_burden)) 50 | slide_list.append(id) 51 | pred_burden = min(pred_burden, 99.999) 52 | ratio_list.append(round(pred_burden, 3)) 53 | 54 | # save the prediction for submission 55 | burden_dict = {} 56 | burden_dict["wsi_id"] = slide_list 57 | burden_dict["ratio"] = ratio_list 58 | pred_csv_path = os.path.join(args.result_dir, "prediction.csv") 59 | format.dict_to_csv(burden_dict, pred_csv_path) 60 | -------------------------------------------------------------------------------- /burden/validate_burden.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | import pandas as pd 6 | from skimage import io 7 | from pydaily import format 8 | 9 | 10 | def cal_train_burden(slides_dir): 11 | slide_list = [] 12 | slide_list.extend([ele[6:-4] for ele in os.listdir(slides_dir) if "svs" in ele]) 13 | slide_list.extend([ele[6:-4] for ele in os.listdir(slides_dir) if "SVS" in ele]) 14 | burden_dict = {} 15 | for ind, cur_slide in enumerate(slide_list): 16 | cur_slide = str(cur_slide) 17 | print("Processing {}/{}".format(ind+1, len(slide_list))) 18 | cur_whole_path = os.path.join(slides_dir, "01_01_"+cur_slide+"_whole.tif") 19 | whole_mask = io.imread(cur_whole_path) 20 | cur_viable_path = os.path.join(slides_dir, "01_01_"+cur_slide+"_viable.tif") 21 | viable_mask = io.imread(cur_viable_path) 22 | cur_burden = np.sum(viable_mask) * 1.0 / np.sum(whole_mask) 23 | burden_dict[cur_slide] = cur_burden 24 | save_json_path = os.path.join(os.path.dirname(slides_dir), "SourceData", "calculated_tumor_burden.json") 25 | format.dict_to_json(burden_dict, save_json_path) 26 | 27 | 28 | def extract_csv_burden(csv_path, case_num): 29 | df = pd.read_csv(csv_path) 30 | slide_ids = df['wsi_id'].values.tolist()[:case_num] 31 | slide_burden = df['pixel ratio'].values.tolist()[:case_num] 32 | burden_dict = {} 33 | for id, burden in zip(slide_ids, slide_burden): 34 | burden_dict[str(int(id)).zfill(4)] = burden 35 | 36 | return burden_dict 37 | 38 | 39 | if __name__ == "__main__": 40 | # extract prepared ground truth viable tumor burden 41 | source_slides_dir = "../data/SourceData" 42 | phase1_path = os.path.join(source_slides_dir, "Phase_1_tumor_burden.csv") 43 | phase2_path = os.path.join(source_slides_dir, "Phase_2_tumor_burden.csv") 44 | gt_burden_dict = {} 45 | phase1_burden_dict = extract_csv_burden(phase1_path, case_num=20) 46 | gt_burden_dict.update(phase1_burden_dict) 47 | phase2_burden_dict = extract_csv_burden(phase2_path, case_num=30) 48 | gt_burden_dict.update(phase2_burden_dict) 49 | 50 | # get calculate viable tumor burden 51 | slides_dir = os.path.join(os.path.dirname(source_slides_dir), "LiverImages") 52 | cal_train_burden(slides_dir) 53 | 54 | # load calcualted burden 55 | cal_burden_path = os.path.join(source_slides_dir, "calculated_tumor_burden.json") 56 | cal_burden_dict = format.json_to_dict(cal_burden_path) 57 | 58 | # compare gt & cal 59 | for ind, key in enumerate(gt_burden_dict): 60 | if key not in cal_burden_dict: 61 | print("Error: {}".format(key)) 62 | gt_burden = gt_burden_dict[key] 63 | cal_burden = cal_burden_dict[key] 64 | if np.absolute(gt_burden-cal_burden) > 0.001: 65 | print("{}/{} {} gt:{:.3f}, cal:{:.3f}".format(ind+1, len(gt_burden_dict), key, 66 | gt_burden, cal_burden)) 67 | -------------------------------------------------------------------------------- /patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PingjunChen/LiverCancerSeg/0d83353af285f0fe05f2e3d65cf86b72140496b2/patches/__init__.py -------------------------------------------------------------------------------- /patches/cal_patch_mean_std.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | from skimage import io 6 | import pydaily 7 | 8 | data_root = "../data/Patches" 9 | tumor_type = "whole" 10 | train_dir = os.path.join(data_root, tumor_type, 'train') 11 | val_dir = os.path.join(data_root, tumor_type, 'val') 12 | 13 | 14 | def get_mean_and_std(img_dir, suffix): 15 | mean, std = np.zeros(3), np.zeros(3) 16 | filelist = pydaily.filesystem.find_ext_files(img_dir, suffix) 17 | 18 | for idx, filepath in enumerate(filelist): 19 | cur_img = io.imread(filepath) / 255.0 20 | for i in range(3): 21 | mean[i] += np.mean(cur_img[:,:,i]) 22 | std[i] += cur_img[:,:,i].std() 23 | mean = [ele * 1.0 / len(filelist) for ele in mean] 24 | std = [ele * 1.0 / len(filelist) for ele in std] 25 | return mean, std 26 | 27 | rgb_mean, rgb_std = get_mean_and_std(os.path.join(train_dir, "imgs"), suffix=".jpg") 28 | print("mean rgb: {}".format(rgb_mean)) 29 | print("std rgb: {}".format(rgb_std)) 30 | -------------------------------------------------------------------------------- /patches/gen_patches.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | import numpy as np 8 | import cv2 9 | import uuid, shutil 10 | from skimage import io, transform 11 | import matplotlib.pyplot as plt 12 | from sklearn.model_selection import train_test_split 13 | 14 | from pydaily import filesystem 15 | import tissueloc as tl 16 | from pyslide import pyramid, contour, patch 17 | from pycontour.cv2_transform import cv_cnt_to_np_arr 18 | from pycontour.poly_transform import np_arr_to_poly, poly_to_np_arr 19 | 20 | 21 | def gen_samples(slides_dir, patch_level, patch_size, tumor_type, slide_list, dset, overlap_mode): 22 | # prepare saving directory 23 | patch_path = os.path.join(os.path.dirname(slides_dir), "Patches", tumor_type) 24 | patch_img_dir = os.path.join(patch_path, dset, "imgs") 25 | if not os.path.exists(patch_img_dir): 26 | os.makedirs(patch_img_dir) 27 | patch_mask_dir = os.path.join(patch_path, dset, "masks") 28 | if not os.path.exists(patch_mask_dir): 29 | os.makedirs(patch_mask_dir) 30 | 31 | # processing slide one-by-one 32 | ttl_patch = 0 33 | slide_list.sort() 34 | for ind, ele in enumerate(slide_list): 35 | print("Processing {} {}/{}".format(ele, ind+1, len(slide_list))) 36 | cur_slide_path = os.path.join(slides_dir, ele+".svs") 37 | if os.path.exists(cur_slide_path): 38 | cur_slide_path = os.path.join(slides_dir, ele+".svs") 39 | 40 | # locate contours and generate batches based on tissue contours 41 | cnts, d_factor = tl.locate_tissue_cnts(cur_slide_path, max_img_size=2048, smooth_sigma=13, 42 | thresh_val=0.88, min_tissue_size=120000) 43 | select_level, select_factor = tl.select_slide_level(cur_slide_path, max_size=2048) 44 | cnts = sorted(cnts, key=lambda x: cv2.contourArea(x), reverse=True) 45 | 46 | # scale contour to slide level 2 47 | wsi_head = pyramid.load_wsi_head(cur_slide_path) 48 | cnt_scale = select_factor / int(wsi_head.level_downsamples[patch_level]) 49 | tissue_arr = cv_cnt_to_np_arr(cnts[0] * cnt_scale).astype(np.int32) 50 | # convert tissue_arr to convex if poly is not valid 51 | tissue_poly = np_arr_to_poly(tissue_arr) 52 | if tissue_poly.is_valid == False: 53 | tissue_arr = poly_to_np_arr(tissue_poly.convex_hull).astype(int) 54 | 55 | coors_arr = None 56 | if overlap_mode == "half_overlap": 57 | level_w, level_h = wsi_head.level_dimensions[patch_level] 58 | coors_arr = contour.contour_patch_splitting_half_overlap(tissue_arr, level_h, level_w, patch_size, inside_ratio=0.80) 59 | elif overlap_mode == "self_overlap": 60 | coors_arr = contour.contour_patch_splitting_self_overlap(tissue_arr, patch_size, inside_ratio=0.80) 61 | else: 62 | raise NotImplementedError("unknown overlapping mode") 63 | 64 | wsi_img = wsi_head.read_region((0, 0), patch_level, wsi_head.level_dimensions[patch_level]) 65 | wsi_img = np.asarray(wsi_img)[:,:,:3] 66 | mask_path = os.path.join(slides_dir, "_".join([ele, tumor_type+".tif"])) 67 | mask_img = io.imread(mask_path) 68 | wsi_mask = (transform.resize(mask_img, wsi_img.shape[:2], order=0) * 255).astype(np.uint8) * 255 69 | 70 | if dset == "val": 71 | test_slides_dir = os.path.join(os.path.dirname(slides_dir), "TestSlides") 72 | if not os.path.exists(os.path.join(test_slides_dir, cur_slide_path)): 73 | shutil.copy(cur_slide_path, test_slides_dir) 74 | if not os.path.exists(os.path.join(test_slides_dir, mask_path)): 75 | shutil.copy(mask_path, test_slides_dir) 76 | 77 | for cur_arr in coors_arr: 78 | cur_h, cur_w = cur_arr[0], cur_arr[1] 79 | cur_patch = wsi_img[cur_h:cur_h+patch_size, cur_w:cur_w+patch_size] 80 | if cur_patch.shape[0] != patch_size or cur_patch.shape[1] != patch_size: 81 | continue 82 | cur_mask = wsi_mask[cur_h:cur_h+patch_size, cur_w:cur_w+patch_size] 83 | # background RGB (235, 210, 235) * [0.299, 0.587, 0.114] 84 | if patch.patch_bk_ratio(cur_patch, bk_thresh=0.864) > 0.88: 85 | continue 86 | 87 | if overlap_mode == "half_overlap" and tumor_type == "viable": 88 | pixel_ratio = np.sum(cur_mask > 0) * 1.0 / cur_mask.size 89 | if pixel_ratio < 0.05: 90 | continue 91 | 92 | patch_name = ele + "_" + str(uuid.uuid1())[:8] 93 | io.imsave(os.path.join(patch_img_dir, patch_name+".jpg"), cur_patch) 94 | io.imsave(os.path.join(patch_mask_dir, patch_name+".png"), cur_mask) 95 | ttl_patch += 1 96 | 97 | print("There are {} patches in total.".format(ttl_patch)) 98 | 99 | 100 | 101 | if __name__ == "__main__": 102 | # prepare train and validation slide list 103 | mask_dir = os.path.join("../data", "Visualization", "TissueLoc") 104 | slide_list = [os.path.splitext(ele)[0] for ele in os.listdir(mask_dir) if "png" in ele] 105 | train_slide_list, val_slide_list = train_test_split(slide_list, test_size=0.20, random_state=1234) 106 | 107 | # generate patches for segmentation model training 108 | slides_dir = os.path.join("../data", "LiverImages") 109 | patch_level, patch_size = 2, 512 110 | # tumor_type = "viable" 111 | tumor_types = ["viable", "whole"] 112 | for cur_type in tumor_types: 113 | print("Generating {} tumor patches.".format(cur_type)) 114 | patch_modes = [(val_slide_list, "val", "half_overlap"), (val_slide_list, "val", "self_overlap"), 115 | (train_slide_list, "train", "half_overlap"), (train_slide_list, "train", "self_overlap")] 116 | for mode in patch_modes: 117 | gen_samples(slides_dir, patch_level, patch_size, cur_type, mode[0], mode[1], overlap_mode=mode[2]) 118 | -------------------------------------------------------------------------------- /preprocess/check_mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | from skimage import io, transform 6 | import scipy.misc as misc 7 | import matplotlib.pyplot as plt 8 | 9 | from pydaily import filesystem 10 | from pyslide import pyramid 11 | 12 | 13 | 14 | def get_slide_filenames(slides_dir): 15 | slide_list = [] 16 | svs_file_list = filesystem.find_ext_files(slides_dir, "svs") 17 | slide_list.extend([os.path.basename(ele) for ele in svs_file_list]) 18 | SVS_file_list = filesystem.find_ext_files(slides_dir, "SVS") 19 | slide_list.extend([os.path.basename(ele) for ele in SVS_file_list]) 20 | slide_filenames = [os.path.splitext(ele)[0] for ele in slide_list] 21 | 22 | return slide_filenames 23 | 24 | 25 | def save_mask_compare(slides_dir, slide_filenames): 26 | slide_num = len(slide_filenames) 27 | mask_save_dir = os.path.join(os.path.dirname(slides_dir), "Visualization/Masks") 28 | filesystem.overwrite_dir(mask_save_dir) 29 | for ind in np.arange(slide_num): 30 | print("processing {}/{}".format(ind+1, slide_num)) 31 | check_slide_mask(slides_dir, slide_filenames, ind) 32 | 33 | 34 | def check_slide_mask(slides_dir, slide_filenames, slide_index, display_level=2): 35 | """ Load slide segmentation mask. 36 | 37 | """ 38 | 39 | slide_path = os.path.join(slides_dir, slide_filenames[slide_index]+".svs") 40 | if not os.path.exists(slide_path): 41 | slide_path = os.path.join(slides_dir, slide_filenames[slide_index]+".SVS") 42 | wsi_head = pyramid.load_wsi_head(slide_path) 43 | new_size = (wsi_head.level_dimensions[display_level][1], wsi_head.level_dimensions[display_level][0]) 44 | slide_img = wsi_head.read_region((0, 0), display_level, wsi_head.level_dimensions[display_level]) 45 | slide_img = np.asarray(slide_img)[:,:,:3] 46 | 47 | # load and resize whole mask 48 | whole_mask_path = os.path.join(slides_dir, slide_filenames[slide_index]+"_whole.tif") 49 | whole_mask_img = io.imread(whole_mask_path) 50 | resize_whole_mask = (transform.resize(whole_mask_img, new_size, order=0) * 255).astype(np.uint8) 51 | # load and resize viable mask 52 | viable_mask_path = os.path.join(slides_dir, slide_filenames[slide_index]+"_viable.tif") 53 | viable_mask_img = io.imread(viable_mask_path) 54 | resize_viable_mask = (transform.resize(viable_mask_img, new_size, order=0) * 255).astype(np.uint8) 55 | 56 | # show the mask 57 | fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(16, 5)) 58 | ax1.imshow(slide_img) 59 | ax1.set_title('Slide Image') 60 | ax2.imshow(resize_whole_mask) 61 | ax2.set_title('Whole Tumor Mask') 62 | ax3.imshow(resize_viable_mask) 63 | ax3.set_title('Viable Tumor Mask') 64 | plt.tight_layout() 65 | # plt.show() 66 | save_path = os.path.join(os.path.dirname(slides_dir), "Visualization/Masks", slide_filenames[slide_index]+".png") 67 | fig.savefig(save_path) 68 | 69 | if __name__ == "__main__": 70 | slides_dir = os.path.join("../data", "LiverImages") 71 | slide_filenames = get_slide_filenames(slides_dir) 72 | # check_slide_mask(slides_dir, slide_filenames, 28) 73 | save_mask_compare(slides_dir, slide_filenames) 74 | -------------------------------------------------------------------------------- /preprocess/check_property.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | 6 | from pydaily import filesystem 7 | from pyslide import pyramid 8 | 9 | 10 | def get_slide_list(slides_dir): 11 | slide_list = [] 12 | svs_file_list = filesystem.find_ext_files(slides_dir, "svs") 13 | slide_list.extend(svs_file_list) 14 | SVS_file_list = filesystem.find_ext_files(slides_dir, "SVS") 15 | slide_list.extend(SVS_file_list) 16 | 17 | return slide_list 18 | 19 | 20 | def check_slide_properties(slide_path): 21 | wsi_head = pyramid.load_wsi_head(slide_path) 22 | flag = True 23 | # if wsi_head.level_count <= 2: 24 | # print("{} has {} levels".format(wsi_head._filename, wsi_head.level_count)) 25 | # flag = False 26 | # print(wsi_head.level_downsamples) 27 | if np.absolute(wsi_head.level_downsamples[2] - 16) > 0.01: 28 | print("{} scale is not {}".format(wsi_head._filename, 4)) 29 | flag = False 30 | 31 | return flag 32 | 33 | 34 | def check_slide_level(slide_list): 35 | fail_num = 0 36 | for index, slide_path in enumerate(slide_list): 37 | check_flag = check_slide_properties(slide_path) 38 | if check_flag == False: 39 | fail_num += 1 40 | print("There are {} slides not satisfying properties.".format(fail_num)) 41 | 42 | 43 | if __name__ == "__main__": 44 | slides_dir = os.path.join("../data", "LiverImages") 45 | slide_list = get_slide_list(slides_dir) 46 | check_slide_level(slide_list) 47 | -------------------------------------------------------------------------------- /preprocess/locate_tissue.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | import cv2 6 | from skimage import io 7 | import matplotlib.pyplot as plt 8 | 9 | from pydaily import filesystem 10 | import tissueloc as tl 11 | from pyslide import pyramid 12 | from pycontour.cv2_transform import cv_cnt_to_np_arr, np_arr_to_cv_cnt 13 | from pycontour.poly_transform import np_arr_to_poly, poly_to_np_arr 14 | 15 | 16 | def locate_tissue(slides_dir): 17 | slide_list = [] 18 | svs_file_list = filesystem.find_ext_files(slides_dir, "svs") 19 | slide_list.extend(svs_file_list) 20 | SVS_file_list = filesystem.find_ext_files(slides_dir, "SVS") 21 | slide_list.extend(SVS_file_list) 22 | 23 | tissue_dir = os.path.join(os.path.dirname(slides_dir), "Visualization/TissueLoc") 24 | filesystem.overwrite_dir(tissue_dir) 25 | for ind, slide_path in enumerate(slide_list): 26 | print("processing {}/{}".format(ind+1, len(slide_list))) 27 | # locate tissue contours with default parameters 28 | cnts, d_factor = tl.locate_tissue_cnts(slide_path, max_img_size=2048, smooth_sigma=13, 29 | thresh_val=0.88, min_tissue_size=120000) 30 | cnts = sorted(cnts, key=lambda x: cv2.contourArea(x), reverse=True) 31 | 32 | # if len(cnts) != 1: 33 | # print("There are {} contours in {}".format(len(cnts), os.path.basename(slide_path))) 34 | 35 | # load slide 36 | select_level, select_factor = tl.select_slide_level(slide_path, max_size=2048) 37 | wsi_head = pyramid.load_wsi_head(slide_path) 38 | slide_img = wsi_head.read_region((0, 0), select_level, wsi_head.level_dimensions[select_level]) 39 | slide_img = np.asarray(slide_img)[:,:,:3] 40 | slide_img = np.ascontiguousarray(slide_img, dtype=np.uint8) 41 | 42 | # change not valid poly to convex_hull 43 | cnt_arr = cv_cnt_to_np_arr(cnts[0]) 44 | cnt_poly = np_arr_to_poly(cnt_arr) 45 | if cnt_poly.is_valid == True: 46 | valid_cnt = cnts[0].astype(int) 47 | else: 48 | valid_arr = poly_to_np_arr(cnt_poly.convex_hull) 49 | valid_cnt = np_arr_to_cv_cnt(valid_arr).astype(int) 50 | cv2.drawContours(slide_img, [valid_cnt], 0, (0, 255, 0), 8) 51 | 52 | # overlay and save 53 | # cv2.drawContours(slide_img, cnts, 0, (0, 255, 0), 8) 54 | tissue_save_name = os.path.splitext(os.path.basename(slide_path))[0] + ".png" 55 | tissue_save_path = os.path.join(tissue_dir, tissue_save_name) 56 | io.imsave(tissue_save_path, slide_img) 57 | 58 | 59 | if __name__ == "__main__": 60 | slides_dir = os.path.join("../data", "LiverImages") 61 | locate_tissue(slides_dir) 62 | -------------------------------------------------------------------------------- /preprocess/unzip_slides.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import zipfile 5 | from pydaily import filesystem 6 | 7 | 8 | def unzip_slides(slides_dir): 9 | """ Unzip all slide files 10 | """ 11 | 12 | unzip_dir = os.path.join(os.path.dirname(slides_dir), "LiverImages") 13 | filesystem.overwrite_dir(unzip_dir) 14 | 15 | zip_list = [ele for ele in os.listdir(slides_dir) if "zip" in ele] 16 | for ind, ele in enumerate(zip_list): 17 | print("processing {}/{}".format(ind+1, len(zip_list))) 18 | zip_ref = zipfile.ZipFile(os.path.join(slides_dir, ele), 'r') 19 | zip_ref.extractall(unzip_dir) 20 | zip_ref.close() 21 | 22 | 23 | if __name__ == "__main__": 24 | # put all download zipped files here 25 | source_slides_dir = "../data/SourceData" 26 | unzip_slides(source_slides_dir) 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openslide-python==1.1.1 2 | Pillow==6.2.0 3 | opencv-python==3.4.4.19 4 | scikit-image==0.14.2 5 | tissueloc==2.0.1 6 | pyslide==0.4.0 7 | pycontour==1.4.0 8 | pydaily==0.4.0 9 | -------------------------------------------------------------------------------- /seg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PingjunChen/LiverCancerSeg/0d83353af285f0fe05f2e3d65cf86b72140496b2/seg/__init__.py -------------------------------------------------------------------------------- /seg/dataload.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | from skimage import io 6 | from torchvision import transforms 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | viable_rgb_mean = (0.790, 0.614, 0.739) 12 | viable_rgb_std = (0.093, 0.128, 0.102) 13 | whole_rgb_mean = (0.765, 0.547, 0.692) 14 | whole_rgb_std = (0.092, 0.119, 0.098) 15 | 16 | 17 | class LiverPatchDataset(Dataset): 18 | def __init__(self, data_dir, transform=None): 19 | self.data_dir = data_dir 20 | self.patch_dir = os.path.join(self.data_dir, "imgs") 21 | self.mask_dir = os.path.join(self.data_dir, "masks") 22 | self.patch_list = [ele for ele in os.listdir(self.patch_dir) if "jpg" in ele] 23 | self.transform = transform 24 | self.cur_img_name = None 25 | 26 | 27 | def __len__(self): 28 | return len(self.patch_list) 29 | 30 | 31 | def __getitem__(self, idx): 32 | self.cur_img_name = os.path.splitext(self.patch_list[idx])[0] 33 | patch_path = os.path.join(self.patch_dir, self.patch_list[idx]) 34 | mask_path = os.path.join(self.mask_dir, self.cur_img_name + ".png") 35 | 36 | image = (io.imread(patch_path) / 255.0).astype(np.float32) 37 | mask = (io.imread(mask_path) / 255.0).astype(np.float32) 38 | mask = np.expand_dims(mask, axis=0) 39 | 40 | if self.transform: 41 | image = self.transform(image) 42 | 43 | return [image, mask] 44 | 45 | 46 | def gen_dloader(data_dir, batch_size, mode="train", normalize=True, tumor_type="viable"): 47 | transform_list = [transforms.ToTensor(), ] 48 | if tumor_type == "viable": 49 | rgb_mean, rgb_std = viable_rgb_mean, viable_rgb_std 50 | elif tumor_type == "whole": 51 | rgb_mean, rgb_std = whole_rgb_mean, whole_rgb_std 52 | else: 53 | raise AssertionError("Unknow tumor type: {}".format(tumor_type)) 54 | if normalize == True: 55 | transform_list.append(transforms.Normalize(rgb_mean, rgb_std)) 56 | trans = transforms.Compose(transform_list) 57 | 58 | 59 | dset = LiverPatchDataset(data_dir, transform=trans) 60 | if mode == "train": 61 | dloader = DataLoader(dset, batch_size=batch_size, shuffle=True, 62 | num_workers=4, drop_last=True) 63 | elif mode == "val": 64 | dloader = DataLoader(dset, batch_size=batch_size, shuffle=False, 65 | num_workers=4, drop_last=False) 66 | else: 67 | raise Exception("Unknow mode: {}".format(mode)) 68 | 69 | 70 | return dloader 71 | 72 | 73 | 74 | class PatchDataset(Dataset): 75 | """ 76 | Dataset for slide testing. Each would be splitted into multiple patches. 77 | Prediction is made on these splitted patches. 78 | """ 79 | 80 | def __init__(self, patch_arr, mask_arr=None, normalize=True, tumor_type="viable"): 81 | self.patches = patch_arr 82 | self.masks = mask_arr 83 | transform_list = [transforms.ToTensor(), ] 84 | if tumor_type == "viable": 85 | rgb_mean, rgb_std = viable_rgb_mean, viable_rgb_std 86 | elif tumor_type == "whole": 87 | rgb_mean, rgb_std = whole_rgb_mean, whole_rgb_std 88 | else: 89 | raise AssertionError("Unknow tumor type: {}".format(tumor_type)) 90 | if normalize == True: 91 | transform_list.append(transforms.Normalize(rgb_mean, rgb_std)) 92 | self.transform = transforms.Compose(transform_list) 93 | 94 | def __len__(self): 95 | return self.patches.shape[0] 96 | 97 | def __getitem__(self, idx): 98 | patch = self.patches[idx,...] 99 | if self.transform: 100 | patch = self.transform(patch) 101 | if isinstance(self.masks, np.ndarray): 102 | mask = np.expand_dims(self.masks[idx,...], axis=0) 103 | return patch, mask 104 | else: 105 | return patch 106 | -------------------------------------------------------------------------------- /seg/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def dice_loss(pred, target, smooth = 1.): 9 | pred = pred.contiguous() 10 | target = target.contiguous() 11 | 12 | intersection = (pred * target).sum(dim=2).sum(dim=2) 13 | dice = ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)) 14 | 15 | return dice.mean() 16 | 17 | 18 | def calc_loss(pred, target, metrics, bce_weight=0.2): 19 | bce = F.binary_cross_entropy_with_logits(pred, target) 20 | 21 | pred = torch.sigmoid(pred) 22 | dice = dice_loss(pred, target) 23 | loss = bce * bce_weight + (1.0 - dice) * (1.0 - bce_weight) 24 | 25 | metrics['bce'] += bce.data.cpu().numpy() * target.size(0) 26 | metrics['dice'] += dice.data.cpu().numpy() * target.size(0) 27 | metrics['loss'] += loss.data.cpu().numpy() * target.size(0) 28 | 29 | return loss 30 | 31 | 32 | def print_metrics(metrics, epoch_samples, phase="test"): 33 | outputs = [] 34 | for k in metrics.keys(): 35 | outputs.append("{}: {:.3f}".format(k, metrics[k] / epoch_samples)) 36 | print("{}: {}".format(phase, ", ".join(outputs))) 37 | -------------------------------------------------------------------------------- /seg/pred_test_slide.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import argparse, uuid, time 10 | from skimage import io, transform 11 | from tifffile import imsave 12 | from collections import defaultdict 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch.utils.data import DataLoader 17 | import torch.backends.cudnn as cudnn 18 | import torch.nn.functional as F 19 | from pydaily import filesystem 20 | from pyslide import patch, pyramid 21 | 22 | from segnet import UNet, pspnet 23 | from utils import get_slide_filenames, gen_patch_wmap 24 | from utils import wsi_stride_splitting 25 | from loss import calc_loss 26 | from dataload import PatchDataset 27 | 28 | 29 | def set_args(): 30 | parser = argparse.ArgumentParser(description = 'Liver Tumor Slide Segmentation') 31 | parser.add_argument("--class_num", type=int, default=1) 32 | parser.add_argument("--in_channels", type=int, default=3) 33 | parser.add_argument("--batch_size", type=int, default=48) 34 | parser.add_argument("--stride_len", type=int, default=128) 35 | parser.add_argument("--patch_len", type=int, default=512) 36 | parser.add_argument("--slide_level", type=int, default=2) 37 | parser.add_argument("--model_name", type=str, default="PSP") 38 | parser.add_argument("--gpu", type=str, default="3,5,6,7") 39 | parser.add_argument("--split", type=str, default="BestModel") 40 | parser.add_argument("--tumor_type", type=str, default="viable") 41 | parser.add_argument("--best_model", type=str, default="PSP-050-0.755.pth") 42 | # parser.add_argument("--tumor_type", type=str, default="whole") 43 | # parser.add_argument("--best_model", type=str, default="PSP-049-0.682.pth") 44 | 45 | parser.add_argument("--model_dir", type=str, default="../data/Models") 46 | parser.add_argument("--slides_dir", type=str, default="../data/TestSlides") 47 | parser.add_argument("--result_dir", type=str, default="../data/TestResults") 48 | parser.add_argument("--normalize", type=bool, default=False) 49 | parser.add_argument("--save_org", type=bool, default=True) 50 | parser.add_argument("--seed", type=int, default=1234) 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | def test_slide_seg(args): 57 | model = None 58 | if args.model_name == "UNet": 59 | model = UNet(n_channels=args.in_channels, n_classes=args.class_num) 60 | elif args.model_name == "PSP": 61 | model = pspnet.PSPNet(n_classes=19, input_size=(512, 512)) 62 | model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) 63 | else: 64 | raise AssertionError("Unknow modle: {}".format(args.model_name)) 65 | model_path = os.path.join(args.model_dir, args.tumor_type, args.split, args.best_model) 66 | model = nn.DataParallel(model) 67 | model.load_state_dict(torch.load(model_path)) 68 | model.cuda() 69 | model.eval() 70 | 71 | since = time.time() 72 | result_dir = os.path.join(args.result_dir, args.tumor_type) 73 | filesystem.overwrite_dir(result_dir) 74 | slide_names = get_slide_filenames(args.slides_dir) 75 | if args.save_org and args.tumor_type == "viable": 76 | org_result_dir = os.path.join(result_dir, "Level0") 77 | filesystem.overwrite_dir(org_result_dir) 78 | 79 | for num, cur_slide in enumerate(slide_names): 80 | print("--{:02d}/{:02d} Slide:{}".format(num+1, len(slide_names), cur_slide)) 81 | metrics = defaultdict(float) 82 | # load level-2 slide 83 | slide_path = os.path.join(args.slides_dir, cur_slide+".svs") 84 | if not os.path.exists(slide_path): 85 | slide_path = os.path.join(args.slides_dir, cur_slide+".SVS") 86 | wsi_head = pyramid.load_wsi_head(slide_path) 87 | p_level = args.slide_level 88 | pred_h, pred_w = (wsi_head.level_dimensions[p_level][1], wsi_head.level_dimensions[p_level][0]) 89 | slide_img = wsi_head.read_region((0, 0), p_level, wsi_head.level_dimensions[p_level]) 90 | slide_img = np.asarray(slide_img)[:,:,:3] 91 | 92 | coors_arr = wsi_stride_splitting(pred_h, pred_w, patch_len=args.patch_len, stride_len=args.stride_len) 93 | patch_arr, wmap = gen_patch_wmap(slide_img, coors_arr, plen=args.patch_len) 94 | patch_dset = PatchDataset(patch_arr, mask_arr=None, normalize=args.normalize, tumor_type=args.tumor_type) 95 | patch_loader = DataLoader(patch_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) 96 | ttl_samples = 0 97 | pred_map = np.zeros_like(wmap).astype(np.float32) 98 | for ind, patches in enumerate(patch_loader): 99 | inputs = Variable(patches.cuda()) 100 | with torch.no_grad(): 101 | outputs = model(inputs) 102 | preds = F.sigmoid(outputs) 103 | preds = torch.squeeze(preds, dim=1).data.cpu().numpy() 104 | if (ind+1)*args.batch_size <= len(coors_arr): 105 | patch_coors = coors_arr[ind*args.batch_size:(ind+1)*args.batch_size] 106 | else: 107 | patch_coors = coors_arr[ind*args.batch_size:] 108 | for ind, coor in enumerate(patch_coors): 109 | ph, pw = coor[0], coor[1] 110 | pred_map[ph:ph+args.patch_len, pw:pw+args.patch_len] += preds[ind] 111 | ttl_samples += inputs.size(0) 112 | 113 | prob_pred = np.divide(pred_map, wmap) 114 | slide_pred = (prob_pred > 0.5).astype(np.uint8) 115 | pred_save_path = os.path.join(result_dir, cur_slide + "_" + args.tumor_type + ".tif") 116 | io.imsave(pred_save_path, slide_pred*255) 117 | 118 | if args.save_org and args.tumor_type == "viable": 119 | org_w, org_h = wsi_head.level_dimensions[0] 120 | org_pred = transform.resize(prob_pred, (org_h, org_w)) 121 | org_pred = (org_pred > 0.5).astype(np.uint8) 122 | org_save_path = os.path.join(org_result_dir, cur_slide[-3:] + ".tif") 123 | imsave(org_save_path, org_pred, compress=9) 124 | 125 | time_elapsed = time.time() - since 126 | print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60, time_elapsed % 60)) 127 | 128 | if __name__ == '__main__': 129 | args = set_args() 130 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 131 | torch.cuda.manual_seed(args.seed) 132 | cudnn.benchmark = True 133 | 134 | # train model 135 | print("{} prediction using: {}, model: {}".format(args.tumor_type.upper(), args.model_name, args.best_model)) 136 | test_slide_seg(args) 137 | -------------------------------------------------------------------------------- /seg/segnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .unet import UNet 4 | from . import pspnet 5 | -------------------------------------------------------------------------------- /seg/segnet/pspnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .pspnet import PSPNet 3 | -------------------------------------------------------------------------------- /seg/segnet/pspnet/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class conv2DBatchNorm(nn.Module): 7 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): 8 | super(conv2DBatchNorm, self).__init__() 9 | 10 | if dilation > 1: 11 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 12 | padding=padding, stride=stride, bias=bias, dilation=dilation) 13 | 14 | else: 15 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 16 | padding=padding, stride=stride, bias=bias, dilation=1) 17 | 18 | 19 | self.cb_unit = nn.Sequential(conv_mod, 20 | nn.BatchNorm2d(int(n_filters)),) 21 | 22 | def forward(self, inputs): 23 | outputs = self.cb_unit(inputs) 24 | return outputs 25 | 26 | 27 | class deconv2DBatchNorm(nn.Module): 28 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 29 | super(deconv2DBatchNorm, self).__init__() 30 | 31 | self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, 32 | padding=padding, stride=stride, bias=bias), 33 | nn.BatchNorm2d(int(n_filters)),) 34 | 35 | def forward(self, inputs): 36 | outputs = self.dcb_unit(inputs) 37 | return outputs 38 | 39 | 40 | class conv2DBatchNormRelu(nn.Module): 41 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): 42 | super(conv2DBatchNormRelu, self).__init__() 43 | 44 | if dilation > 1: 45 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 46 | padding=padding, stride=stride, bias=bias, dilation=dilation) 47 | 48 | else: 49 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 50 | padding=padding, stride=stride, bias=bias, dilation=1) 51 | 52 | self.cbr_unit = nn.Sequential(conv_mod, 53 | nn.BatchNorm2d(int(n_filters)), 54 | nn.ReLU(inplace=True),) 55 | 56 | def forward(self, inputs): 57 | outputs = self.cbr_unit(inputs) 58 | return outputs 59 | 60 | 61 | class deconv2DBatchNormRelu(nn.Module): 62 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 63 | super(deconv2DBatchNormRelu, self).__init__() 64 | 65 | self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, 66 | padding=padding, stride=stride, bias=bias), 67 | nn.BatchNorm2d(int(n_filters)), 68 | nn.ReLU(inplace=True),) 69 | 70 | def forward(self, inputs): 71 | outputs = self.dcbr_unit(inputs) 72 | return outputs 73 | 74 | 75 | class unetConv2(nn.Module): 76 | def __init__(self, in_size, out_size, is_batchnorm): 77 | super(unetConv2, self).__init__() 78 | 79 | if is_batchnorm: 80 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), 81 | nn.BatchNorm2d(out_size), 82 | nn.ReLU(),) 83 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), 84 | nn.BatchNorm2d(out_size), 85 | nn.ReLU(),) 86 | else: 87 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), 88 | nn.ReLU(),) 89 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), 90 | nn.ReLU(),) 91 | def forward(self, inputs): 92 | outputs = self.conv1(inputs) 93 | outputs = self.conv2(outputs) 94 | return outputs 95 | 96 | 97 | class unetUp(nn.Module): 98 | def __init__(self, in_size, out_size, is_deconv): 99 | super(unetUp, self).__init__() 100 | self.conv = unetConv2(in_size, out_size, False) 101 | if is_deconv: 102 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 103 | else: 104 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 105 | 106 | def forward(self, inputs1, inputs2): 107 | outputs2 = self.up(inputs2) 108 | offset = outputs2.size()[2] - inputs1.size()[2] 109 | padding = 2 * [offset // 2, offset // 2] 110 | outputs1 = F.pad(inputs1, padding) 111 | return self.conv(torch.cat([outputs1, outputs2], 1)) 112 | 113 | 114 | class segnetDown2(nn.Module): 115 | def __init__(self, in_size, out_size): 116 | super(segnetDown2, self).__init__() 117 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 118 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 119 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 120 | 121 | def forward(self, inputs): 122 | outputs = self.conv1(inputs) 123 | outputs = self.conv2(outputs) 124 | unpooled_shape = outputs.size() 125 | outputs, indices = self.maxpool_with_argmax(outputs) 126 | return outputs, indices, unpooled_shape 127 | 128 | 129 | class segnetDown3(nn.Module): 130 | def __init__(self, in_size, out_size): 131 | super(segnetDown3, self).__init__() 132 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 133 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 134 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 135 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 136 | 137 | def forward(self, inputs): 138 | outputs = self.conv1(inputs) 139 | outputs = self.conv2(outputs) 140 | outputs = self.conv3(outputs) 141 | unpooled_shape = outputs.size() 142 | outputs, indices = self.maxpool_with_argmax(outputs) 143 | return outputs, indices, unpooled_shape 144 | 145 | 146 | class segnetUp2(nn.Module): 147 | def __init__(self, in_size, out_size): 148 | super(segnetUp2, self).__init__() 149 | self.unpool = nn.MaxUnpool2d(2, 2) 150 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 151 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 152 | 153 | def forward(self, inputs, indices, output_shape): 154 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 155 | outputs = self.conv1(outputs) 156 | outputs = self.conv2(outputs) 157 | return outputs 158 | 159 | 160 | class segnetUp3(nn.Module): 161 | def __init__(self, in_size, out_size): 162 | super(segnetUp3, self).__init__() 163 | self.unpool = nn.MaxUnpool2d(2, 2) 164 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 165 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 166 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 167 | 168 | def forward(self, inputs, indices, output_shape): 169 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 170 | outputs = self.conv1(outputs) 171 | outputs = self.conv2(outputs) 172 | outputs = self.conv3(outputs) 173 | return outputs 174 | 175 | 176 | class residualBlock(nn.Module): 177 | expansion = 1 178 | 179 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 180 | super(residualBlock, self).__init__() 181 | 182 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) 183 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) 184 | self.downsample = downsample 185 | self.stride = stride 186 | self.relu = nn.ReLU(inplace=True) 187 | 188 | def forward(self, x): 189 | residual = x 190 | 191 | out = self.convbnrelu1(x) 192 | out = self.convbn2(out) 193 | 194 | if self.downsample is not None: 195 | residual = self.downsample(x) 196 | 197 | out += residual 198 | out = self.relu(out) 199 | return out 200 | 201 | 202 | class residualBottleneck(nn.Module): 203 | expansion = 4 204 | 205 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 206 | super(residualBottleneck, self).__init__() 207 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) 208 | self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) 209 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) 210 | self.relu = nn.ReLU(inplace=True) 211 | self.downsample = downsample 212 | self.stride = stride 213 | 214 | def forward(self, x): 215 | residual = x 216 | 217 | out = self.convbn1(x) 218 | out = self.convbn2(out) 219 | out = self.convbn3(out) 220 | 221 | if self.downsample is not None: 222 | residual = self.downsample(x) 223 | 224 | out += residual 225 | out = self.relu(out) 226 | 227 | return out 228 | 229 | 230 | class linknetUp(nn.Module): 231 | def __init__(self, in_channels, n_filters): 232 | super(linknetUp, self).__init__() 233 | 234 | # B, 2C, H, W -> B, C/2, H, W 235 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters/2, k_size=1, stride=1, padding=1) 236 | 237 | # B, C/2, H, W -> B, C/2, H, W 238 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(n_filters/2, n_filters/2, k_size=3, stride=2, padding=0) 239 | 240 | # B, C/2, H, W -> B, C, H, W 241 | self.convbnrelu3 = conv2DBatchNormRelu(n_filters/2, n_filters, k_size=1, stride=1, padding=1) 242 | 243 | def forward(self, x): 244 | x = self.convbnrelu1(x) 245 | x = self.deconvbnrelu2(x) 246 | x = self.convbnrelu3(x) 247 | return x 248 | 249 | 250 | class FRRU(nn.Module): 251 | """ 252 | Full Resolution Residual Unit for FRRN 253 | """ 254 | def __init__(self, prev_channels, out_channels, scale): 255 | super(FRRU, self).__init__() 256 | self.scale = scale 257 | self.prev_channels = prev_channels 258 | self.out_channels = out_channels 259 | 260 | self.conv1 = conv2DBatchNormRelu(prev_channels + 32, out_channels, k_size=3, stride=1, padding=1) 261 | self.conv2 = conv2DBatchNormRelu(out_channels, out_channels, k_size=3, stride=1, padding=1) 262 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) 263 | 264 | def forward(self, y, z): 265 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) 266 | y_prime = self.conv1(x) 267 | y_prime = self.conv2(y_prime) 268 | 269 | x = self.conv_res(y_prime) 270 | upsample_size = torch.Size([_s*self.scale for _s in y_prime.shape[-2:]]) 271 | x = F.upsample(x, size=upsample_size, mode='nearest') 272 | z_prime = z + x 273 | 274 | return y_prime, z_prime 275 | 276 | 277 | class RU(nn.Module): 278 | """ 279 | Residual Unit for FRRN 280 | """ 281 | def __init__(self, channels, kernel_size=3, strides=1): 282 | super(RU, self).__init__() 283 | 284 | self.conv1 = conv2DBatchNormRelu(channels, channels, k_size=kernel_size, stride=strides, padding=1) 285 | self.conv2 = conv2DBatchNorm(channels, channels, k_size=kernel_size, stride=strides, padding=1) 286 | 287 | def forward(self, x): 288 | incoming = x 289 | x = self.conv1(x) 290 | x = self.conv2(x) 291 | return x + incoming 292 | 293 | 294 | class residualConvUnit(nn.Module): 295 | def __init__(self, channels, kernel_size=3): 296 | super(residualConvUnit, self).__init__() 297 | 298 | self.residual_conv_unit = nn.Sequential(nn.ReLU(inplace=True), 299 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 300 | nn.ReLU(inplace=True), 301 | nn.Conv2d(channels, channels, kernel_size=kernel_size),) 302 | def forward(self, x): 303 | input = x 304 | x = self.residual_conv_unit(x) 305 | return x + input 306 | 307 | class multiResolutionFusion(nn.Module): 308 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): 309 | super(multiResolutionFusion, self).__init__() 310 | 311 | self.up_scale_high = up_scale_high 312 | self.up_scale_low = up_scale_low 313 | 314 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) 315 | 316 | if low_shape is not None: 317 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) 318 | 319 | def forward(self, x_high, x_low): 320 | high_upsampled = F.upsample(self.conv_high(x_high), 321 | scale_factor=self.up_scale_high, 322 | mode='bilinear') 323 | 324 | if x_low is None: 325 | return high_upsampled 326 | 327 | low_upsampled = F.upsample(self.conv_low(x_low), 328 | scale_factor=self.up_scale_low, 329 | mode='bilinear') 330 | 331 | return low_upsampled + high_upsampled 332 | 333 | class chainedResidualPooling(nn.Module): 334 | def __init__(self, channels, input_shape): 335 | super(chainedResidualPooling, self).__init__() 336 | 337 | self.chained_residual_pooling = nn.Sequential(nn.ReLU(inplace=True), 338 | nn.MaxPool2d(5, 1, 2), 339 | nn.Conv2d(input_shape[1], channels, kernel_size=3),) 340 | 341 | def forward(self, x): 342 | input = x 343 | x = self.chained_residual_pooling(x) 344 | return x + input 345 | 346 | 347 | class pyramidPooling(nn.Module): 348 | 349 | def __init__(self, in_channels, pool_sizes): 350 | super(pyramidPooling, self).__init__() 351 | 352 | self.paths = [] 353 | for i in range(len(pool_sizes)): 354 | self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=False)) 355 | 356 | self.path_module_list = nn.ModuleList(self.paths) 357 | self.pool_sizes = pool_sizes 358 | 359 | def forward(self, x): 360 | output_slices = [x] 361 | h, w = x.shape[2:] 362 | 363 | for module, pool_size in zip(self.path_module_list, self.pool_sizes): 364 | out = F.avg_pool2d(x, int(h/pool_size), int(h/pool_size), 0) 365 | out = module(out) 366 | out = F.upsample(out, size=(h,w), mode='bilinear') 367 | output_slices.append(out) 368 | 369 | return torch.cat(output_slices, dim=1) 370 | 371 | 372 | class bottleNeckPSP(nn.Module): 373 | 374 | def __init__(self, in_channels, mid_channels, out_channels, 375 | stride, dilation=1): 376 | super(bottleNeckPSP, self).__init__() 377 | 378 | self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) 379 | if dilation > 1: 380 | self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, 381 | padding=dilation, bias=False, 382 | dilation=dilation) 383 | else: 384 | self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 385 | stride=stride, padding=1, 386 | bias=False, dilation=1) 387 | self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, 1, 0, bias=False) 388 | self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride, 0, bias=False) 389 | 390 | def forward(self, x): 391 | conv = self.cb3(self.cbr2(self.cbr1(x))) 392 | residual = self.cb4(x) 393 | return F.relu(conv+residual, inplace=True) 394 | 395 | 396 | class bottleNeckIdentifyPSP(nn.Module): 397 | 398 | def __init__(self, in_channels, mid_channels, stride, dilation=1): 399 | super(bottleNeckIdentifyPSP, self).__init__() 400 | 401 | self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) 402 | if dilation > 1: 403 | self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, 404 | padding=dilation, bias=False, 405 | dilation=dilation) 406 | else: 407 | self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 408 | stride=1, padding=1, 409 | bias=False, dilation=1) 410 | self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, 1, 0, bias=False) 411 | 412 | def forward(self, x): 413 | residual = x 414 | x = self.cb3(self.cbr2(self.cbr1(x))) 415 | return F.relu(x+residual, inplace=True) 416 | 417 | 418 | class residualBlockPSP(nn.Module): 419 | 420 | def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation=1): 421 | super(residualBlockPSP, self).__init__() 422 | 423 | if dilation > 1: 424 | stride = 1 425 | 426 | layers = [bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation)] 427 | for i in range(n_blocks): 428 | layers.append(bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation)) 429 | 430 | self.layers = nn.Sequential(*layers) 431 | 432 | def forward(self, x): 433 | return self.layers(x) 434 | -------------------------------------------------------------------------------- /seg/segnet/pspnet/pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | from math import ceil 6 | from torch.autograd import Variable 7 | 8 | from .layers import * 9 | from . import caffe_pb2 10 | 11 | 12 | pspnet_specs = { 13 | 'pascalvoc': 14 | { 15 | 'n_classes': 21, 16 | 'input_size': (473, 473), 17 | 'block_config': [3, 4, 23, 3], 18 | }, 19 | 20 | 'cityscapes': 21 | { 22 | 'n_classes': 19, 23 | 'input_size': (713, 713), 24 | 'block_config': [3, 4, 23, 3], 25 | }, 26 | 27 | 'ade20k': 28 | { 29 | 'n_classes': 150, 30 | 'input_size': (473, 473), 31 | 'block_config': [3, 4, 6, 3], 32 | }, 33 | } 34 | 35 | 36 | class PSPNet(nn.Module): 37 | """ 38 | Pyramid Scene Parsing Network 39 | 40 | """ 41 | 42 | def __init__(self, 43 | n_classes=19, 44 | block_config=[3, 4, 23, 3], 45 | input_size=(473,473), 46 | version=None): 47 | 48 | super(PSPNet, self).__init__() 49 | 50 | self.block_config = pspnet_specs[version]['block_config'] if version is not None else block_config 51 | self.n_classes = pspnet_specs[version]['n_classes'] if version is not None else n_classes 52 | self.input_size = pspnet_specs[version]['input_size'] if version is not None else input_size 53 | 54 | # Encoder 55 | self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=64, 56 | padding=1, stride=2, bias=False) 57 | self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64, 58 | padding=1, stride=1, bias=False) 59 | self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=128, 60 | padding=1, stride=1, bias=False) 61 | 62 | # Vanilla Residual Blocks 63 | self.res_block2 = residualBlockPSP(self.block_config[0], 128, 64, 256, 1, 1) 64 | self.res_block3 = residualBlockPSP(self.block_config[1], 256, 128, 512, 2, 1) 65 | 66 | # Dilated Residual Blocks 67 | self.res_block4 = residualBlockPSP(self.block_config[2], 512, 256, 1024, 1, 2) 68 | self.res_block5 = residualBlockPSP(self.block_config[3], 1024, 512, 2048, 1, 4) 69 | 70 | # Pyramid Pooling Module 71 | self.pyramid_pooling = pyramidPooling(2048, [6, 3, 2, 1]) 72 | 73 | # Final conv layers 74 | self.cbr_final = conv2DBatchNormRelu(4096, 512, 3, 1, 1, False) 75 | self.dropout = nn.Dropout2d(p=0.1, inplace=False) 76 | self.classification = nn.Conv2d(512, self.n_classes, 1, 1, 0) 77 | 78 | 79 | def forward(self, x): 80 | inp_shape = x.shape[2:] 81 | 82 | # H, W -> H/2, W/2 83 | x = self.convbnrelu1_1(x) 84 | x = self.convbnrelu1_2(x) 85 | x = self.convbnrelu1_3(x) 86 | 87 | # H/2, W/2 -> H/4, W/4 88 | x = F.max_pool2d(x, 3, 2, 1) 89 | 90 | # H/4, W/4 -> H/8, W/8 91 | x = self.res_block2(x) 92 | x = self.res_block3(x) 93 | x = self.res_block4(x) 94 | x = self.res_block5(x) 95 | 96 | x = self.pyramid_pooling(x) 97 | 98 | x = self.cbr_final(x) 99 | x = self.dropout(x) 100 | 101 | x = self.classification(x) 102 | x = F.upsample(x, size=inp_shape, mode='bilinear') 103 | return x 104 | 105 | def load_pretrained_model(self, model_path): 106 | """ 107 | Load weights from caffemodel w/o caffe dependency 108 | and plug them in corresponding modules 109 | """ 110 | # My eyes and my heart both hurt when writing this method 111 | 112 | # Only care about layer_types that have trainable parameters 113 | ltypes = ['BNData', 'ConvolutionData', 'HoleConvolutionData'] 114 | 115 | def _get_layer_params(layer, ltype): 116 | 117 | if ltype == 'BNData': 118 | gamma = np.array(layer.blobs[0].data) 119 | beta = np.array(layer.blobs[1].data) 120 | mean = np.array(layer.blobs[2].data) 121 | var = np.array(layer.blobs[3].data) 122 | return [mean, var, gamma, beta] 123 | 124 | elif ltype in ['ConvolutionData', 'HoleConvolutionData']: 125 | is_bias = layer.convolution_param.bias_term 126 | weights = np.array(layer.blobs[0].data) 127 | bias = [] 128 | if is_bias: 129 | bias = np.array(layer.blobs[1].data) 130 | return [weights, bias] 131 | 132 | elif ltype == 'InnerProduct': 133 | raise Exception("Fully connected layers {}, not supported".format(ltype)) 134 | 135 | else: 136 | raise Exception("Unkown layer type {}".format(ltype)) 137 | 138 | 139 | net = caffe_pb2.NetParameter() 140 | with open(model_path, 'rb') as model_file: 141 | net.MergeFromString(model_file.read()) 142 | 143 | # dict formatted as -> key: :: value: 144 | layer_types = {} 145 | # dict formatted as -> key: :: value:[] 146 | layer_params = {} 147 | 148 | for l in net.layer: 149 | lname = l.name 150 | ltype = l.type 151 | if ltype in ltypes: 152 | #print("Processing layer {}".format(lname)) 153 | layer_types[lname] = ltype 154 | layer_params[lname] = _get_layer_params(l, ltype) 155 | 156 | # Set affine=False for all batchnorm modules 157 | def _no_affine_bn(module=None): 158 | if isinstance(module, nn.BatchNorm2d): 159 | module.affine = False 160 | 161 | if len([m for m in module.children()]) > 0: 162 | for child in module.children(): 163 | _no_affine_bn(child) 164 | 165 | #_no_affine_bn(self) 166 | 167 | 168 | def _transfer_conv(layer_name, module): 169 | weights, bias = layer_params[layer_name] 170 | w_shape = np.array(module.weight.size()) 171 | 172 | #print("CONV {}: Original {} and trans weights {}".format(layer_name, 173 | #w_shape, 174 | #weights.shape)) 175 | 176 | module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight)) 177 | 178 | if len(bias) != 0: 179 | b_shape = np.array(module.bias.size()) 180 | #print("CONV {}: Original {} and trans bias {}".format(layer_name, 181 | #b_shape, 182 | #bias.shape)) 183 | module.bias.data.copy_(torch.from_numpy(bias)) 184 | 185 | 186 | def _transfer_conv_bn(conv_layer_name, mother_module): 187 | conv_module = mother_module[0] 188 | bn_module = mother_module[1] 189 | 190 | _transfer_conv(conv_layer_name, conv_module) 191 | 192 | mean, var, gamma, beta = layer_params[conv_layer_name+'/bn'] 193 | #print("BN {}: Original {} and trans weights {}".format(conv_layer_name, 194 | #bn_module.running_mean.size(), 195 | #mean.shape)) 196 | bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean)) 197 | bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var)) 198 | bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight)) 199 | bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias)) 200 | 201 | 202 | def _transfer_residual(prefix, block): 203 | block_module, n_layers = block[0], block[1] 204 | 205 | bottleneck = block_module.layers[0] 206 | bottleneck_conv_bn_dic = {prefix + '_1_1x1_reduce': bottleneck.cbr1.cbr_unit, 207 | prefix + '_1_3x3': bottleneck.cbr2.cbr_unit, 208 | prefix + '_1_1x1_proj': bottleneck.cb4.cb_unit, 209 | prefix + '_1_1x1_increase': bottleneck.cb3.cb_unit,} 210 | 211 | for k, v in bottleneck_conv_bn_dic.items(): 212 | _transfer_conv_bn(k, v) 213 | 214 | for layer_idx in range(2, n_layers+1): 215 | residual_layer = block_module.layers[layer_idx-1] 216 | residual_conv_bn_dic = {'_'.join(map(str, [prefix, layer_idx, '1x1_reduce'])): residual_layer.cbr1.cbr_unit, 217 | '_'.join(map(str, [prefix, layer_idx, '3x3'])): residual_layer.cbr2.cbr_unit, 218 | '_'.join(map(str, [prefix, layer_idx, '1x1_increase'])): residual_layer.cb3.cb_unit,} 219 | 220 | for k, v in residual_conv_bn_dic.items(): 221 | _transfer_conv_bn(k, v) 222 | 223 | 224 | convbn_layer_mapping = {'conv1_1_3x3_s2': self.convbnrelu1_1.cbr_unit, 225 | 'conv1_2_3x3': self.convbnrelu1_2.cbr_unit, 226 | 'conv1_3_3x3': self.convbnrelu1_3.cbr_unit, 227 | 'conv5_3_pool6_conv': self.pyramid_pooling.paths[0].cbr_unit, 228 | 'conv5_3_pool3_conv': self.pyramid_pooling.paths[1].cbr_unit, 229 | 'conv5_3_pool2_conv': self.pyramid_pooling.paths[2].cbr_unit, 230 | 'conv5_3_pool1_conv': self.pyramid_pooling.paths[3].cbr_unit, 231 | 'conv5_4': self.cbr_final.cbr_unit,} 232 | 233 | residual_layers = {'conv2': [self.res_block2, self.block_config[0]], 234 | 'conv3': [self.res_block3, self.block_config[1]], 235 | 'conv4': [self.res_block4, self.block_config[2]], 236 | 'conv5': [self.res_block5, self.block_config[3]],} 237 | 238 | # Transfer weights for all non-residual conv+bn layers 239 | for k, v in convbn_layer_mapping.items(): 240 | _transfer_conv_bn(k, v) 241 | 242 | # Transfer weights for final non-bn conv layer 243 | _transfer_conv('conv6', self.classification) 244 | 245 | # Transfer weights for all residual layers 246 | for k, v in residual_layers.items(): 247 | _transfer_residual(k, v) 248 | 249 | 250 | def tile_predict(self, img): 251 | """ 252 | Predict by takin overlapping tiles from the image. 253 | 254 | Strides are adaptively computed from the img shape 255 | and input size 256 | 257 | :param img: np.ndarray with shape [C, H, W] in BGR format 258 | :param side: int with side length of model input 259 | :param n_classes: int with number of classes in seg output. 260 | """ 261 | 262 | side = self.input_size[0] 263 | n_classes = self.n_classes 264 | h, w = img.shape[1:] 265 | n = int(max(h,w) / float(side) + 1) 266 | stride_x = ( h - side ) / float(n) 267 | stride_y = ( w - side ) / float(n) 268 | 269 | x_ends = [[int(i*stride_x), int(i*stride_x) + side] for i in range(n+1)] 270 | y_ends = [[int(i*stride_y), int(i*stride_y) + side] for i in range(n+1)] 271 | 272 | pred = np.zeros([1, n_classes, h, w]) 273 | count = np.zeros([h, w]) 274 | 275 | slice_count = 0 276 | for sx, ex in x_ends: 277 | for sy, ey in y_ends: 278 | slice_count += 1 279 | 280 | img_slice = img[:, sx:ex, sy:ey] 281 | img_slice_flip = np.copy(img_slice[:,:,::-1]) 282 | 283 | is_model_on_cuda = next(self.parameters()).is_cuda 284 | 285 | inp = Variable(torch.unsqueeze(torch.from_numpy(img_slice).float(), 0), volatile=True) 286 | flp = Variable(torch.unsqueeze(torch.from_numpy(img_slice_flip).float(), 0), volatile=True) 287 | 288 | if is_model_on_cuda: 289 | inp = inp.cuda() 290 | flp = flp.cuda() 291 | 292 | psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() 293 | psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() 294 | psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 295 | 296 | pred[:, :, sx:ex, sy:ey] = psub 297 | count[sx:ex, sy:ey] += 1.0 298 | 299 | score = (pred / count[None, None, ...]).astype(np.float32)[0] 300 | return score / score.sum(axis=0) 301 | -------------------------------------------------------------------------------- /seg/segnet/unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | def __init__(self, in_ch, out_ch): 11 | super(double_conv, self).__init__() 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 14 | nn.BatchNorm2d(out_ch), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 17 | nn.BatchNorm2d(out_ch), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class inconv(nn.Module): 27 | def __init__(self, in_ch, out_ch): 28 | super(inconv, self).__init__() 29 | self.conv = double_conv(in_ch, out_ch) 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | return x 34 | 35 | 36 | class down(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(down, self).__init__() 39 | self.mpconv = nn.Sequential( 40 | nn.MaxPool2d(2), 41 | double_conv(in_ch, out_ch) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.mpconv(x) 46 | return x 47 | 48 | 49 | class up(nn.Module): 50 | def __init__(self, in_ch, out_ch, bilinear=True): 51 | super(up, self).__init__() 52 | 53 | if bilinear: 54 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 55 | else: 56 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 57 | 58 | self.conv = double_conv(in_ch, out_ch) 59 | 60 | def forward(self, x1, x2): 61 | x1 = self.up(x1) 62 | 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 68 | diffY // 2, diffY - diffY//2)) 69 | 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | 75 | class outconv(nn.Module): 76 | def __init__(self, in_ch, out_ch): 77 | super(outconv, self).__init__() 78 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | return x 83 | 84 | 85 | class UNet(nn.Module): 86 | def __init__(self, n_channels, n_classes): 87 | super(UNet, self).__init__() 88 | self.inc = inconv(n_channels, 64) 89 | self.down1 = down(64, 128) 90 | self.down2 = down(128, 256) 91 | self.down3 = down(256, 512) 92 | self.down4 = down(512, 512) 93 | self.up1 = up(1024, 256) 94 | self.up2 = up(512, 128) 95 | self.up3 = up(256, 64) 96 | self.up4 = up(128, 64) 97 | self.outc = outconv(64, n_classes) 98 | 99 | def forward(self, x): 100 | x1 = self.inc(x) 101 | x2 = self.down1(x1) 102 | x3 = self.down2(x2) 103 | x4 = self.down3(x3) 104 | x5 = self.down4(x4) 105 | x = self.up1(x5, x4) 106 | x = self.up2(x, x3) 107 | x = self.up3(x, x2) 108 | x = self.up4(x, x1) 109 | x = self.outc(x) 110 | 111 | return x 112 | -------------------------------------------------------------------------------- /seg/train_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | import numpy as np 8 | import argparse 9 | import time, copy 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | from torch.optim import lr_scheduler 15 | from torchvision import transforms 16 | import torch.backends.cudnn as cudnn 17 | 18 | from collections import defaultdict 19 | from segnet import UNet, pspnet 20 | 21 | from dataload import gen_dloader 22 | from loss import calc_loss, print_metrics 23 | from utils import LambdaLR 24 | 25 | 26 | def set_args(): 27 | parser = argparse.ArgumentParser(description = 'Liver Tumor Patch Segmentation') 28 | parser.add_argument("--class_num", type=int, default=1) 29 | parser.add_argument("--batch_size", type=int, default=8, help="batch size") 30 | parser.add_argument("--in_channels", type=int, default=3, help="input channel number") 31 | parser.add_argument("--maxepoch", type=int, default=50, help="number of epochs to train") 32 | parser.add_argument("--init_lr", type=float, default=1.0e-3, help="init learning rate for optimization") 33 | parser.add_argument("--bce_weight", type=float, default=0.1, help="weight of bce loss") 34 | parser.add_argument("--data_dir", type=str, default="../data/Patches") 35 | parser.add_argument("--model_dir", type=str, default="../data/Models") 36 | parser.add_argument("--tumor_type", type=str, default="whole") 37 | parser.add_argument("--normalize", type=bool, default=False) 38 | parser.add_argument("--model_name", type=str, default="PSP") 39 | parser.add_argument("--optim_name", type=str, default="SGD") 40 | parser.add_argument("--gpu", type=str, default="3,5,6,7", help="training gpu") 41 | parser.add_argument("--seed", type=int, default=1234, help="training seed") 42 | parser.add_argument("--session", type=str, default="Step02", help="training session") 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def train_seg_model(args): 49 | # model 50 | model = None 51 | if args.model_name == "UNet": 52 | model = UNet(n_channels=args.in_channels, n_classes=args.class_num) 53 | elif args.model_name == "PSP": 54 | model = pspnet.PSPNet(n_classes=19, input_size=(512, 512)) 55 | model.load_pretrained_model(model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel") 56 | model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) 57 | else: 58 | raise AssertionError("Unknow modle: {}".format(args.model_name)) 59 | model = nn.DataParallel(model) 60 | model.cuda() 61 | # optimizer 62 | optimizer = None 63 | if args.optim_name == "Adam": 64 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1.0e-3) 65 | elif args.optim_name == "SGD": 66 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 67 | lr=args.init_lr, momentum=0.9, weight_decay=0.0005) 68 | else: 69 | raise AssertionError("Unknow optimizer: {}".format(args.optim_name)) 70 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(args.maxepoch, 0, 0).step) 71 | # dataloader 72 | train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train") 73 | train_dloader = gen_dloader(train_data_dir, args.batch_size, mode="train", normalize=args.normalize, tumor_type=args.tumor_type) 74 | test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val") 75 | val_dloader = gen_dloader(test_data_dir, args.batch_size, mode="val", normalize=args.normalize, tumor_type=args.tumor_type) 76 | 77 | # training 78 | save_model_dir = os.path.join(args.model_dir, args.tumor_type, args.session) 79 | if not os.path.exists(save_model_dir): 80 | os.makedirs(save_model_dir) 81 | best_dice = 0.0 82 | for epoch in np.arange(0, args.maxepoch): 83 | print('Epoch {}/{}'.format(epoch+1, args.maxepoch)) 84 | print('-' * 10) 85 | since = time.time() 86 | for phase in ['train', 'val']: 87 | if phase == 'train': 88 | dloader = train_dloader 89 | scheduler.step() 90 | for param_group in optimizer.param_groups: 91 | print("Current LR: {:.8f}".format(param_group['lr'])) 92 | model.train() # Set model to training mode 93 | else: 94 | dloader = val_dloader 95 | model.eval() # Set model to evaluate mode 96 | 97 | metrics = defaultdict(float) 98 | epoch_samples = 0 99 | for batch_ind, (imgs, masks) in enumerate(dloader): 100 | inputs = Variable(imgs.cuda()) 101 | masks = Variable(masks.cuda()) 102 | optimizer.zero_grad() 103 | 104 | with torch.set_grad_enabled(phase=='train'): 105 | outputs = model(inputs) 106 | loss = calc_loss(outputs, masks, metrics, bce_weight=args.bce_weight) 107 | if phase == 'train': 108 | loss.backward() 109 | optimizer.step() 110 | # statistics 111 | epoch_samples += inputs.size(0) 112 | print_metrics(metrics, epoch_samples, phase) 113 | epoch_dice = metrics['dice'] / epoch_samples 114 | 115 | # deep copy the model 116 | if phase == 'val' and (epoch_dice > best_dice or epoch > args.maxepoch-5): 117 | best_dice = epoch_dice 118 | best_model = copy.deepcopy(model.state_dict()) 119 | best_model_name = "-".join([args.model_name, "{:03d}-{:.3f}.pth".format(epoch, best_dice)]) 120 | torch.save(best_model, os.path.join(save_model_dir, best_model_name)) 121 | time_elapsed = time.time() - since 122 | print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format(epoch, time_elapsed // 60, time_elapsed % 60)) 123 | print("================================================================================") 124 | print("Training finished...") 125 | 126 | 127 | if __name__ == '__main__': 128 | args = set_args() 129 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 130 | torch.cuda.manual_seed(args.seed) 131 | cudnn.benchmark = True 132 | 133 | # train model 134 | print("Training {} with {} on {}".format(args.model_name, args.optim_name, args.tumor_type)) 135 | train_seg_model(args) 136 | -------------------------------------------------------------------------------- /seg/train_seg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # sleep 3h 4 | python train_seg.py 5 | -------------------------------------------------------------------------------- /seg/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import numpy as np 5 | import itertools 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn.functional as F 9 | from pydaily import filesystem 10 | 11 | 12 | def get_slide_filenames(slides_dir): 13 | slide_list = [] 14 | svs_file_list = filesystem.find_ext_files(slides_dir, "svs") 15 | slide_list.extend([os.path.basename(ele) for ele in svs_file_list]) 16 | SVS_file_list = filesystem.find_ext_files(slides_dir, "SVS") 17 | slide_list.extend([os.path.basename(ele) for ele in SVS_file_list]) 18 | slide_filenames = [os.path.splitext(ele)[0] for ele in slide_list] 19 | 20 | slide_filenames.sort() 21 | 22 | return slide_filenames 23 | 24 | 25 | def mask2color(mask): 26 | # colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56), 27 | # (0, 0, 117), (128, 128, 0), (191, 239, 69), (145, 30, 180)]) 28 | color_img = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) 29 | color_img[mask==255] = (191, 239, 69) 30 | 31 | return color_img 32 | 33 | 34 | def gen_patch_pred(inputs, masks, preds): 35 | imgs = inputs.permute(0, 2, 3, 1).data.cpu().numpy() 36 | masks = torch.squeeze(masks, dim=1).data.cpu().numpy() 37 | preds = torch.squeeze(F.sigmoid(preds), dim=1).data.cpu().numpy() 38 | imgs = (imgs * 255).astype(np.uint8) 39 | masks = ((masks > 0.5) * 255).astype(np.uint8) 40 | preds = ((preds > 0.5) * 255).astype(np.uint8) 41 | 42 | img_num, img_size = imgs.shape[0], imgs.shape[1] 43 | result_img = np.zeros((img_num*img_size, img_size*3, imgs.shape[3]), dtype=np.uint8) 44 | for ind in np.arange(img_num): 45 | result_img[ind*img_size:(ind+1)*img_size, :img_size] = imgs[ind] 46 | result_img[ind*img_size:(ind+1)*img_size, img_size:img_size*2] = mask2color(masks[ind]) 47 | result_img[ind*img_size:(ind+1)*img_size, img_size*2:img_size*3] = mask2color(preds[ind]) 48 | 49 | return result_img 50 | 51 | 52 | def gen_patch_mask_wmap(slide_img, mask_img, coors_arr, plen): 53 | patch_list, mask_list = [], [] 54 | wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]), dtype=np.int32) 55 | for coor in coors_arr: 56 | ph, pw = coor[0], coor[1] 57 | patch_list.append(slide_img[ph:ph+plen, pw:pw+plen] / 255.0) 58 | mask_list.append(mask_img[ph:ph+plen, pw:pw+plen]) 59 | wmap[ph:ph+plen, pw:pw+plen] += 1 60 | patch_arr = np.asarray(patch_list).astype(np.float32) 61 | mask_arr = np.asarray(mask_list).astype(np.float32) 62 | 63 | return patch_arr, mask_arr, wmap 64 | 65 | 66 | def gen_patch_wmap(slide_img, coors_arr, plen): 67 | patch_list = [] 68 | wmap = np.zeros((slide_img.shape[0], slide_img.shape[1]), dtype=np.int32) 69 | for coor in coors_arr: 70 | ph, pw = coor[0], coor[1] 71 | patch_list.append(slide_img[ph:ph+plen, pw:pw+plen] / 255.0) 72 | wmap[ph:ph+plen, pw:pw+plen] += 1 73 | patch_arr = np.asarray(patch_list).astype(np.float32) 74 | 75 | return patch_arr, wmap 76 | 77 | 78 | def wsi_stride_splitting(wsi_h, wsi_w, patch_len, stride_len): 79 | """ Spltting whole slide image to patches by stride. 80 | 81 | Parameters 82 | ------- 83 | wsi_h: int 84 | height of whole slide image 85 | wsi_w: int 86 | width of whole slide image 87 | patch_len: int 88 | length of the patch image 89 | stride_len: int 90 | length of the stride 91 | 92 | Returns 93 | ------- 94 | coors_arr: list 95 | list of starting coordinates of patches ([0]-h, [1]-w) 96 | 97 | """ 98 | 99 | coors_arr = [] 100 | def stride_split(ttl_len, patch_len, stride_len): 101 | p_sets = [] 102 | if patch_len > ttl_len: 103 | raise AssertionError("patch length larger than total length") 104 | elif patch_len == ttl_len: 105 | p_sets.append(0) 106 | else: 107 | stride_num = int(np.ceil((ttl_len - patch_len) * 1.0 / stride_len)) 108 | for ind in range(stride_num+1): 109 | cur_pos = int(((ttl_len - patch_len) * 1.0 / stride_num) * ind) 110 | p_sets.append(cur_pos) 111 | 112 | return p_sets 113 | 114 | h_sets = stride_split(wsi_h, patch_len, stride_len) 115 | w_sets = stride_split(wsi_w, patch_len, stride_len) 116 | 117 | # combine points in both w and h direction 118 | if len(w_sets) > 0 and len(h_sets) > 0: 119 | coors_arr = list(itertools.product(h_sets, w_sets)) 120 | 121 | return coors_arr 122 | 123 | 124 | class LambdaLR(): 125 | def __init__(self, n_epochs, offset, decay_start_epoch): 126 | assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!" 127 | self.n_epochs = n_epochs 128 | self.offset = offset 129 | self.decay_start_epoch = decay_start_epoch 130 | 131 | def step(self, epoch): 132 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch) 133 | -------------------------------------------------------------------------------- /viable_whole_burden.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PingjunChen/LiverCancerSeg/0d83353af285f0fe05f2e3d65cf86b72140496b2/viable_whole_burden.jpg --------------------------------------------------------------------------------