├── src ├── __init__.py ├── unet │ ├── __init__.py │ ├── unet.py │ ├── blocks.py │ ├── rprops.py │ └── utils.py ├── measure.py ├── train.py └── preprocessing │ └── mask.py ├── .gitignore ├── docs └── img │ ├── step_02.png │ ├── step_03.png │ ├── step_04.png │ ├── step_05.png │ ├── step_06.png │ ├── step_07.png │ ├── step_08.png │ ├── step_09.png │ ├── step_10.png │ ├── vis_bbox.png │ ├── vis_seg.png │ └── vis_skl.png ├── LICENSE └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/unet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /docs/img/step_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_02.png -------------------------------------------------------------------------------- /docs/img/step_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_03.png -------------------------------------------------------------------------------- /docs/img/step_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_04.png -------------------------------------------------------------------------------- /docs/img/step_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_05.png -------------------------------------------------------------------------------- /docs/img/step_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_06.png -------------------------------------------------------------------------------- /docs/img/step_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_07.png -------------------------------------------------------------------------------- /docs/img/step_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_08.png -------------------------------------------------------------------------------- /docs/img/step_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_09.png -------------------------------------------------------------------------------- /docs/img/step_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_10.png -------------------------------------------------------------------------------- /docs/img/vis_bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_bbox.png -------------------------------------------------------------------------------- /docs/img/vis_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_seg.png -------------------------------------------------------------------------------- /docs/img/vis_skl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_skl.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 BIOMAG - Szeged 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/measure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from argparse import ArgumentParser 4 | 5 | from unet.utils import * 6 | from unet.unet import UNet 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument("--images_path", type=str, required=True) 10 | parser.add_argument("--model", type=str, required=True) 11 | parser.add_argument("--result_folder", type=str, required=True) 12 | parser.add_argument("--device", type=str, default='cpu') 13 | parser.add_argument("--min_object_size", type=float, default=0) 14 | parser.add_argument("--max_object_size", type=float, default=np.inf) 15 | parser.add_argument("--dpi", type=float, default=False) 16 | parser.add_argument("--dpm", type=float, default=False) 17 | parser.add_argument("--visualize", type=bool, default=False) 18 | args = parser.parse_args() 19 | 20 | # determining dpm 21 | dpm = args.dpm if not args.dpi else dpi_to_dpm(args.dpi) 22 | 23 | print("Loading dataset...") 24 | predict_dataset = ReadTestDataset(args.images_path) 25 | device = torch.device(args.device) 26 | print("Dataset loaded") 27 | 28 | print("Loading model...") 29 | unet = UNet(3, 3) 30 | unet.load_state_dict(torch.load(args.model, map_location=device)) 31 | model = ModelWrapper(unet, args.result_folder, cuda_device=device) 32 | print("Model loaded") 33 | print("Measuring images...") 34 | model.measure_large_images(predict_dataset, export_path=args.result_folder, 35 | visualize_bboxes=args.visualize, filter=[args.min_object_size, args.max_object_size], 36 | dpm=dpm, verbose=True, tile_res=(256, 256)) 37 | -------------------------------------------------------------------------------- /src/unet/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .blocks import * 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, in_channels, out_channels, conv_depths=[64, 128, 256, 512, 1024]): 10 | assert len(conv_depths) > 2, 'conv_depths must have at least 3 members' 11 | 12 | super(UNet, self).__init__() 13 | 14 | # defining encoder layers 15 | encoder_layers = [] 16 | encoder_layers.append(First2D(in_channels, conv_depths[0], conv_depths[0])) 17 | encoder_layers.extend([Encoder2D(conv_depths[i], conv_depths[i + 1], conv_depths[i + 1]) 18 | for i in range(len(conv_depths)-2)]) 19 | 20 | # defining decoder layers 21 | decoder_layers = [] 22 | decoder_layers.extend([Decoder2D(2 * conv_depths[i + 1], 2 * conv_depths[i], 2 * conv_depths[i], conv_depths[i]) 23 | for i in reversed(range(len(conv_depths)-2))]) 24 | decoder_layers.append(Last2D(conv_depths[1], conv_depths[0], out_channels)) 25 | 26 | # encoder, center and decoder layers 27 | self.encoder_layers = nn.Sequential(*encoder_layers) 28 | self.center = Center2D(conv_depths[-2], conv_depths[-1], conv_depths[-1], conv_depths[-2]) 29 | self.decoder_layers = nn.Sequential(*decoder_layers) 30 | 31 | def forward(self, x): 32 | x_enc = [x] 33 | for enc_layer in self.encoder_layers: 34 | x_enc.append(enc_layer(x_enc[-1])) 35 | 36 | x_dec = [self.center(x_enc[-1])] 37 | for dec_layer_idx, dec_layer in enumerate(self.decoder_layers): 38 | x_opposite = x_enc[-1-dec_layer_idx] 39 | x_cat = torch.cat( 40 | [pad_to_shape(x_dec[-1], x_opposite.shape), x_opposite], 41 | dim=1 42 | ) 43 | x_dec.append(dec_layer(x_cat)) 44 | 45 | return x_dec[-1] 46 | 47 | 48 | def pad_to_shape(this, shp): 49 | """ 50 | Pads this image with zeroes to shp. 51 | Args: 52 | this: image tensor to pad 53 | shp: desired output shape 54 | 55 | Returns: 56 | Zero-padded tensor of shape shp. 57 | """ 58 | if len(shp) == 4: 59 | pad = (0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2]) 60 | elif len(shp) == 5: 61 | pad = (0, shp[4] - this.shape[4], 0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2]) 62 | return F.pad(this, pad) 63 | 64 | 65 | if __name__ == '__main__': 66 | device = torch.device('cuda:1') 67 | unet = UNet3D(3, 2, conv_depths=[10, 20, 30, 40]).to(device) 68 | x = torch.rand(1, 3, 128, 128, 128).to(device) 69 | y = unet(x) 70 | print(y.shape) 71 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from unet.unet import UNet 9 | from unet.utils import * 10 | 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument("--train_dataset", type=str, required=True) 14 | parser.add_argument("--val_dataset", type=str, default=None) 15 | parser.add_argument("--epochs", type=int, default=1000) 16 | parser.add_argument("--batch_size", type=int, default=1) 17 | parser.add_argument("--model_name", type=str, default="UNet-hypocotyl") 18 | parser.add_argument("--trained_model_path", type=str, default=None) 19 | parser.add_argument("--initial_lr", type=float, default=1e-3) 20 | parser.add_argument("--device", type=str, default="cpu") 21 | parser.add_argument("--model_save_freq", type=int, default=200) 22 | 23 | args = parser.parse_args() 24 | 25 | tf_train = make_transform(crop=(512, 512), long_mask=True) 26 | tf_validate = make_transform(crop=(512, 512), long_mask=True, rotate_range=False, 27 | p_flip=0.0, normalize=False, color_jitter_params=None) 28 | 29 | # load dataset 30 | train_dataset_path = args.train_dataset 31 | train_dataset = ReadTrainDataset(train_dataset_path, transform=tf_train) 32 | 33 | if args.val_dataset is not None: 34 | validate_dataset_path = args.val_dataset 35 | validate_dataset = ReadTrainDataset(validate_dataset_path, transform=tf_validate) 36 | else: 37 | validate_dataset = None 38 | 39 | tf_train = make_transform(crop=(512, 512), long_mask=True, p_random_affine=0.0) 40 | tf_validate = make_transform(crop=(512, 512), long_mask=True, rotate_range=False, 41 | p_flip=0.0, normalize=False, color_jitter_params=None) 42 | 43 | # creating checkpoint folder 44 | model_name = args.model_name 45 | file_dir = os.path.split(os.path.realpath(__file__))[0] 46 | results_folder = os.path.join(file_dir, '..', 'checkpoints', model_name) 47 | if not os.path.exists(results_folder): 48 | os.makedirs(results_folder) 49 | 50 | # load model 51 | unet = UNet(3, 3) 52 | if args.trained_model_path is not None: 53 | unet.load_state_dict(torch.load(args.trained_model_path)) 54 | 55 | loss = SoftDiceLoss(weight=torch.Tensor([1, 5, 5])) 56 | optimizer = optim.Adam(unet.parameters(), lr=args.initial_lr) 57 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, verbose=True) 58 | 59 | cuda_device = torch.device(args.device) 60 | model = ModelWrapper(unet, loss=loss, optimizer=optimizer, scheduler=scheduler, 61 | results_folder=results_folder, cuda_device=cuda_device) 62 | 63 | model.train_model(train_dataset, validation_dataset=validate_dataset, 64 | n_batch=args.batch_size, n_epochs=args.epochs, 65 | verbose=False, save_freq=args.model_save_freq) 66 | -------------------------------------------------------------------------------- /src/unet/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class First2D(nn.Module): 5 | def __init__(self, in_channels, middle_channels, out_channels, dropout=False): 6 | super(First2D, self).__init__() 7 | 8 | layers = [ 9 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 10 | nn.BatchNorm2d(middle_channels), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(inplace=True) 15 | ] 16 | 17 | if dropout: 18 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1' 19 | layers.append(nn.Dropout2d(p=dropout)) 20 | 21 | self.first = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.first(x) 25 | 26 | 27 | class Encoder2D(nn.Module): 28 | def __init__( 29 | self, in_channels, middle_channels, out_channels, 30 | dropout=False, downsample_kernel=2 31 | ): 32 | super(Encoder2D, self).__init__() 33 | 34 | layers = [ 35 | nn.MaxPool2d(kernel_size=downsample_kernel), 36 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(middle_channels), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(out_channels), 41 | nn.ReLU(inplace=True) 42 | ] 43 | 44 | if dropout: 45 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1' 46 | layers.append(nn.Dropout2d(p=dropout)) 47 | 48 | self.encoder = nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | return self.encoder(x) 52 | 53 | 54 | class Center2D(nn.Module): 55 | def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False): 56 | super(Center2D, self).__init__() 57 | 58 | layers = [ 59 | nn.MaxPool2d(kernel_size=2), 60 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(middle_channels), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1), 64 | nn.BatchNorm2d(out_channels), 65 | nn.ReLU(inplace=True), 66 | nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2) 67 | ] 68 | 69 | if dropout: 70 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1' 71 | layers.append(nn.Dropout2d(p=dropout)) 72 | 73 | self.center = nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | return self.center(x) 77 | 78 | 79 | class Decoder2D(nn.Module): 80 | def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False): 81 | super(Decoder2D, self).__init__() 82 | 83 | layers = [ 84 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 85 | nn.BatchNorm2d(middle_channels), 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1), 88 | nn.BatchNorm2d(out_channels), 89 | nn.ReLU(inplace=True), 90 | nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2) 91 | ] 92 | 93 | if dropout: 94 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1' 95 | layers.append(nn.Dropout2d(p=dropout)) 96 | 97 | self.decoder = nn.Sequential(*layers) 98 | 99 | def forward(self, x): 100 | return self.decoder(x) 101 | 102 | 103 | class Last2D(nn.Module): 104 | def __init__(self, in_channels, middle_channels, out_channels, softmax=False): 105 | super(Last2D, self).__init__() 106 | 107 | layers = [ 108 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 109 | nn.BatchNorm2d(middle_channels), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1), 112 | nn.BatchNorm2d(middle_channels), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(middle_channels, out_channels, kernel_size=1), 115 | nn.Softmax(dim=1) 116 | ] 117 | 118 | self.first = nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | return self.first(x) 122 | -------------------------------------------------------------------------------- /src/preprocessing/mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io 4 | from itertools import product 5 | from collections import defaultdict 6 | from shutil import copyfile 7 | from argparse import ArgumentParser 8 | 9 | 10 | def chk_mkdir(*args): 11 | for path in args: 12 | if not os.path.exists(path): 13 | os.makedirs(path) 14 | 15 | 16 | def make_mask(hypo_mask_path, nonhypo_mask_path, export_path=None): 17 | hypo, nonhypo = io.imread(hypo_mask_path), io.imread(nonhypo_mask_path) 18 | hypo[hypo == 255] = 2 19 | nonhypo[nonhypo == 255] = 1 20 | mask = np.maximum(hypo, nonhypo) 21 | if export_path: 22 | io.imsave(export_path, mask) 23 | else: 24 | return mask 25 | 26 | 27 | def make_patches(dataset_path, export_path, patch_size=(512, 512), no_overlap=False): 28 | """ 29 | Takes the data folder CONTAINING MERGED MASKS and slices the 30 | images and masks into patches. 31 | """ 32 | # make output directories 33 | dataset_images_path = os.path.join(dataset_path, 'images') 34 | dataset_masks_path = os.path.join(dataset_path, 'masks') 35 | new_images_path = os.path.join(export_path, 'images') 36 | new_masks_path = os.path.join(export_path, 'masks') 37 | 38 | chk_mkdir(new_masks_path, new_images_path) 39 | 40 | for image_filename in os.listdir(dataset_images_path): 41 | # reading images 42 | im = io.imread(os.path.join(dataset_images_path, image_filename)) 43 | masked_im = io.imread(os.path.join(dataset_masks_path, image_filename)) 44 | # make new folders 45 | 46 | x_start = list() 47 | y_start = list() 48 | 49 | if no_overlap: 50 | x_step = patch_size[0] 51 | y_step = patch_size[1] 52 | else: 53 | x_step = patch_size[0] // 2 54 | y_step = patch_size[1] // 2 55 | 56 | for x_idx in range(0, im.shape[0] - patch_size[0] + 1, x_step): 57 | x_start.append(x_idx) 58 | 59 | if im.shape[0] - patch_size[0] - 1 > 0: 60 | x_start.append(im.shape[0] - patch_size[0] - 1) 61 | 62 | for y_idx in range(0, im.shape[1] - patch_size[1] + 1, y_step): 63 | y_start.append(y_idx) 64 | 65 | if im.shape[1] - patch_size[1] - 1 > 0: 66 | y_start.append(im.shape[1] - patch_size[1] - 1) 67 | 68 | for num, (x_idx, y_idx) in enumerate(product(x_start, y_start)): 69 | new_image_filename = os.path.splitext(image_filename)[0] + '_%d.png' % num 70 | # saving a patch of the original image 71 | io.imsave( 72 | os.path.join(new_images_path, new_image_filename), 73 | im[x_idx:x_idx + patch_size[0], y_idx:y_idx + patch_size[1], :] 74 | ) 75 | # saving the corresponding patch of the mask 76 | io.imsave( 77 | os.path.join(new_masks_path, new_image_filename), 78 | masked_im[x_idx:x_idx + patch_size[0], y_idx:y_idx + patch_size[1]] 79 | ) 80 | 81 | 82 | def train_test_validate_split(data_path, export_path, ratios=[0.6, 0.2, 0.2]): 83 | dst_path = defaultdict(dict) 84 | for dataset, data_type in product(['train', 'test', 'validate'], ['images', 'masks']): 85 | set_type_path = os.path.join(export_path, dataset, data_type) 86 | dst_path[dataset][data_type] = set_type_path 87 | chk_mkdir(set_type_path) 88 | 89 | for image_filename in os.listdir(os.path.join(data_path, 'images')): 90 | src_path = { 91 | 'images': os.path.join(data_path, 'images', image_filename), 92 | 'masks': os.path.join(data_path, 'masks', image_filename) 93 | } 94 | 95 | dataset = np.random.choice(['train', 'test', 'validate'], p=ratios) 96 | 97 | for data_type in ['images', 'masks']: 98 | copyfile(src_path[data_type], os.path.join(dst_path[dataset][data_type], image_filename)) 99 | 100 | 101 | def imageJ_elementwise_mask_to_png(elementwise_mask_path): 102 | img = io.imread(elementwise_mask_path) 103 | folder, fname = os.path.split(elementwise_mask_path) 104 | fname_root, ext = os.path.splitext(fname) 105 | io.imsave(os.path.join(folder, fname_root + '.png'), img) 106 | 107 | 108 | if __name__ == '__main__': 109 | 110 | parser = ArgumentParser() 111 | parser.add_argument("--images_folder", required=True, type=str) 112 | parser.add_argument("--export_folder", required=True, type=str) 113 | parser.add_argument("--make_patches", type=(lambda x: str(x).lower() == 'true'), default=False) 114 | 115 | args = parser.parse_args() 116 | 117 | images_folder = args.images_folder 118 | export_root = args.export_folder 119 | 120 | unet_ready_folder = os.path.join(export_root, 'converted') 121 | 122 | chk_mkdir(export_root, os.path.join(unet_ready_folder, 'images'), 123 | os.path.join(unet_ready_folder, 'masks')) 124 | 125 | for image_name in os.listdir(images_folder): 126 | hypo_mask_path = os.path.join(images_folder, image_name, '%s-hypo.png' % image_name) 127 | nonhypo_mask_path = os.path.join(images_folder, image_name, '%s-nonhypo.png' % image_name) 128 | export_path = os.path.join(unet_ready_folder, 'masks', '%s.png' % image_name) 129 | make_mask(hypo_mask_path, nonhypo_mask_path, export_path) 130 | copyfile(os.path.join(images_folder, image_name, image_name + '.png'), 131 | os.path.join(unet_ready_folder, 'images', image_name + '.png')) 132 | 133 | if args.make_patches: 134 | patches_export_folder = os.path.join(export_root, 'patched_images') 135 | chk_mkdir(patches_export_folder) 136 | make_patches(unet_ready_folder, patches_export_folder, (800, 800), no_overlap=True) 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A deep learning-based approach for high throughput plant phenotyping 2 | 3 | This repository is the companion for the paper [A deep learning-based approach for high throughput plant phenotyping, 4 | Dobos et al.](http://www.plantphysiol.org/content/181/4/1415). 5 | 6 | [The dataset used in the article can be found at this page](https://www.kaggle.com/tivadardanka/plant-segmentation/). 7 | 8 | [The pretrained model used in the research article can be downloaded here.](https://drive.google.com/open?id=1SlUui64l-k63vxysl0YAflKaECfpj8Rr) 9 | 10 | ## Contents 11 | - [Usage, dependencies](#usage) 12 | - [Using a trained model for measuring hypocotyls](#measuring) 13 | - [Generating your own training data](#annotating) 14 | - [Training a model on your own images](#training) 15 | 16 | ## Usage, dependencies 17 | For measuring hypocotyls and training a custom model, it is required to have 18 | - Python >= 3.5 19 | - PyTorch >= 0.4 20 | - NumPy >= 1.13 21 | - Pandas >= 0.23 22 | - scikit-image >= 0.14 23 | - matplotlib >= 3.0 24 | 25 | To use the hypocotyl segmentation tool, clone the repository to the local machine: 26 | ```bash 27 | git clone https://github.com/biomag-lab/hypocotyl-UNet 28 | ``` 29 | `src/measure.py` can be used for applying the measuring algorithm on custom images, while `src/train.py` are for training the UNet model on custom annotated data. (Detailed description on them can be found below.) 30 | 31 | ## Using a trained model for measuring hypocotyls 32 | To apply the algorithm on custom images, the folder containing the images should be organized into the following directory structure: 33 | ```bash 34 | images_folder 35 | |-- images 36 | |-- img001.png 37 | |-- img002.png 38 | |-- ... 39 | ``` 40 | The `src/measure.py` script can be used to run the algorithm. The required arguments are 41 | - `--images_path`: path to the images folder, which must have the structure outlined above. 42 | - `--model`: path to the UNet model used in the algorithm. 43 | - `--result_folder`: path to the folder where results will be exported. 44 | 45 | [The model used in the research article can be found here.](https://drive.google.com/open?id=1SlUui64l-k63vxysl0YAflKaECfpj8Rr) 46 | 47 | Additionally, you can specify the following: 48 | - `--min_object_size`: the expected minimum object size in pixels. Default is 0. Detected objects below this size will be filtered out. 49 | - `--max_object_size`: the expected maximum object size in pixels. Default is `np.inf`. Detected objects above this size will be filtered out. 50 | - `--dpi` and `--dpm`: to export the lengths in *mm*, it is required to provide a *dpi* (dots per inch) or *dpm* (dots per millimeter) value. If any of this is available, the pixel units will be converted to *mm* during measurements. If both *dpi* and *dpm* is set, only *dpm* will be taken into account. 51 | - `--device`: device to be used for the UNet prediction. Default is `cpu`, but if a GPU with the CUDA framework installed is available, `cuda:$ID` can be used, where `$ID` is the ID of the GPU. For example, `cuda:0`. (For PyTorch users: this argument is passed directly to the `torch.Tensor.device` object during initialization, which will be used for the rest of the workflow.) 52 | - `--visualize`: set to True to export a visualization of the results. For each measured image, the following images are exported along with the length measurements. 53 |

54 | 55 | 56 | 57 |

58 | 59 | For instance, an example is the following: 60 | ```bash 61 | python3 measure.py --images_path path_to_images \ 62 | --model ../models/unet \ 63 | --result_folder path_to_results \ 64 | --device cuda:0 65 | ``` 66 | 67 | ## Generating your own training data 68 | In case the algorithm performs poorly, for instance if the images were taken under very different conditions than the ones provided to the available model during training, the UNet backbone model can be retrained on custom data. This process is called *annotation*, which can be done easily with ImageJ. 69 | 70 | Step 1. Organize your images to be annotated such that each image is contained in a separate folder with common root. The folder should be named after the image. For example: 71 | ```bash 72 | images_folder 73 | |-- image_1 74 | |-- image_1.png 75 | |-- image_2 76 | |-- image_2.png 77 | |-- ... 78 | ``` 79 | 80 | Step 2. Open the image to be annotated in ImageJ. 81 | ![](docs/img/step_02.png) 82 | 83 | Step 3. Open the *ROI manager* tool. 84 | ![](docs/img/step_03.png) 85 | 86 | Step 4. Select an appropriate selection tool, for example *Freehand selections*. 87 | ![](docs/img/step_04.png) 88 | 89 | Step 5. Draw the outline of the part which is part of the plant and should be included during the measurements. Press *t* or click on the *Add [t]* button to add the selected part to the ROI manager. 90 | ![](docs/img/step_05.png) 91 | 92 | Step 6. Repeat the outlining with all of the selections, adding them one by one. When it is done, select all and click *More > OR (Combine)*. 93 | ![](docs/img/step_06.png) 94 | 95 | Step 7. Press *Edit > Selection > Create Mask*. This will open up a new image. 96 | ![](docs/img/step_07.png) 97 | 98 | Step 8. Invert the mask image by pressing *Edit > Invert* or pressing *Ctrl + Shift + I*. 99 | ![](docs/img/step_08.png) 100 | 101 | Step 9. Save the mask image to the folder with the original. Add the suffix *-hypo* to the name. For example, if the original name was *image_1.png*, this image should be named *image_1-hypo.png*. 102 | ![](docs/img/step_09.png) 103 | 104 | Step 10. Repeat the same process to annotate the parts of the plant which should not be included in the measurements. Save the mask to the same folder. Add the suffix *-nonhypo* to the name. 105 | ![](docs/img/step_10.png) 106 | 107 | Step 11. Create the training data for the algorithm by running the `src/preprocessing/mask.py`. The required arguments are 108 | - `--images_folder`: path to the folder where the folders containing the images and masks located. 109 | - `--export_folder`: path to the folder where the results should be exported. 110 | - `--make_patches`: True if images and masks are to be patched up to smaller pieces. Recommended for training. 111 | 112 | ## Training a model on your own images 113 | If custom annotated data is available, the containing folder should be organized into the following directory structure: 114 | ```bash 115 | images_folder 116 | |-- images 117 | |-- img001.png 118 | |-- img002.png 119 | |-- ... 120 | |-- masks 121 | |-- img001.png 122 | |-- img002.png 123 | |-- ... 124 | ``` 125 | The mask images should have identical name to their corresponding image. After the training data is organized, the `src/train.py` script can be used to train a custom UNet model. The required arguments are 126 | - `--train_dataset`: path to the folder where the training data is located. This should match with the `--export_folder` argument given to the `src/preprocessing/mask.py` script during Step 11. of the previous point. 127 | 128 | The optional arguments are 129 | - `--epochs`: the number of epochs during training. Default is 1000, but this is very dependent on the dataset and data augmentation method used. 130 | - `--batch_size`: the size of the batch during training. Default is 1. If GPU is used, it is recommended to select batch size as large as GPU memory allows. 131 | - `--initial_lr`: initial learning rate. Default is 1e-3, which proved to be the best for the training dataset used. (Different datasets might have more optimal initial learning rates.) 132 | - `--model_name`: name of the model. Default is *UNet-hypocotyl*. 133 | - `--trained_model_path`: to continue training of a previously trained model, its path can be given. 134 | - `--val_dataset`: path to the folder where the validation data is located. During training, validation data is used to catch overfitting and apply early stopping. 135 | - `--model_save_freq`: frequency of model saving. Default is 200, which means that the model is saved after every 200th epoch. 136 | - `--device`: device to be used for the UNet training. Default is `cpu`, but if a GPU with the CUDA framework installed is available, `cuda:$ID` can be used, where `$ID` is the ID of the GPU. For example, `cuda:0`. (For PyTorch users: this argument is passed directly to the `torch.Tensor.device` object during initialization, which will be used for the rest of the workflow.) 137 | 138 | -------------------------------------------------------------------------------- /src/unet/rprops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib.patches as patches 7 | 8 | from skimage import io, img_as_uint 9 | from skimage.morphology import skeletonize, medial_axis, skeletonize_3d 10 | from skimage.measure import regionprops, label 11 | from skimage.filters import threshold_otsu 12 | from skimage.measure._regionprops import _RegionProperties 13 | 14 | from typing import Container 15 | from numbers import Number 16 | 17 | 18 | class BBox: 19 | def __init__(self, rprops_bbox): 20 | min_row, min_col, max_row, max_col = rprops_bbox 21 | # regionprops bbox representation 22 | self.min_row = min_row 23 | self.min_col = min_col 24 | self.max_row = max_row 25 | self.max_col = max_col 26 | self.bbox = rprops_bbox 27 | 28 | # rectangle representation 29 | self.x, self.y = min_col, min_row 30 | self.width = max_col - min_col 31 | self.height = max_row - min_row 32 | 33 | # coordinate representation 34 | self.P1 = (min_col, min_row) 35 | self.P2 = (max_col, min_row) 36 | self.P3 = (min_col, max_row) 37 | self.P4 = (max_col, max_row) 38 | 39 | def __repr__(self): 40 | return str(self.bbox) 41 | 42 | def __getitem__(self, item): 43 | return self.bbox[item] 44 | 45 | def IOU(self, other_bbox): 46 | # determining the intersection coordinates 47 | P1_int = (max(self.P1[0], other_bbox.P1[0]), 48 | max(self.P1[1], other_bbox.P1[1])) 49 | P4_int = (min(self.P4[0], other_bbox.P4[0]), 50 | min(self.P4[1], other_bbox.P4[1])) 51 | 52 | # check for intersections 53 | if (P1_int[0] > P4_int[0]) or (P1_int[1] > P4_int[1]): 54 | return 0 55 | 56 | intersection_area = (P4_int[0] - P1_int[0]) * (P4_int[1] - P1_int[1]) 57 | union_area = self.area() + other_bbox.area() - intersection_area 58 | 59 | return intersection_area / union_area 60 | 61 | def area(self): 62 | return self.width * self.height 63 | 64 | 65 | class Hypo: 66 | def __init__(self, rprops, dpm=False): 67 | self.length = rprops.area 68 | if dpm: 69 | self.length /= dpm 70 | self.bbox = BBox(rprops.bbox) 71 | 72 | def __repr__(self): 73 | return "[%d, %s]" % (self.length, self.bbox) 74 | 75 | def IOU(self, other_hypo): 76 | return self.bbox.IOU(other_hypo.bbox) 77 | 78 | 79 | class HypoResult: 80 | def __init__(self, rprops_or_hypos, dpm=False): 81 | if isinstance(rprops_or_hypos[0], Hypo): 82 | self.hypo_list = rprops_or_hypos 83 | elif isinstance(rprops_or_hypos[0], _RegionProperties): 84 | self.hypo_list = [Hypo(rprops, dpm) for rprops in rprops_or_hypos] 85 | self.gt_match = None 86 | 87 | def __getitem__(self, item): 88 | if isinstance(item, Number): 89 | return self.hypo_list[item] 90 | if isinstance(item, Container): 91 | # check the datatype of the list 92 | if isinstance(item[0], np.bool_): 93 | item = [idx for idx, val in enumerate(item) if val] 94 | return HypoResult([self.hypo_list[idx] for idx in item]) 95 | 96 | def __len__(self): 97 | return len(self.hypo_list) 98 | 99 | def mean(self): 100 | return np.mean([hypo.length for hypo in self.hypo_list]) 101 | 102 | def std(self): 103 | return np.std([hypo.length for hypo in self.hypo_list]) 104 | 105 | def score(self, gt_hyporesult, match_threshold=0.5): 106 | scores = [] 107 | hypo_ious = np.zeros((len(self), len(gt_hyporesult))) 108 | objectwise_df = pd.DataFrame(columns=['algorithm', 'ground truth'], index=range(len(gt_hyporesult))) 109 | 110 | for hypo_idx, hypo in enumerate(self.hypo_list): 111 | hypo_ious[hypo_idx] = np.array([hypo.IOU(gt_hypo) for gt_hypo in gt_hyporesult]) 112 | best_match = np.argmax(hypo_ious[hypo_idx]) 113 | 114 | # a match is found if the intersection over union metric is 115 | # larger than the given threshold 116 | if hypo_ious[hypo_idx][best_match] > match_threshold: 117 | # calculate the accuracy of the measurement 118 | gt_hypo = gt_hyporesult[best_match] 119 | error = abs(hypo.length - gt_hypo.length) 120 | scores.append(1 - error/gt_hypo.length) 121 | 122 | gt_hypo_ious = hypo_ious.T 123 | for gt_hypo_idx, gt_hypo in enumerate(gt_hyporesult): 124 | objectwise_df.loc[gt_hypo_idx, 'ground truth'] = gt_hypo.length 125 | best_match = np.argmax(gt_hypo_ious[gt_hypo_idx]) 126 | if gt_hypo_ious[gt_hypo_idx][best_match] > match_threshold: 127 | objectwise_df.loc[gt_hypo_idx, 'algorithm'] = self.hypo_list[best_match].length 128 | 129 | # precision, recall 130 | self.gt_match = np.apply_along_axis(np.any, 0, hypo_ious > match_threshold) 131 | self.match = np.apply_along_axis(np.any, 1, hypo_ious > match_threshold) 132 | # identified_objects = self[self.match] 133 | true_positives = self.gt_match.sum() 134 | precision = true_positives/len(self) 135 | recall = true_positives/len(gt_hyporesult) 136 | 137 | score_dict = {'accuracy': np.mean(scores), 138 | 'precision': precision, 139 | 'recall': recall, 140 | 'gt_mean': gt_hyporesult.mean(), 141 | 'result_mean': self.mean(), 142 | 'gt_std': gt_hyporesult.std(), 143 | 'result_std': self.std()} 144 | 145 | return score_dict, objectwise_df 146 | 147 | def make_df(self): 148 | result_df = pd.DataFrame( 149 | [[hypo.length, *hypo.bbox] for hypo in self.hypo_list], 150 | columns=['length', 'min_row', 'min_col', 'max_row', 'max_col'], 151 | index=range(1, len(self)+1) 152 | ) 153 | 154 | return result_df 155 | 156 | def hist(self, gt_hyporesult, export_path): 157 | lengths = [hypo.length for hypo in self.hypo_list] 158 | gt_lengths = [hypo.length for hypo in gt_hyporesult] 159 | histogram_bins = range(0, 500, 10) 160 | 161 | with plt.style.context('seaborn-white'): 162 | plt.figure(figsize=(10, 15)) 163 | plt.hist(lengths, bins=histogram_bins, color='r', alpha=0.2, label='result') 164 | plt.hist(gt_lengths, bins=histogram_bins, color='b', alpha=0.2, label='ground truth') 165 | plt.legend() 166 | plt.savefig(export_path) 167 | plt.close('all') 168 | 169 | def filter(self, flt): 170 | if isinstance(flt, Container): 171 | min_length, max_length = flt 172 | self.hypo_list = [h for h in self.hypo_list if min_length <= h.length <= max_length] 173 | elif isinstance(flt, bool) and flt: 174 | otsu_thresh = threshold_otsu(np.array([h.length for h in self.hypo_list])) 175 | self.hypo_list = [h for h in self.hypo_list if otsu_thresh <= h.length] 176 | 177 | 178 | def bbox_to_rectangle(bbox): 179 | # bbox format: 'min_row', 'min_col', 'max_row', 'max_col' 180 | # Rectangle format: bottom left (x, y), width, height 181 | min_row, min_col, max_row, max_col = bbox 182 | x, y = min_col, min_row 183 | width = max_col - min_col 184 | height = max_row - min_row 185 | 186 | return (x, y), width, height 187 | 188 | 189 | def get_hypo_rprops(hypo, filter=True, already_skeletonized=False, skeleton_method=skeletonize_3d, 190 | return_skeleton=False, dpm=False): 191 | """ 192 | Args: 193 | hypo: segmented hypocotyl image 194 | filter: boolean or list of [min_length, max_length] 195 | """ 196 | hypo_thresh = (hypo > 0.5) 197 | if not already_skeletonized: 198 | hypo_skeleton = label(img_as_uint(skeleton_method(hypo_thresh))) 199 | else: 200 | hypo_skeleton = label(img_as_uint(hypo_thresh)) 201 | 202 | hypo_rprops = regionprops(hypo_skeleton) 203 | # filter out small regions 204 | hypo_result = HypoResult(hypo_rprops, dpm) 205 | hypo_result.filter(flt=filter) 206 | 207 | if return_skeleton: 208 | return hypo_result, hypo_skeleton > 0 209 | 210 | return hypo_result 211 | 212 | 213 | def visualize_regions(hypo_img, hypo_result, export_path=None, bbox_color='r', dpi=800): 214 | with plt.style.context('seaborn-white'): 215 | # parameters 216 | fontsize = 3.0 * (800.0 / dpi) 217 | linewidth = fontsize / 10.0 218 | 219 | figsize = (hypo_img.shape[1]/dpi, hypo_img.shape[0]/dpi) 220 | fig = plt.figure(figsize=figsize, dpi=dpi) 221 | ax = plt.Axes(fig, [0,0,1,1]) #plt.subplot(111) 222 | fig.add_axes(ax) 223 | ax.imshow(hypo_img) 224 | for hypo_idx, hypo in enumerate(hypo_result): 225 | rectangle = patches.Rectangle((hypo.bbox.x, hypo.bbox.y), hypo.bbox.width, hypo.bbox.height, 226 | linewidth=linewidth, edgecolor=bbox_color, facecolor='none') 227 | ax.add_patch(rectangle) 228 | ax.text(hypo.bbox.x, hypo.bbox.y - linewidth - 24, "N.%d." % (hypo_idx+1), fontsize=fontsize, color='k') 229 | ax.text(hypo.bbox.x, hypo.bbox.y - linewidth, str(hypo.length)[:4], fontsize=fontsize, color=bbox_color) 230 | 231 | fig.axes[0].get_xaxis().set_visible(False) 232 | fig.axes[0].get_yaxis().set_visible(False) 233 | 234 | if export_path is None: 235 | plt.show() 236 | else: 237 | plt.savefig(export_path, dpi=dpi) 238 | 239 | plt.close('all') 240 | -------------------------------------------------------------------------------- /src/unet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | from skimage import io, img_as_uint 7 | from skimage.morphology import skeletonize_3d 8 | 9 | from numbers import Number 10 | from itertools import product 11 | 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader, Dataset 14 | from torch.nn.functional import cross_entropy 15 | from torch.nn.modules.loss import _WeightedLoss 16 | from torchvision import transforms as T 17 | from torchvision.transforms import functional as F 18 | 19 | from .rprops import get_hypo_rprops, visualize_regions 20 | 21 | 22 | def chk_mkdir(*args): 23 | for path in args: 24 | if path is not None and not os.path.exists(path): 25 | os.makedirs(path) 26 | 27 | 28 | def dpi_to_dpm(dpi): 29 | # small hack, default value for dpi is False 30 | if not dpi: 31 | return False 32 | 33 | return dpi/25.4 34 | 35 | 36 | def dpm_to_dpi(dpm): 37 | if not dpm: 38 | return False 39 | 40 | return dpm * 25.4 41 | 42 | 43 | def to_long_tensor(pic): 44 | # handle numpy array 45 | img = torch.from_numpy(np.array(pic, np.uint8)) 46 | # backward compatibility 47 | return img.long() 48 | 49 | 50 | def joint_to_long_tensor(image, mask): 51 | return to_long_tensor(image), to_long_tensor(mask) 52 | 53 | 54 | def make_transform( 55 | crop=(256, 256), p_flip=0.5, p_color=0.0, color_jitter_params=(0.1, 0.1, 0.1, 0.1), 56 | p_random_affine=0.0, rotate_range=False, normalize=False, long_mask=False 57 | ): 58 | 59 | if color_jitter_params is not None: 60 | color_tf = T.ColorJitter(*color_jitter_params) 61 | else: 62 | color_tf = None 63 | 64 | if normalize: 65 | tf_normalize = T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1)) 66 | 67 | def joint_transform(image, mask): 68 | # transforming to PIL image 69 | image, mask = F.to_pil_image(image), F.to_pil_image(mask) 70 | 71 | # random crop 72 | if crop: 73 | i, j, h, w = T.RandomCrop.get_params(image, crop) 74 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w) 75 | if np.random.rand() < p_flip: 76 | image, mask = F.hflip(image), F.hflip(mask) 77 | 78 | # color transforms || ONLY ON IMAGE 79 | if color_tf is not None: 80 | if np.random.rand() < p_color: 81 | image = color_tf(image) 82 | 83 | # random rotation 84 | if rotate_range and not p_random_affine: 85 | if np.random.rand() < 0.5: 86 | angle = rotate_range * (np.random.rand() - 0.5) 87 | image, mask = F.rotate(image, angle), F.rotate(mask, angle) 88 | 89 | # random affine 90 | if np.random.rand() < p_random_affine: 91 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), crop) 92 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params) 93 | 94 | # transforming to tensor 95 | image = F.to_tensor(image) 96 | if not long_mask: 97 | mask = F.to_tensor(mask) 98 | else: 99 | mask = to_long_tensor(mask) 100 | 101 | # normalizing image 102 | if normalize: 103 | image = tf_normalize(image) 104 | 105 | return image, mask 106 | 107 | return joint_transform 108 | 109 | 110 | def confusion_matrix(prediction, target, n_classes): 111 | """ 112 | prediction, target: torch.Tensor objects 113 | """ 114 | prediction = torch.argmax(prediction, dim=0).long() 115 | target = torch.squeeze(target, dim=0) 116 | 117 | conf_mtx = torch.zeros(n_classes, n_classes).long() 118 | for i, j in product(range(n_classes), range(n_classes)): 119 | conf_mtx[i, j] = torch.sum((prediction == j) * (target == i)) 120 | 121 | return conf_mtx 122 | 123 | 124 | class SoftDiceLoss(_WeightedLoss): 125 | __constants__ = ['weight', 'reduction'] 126 | 127 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): 128 | if weight is None: 129 | weight = torch.tensor(1) 130 | else: 131 | # creating tensor if needed 132 | if not isinstance(weight, torch.Tensor): 133 | weight = torch.tensor(weight) 134 | # normalizing weights 135 | weight /= torch.sum(weight) 136 | 137 | super(SoftDiceLoss, self).__init__(weight, size_average, reduce, reduction) 138 | 139 | def forward(self, y_pred, y_gt): 140 | """ 141 | Args: 142 | y_pred: torch.Tensor of shape (n_batch, n_classes, image.shape) 143 | y_gt: torch.LongTensor of shape (n_batch, image.shape) 144 | """ 145 | dims = (0, *range(2, len(y_pred.shape))) 146 | 147 | y_gt = torch.zeros_like(y_pred).scatter_(1, y_gt[:, None, :], 1) 148 | numerator = 2 * torch.sum(y_pred * y_gt, dim=dims) 149 | denominator = torch.sum(y_pred * y_pred + y_gt * y_gt, dim=dims) 150 | return torch.sum((1 - numerator / denominator)*self.weight) 151 | 152 | 153 | class LogNLLLoss(_WeightedLoss): 154 | __constants__ = ['weight', 'reduction', 'ignore_index'] 155 | 156 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', 157 | ignore_index=-100): 158 | super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction) 159 | self.ignore_index = ignore_index 160 | 161 | def forward(self, input, target): 162 | input = torch.log(input) 163 | return cross_entropy(input, target, weight=self.weight, reduction=self.reduction, 164 | ignore_index=self.ignore_index) 165 | 166 | 167 | class ReadTrainDataset(Dataset): 168 | """ 169 | Structure of the dataset should be: 170 | 171 | dataset_path 172 | |--images 173 | |--img001.png 174 | |--img002.png 175 | |--masks 176 | |--img001.png 177 | |--img002.png 178 | 179 | """ 180 | 181 | def __init__(self, dataset_path, transform=None, one_hot_mask=False, long_mask=True): 182 | self.dataset_path = dataset_path 183 | self.images_path = os.path.join(dataset_path, 'images') 184 | self.masks_path = os.path.join(dataset_path, 'masks') 185 | self.images_list = os.listdir(self.images_path) 186 | 187 | self.transform = transform 188 | self.one_hot_mask = one_hot_mask 189 | self.long_mask = long_mask 190 | 191 | def __len__(self): 192 | return len(os.listdir(self.images_path)) 193 | 194 | def __getitem__(self, idx): 195 | image_filename = self.images_list[idx] 196 | image = io.imread(os.path.join(self.images_path, image_filename)) 197 | mask = io.imread(os.path.join(self.masks_path, image_filename)) 198 | if len(mask.shape) == 2: 199 | mask = np.expand_dims(mask, axis=2) 200 | 201 | if self.transform: 202 | image, mask = self.transform(image, mask) 203 | else: 204 | image = F.to_tensor(image) 205 | if self.long_mask: 206 | mask = to_long_tensor(F.to_pil_image(mask)) 207 | else: 208 | mask = F.to_tensor(mask) 209 | 210 | if self.one_hot_mask: 211 | assert self.one_hot_mask >= 0, 'one_hot_mask must be nonnegative' 212 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1) 213 | 214 | return image, mask, image_filename 215 | 216 | 217 | class ReadTestDataset(Dataset): 218 | """ 219 | Structure of the dataset should be: 220 | 221 | dataset_path 222 | |--images 223 | |--img001.png 224 | |--img002.png 225 | 226 | """ 227 | 228 | def __init__(self, dataset_path, transform=None): 229 | self.dataset_path = dataset_path 230 | self.images_path = os.path.join(dataset_path, 'images') 231 | self.images_list = os.listdir(self.images_path) 232 | 233 | self.transform = transform 234 | 235 | def __len__(self): 236 | return len(os.listdir(self.images_path)) 237 | 238 | def __getitem__(self, idx): 239 | image_filename = self.images_list[idx] 240 | image = io.imread(os.path.join(self.images_path, image_filename)) 241 | 242 | if self.transform: 243 | image = self.transform(image) 244 | else: 245 | image = F.to_tensor(image) 246 | 247 | return image, image_filename 248 | 249 | 250 | class ModelWrapper: 251 | def __init__( 252 | self, model, results_folder, loss=None, optimizer=None, 253 | scheduler=None, cuda_device=None 254 | ): 255 | self.model = model 256 | self.loss = loss 257 | self.optimizer = optimizer 258 | self.scheduler = scheduler 259 | self.results_folder = results_folder 260 | chk_mkdir(self.results_folder) 261 | 262 | self.cuda_device = cuda_device 263 | if self.cuda_device: 264 | self.model.to(device=self.cuda_device) 265 | try: 266 | self.loss.to(device=self.cuda_device) 267 | except AttributeError: 268 | pass 269 | 270 | def train_model(self, dataset, n_epochs, n_batch=1, verbose=False, 271 | validation_dataset=None, prediction_dataset=None, 272 | save_freq=100): 273 | self.model.train(True) 274 | 275 | # logging losses 276 | loss_df = pd.DataFrame(np.zeros(shape=(n_epochs, 2)), columns=['train', 'validate'], index=range(n_epochs)) 277 | 278 | min_loss = np.inf 279 | total_running_loss = 0 280 | for epoch_idx in range(n_epochs): 281 | 282 | epoch_running_loss = 0 283 | for batch_idx, (X_batch, y_batch, name) in enumerate(DataLoader(dataset, batch_size=n_batch, shuffle=True)): 284 | if self.cuda_device: 285 | X_batch = Variable(X_batch.to(device=self.cuda_device)) 286 | y_batch = Variable(y_batch.to(device=self.cuda_device)) 287 | else: 288 | X_batch, y_batch = Variable(X_batch), Variable(y_batch) 289 | 290 | # training 291 | self.optimizer.zero_grad() 292 | y_out = self.model(X_batch) 293 | training_loss = self.loss(y_out, y_batch) 294 | training_loss.backward() 295 | self.optimizer.step() 296 | 297 | epoch_running_loss += training_loss.item() 298 | 299 | if verbose: 300 | print('(Epoch no. %d, batch no. %d) loss: %f' % (epoch_idx, batch_idx, training_loss.item())) 301 | 302 | total_running_loss += epoch_running_loss/(batch_idx + 1) 303 | print('(Epoch no. %d) loss: %f' % (epoch_idx, epoch_running_loss/(batch_idx + 1))) 304 | loss_df.loc[epoch_idx, 'train'] = epoch_running_loss/(batch_idx + 1) 305 | 306 | if validation_dataset is not None: 307 | validation_error = self.validate(validation_dataset, n_batch=1) 308 | loss_df.loc[epoch_idx, 'validate'] = validation_error 309 | if validation_error < min_loss: 310 | torch.save(self.model.state_dict(), os.path.join(self.results_folder, 'model')) 311 | print('Validation loss improved from %f to %f, model saved to %s' 312 | % (min_loss, validation_error, self.results_folder)) 313 | min_loss = validation_error 314 | 315 | if self.scheduler is not None: 316 | self.scheduler.step(validation_error) 317 | 318 | else: 319 | if epoch_running_loss/(batch_idx + 1) < min_loss: 320 | torch.save(self.model.state_dict(), os.path.join(self.results_folder, 'model')) 321 | print('Training loss improved from %f to %f, model saved to %s' 322 | % (min_loss, epoch_running_loss / (batch_idx + 1), self.results_folder)) 323 | min_loss = epoch_running_loss / (batch_idx + 1) 324 | 325 | if self.scheduler is not None: 326 | self.scheduler.step(epoch_running_loss / (batch_idx + 1)) 327 | 328 | # saving model and logs 329 | loss_df.to_csv(os.path.join(self.results_folder, 'loss.csv')) 330 | if epoch_idx % save_freq == 0: 331 | epoch_save_path = os.path.join(self.results_folder, '%d' % epoch_idx) 332 | chk_mkdir(epoch_save_path) 333 | torch.save(self.model.state_dict(), os.path.join(epoch_save_path, 'model')) 334 | if prediction_dataset: 335 | self.predict_large_images(prediction_dataset, epoch_save_path) 336 | 337 | self.model.train(False) 338 | 339 | del X_batch, y_batch 340 | 341 | return total_running_loss/n_batch 342 | 343 | def validate(self, dataset, n_batch=1): 344 | self.model.train(False) 345 | 346 | total_running_loss = 0 347 | for batch_idx, (X_batch, y_batch, name) in enumerate(DataLoader(dataset, batch_size=n_batch, shuffle=False)): 348 | 349 | if self.cuda_device: 350 | X_batch = Variable(X_batch.to(device=self.cuda_device)) 351 | y_batch = Variable(y_batch.to(device=self.cuda_device)) 352 | else: 353 | X_batch, y_batch = Variable(X_batch), Variable(y_batch) 354 | 355 | y_out = self.model(X_batch) 356 | training_loss = self.loss(y_out, y_batch) 357 | 358 | total_running_loss += training_loss.item() 359 | 360 | print('Validation loss: %f' % (total_running_loss / (batch_idx + 1))) 361 | self.model.train(True) 362 | 363 | del X_batch, y_batch 364 | 365 | return total_running_loss/(batch_idx + 1) 366 | 367 | def predict(self, dataset, export_path, channel=None): 368 | self.model.train(False) 369 | chk_mkdir(export_path) 370 | 371 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)): 372 | if self.cuda_device: 373 | X_batch = Variable(X_batch.to(device=self.cuda_device)) 374 | y_out = self.model(X_batch).cpu().data.numpy() 375 | else: 376 | X_batch = Variable(X_batch) 377 | y_out = self.model(X_batch).data.numpy() 378 | 379 | if channel: 380 | try: 381 | io.imsave(os.path.join(export_path, image_filename[0]), y_out[0, channel, :, :]) 382 | except: 383 | print('something went wrong upon prediction') 384 | 385 | else: 386 | try: 387 | io.imsave(os.path.join(export_path, image_filename[0]), y_out[0, :, :, :].transpose((1, 2, 0))) 388 | except: 389 | print('something went wrong upon prediction') 390 | 391 | def predict_large_images(self, dataset, export_path=None, channel=None, tile_res=(512, 512)): 392 | self.model.train(False) 393 | if export_path: 394 | chk_mkdir(export_path) 395 | else: 396 | results = [] 397 | 398 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)): 399 | out = self.predict_single_large_image(X_batch, channel=channel, tile_res=tile_res) 400 | 401 | if export_path: 402 | io.imsave(os.path.join(export_path, image_filename[0]), out) 403 | else: 404 | results.append(out) 405 | 406 | if not export_path: 407 | return results 408 | 409 | def predict_single_large_image(self, X_image, channel=None, tile_res=(512, 512)): 410 | image_res = X_image.shape 411 | # placeholder for output 412 | y_out_full = np.zeros(shape=(1, 3, image_res[2], image_res[3])) 413 | # generate tile coordinates 414 | tile_x = list(range(0, image_res[2], tile_res[0]))[:-1] + [image_res[2] - tile_res[0]] 415 | tile_y = list(range(0, image_res[3], tile_res[1]))[:-1] + [image_res[3] - tile_res[1]] 416 | tile = product(tile_x, tile_y) 417 | # predictions 418 | for slice in tile: 419 | 420 | if self.cuda_device: 421 | X_in = X_image[:, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]].to( 422 | device=self.cuda_device) 423 | X_in = Variable(X_in) 424 | else: 425 | X_in = X_image[:, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]] 426 | X_in = Variable(X_in) 427 | 428 | y_out = self.model(X_in).cpu().data.numpy() 429 | y_out_full[0, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]] = y_out 430 | 431 | # save image 432 | if channel: 433 | out = y_out_full[0, channel, :, :] 434 | else: 435 | out = y_out_full[0, :, :, :].transpose((1, 2, 0)) 436 | 437 | return out 438 | 439 | def measure_large_images(self, dataset, visualize_bboxes=False, filter=True, export_path=None, 440 | skeleton_method=skeletonize_3d, dpm=False, verbose=False, tile_res=(512, 512)): 441 | hypocotyl_lengths = dict() 442 | chk_mkdir(export_path) 443 | 444 | assert any(isinstance(dpm, tp) for tp in [str, bool, Number]), 'dpm must be string, bool or Number' 445 | 446 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)): 447 | 448 | if verbose: 449 | print("Measuring %s" % image_filename[0]) 450 | 451 | hypo_segmented = self.predict_single_large_image(X_batch, tile_res=tile_res) 452 | hypo_segmented_mask = hypo_segmented[:, :, 2] 453 | hypo_result, hypo_skeleton = get_hypo_rprops(hypo_segmented_mask, filter=filter, return_skeleton=True, 454 | skeleton_method=skeleton_method, 455 | dpm=dpm) 456 | hypo_df = hypo_result.make_df() 457 | 458 | hypocotyl_lengths[image_filename] = hypo_df 459 | 460 | if export_path: 461 | if visualize_bboxes: 462 | hypo_img = X_batch[0].cpu().data.numpy().transpose((1, 2, 0)) 463 | # original image 464 | visualize_regions(hypo_img, hypo_result, 465 | os.path.join(export_path, image_filename[0][:-4] + '.png')) 466 | # segmentation 467 | visualize_regions(hypo_segmented, hypo_result, 468 | os.path.join(export_path, image_filename[0][:-4] + '_segmentation.png'), 469 | bbox_color='0.5') 470 | # skeletonization 471 | visualize_regions(hypo_skeleton, hypo_result, 472 | os.path.join(export_path, image_filename[0][:-4] + '_skeleton.png')) 473 | 474 | hypocotyl_lengths[image_filename].to_csv(os.path.join(export_path, image_filename[0][:-4] + '.csv'), 475 | header=True, index=True) 476 | 477 | return hypocotyl_lengths 478 | 479 | def score_large_images(self, dataset, export_path, visualize_bboxes=False, visualize_histograms=False, 480 | visualize_segmentation=False, 481 | filter=True, skeletonized_gt=False, match_threshold=0.5, tile_res=(512, 512), 482 | dpm=False): 483 | chk_mkdir(export_path) 484 | 485 | scores = {} 486 | 487 | assert any(isinstance(dpm, tp) for tp in [str, bool, Number]), 'dpm must be string, bool or Number' 488 | 489 | if isinstance(dpm, str): 490 | dpm_df = pd.read_csv(dpm, header=None, index_col=0) 491 | 492 | for batch_idx, (X_batch, y_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)): 493 | if isinstance(dpm, str): 494 | dpm_val = dpm_df.loc[image_filename].values[0] 495 | elif isinstance(dpm, Number) or dpm == False: 496 | dpm_val = dpm 497 | else: 498 | raise ValueError('dpm must be str, Number or False') 499 | 500 | # getting filter range 501 | if isinstance(filter, dict): 502 | filter_val = filter[image_filename[0]] 503 | else: 504 | filter_val = filter 505 | 506 | segmented_img = self.predict_single_large_image(X_batch, tile_res=tile_res) 507 | hypo_result_mask = segmented_img[:, :, 2] 508 | hypo_result, hypo_result_skeleton = get_hypo_rprops(hypo_result_mask, filter=filter_val, 509 | return_skeleton=True, dpm=dpm_val) 510 | hypo_result.make_df().to_csv(os.path.join(export_path, image_filename[0][:-4] + '_result.csv')) 511 | 512 | if visualize_segmentation: 513 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_skeletons.png'), 514 | img_as_uint(hypo_result_skeleton)) 515 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_hypo.png'), 516 | hypo_result_mask) 517 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_full.png'), 518 | segmented_img) 519 | 520 | if not skeletonized_gt: 521 | hypo_gt_mask = y_batch[0].data.numpy() == 2 522 | else: 523 | hypo_gt_mask = y_batch[0].data.numpy() > 0 524 | 525 | hypo_result_gt = get_hypo_rprops(hypo_gt_mask, filter=[20/dpm_val, np.inf], 526 | already_skeletonized=skeletonized_gt, dpm=dpm_val) 527 | hypo_result_gt.make_df().to_csv(os.path.join(export_path, image_filename[0][:-4] + '_gt.csv')) 528 | 529 | scores[image_filename[0]], objectwise_df = hypo_result.score(hypo_result_gt, 530 | match_threshold=match_threshold) 531 | objectwise_df.to_csv(os.path.join(export_path, image_filename[0][:-4] + '_matched.csv')) 532 | 533 | # visualization 534 | # histograms 535 | if visualize_histograms: 536 | hypo_result.hist(hypo_result_gt, 537 | os.path.join(export_path, image_filename[0][:-4] + '_hist.png')) 538 | 539 | # bounding boxes 540 | if visualize_bboxes: 541 | visualize_regions(hypo_gt_mask, hypo_result_gt, 542 | export_path=os.path.join(export_path, image_filename[0][:-4] + '_gt.png')) 543 | visualize_regions(hypo_result_skeleton, hypo_result, 544 | export_path=os.path.join(export_path, image_filename[0][:-4] + '_result.png')) 545 | 546 | score_df = pd.DataFrame(scores).T 547 | score_df.to_csv(os.path.join(export_path, 'scores.csv')) 548 | 549 | return scores 550 | 551 | def visualize_workflow(self, dataset, export_path, filter=False): 552 | for image, mask, image_filename in DataLoader(dataset, batch_size=1): 553 | image_filename_root = image_filename[0].split('.')[0] 554 | hypo_full_result = self.predict_single_large_image(image, tile_res=(512, 512)) 555 | hypo_result_mask = self.predict_single_large_image(image, channel=2, tile_res=(512, 512)) 556 | # save multiclass mask 557 | io.imsave(os.path.join(export_path, image_filename_root + '_1.png'), hypo_full_result) 558 | hypo_result, hypo_result_skeleton = get_hypo_rprops(hypo_result_mask, return_skeleton=True, filter=filter) 559 | io.imsave(os.path.join(export_path, image_filename_root + '_2.png'), img_as_uint(hypo_result_skeleton)) 560 | visualize_regions(image.data.cpu().numpy()[0].transpose((1, 2, 0)), hypo_result, 561 | os.path.join(export_path, image_filename_root + '_3.png')) 562 | 563 | 564 | if __name__ == '__main__': 565 | hypo_mask_img = io.imread('/home/namazu/Data/hypocotyl/measurement_test/masks/140925 8-4 050.png') 566 | hypo_result = get_hypo_rprops(hypo_mask_img, filter=False, already_skeletonized=True) 567 | hypo_result.score(hypo_result) 568 | --------------------------------------------------------------------------------