├── src
├── __init__.py
├── unet
│ ├── __init__.py
│ ├── unet.py
│ ├── blocks.py
│ ├── rprops.py
│ └── utils.py
├── measure.py
├── train.py
└── preprocessing
│ └── mask.py
├── .gitignore
├── docs
└── img
│ ├── step_02.png
│ ├── step_03.png
│ ├── step_04.png
│ ├── step_05.png
│ ├── step_06.png
│ ├── step_07.png
│ ├── step_08.png
│ ├── step_09.png
│ ├── step_10.png
│ ├── vis_bbox.png
│ ├── vis_seg.png
│ └── vis_skl.png
├── LICENSE
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/unet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 |
--------------------------------------------------------------------------------
/docs/img/step_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_02.png
--------------------------------------------------------------------------------
/docs/img/step_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_03.png
--------------------------------------------------------------------------------
/docs/img/step_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_04.png
--------------------------------------------------------------------------------
/docs/img/step_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_05.png
--------------------------------------------------------------------------------
/docs/img/step_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_06.png
--------------------------------------------------------------------------------
/docs/img/step_07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_07.png
--------------------------------------------------------------------------------
/docs/img/step_08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_08.png
--------------------------------------------------------------------------------
/docs/img/step_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_09.png
--------------------------------------------------------------------------------
/docs/img/step_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/step_10.png
--------------------------------------------------------------------------------
/docs/img/vis_bbox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_bbox.png
--------------------------------------------------------------------------------
/docs/img/vis_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_seg.png
--------------------------------------------------------------------------------
/docs/img/vis_skl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biomag-lab/hypocotyl-UNet/HEAD/docs/img/vis_skl.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 BIOMAG - Szeged
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/measure.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from argparse import ArgumentParser
4 |
5 | from unet.utils import *
6 | from unet.unet import UNet
7 |
8 | parser = ArgumentParser()
9 | parser.add_argument("--images_path", type=str, required=True)
10 | parser.add_argument("--model", type=str, required=True)
11 | parser.add_argument("--result_folder", type=str, required=True)
12 | parser.add_argument("--device", type=str, default='cpu')
13 | parser.add_argument("--min_object_size", type=float, default=0)
14 | parser.add_argument("--max_object_size", type=float, default=np.inf)
15 | parser.add_argument("--dpi", type=float, default=False)
16 | parser.add_argument("--dpm", type=float, default=False)
17 | parser.add_argument("--visualize", type=bool, default=False)
18 | args = parser.parse_args()
19 |
20 | # determining dpm
21 | dpm = args.dpm if not args.dpi else dpi_to_dpm(args.dpi)
22 |
23 | print("Loading dataset...")
24 | predict_dataset = ReadTestDataset(args.images_path)
25 | device = torch.device(args.device)
26 | print("Dataset loaded")
27 |
28 | print("Loading model...")
29 | unet = UNet(3, 3)
30 | unet.load_state_dict(torch.load(args.model, map_location=device))
31 | model = ModelWrapper(unet, args.result_folder, cuda_device=device)
32 | print("Model loaded")
33 | print("Measuring images...")
34 | model.measure_large_images(predict_dataset, export_path=args.result_folder,
35 | visualize_bboxes=args.visualize, filter=[args.min_object_size, args.max_object_size],
36 | dpm=dpm, verbose=True, tile_res=(256, 256))
37 |
--------------------------------------------------------------------------------
/src/unet/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .blocks import *
6 |
7 |
8 | class UNet(nn.Module):
9 | def __init__(self, in_channels, out_channels, conv_depths=[64, 128, 256, 512, 1024]):
10 | assert len(conv_depths) > 2, 'conv_depths must have at least 3 members'
11 |
12 | super(UNet, self).__init__()
13 |
14 | # defining encoder layers
15 | encoder_layers = []
16 | encoder_layers.append(First2D(in_channels, conv_depths[0], conv_depths[0]))
17 | encoder_layers.extend([Encoder2D(conv_depths[i], conv_depths[i + 1], conv_depths[i + 1])
18 | for i in range(len(conv_depths)-2)])
19 |
20 | # defining decoder layers
21 | decoder_layers = []
22 | decoder_layers.extend([Decoder2D(2 * conv_depths[i + 1], 2 * conv_depths[i], 2 * conv_depths[i], conv_depths[i])
23 | for i in reversed(range(len(conv_depths)-2))])
24 | decoder_layers.append(Last2D(conv_depths[1], conv_depths[0], out_channels))
25 |
26 | # encoder, center and decoder layers
27 | self.encoder_layers = nn.Sequential(*encoder_layers)
28 | self.center = Center2D(conv_depths[-2], conv_depths[-1], conv_depths[-1], conv_depths[-2])
29 | self.decoder_layers = nn.Sequential(*decoder_layers)
30 |
31 | def forward(self, x):
32 | x_enc = [x]
33 | for enc_layer in self.encoder_layers:
34 | x_enc.append(enc_layer(x_enc[-1]))
35 |
36 | x_dec = [self.center(x_enc[-1])]
37 | for dec_layer_idx, dec_layer in enumerate(self.decoder_layers):
38 | x_opposite = x_enc[-1-dec_layer_idx]
39 | x_cat = torch.cat(
40 | [pad_to_shape(x_dec[-1], x_opposite.shape), x_opposite],
41 | dim=1
42 | )
43 | x_dec.append(dec_layer(x_cat))
44 |
45 | return x_dec[-1]
46 |
47 |
48 | def pad_to_shape(this, shp):
49 | """
50 | Pads this image with zeroes to shp.
51 | Args:
52 | this: image tensor to pad
53 | shp: desired output shape
54 |
55 | Returns:
56 | Zero-padded tensor of shape shp.
57 | """
58 | if len(shp) == 4:
59 | pad = (0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2])
60 | elif len(shp) == 5:
61 | pad = (0, shp[4] - this.shape[4], 0, shp[3] - this.shape[3], 0, shp[2] - this.shape[2])
62 | return F.pad(this, pad)
63 |
64 |
65 | if __name__ == '__main__':
66 | device = torch.device('cuda:1')
67 | unet = UNet3D(3, 2, conv_depths=[10, 20, 30, 40]).to(device)
68 | x = torch.rand(1, 3, 128, 128, 128).to(device)
69 | y = unet(x)
70 | print(y.shape)
71 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 |
8 | from unet.unet import UNet
9 | from unet.utils import *
10 |
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument("--train_dataset", type=str, required=True)
14 | parser.add_argument("--val_dataset", type=str, default=None)
15 | parser.add_argument("--epochs", type=int, default=1000)
16 | parser.add_argument("--batch_size", type=int, default=1)
17 | parser.add_argument("--model_name", type=str, default="UNet-hypocotyl")
18 | parser.add_argument("--trained_model_path", type=str, default=None)
19 | parser.add_argument("--initial_lr", type=float, default=1e-3)
20 | parser.add_argument("--device", type=str, default="cpu")
21 | parser.add_argument("--model_save_freq", type=int, default=200)
22 |
23 | args = parser.parse_args()
24 |
25 | tf_train = make_transform(crop=(512, 512), long_mask=True)
26 | tf_validate = make_transform(crop=(512, 512), long_mask=True, rotate_range=False,
27 | p_flip=0.0, normalize=False, color_jitter_params=None)
28 |
29 | # load dataset
30 | train_dataset_path = args.train_dataset
31 | train_dataset = ReadTrainDataset(train_dataset_path, transform=tf_train)
32 |
33 | if args.val_dataset is not None:
34 | validate_dataset_path = args.val_dataset
35 | validate_dataset = ReadTrainDataset(validate_dataset_path, transform=tf_validate)
36 | else:
37 | validate_dataset = None
38 |
39 | tf_train = make_transform(crop=(512, 512), long_mask=True, p_random_affine=0.0)
40 | tf_validate = make_transform(crop=(512, 512), long_mask=True, rotate_range=False,
41 | p_flip=0.0, normalize=False, color_jitter_params=None)
42 |
43 | # creating checkpoint folder
44 | model_name = args.model_name
45 | file_dir = os.path.split(os.path.realpath(__file__))[0]
46 | results_folder = os.path.join(file_dir, '..', 'checkpoints', model_name)
47 | if not os.path.exists(results_folder):
48 | os.makedirs(results_folder)
49 |
50 | # load model
51 | unet = UNet(3, 3)
52 | if args.trained_model_path is not None:
53 | unet.load_state_dict(torch.load(args.trained_model_path))
54 |
55 | loss = SoftDiceLoss(weight=torch.Tensor([1, 5, 5]))
56 | optimizer = optim.Adam(unet.parameters(), lr=args.initial_lr)
57 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, verbose=True)
58 |
59 | cuda_device = torch.device(args.device)
60 | model = ModelWrapper(unet, loss=loss, optimizer=optimizer, scheduler=scheduler,
61 | results_folder=results_folder, cuda_device=cuda_device)
62 |
63 | model.train_model(train_dataset, validation_dataset=validate_dataset,
64 | n_batch=args.batch_size, n_epochs=args.epochs,
65 | verbose=False, save_freq=args.model_save_freq)
66 |
--------------------------------------------------------------------------------
/src/unet/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class First2D(nn.Module):
5 | def __init__(self, in_channels, middle_channels, out_channels, dropout=False):
6 | super(First2D, self).__init__()
7 |
8 | layers = [
9 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
10 | nn.BatchNorm2d(middle_channels),
11 | nn.ReLU(inplace=True),
12 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
13 | nn.BatchNorm2d(out_channels),
14 | nn.ReLU(inplace=True)
15 | ]
16 |
17 | if dropout:
18 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
19 | layers.append(nn.Dropout2d(p=dropout))
20 |
21 | self.first = nn.Sequential(*layers)
22 |
23 | def forward(self, x):
24 | return self.first(x)
25 |
26 |
27 | class Encoder2D(nn.Module):
28 | def __init__(
29 | self, in_channels, middle_channels, out_channels,
30 | dropout=False, downsample_kernel=2
31 | ):
32 | super(Encoder2D, self).__init__()
33 |
34 | layers = [
35 | nn.MaxPool2d(kernel_size=downsample_kernel),
36 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
37 | nn.BatchNorm2d(middle_channels),
38 | nn.ReLU(inplace=True),
39 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
40 | nn.BatchNorm2d(out_channels),
41 | nn.ReLU(inplace=True)
42 | ]
43 |
44 | if dropout:
45 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
46 | layers.append(nn.Dropout2d(p=dropout))
47 |
48 | self.encoder = nn.Sequential(*layers)
49 |
50 | def forward(self, x):
51 | return self.encoder(x)
52 |
53 |
54 | class Center2D(nn.Module):
55 | def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
56 | super(Center2D, self).__init__()
57 |
58 | layers = [
59 | nn.MaxPool2d(kernel_size=2),
60 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(middle_channels),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
64 | nn.BatchNorm2d(out_channels),
65 | nn.ReLU(inplace=True),
66 | nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2)
67 | ]
68 |
69 | if dropout:
70 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
71 | layers.append(nn.Dropout2d(p=dropout))
72 |
73 | self.center = nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | return self.center(x)
77 |
78 |
79 | class Decoder2D(nn.Module):
80 | def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
81 | super(Decoder2D, self).__init__()
82 |
83 | layers = [
84 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
85 | nn.BatchNorm2d(middle_channels),
86 | nn.ReLU(inplace=True),
87 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
88 | nn.BatchNorm2d(out_channels),
89 | nn.ReLU(inplace=True),
90 | nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2)
91 | ]
92 |
93 | if dropout:
94 | assert 0 <= dropout <= 1, 'dropout must be between 0 and 1'
95 | layers.append(nn.Dropout2d(p=dropout))
96 |
97 | self.decoder = nn.Sequential(*layers)
98 |
99 | def forward(self, x):
100 | return self.decoder(x)
101 |
102 |
103 | class Last2D(nn.Module):
104 | def __init__(self, in_channels, middle_channels, out_channels, softmax=False):
105 | super(Last2D, self).__init__()
106 |
107 | layers = [
108 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
109 | nn.BatchNorm2d(middle_channels),
110 | nn.ReLU(inplace=True),
111 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
112 | nn.BatchNorm2d(middle_channels),
113 | nn.ReLU(inplace=True),
114 | nn.Conv2d(middle_channels, out_channels, kernel_size=1),
115 | nn.Softmax(dim=1)
116 | ]
117 |
118 | self.first = nn.Sequential(*layers)
119 |
120 | def forward(self, x):
121 | return self.first(x)
122 |
--------------------------------------------------------------------------------
/src/preprocessing/mask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from skimage import io
4 | from itertools import product
5 | from collections import defaultdict
6 | from shutil import copyfile
7 | from argparse import ArgumentParser
8 |
9 |
10 | def chk_mkdir(*args):
11 | for path in args:
12 | if not os.path.exists(path):
13 | os.makedirs(path)
14 |
15 |
16 | def make_mask(hypo_mask_path, nonhypo_mask_path, export_path=None):
17 | hypo, nonhypo = io.imread(hypo_mask_path), io.imread(nonhypo_mask_path)
18 | hypo[hypo == 255] = 2
19 | nonhypo[nonhypo == 255] = 1
20 | mask = np.maximum(hypo, nonhypo)
21 | if export_path:
22 | io.imsave(export_path, mask)
23 | else:
24 | return mask
25 |
26 |
27 | def make_patches(dataset_path, export_path, patch_size=(512, 512), no_overlap=False):
28 | """
29 | Takes the data folder CONTAINING MERGED MASKS and slices the
30 | images and masks into patches.
31 | """
32 | # make output directories
33 | dataset_images_path = os.path.join(dataset_path, 'images')
34 | dataset_masks_path = os.path.join(dataset_path, 'masks')
35 | new_images_path = os.path.join(export_path, 'images')
36 | new_masks_path = os.path.join(export_path, 'masks')
37 |
38 | chk_mkdir(new_masks_path, new_images_path)
39 |
40 | for image_filename in os.listdir(dataset_images_path):
41 | # reading images
42 | im = io.imread(os.path.join(dataset_images_path, image_filename))
43 | masked_im = io.imread(os.path.join(dataset_masks_path, image_filename))
44 | # make new folders
45 |
46 | x_start = list()
47 | y_start = list()
48 |
49 | if no_overlap:
50 | x_step = patch_size[0]
51 | y_step = patch_size[1]
52 | else:
53 | x_step = patch_size[0] // 2
54 | y_step = patch_size[1] // 2
55 |
56 | for x_idx in range(0, im.shape[0] - patch_size[0] + 1, x_step):
57 | x_start.append(x_idx)
58 |
59 | if im.shape[0] - patch_size[0] - 1 > 0:
60 | x_start.append(im.shape[0] - patch_size[0] - 1)
61 |
62 | for y_idx in range(0, im.shape[1] - patch_size[1] + 1, y_step):
63 | y_start.append(y_idx)
64 |
65 | if im.shape[1] - patch_size[1] - 1 > 0:
66 | y_start.append(im.shape[1] - patch_size[1] - 1)
67 |
68 | for num, (x_idx, y_idx) in enumerate(product(x_start, y_start)):
69 | new_image_filename = os.path.splitext(image_filename)[0] + '_%d.png' % num
70 | # saving a patch of the original image
71 | io.imsave(
72 | os.path.join(new_images_path, new_image_filename),
73 | im[x_idx:x_idx + patch_size[0], y_idx:y_idx + patch_size[1], :]
74 | )
75 | # saving the corresponding patch of the mask
76 | io.imsave(
77 | os.path.join(new_masks_path, new_image_filename),
78 | masked_im[x_idx:x_idx + patch_size[0], y_idx:y_idx + patch_size[1]]
79 | )
80 |
81 |
82 | def train_test_validate_split(data_path, export_path, ratios=[0.6, 0.2, 0.2]):
83 | dst_path = defaultdict(dict)
84 | for dataset, data_type in product(['train', 'test', 'validate'], ['images', 'masks']):
85 | set_type_path = os.path.join(export_path, dataset, data_type)
86 | dst_path[dataset][data_type] = set_type_path
87 | chk_mkdir(set_type_path)
88 |
89 | for image_filename in os.listdir(os.path.join(data_path, 'images')):
90 | src_path = {
91 | 'images': os.path.join(data_path, 'images', image_filename),
92 | 'masks': os.path.join(data_path, 'masks', image_filename)
93 | }
94 |
95 | dataset = np.random.choice(['train', 'test', 'validate'], p=ratios)
96 |
97 | for data_type in ['images', 'masks']:
98 | copyfile(src_path[data_type], os.path.join(dst_path[dataset][data_type], image_filename))
99 |
100 |
101 | def imageJ_elementwise_mask_to_png(elementwise_mask_path):
102 | img = io.imread(elementwise_mask_path)
103 | folder, fname = os.path.split(elementwise_mask_path)
104 | fname_root, ext = os.path.splitext(fname)
105 | io.imsave(os.path.join(folder, fname_root + '.png'), img)
106 |
107 |
108 | if __name__ == '__main__':
109 |
110 | parser = ArgumentParser()
111 | parser.add_argument("--images_folder", required=True, type=str)
112 | parser.add_argument("--export_folder", required=True, type=str)
113 | parser.add_argument("--make_patches", type=(lambda x: str(x).lower() == 'true'), default=False)
114 |
115 | args = parser.parse_args()
116 |
117 | images_folder = args.images_folder
118 | export_root = args.export_folder
119 |
120 | unet_ready_folder = os.path.join(export_root, 'converted')
121 |
122 | chk_mkdir(export_root, os.path.join(unet_ready_folder, 'images'),
123 | os.path.join(unet_ready_folder, 'masks'))
124 |
125 | for image_name in os.listdir(images_folder):
126 | hypo_mask_path = os.path.join(images_folder, image_name, '%s-hypo.png' % image_name)
127 | nonhypo_mask_path = os.path.join(images_folder, image_name, '%s-nonhypo.png' % image_name)
128 | export_path = os.path.join(unet_ready_folder, 'masks', '%s.png' % image_name)
129 | make_mask(hypo_mask_path, nonhypo_mask_path, export_path)
130 | copyfile(os.path.join(images_folder, image_name, image_name + '.png'),
131 | os.path.join(unet_ready_folder, 'images', image_name + '.png'))
132 |
133 | if args.make_patches:
134 | patches_export_folder = os.path.join(export_root, 'patched_images')
135 | chk_mkdir(patches_export_folder)
136 | make_patches(unet_ready_folder, patches_export_folder, (800, 800), no_overlap=True)
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A deep learning-based approach for high throughput plant phenotyping
2 |
3 | This repository is the companion for the paper [A deep learning-based approach for high throughput plant phenotyping,
4 | Dobos et al.](http://www.plantphysiol.org/content/181/4/1415).
5 |
6 | [The dataset used in the article can be found at this page](https://www.kaggle.com/tivadardanka/plant-segmentation/).
7 |
8 | [The pretrained model used in the research article can be downloaded here.](https://drive.google.com/open?id=1SlUui64l-k63vxysl0YAflKaECfpj8Rr)
9 |
10 | ## Contents
11 | - [Usage, dependencies](#usage)
12 | - [Using a trained model for measuring hypocotyls](#measuring)
13 | - [Generating your own training data](#annotating)
14 | - [Training a model on your own images](#training)
15 |
16 | ## Usage, dependencies
17 | For measuring hypocotyls and training a custom model, it is required to have
18 | - Python >= 3.5
19 | - PyTorch >= 0.4
20 | - NumPy >= 1.13
21 | - Pandas >= 0.23
22 | - scikit-image >= 0.14
23 | - matplotlib >= 3.0
24 |
25 | To use the hypocotyl segmentation tool, clone the repository to the local machine:
26 | ```bash
27 | git clone https://github.com/biomag-lab/hypocotyl-UNet
28 | ```
29 | `src/measure.py` can be used for applying the measuring algorithm on custom images, while `src/train.py` are for training the UNet model on custom annotated data. (Detailed description on them can be found below.)
30 |
31 | ## Using a trained model for measuring hypocotyls
32 | To apply the algorithm on custom images, the folder containing the images should be organized into the following directory structure:
33 | ```bash
34 | images_folder
35 | |-- images
36 | |-- img001.png
37 | |-- img002.png
38 | |-- ...
39 | ```
40 | The `src/measure.py` script can be used to run the algorithm. The required arguments are
41 | - `--images_path`: path to the images folder, which must have the structure outlined above.
42 | - `--model`: path to the UNet model used in the algorithm.
43 | - `--result_folder`: path to the folder where results will be exported.
44 |
45 | [The model used in the research article can be found here.](https://drive.google.com/open?id=1SlUui64l-k63vxysl0YAflKaECfpj8Rr)
46 |
47 | Additionally, you can specify the following:
48 | - `--min_object_size`: the expected minimum object size in pixels. Default is 0. Detected objects below this size will be filtered out.
49 | - `--max_object_size`: the expected maximum object size in pixels. Default is `np.inf`. Detected objects above this size will be filtered out.
50 | - `--dpi` and `--dpm`: to export the lengths in *mm*, it is required to provide a *dpi* (dots per inch) or *dpm* (dots per millimeter) value. If any of this is available, the pixel units will be converted to *mm* during measurements. If both *dpi* and *dpm* is set, only *dpm* will be taken into account.
51 | - `--device`: device to be used for the UNet prediction. Default is `cpu`, but if a GPU with the CUDA framework installed is available, `cuda:$ID` can be used, where `$ID` is the ID of the GPU. For example, `cuda:0`. (For PyTorch users: this argument is passed directly to the `torch.Tensor.device` object during initialization, which will be used for the rest of the workflow.)
52 | - `--visualize`: set to True to export a visualization of the results. For each measured image, the following images are exported along with the length measurements.
53 |
54 |
55 |
56 |
57 |
58 |
59 | For instance, an example is the following:
60 | ```bash
61 | python3 measure.py --images_path path_to_images \
62 | --model ../models/unet \
63 | --result_folder path_to_results \
64 | --device cuda:0
65 | ```
66 |
67 | ## Generating your own training data
68 | In case the algorithm performs poorly, for instance if the images were taken under very different conditions than the ones provided to the available model during training, the UNet backbone model can be retrained on custom data. This process is called *annotation*, which can be done easily with ImageJ.
69 |
70 | Step 1. Organize your images to be annotated such that each image is contained in a separate folder with common root. The folder should be named after the image. For example:
71 | ```bash
72 | images_folder
73 | |-- image_1
74 | |-- image_1.png
75 | |-- image_2
76 | |-- image_2.png
77 | |-- ...
78 | ```
79 |
80 | Step 2. Open the image to be annotated in ImageJ.
81 | 
82 |
83 | Step 3. Open the *ROI manager* tool.
84 | 
85 |
86 | Step 4. Select an appropriate selection tool, for example *Freehand selections*.
87 | 
88 |
89 | Step 5. Draw the outline of the part which is part of the plant and should be included during the measurements. Press *t* or click on the *Add [t]* button to add the selected part to the ROI manager.
90 | 
91 |
92 | Step 6. Repeat the outlining with all of the selections, adding them one by one. When it is done, select all and click *More > OR (Combine)*.
93 | 
94 |
95 | Step 7. Press *Edit > Selection > Create Mask*. This will open up a new image.
96 | 
97 |
98 | Step 8. Invert the mask image by pressing *Edit > Invert* or pressing *Ctrl + Shift + I*.
99 | 
100 |
101 | Step 9. Save the mask image to the folder with the original. Add the suffix *-hypo* to the name. For example, if the original name was *image_1.png*, this image should be named *image_1-hypo.png*.
102 | 
103 |
104 | Step 10. Repeat the same process to annotate the parts of the plant which should not be included in the measurements. Save the mask to the same folder. Add the suffix *-nonhypo* to the name.
105 | 
106 |
107 | Step 11. Create the training data for the algorithm by running the `src/preprocessing/mask.py`. The required arguments are
108 | - `--images_folder`: path to the folder where the folders containing the images and masks located.
109 | - `--export_folder`: path to the folder where the results should be exported.
110 | - `--make_patches`: True if images and masks are to be patched up to smaller pieces. Recommended for training.
111 |
112 | ## Training a model on your own images
113 | If custom annotated data is available, the containing folder should be organized into the following directory structure:
114 | ```bash
115 | images_folder
116 | |-- images
117 | |-- img001.png
118 | |-- img002.png
119 | |-- ...
120 | |-- masks
121 | |-- img001.png
122 | |-- img002.png
123 | |-- ...
124 | ```
125 | The mask images should have identical name to their corresponding image. After the training data is organized, the `src/train.py` script can be used to train a custom UNet model. The required arguments are
126 | - `--train_dataset`: path to the folder where the training data is located. This should match with the `--export_folder` argument given to the `src/preprocessing/mask.py` script during Step 11. of the previous point.
127 |
128 | The optional arguments are
129 | - `--epochs`: the number of epochs during training. Default is 1000, but this is very dependent on the dataset and data augmentation method used.
130 | - `--batch_size`: the size of the batch during training. Default is 1. If GPU is used, it is recommended to select batch size as large as GPU memory allows.
131 | - `--initial_lr`: initial learning rate. Default is 1e-3, which proved to be the best for the training dataset used. (Different datasets might have more optimal initial learning rates.)
132 | - `--model_name`: name of the model. Default is *UNet-hypocotyl*.
133 | - `--trained_model_path`: to continue training of a previously trained model, its path can be given.
134 | - `--val_dataset`: path to the folder where the validation data is located. During training, validation data is used to catch overfitting and apply early stopping.
135 | - `--model_save_freq`: frequency of model saving. Default is 200, which means that the model is saved after every 200th epoch.
136 | - `--device`: device to be used for the UNet training. Default is `cpu`, but if a GPU with the CUDA framework installed is available, `cuda:$ID` can be used, where `$ID` is the ID of the GPU. For example, `cuda:0`. (For PyTorch users: this argument is passed directly to the `torch.Tensor.device` object during initialization, which will be used for the rest of the workflow.)
137 |
138 |
--------------------------------------------------------------------------------
/src/unet/rprops.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 |
5 | import matplotlib.pyplot as plt
6 | import matplotlib.patches as patches
7 |
8 | from skimage import io, img_as_uint
9 | from skimage.morphology import skeletonize, medial_axis, skeletonize_3d
10 | from skimage.measure import regionprops, label
11 | from skimage.filters import threshold_otsu
12 | from skimage.measure._regionprops import _RegionProperties
13 |
14 | from typing import Container
15 | from numbers import Number
16 |
17 |
18 | class BBox:
19 | def __init__(self, rprops_bbox):
20 | min_row, min_col, max_row, max_col = rprops_bbox
21 | # regionprops bbox representation
22 | self.min_row = min_row
23 | self.min_col = min_col
24 | self.max_row = max_row
25 | self.max_col = max_col
26 | self.bbox = rprops_bbox
27 |
28 | # rectangle representation
29 | self.x, self.y = min_col, min_row
30 | self.width = max_col - min_col
31 | self.height = max_row - min_row
32 |
33 | # coordinate representation
34 | self.P1 = (min_col, min_row)
35 | self.P2 = (max_col, min_row)
36 | self.P3 = (min_col, max_row)
37 | self.P4 = (max_col, max_row)
38 |
39 | def __repr__(self):
40 | return str(self.bbox)
41 |
42 | def __getitem__(self, item):
43 | return self.bbox[item]
44 |
45 | def IOU(self, other_bbox):
46 | # determining the intersection coordinates
47 | P1_int = (max(self.P1[0], other_bbox.P1[0]),
48 | max(self.P1[1], other_bbox.P1[1]))
49 | P4_int = (min(self.P4[0], other_bbox.P4[0]),
50 | min(self.P4[1], other_bbox.P4[1]))
51 |
52 | # check for intersections
53 | if (P1_int[0] > P4_int[0]) or (P1_int[1] > P4_int[1]):
54 | return 0
55 |
56 | intersection_area = (P4_int[0] - P1_int[0]) * (P4_int[1] - P1_int[1])
57 | union_area = self.area() + other_bbox.area() - intersection_area
58 |
59 | return intersection_area / union_area
60 |
61 | def area(self):
62 | return self.width * self.height
63 |
64 |
65 | class Hypo:
66 | def __init__(self, rprops, dpm=False):
67 | self.length = rprops.area
68 | if dpm:
69 | self.length /= dpm
70 | self.bbox = BBox(rprops.bbox)
71 |
72 | def __repr__(self):
73 | return "[%d, %s]" % (self.length, self.bbox)
74 |
75 | def IOU(self, other_hypo):
76 | return self.bbox.IOU(other_hypo.bbox)
77 |
78 |
79 | class HypoResult:
80 | def __init__(self, rprops_or_hypos, dpm=False):
81 | if isinstance(rprops_or_hypos[0], Hypo):
82 | self.hypo_list = rprops_or_hypos
83 | elif isinstance(rprops_or_hypos[0], _RegionProperties):
84 | self.hypo_list = [Hypo(rprops, dpm) for rprops in rprops_or_hypos]
85 | self.gt_match = None
86 |
87 | def __getitem__(self, item):
88 | if isinstance(item, Number):
89 | return self.hypo_list[item]
90 | if isinstance(item, Container):
91 | # check the datatype of the list
92 | if isinstance(item[0], np.bool_):
93 | item = [idx for idx, val in enumerate(item) if val]
94 | return HypoResult([self.hypo_list[idx] for idx in item])
95 |
96 | def __len__(self):
97 | return len(self.hypo_list)
98 |
99 | def mean(self):
100 | return np.mean([hypo.length for hypo in self.hypo_list])
101 |
102 | def std(self):
103 | return np.std([hypo.length for hypo in self.hypo_list])
104 |
105 | def score(self, gt_hyporesult, match_threshold=0.5):
106 | scores = []
107 | hypo_ious = np.zeros((len(self), len(gt_hyporesult)))
108 | objectwise_df = pd.DataFrame(columns=['algorithm', 'ground truth'], index=range(len(gt_hyporesult)))
109 |
110 | for hypo_idx, hypo in enumerate(self.hypo_list):
111 | hypo_ious[hypo_idx] = np.array([hypo.IOU(gt_hypo) for gt_hypo in gt_hyporesult])
112 | best_match = np.argmax(hypo_ious[hypo_idx])
113 |
114 | # a match is found if the intersection over union metric is
115 | # larger than the given threshold
116 | if hypo_ious[hypo_idx][best_match] > match_threshold:
117 | # calculate the accuracy of the measurement
118 | gt_hypo = gt_hyporesult[best_match]
119 | error = abs(hypo.length - gt_hypo.length)
120 | scores.append(1 - error/gt_hypo.length)
121 |
122 | gt_hypo_ious = hypo_ious.T
123 | for gt_hypo_idx, gt_hypo in enumerate(gt_hyporesult):
124 | objectwise_df.loc[gt_hypo_idx, 'ground truth'] = gt_hypo.length
125 | best_match = np.argmax(gt_hypo_ious[gt_hypo_idx])
126 | if gt_hypo_ious[gt_hypo_idx][best_match] > match_threshold:
127 | objectwise_df.loc[gt_hypo_idx, 'algorithm'] = self.hypo_list[best_match].length
128 |
129 | # precision, recall
130 | self.gt_match = np.apply_along_axis(np.any, 0, hypo_ious > match_threshold)
131 | self.match = np.apply_along_axis(np.any, 1, hypo_ious > match_threshold)
132 | # identified_objects = self[self.match]
133 | true_positives = self.gt_match.sum()
134 | precision = true_positives/len(self)
135 | recall = true_positives/len(gt_hyporesult)
136 |
137 | score_dict = {'accuracy': np.mean(scores),
138 | 'precision': precision,
139 | 'recall': recall,
140 | 'gt_mean': gt_hyporesult.mean(),
141 | 'result_mean': self.mean(),
142 | 'gt_std': gt_hyporesult.std(),
143 | 'result_std': self.std()}
144 |
145 | return score_dict, objectwise_df
146 |
147 | def make_df(self):
148 | result_df = pd.DataFrame(
149 | [[hypo.length, *hypo.bbox] for hypo in self.hypo_list],
150 | columns=['length', 'min_row', 'min_col', 'max_row', 'max_col'],
151 | index=range(1, len(self)+1)
152 | )
153 |
154 | return result_df
155 |
156 | def hist(self, gt_hyporesult, export_path):
157 | lengths = [hypo.length for hypo in self.hypo_list]
158 | gt_lengths = [hypo.length for hypo in gt_hyporesult]
159 | histogram_bins = range(0, 500, 10)
160 |
161 | with plt.style.context('seaborn-white'):
162 | plt.figure(figsize=(10, 15))
163 | plt.hist(lengths, bins=histogram_bins, color='r', alpha=0.2, label='result')
164 | plt.hist(gt_lengths, bins=histogram_bins, color='b', alpha=0.2, label='ground truth')
165 | plt.legend()
166 | plt.savefig(export_path)
167 | plt.close('all')
168 |
169 | def filter(self, flt):
170 | if isinstance(flt, Container):
171 | min_length, max_length = flt
172 | self.hypo_list = [h for h in self.hypo_list if min_length <= h.length <= max_length]
173 | elif isinstance(flt, bool) and flt:
174 | otsu_thresh = threshold_otsu(np.array([h.length for h in self.hypo_list]))
175 | self.hypo_list = [h for h in self.hypo_list if otsu_thresh <= h.length]
176 |
177 |
178 | def bbox_to_rectangle(bbox):
179 | # bbox format: 'min_row', 'min_col', 'max_row', 'max_col'
180 | # Rectangle format: bottom left (x, y), width, height
181 | min_row, min_col, max_row, max_col = bbox
182 | x, y = min_col, min_row
183 | width = max_col - min_col
184 | height = max_row - min_row
185 |
186 | return (x, y), width, height
187 |
188 |
189 | def get_hypo_rprops(hypo, filter=True, already_skeletonized=False, skeleton_method=skeletonize_3d,
190 | return_skeleton=False, dpm=False):
191 | """
192 | Args:
193 | hypo: segmented hypocotyl image
194 | filter: boolean or list of [min_length, max_length]
195 | """
196 | hypo_thresh = (hypo > 0.5)
197 | if not already_skeletonized:
198 | hypo_skeleton = label(img_as_uint(skeleton_method(hypo_thresh)))
199 | else:
200 | hypo_skeleton = label(img_as_uint(hypo_thresh))
201 |
202 | hypo_rprops = regionprops(hypo_skeleton)
203 | # filter out small regions
204 | hypo_result = HypoResult(hypo_rprops, dpm)
205 | hypo_result.filter(flt=filter)
206 |
207 | if return_skeleton:
208 | return hypo_result, hypo_skeleton > 0
209 |
210 | return hypo_result
211 |
212 |
213 | def visualize_regions(hypo_img, hypo_result, export_path=None, bbox_color='r', dpi=800):
214 | with plt.style.context('seaborn-white'):
215 | # parameters
216 | fontsize = 3.0 * (800.0 / dpi)
217 | linewidth = fontsize / 10.0
218 |
219 | figsize = (hypo_img.shape[1]/dpi, hypo_img.shape[0]/dpi)
220 | fig = plt.figure(figsize=figsize, dpi=dpi)
221 | ax = plt.Axes(fig, [0,0,1,1]) #plt.subplot(111)
222 | fig.add_axes(ax)
223 | ax.imshow(hypo_img)
224 | for hypo_idx, hypo in enumerate(hypo_result):
225 | rectangle = patches.Rectangle((hypo.bbox.x, hypo.bbox.y), hypo.bbox.width, hypo.bbox.height,
226 | linewidth=linewidth, edgecolor=bbox_color, facecolor='none')
227 | ax.add_patch(rectangle)
228 | ax.text(hypo.bbox.x, hypo.bbox.y - linewidth - 24, "N.%d." % (hypo_idx+1), fontsize=fontsize, color='k')
229 | ax.text(hypo.bbox.x, hypo.bbox.y - linewidth, str(hypo.length)[:4], fontsize=fontsize, color=bbox_color)
230 |
231 | fig.axes[0].get_xaxis().set_visible(False)
232 | fig.axes[0].get_yaxis().set_visible(False)
233 |
234 | if export_path is None:
235 | plt.show()
236 | else:
237 | plt.savefig(export_path, dpi=dpi)
238 |
239 | plt.close('all')
240 |
--------------------------------------------------------------------------------
/src/unet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 |
6 | from skimage import io, img_as_uint
7 | from skimage.morphology import skeletonize_3d
8 |
9 | from numbers import Number
10 | from itertools import product
11 |
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader, Dataset
14 | from torch.nn.functional import cross_entropy
15 | from torch.nn.modules.loss import _WeightedLoss
16 | from torchvision import transforms as T
17 | from torchvision.transforms import functional as F
18 |
19 | from .rprops import get_hypo_rprops, visualize_regions
20 |
21 |
22 | def chk_mkdir(*args):
23 | for path in args:
24 | if path is not None and not os.path.exists(path):
25 | os.makedirs(path)
26 |
27 |
28 | def dpi_to_dpm(dpi):
29 | # small hack, default value for dpi is False
30 | if not dpi:
31 | return False
32 |
33 | return dpi/25.4
34 |
35 |
36 | def dpm_to_dpi(dpm):
37 | if not dpm:
38 | return False
39 |
40 | return dpm * 25.4
41 |
42 |
43 | def to_long_tensor(pic):
44 | # handle numpy array
45 | img = torch.from_numpy(np.array(pic, np.uint8))
46 | # backward compatibility
47 | return img.long()
48 |
49 |
50 | def joint_to_long_tensor(image, mask):
51 | return to_long_tensor(image), to_long_tensor(mask)
52 |
53 |
54 | def make_transform(
55 | crop=(256, 256), p_flip=0.5, p_color=0.0, color_jitter_params=(0.1, 0.1, 0.1, 0.1),
56 | p_random_affine=0.0, rotate_range=False, normalize=False, long_mask=False
57 | ):
58 |
59 | if color_jitter_params is not None:
60 | color_tf = T.ColorJitter(*color_jitter_params)
61 | else:
62 | color_tf = None
63 |
64 | if normalize:
65 | tf_normalize = T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1))
66 |
67 | def joint_transform(image, mask):
68 | # transforming to PIL image
69 | image, mask = F.to_pil_image(image), F.to_pil_image(mask)
70 |
71 | # random crop
72 | if crop:
73 | i, j, h, w = T.RandomCrop.get_params(image, crop)
74 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w)
75 | if np.random.rand() < p_flip:
76 | image, mask = F.hflip(image), F.hflip(mask)
77 |
78 | # color transforms || ONLY ON IMAGE
79 | if color_tf is not None:
80 | if np.random.rand() < p_color:
81 | image = color_tf(image)
82 |
83 | # random rotation
84 | if rotate_range and not p_random_affine:
85 | if np.random.rand() < 0.5:
86 | angle = rotate_range * (np.random.rand() - 0.5)
87 | image, mask = F.rotate(image, angle), F.rotate(mask, angle)
88 |
89 | # random affine
90 | if np.random.rand() < p_random_affine:
91 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), crop)
92 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params)
93 |
94 | # transforming to tensor
95 | image = F.to_tensor(image)
96 | if not long_mask:
97 | mask = F.to_tensor(mask)
98 | else:
99 | mask = to_long_tensor(mask)
100 |
101 | # normalizing image
102 | if normalize:
103 | image = tf_normalize(image)
104 |
105 | return image, mask
106 |
107 | return joint_transform
108 |
109 |
110 | def confusion_matrix(prediction, target, n_classes):
111 | """
112 | prediction, target: torch.Tensor objects
113 | """
114 | prediction = torch.argmax(prediction, dim=0).long()
115 | target = torch.squeeze(target, dim=0)
116 |
117 | conf_mtx = torch.zeros(n_classes, n_classes).long()
118 | for i, j in product(range(n_classes), range(n_classes)):
119 | conf_mtx[i, j] = torch.sum((prediction == j) * (target == i))
120 |
121 | return conf_mtx
122 |
123 |
124 | class SoftDiceLoss(_WeightedLoss):
125 | __constants__ = ['weight', 'reduction']
126 |
127 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
128 | if weight is None:
129 | weight = torch.tensor(1)
130 | else:
131 | # creating tensor if needed
132 | if not isinstance(weight, torch.Tensor):
133 | weight = torch.tensor(weight)
134 | # normalizing weights
135 | weight /= torch.sum(weight)
136 |
137 | super(SoftDiceLoss, self).__init__(weight, size_average, reduce, reduction)
138 |
139 | def forward(self, y_pred, y_gt):
140 | """
141 | Args:
142 | y_pred: torch.Tensor of shape (n_batch, n_classes, image.shape)
143 | y_gt: torch.LongTensor of shape (n_batch, image.shape)
144 | """
145 | dims = (0, *range(2, len(y_pred.shape)))
146 |
147 | y_gt = torch.zeros_like(y_pred).scatter_(1, y_gt[:, None, :], 1)
148 | numerator = 2 * torch.sum(y_pred * y_gt, dim=dims)
149 | denominator = torch.sum(y_pred * y_pred + y_gt * y_gt, dim=dims)
150 | return torch.sum((1 - numerator / denominator)*self.weight)
151 |
152 |
153 | class LogNLLLoss(_WeightedLoss):
154 | __constants__ = ['weight', 'reduction', 'ignore_index']
155 |
156 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean',
157 | ignore_index=-100):
158 | super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
159 | self.ignore_index = ignore_index
160 |
161 | def forward(self, input, target):
162 | input = torch.log(input)
163 | return cross_entropy(input, target, weight=self.weight, reduction=self.reduction,
164 | ignore_index=self.ignore_index)
165 |
166 |
167 | class ReadTrainDataset(Dataset):
168 | """
169 | Structure of the dataset should be:
170 |
171 | dataset_path
172 | |--images
173 | |--img001.png
174 | |--img002.png
175 | |--masks
176 | |--img001.png
177 | |--img002.png
178 |
179 | """
180 |
181 | def __init__(self, dataset_path, transform=None, one_hot_mask=False, long_mask=True):
182 | self.dataset_path = dataset_path
183 | self.images_path = os.path.join(dataset_path, 'images')
184 | self.masks_path = os.path.join(dataset_path, 'masks')
185 | self.images_list = os.listdir(self.images_path)
186 |
187 | self.transform = transform
188 | self.one_hot_mask = one_hot_mask
189 | self.long_mask = long_mask
190 |
191 | def __len__(self):
192 | return len(os.listdir(self.images_path))
193 |
194 | def __getitem__(self, idx):
195 | image_filename = self.images_list[idx]
196 | image = io.imread(os.path.join(self.images_path, image_filename))
197 | mask = io.imread(os.path.join(self.masks_path, image_filename))
198 | if len(mask.shape) == 2:
199 | mask = np.expand_dims(mask, axis=2)
200 |
201 | if self.transform:
202 | image, mask = self.transform(image, mask)
203 | else:
204 | image = F.to_tensor(image)
205 | if self.long_mask:
206 | mask = to_long_tensor(F.to_pil_image(mask))
207 | else:
208 | mask = F.to_tensor(mask)
209 |
210 | if self.one_hot_mask:
211 | assert self.one_hot_mask >= 0, 'one_hot_mask must be nonnegative'
212 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1)
213 |
214 | return image, mask, image_filename
215 |
216 |
217 | class ReadTestDataset(Dataset):
218 | """
219 | Structure of the dataset should be:
220 |
221 | dataset_path
222 | |--images
223 | |--img001.png
224 | |--img002.png
225 |
226 | """
227 |
228 | def __init__(self, dataset_path, transform=None):
229 | self.dataset_path = dataset_path
230 | self.images_path = os.path.join(dataset_path, 'images')
231 | self.images_list = os.listdir(self.images_path)
232 |
233 | self.transform = transform
234 |
235 | def __len__(self):
236 | return len(os.listdir(self.images_path))
237 |
238 | def __getitem__(self, idx):
239 | image_filename = self.images_list[idx]
240 | image = io.imread(os.path.join(self.images_path, image_filename))
241 |
242 | if self.transform:
243 | image = self.transform(image)
244 | else:
245 | image = F.to_tensor(image)
246 |
247 | return image, image_filename
248 |
249 |
250 | class ModelWrapper:
251 | def __init__(
252 | self, model, results_folder, loss=None, optimizer=None,
253 | scheduler=None, cuda_device=None
254 | ):
255 | self.model = model
256 | self.loss = loss
257 | self.optimizer = optimizer
258 | self.scheduler = scheduler
259 | self.results_folder = results_folder
260 | chk_mkdir(self.results_folder)
261 |
262 | self.cuda_device = cuda_device
263 | if self.cuda_device:
264 | self.model.to(device=self.cuda_device)
265 | try:
266 | self.loss.to(device=self.cuda_device)
267 | except AttributeError:
268 | pass
269 |
270 | def train_model(self, dataset, n_epochs, n_batch=1, verbose=False,
271 | validation_dataset=None, prediction_dataset=None,
272 | save_freq=100):
273 | self.model.train(True)
274 |
275 | # logging losses
276 | loss_df = pd.DataFrame(np.zeros(shape=(n_epochs, 2)), columns=['train', 'validate'], index=range(n_epochs))
277 |
278 | min_loss = np.inf
279 | total_running_loss = 0
280 | for epoch_idx in range(n_epochs):
281 |
282 | epoch_running_loss = 0
283 | for batch_idx, (X_batch, y_batch, name) in enumerate(DataLoader(dataset, batch_size=n_batch, shuffle=True)):
284 | if self.cuda_device:
285 | X_batch = Variable(X_batch.to(device=self.cuda_device))
286 | y_batch = Variable(y_batch.to(device=self.cuda_device))
287 | else:
288 | X_batch, y_batch = Variable(X_batch), Variable(y_batch)
289 |
290 | # training
291 | self.optimizer.zero_grad()
292 | y_out = self.model(X_batch)
293 | training_loss = self.loss(y_out, y_batch)
294 | training_loss.backward()
295 | self.optimizer.step()
296 |
297 | epoch_running_loss += training_loss.item()
298 |
299 | if verbose:
300 | print('(Epoch no. %d, batch no. %d) loss: %f' % (epoch_idx, batch_idx, training_loss.item()))
301 |
302 | total_running_loss += epoch_running_loss/(batch_idx + 1)
303 | print('(Epoch no. %d) loss: %f' % (epoch_idx, epoch_running_loss/(batch_idx + 1)))
304 | loss_df.loc[epoch_idx, 'train'] = epoch_running_loss/(batch_idx + 1)
305 |
306 | if validation_dataset is not None:
307 | validation_error = self.validate(validation_dataset, n_batch=1)
308 | loss_df.loc[epoch_idx, 'validate'] = validation_error
309 | if validation_error < min_loss:
310 | torch.save(self.model.state_dict(), os.path.join(self.results_folder, 'model'))
311 | print('Validation loss improved from %f to %f, model saved to %s'
312 | % (min_loss, validation_error, self.results_folder))
313 | min_loss = validation_error
314 |
315 | if self.scheduler is not None:
316 | self.scheduler.step(validation_error)
317 |
318 | else:
319 | if epoch_running_loss/(batch_idx + 1) < min_loss:
320 | torch.save(self.model.state_dict(), os.path.join(self.results_folder, 'model'))
321 | print('Training loss improved from %f to %f, model saved to %s'
322 | % (min_loss, epoch_running_loss / (batch_idx + 1), self.results_folder))
323 | min_loss = epoch_running_loss / (batch_idx + 1)
324 |
325 | if self.scheduler is not None:
326 | self.scheduler.step(epoch_running_loss / (batch_idx + 1))
327 |
328 | # saving model and logs
329 | loss_df.to_csv(os.path.join(self.results_folder, 'loss.csv'))
330 | if epoch_idx % save_freq == 0:
331 | epoch_save_path = os.path.join(self.results_folder, '%d' % epoch_idx)
332 | chk_mkdir(epoch_save_path)
333 | torch.save(self.model.state_dict(), os.path.join(epoch_save_path, 'model'))
334 | if prediction_dataset:
335 | self.predict_large_images(prediction_dataset, epoch_save_path)
336 |
337 | self.model.train(False)
338 |
339 | del X_batch, y_batch
340 |
341 | return total_running_loss/n_batch
342 |
343 | def validate(self, dataset, n_batch=1):
344 | self.model.train(False)
345 |
346 | total_running_loss = 0
347 | for batch_idx, (X_batch, y_batch, name) in enumerate(DataLoader(dataset, batch_size=n_batch, shuffle=False)):
348 |
349 | if self.cuda_device:
350 | X_batch = Variable(X_batch.to(device=self.cuda_device))
351 | y_batch = Variable(y_batch.to(device=self.cuda_device))
352 | else:
353 | X_batch, y_batch = Variable(X_batch), Variable(y_batch)
354 |
355 | y_out = self.model(X_batch)
356 | training_loss = self.loss(y_out, y_batch)
357 |
358 | total_running_loss += training_loss.item()
359 |
360 | print('Validation loss: %f' % (total_running_loss / (batch_idx + 1)))
361 | self.model.train(True)
362 |
363 | del X_batch, y_batch
364 |
365 | return total_running_loss/(batch_idx + 1)
366 |
367 | def predict(self, dataset, export_path, channel=None):
368 | self.model.train(False)
369 | chk_mkdir(export_path)
370 |
371 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)):
372 | if self.cuda_device:
373 | X_batch = Variable(X_batch.to(device=self.cuda_device))
374 | y_out = self.model(X_batch).cpu().data.numpy()
375 | else:
376 | X_batch = Variable(X_batch)
377 | y_out = self.model(X_batch).data.numpy()
378 |
379 | if channel:
380 | try:
381 | io.imsave(os.path.join(export_path, image_filename[0]), y_out[0, channel, :, :])
382 | except:
383 | print('something went wrong upon prediction')
384 |
385 | else:
386 | try:
387 | io.imsave(os.path.join(export_path, image_filename[0]), y_out[0, :, :, :].transpose((1, 2, 0)))
388 | except:
389 | print('something went wrong upon prediction')
390 |
391 | def predict_large_images(self, dataset, export_path=None, channel=None, tile_res=(512, 512)):
392 | self.model.train(False)
393 | if export_path:
394 | chk_mkdir(export_path)
395 | else:
396 | results = []
397 |
398 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)):
399 | out = self.predict_single_large_image(X_batch, channel=channel, tile_res=tile_res)
400 |
401 | if export_path:
402 | io.imsave(os.path.join(export_path, image_filename[0]), out)
403 | else:
404 | results.append(out)
405 |
406 | if not export_path:
407 | return results
408 |
409 | def predict_single_large_image(self, X_image, channel=None, tile_res=(512, 512)):
410 | image_res = X_image.shape
411 | # placeholder for output
412 | y_out_full = np.zeros(shape=(1, 3, image_res[2], image_res[3]))
413 | # generate tile coordinates
414 | tile_x = list(range(0, image_res[2], tile_res[0]))[:-1] + [image_res[2] - tile_res[0]]
415 | tile_y = list(range(0, image_res[3], tile_res[1]))[:-1] + [image_res[3] - tile_res[1]]
416 | tile = product(tile_x, tile_y)
417 | # predictions
418 | for slice in tile:
419 |
420 | if self.cuda_device:
421 | X_in = X_image[:, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]].to(
422 | device=self.cuda_device)
423 | X_in = Variable(X_in)
424 | else:
425 | X_in = X_image[:, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]]
426 | X_in = Variable(X_in)
427 |
428 | y_out = self.model(X_in).cpu().data.numpy()
429 | y_out_full[0, :, slice[0]:slice[0] + tile_res[0], slice[1]:slice[1] + tile_res[1]] = y_out
430 |
431 | # save image
432 | if channel:
433 | out = y_out_full[0, channel, :, :]
434 | else:
435 | out = y_out_full[0, :, :, :].transpose((1, 2, 0))
436 |
437 | return out
438 |
439 | def measure_large_images(self, dataset, visualize_bboxes=False, filter=True, export_path=None,
440 | skeleton_method=skeletonize_3d, dpm=False, verbose=False, tile_res=(512, 512)):
441 | hypocotyl_lengths = dict()
442 | chk_mkdir(export_path)
443 |
444 | assert any(isinstance(dpm, tp) for tp in [str, bool, Number]), 'dpm must be string, bool or Number'
445 |
446 | for batch_idx, (X_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)):
447 |
448 | if verbose:
449 | print("Measuring %s" % image_filename[0])
450 |
451 | hypo_segmented = self.predict_single_large_image(X_batch, tile_res=tile_res)
452 | hypo_segmented_mask = hypo_segmented[:, :, 2]
453 | hypo_result, hypo_skeleton = get_hypo_rprops(hypo_segmented_mask, filter=filter, return_skeleton=True,
454 | skeleton_method=skeleton_method,
455 | dpm=dpm)
456 | hypo_df = hypo_result.make_df()
457 |
458 | hypocotyl_lengths[image_filename] = hypo_df
459 |
460 | if export_path:
461 | if visualize_bboxes:
462 | hypo_img = X_batch[0].cpu().data.numpy().transpose((1, 2, 0))
463 | # original image
464 | visualize_regions(hypo_img, hypo_result,
465 | os.path.join(export_path, image_filename[0][:-4] + '.png'))
466 | # segmentation
467 | visualize_regions(hypo_segmented, hypo_result,
468 | os.path.join(export_path, image_filename[0][:-4] + '_segmentation.png'),
469 | bbox_color='0.5')
470 | # skeletonization
471 | visualize_regions(hypo_skeleton, hypo_result,
472 | os.path.join(export_path, image_filename[0][:-4] + '_skeleton.png'))
473 |
474 | hypocotyl_lengths[image_filename].to_csv(os.path.join(export_path, image_filename[0][:-4] + '.csv'),
475 | header=True, index=True)
476 |
477 | return hypocotyl_lengths
478 |
479 | def score_large_images(self, dataset, export_path, visualize_bboxes=False, visualize_histograms=False,
480 | visualize_segmentation=False,
481 | filter=True, skeletonized_gt=False, match_threshold=0.5, tile_res=(512, 512),
482 | dpm=False):
483 | chk_mkdir(export_path)
484 |
485 | scores = {}
486 |
487 | assert any(isinstance(dpm, tp) for tp in [str, bool, Number]), 'dpm must be string, bool or Number'
488 |
489 | if isinstance(dpm, str):
490 | dpm_df = pd.read_csv(dpm, header=None, index_col=0)
491 |
492 | for batch_idx, (X_batch, y_batch, image_filename) in enumerate(DataLoader(dataset, batch_size=1)):
493 | if isinstance(dpm, str):
494 | dpm_val = dpm_df.loc[image_filename].values[0]
495 | elif isinstance(dpm, Number) or dpm == False:
496 | dpm_val = dpm
497 | else:
498 | raise ValueError('dpm must be str, Number or False')
499 |
500 | # getting filter range
501 | if isinstance(filter, dict):
502 | filter_val = filter[image_filename[0]]
503 | else:
504 | filter_val = filter
505 |
506 | segmented_img = self.predict_single_large_image(X_batch, tile_res=tile_res)
507 | hypo_result_mask = segmented_img[:, :, 2]
508 | hypo_result, hypo_result_skeleton = get_hypo_rprops(hypo_result_mask, filter=filter_val,
509 | return_skeleton=True, dpm=dpm_val)
510 | hypo_result.make_df().to_csv(os.path.join(export_path, image_filename[0][:-4] + '_result.csv'))
511 |
512 | if visualize_segmentation:
513 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_skeletons.png'),
514 | img_as_uint(hypo_result_skeleton))
515 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_hypo.png'),
516 | hypo_result_mask)
517 | io.imsave(os.path.join(export_path, image_filename[0][:-4] + '_segmentation_full.png'),
518 | segmented_img)
519 |
520 | if not skeletonized_gt:
521 | hypo_gt_mask = y_batch[0].data.numpy() == 2
522 | else:
523 | hypo_gt_mask = y_batch[0].data.numpy() > 0
524 |
525 | hypo_result_gt = get_hypo_rprops(hypo_gt_mask, filter=[20/dpm_val, np.inf],
526 | already_skeletonized=skeletonized_gt, dpm=dpm_val)
527 | hypo_result_gt.make_df().to_csv(os.path.join(export_path, image_filename[0][:-4] + '_gt.csv'))
528 |
529 | scores[image_filename[0]], objectwise_df = hypo_result.score(hypo_result_gt,
530 | match_threshold=match_threshold)
531 | objectwise_df.to_csv(os.path.join(export_path, image_filename[0][:-4] + '_matched.csv'))
532 |
533 | # visualization
534 | # histograms
535 | if visualize_histograms:
536 | hypo_result.hist(hypo_result_gt,
537 | os.path.join(export_path, image_filename[0][:-4] + '_hist.png'))
538 |
539 | # bounding boxes
540 | if visualize_bboxes:
541 | visualize_regions(hypo_gt_mask, hypo_result_gt,
542 | export_path=os.path.join(export_path, image_filename[0][:-4] + '_gt.png'))
543 | visualize_regions(hypo_result_skeleton, hypo_result,
544 | export_path=os.path.join(export_path, image_filename[0][:-4] + '_result.png'))
545 |
546 | score_df = pd.DataFrame(scores).T
547 | score_df.to_csv(os.path.join(export_path, 'scores.csv'))
548 |
549 | return scores
550 |
551 | def visualize_workflow(self, dataset, export_path, filter=False):
552 | for image, mask, image_filename in DataLoader(dataset, batch_size=1):
553 | image_filename_root = image_filename[0].split('.')[0]
554 | hypo_full_result = self.predict_single_large_image(image, tile_res=(512, 512))
555 | hypo_result_mask = self.predict_single_large_image(image, channel=2, tile_res=(512, 512))
556 | # save multiclass mask
557 | io.imsave(os.path.join(export_path, image_filename_root + '_1.png'), hypo_full_result)
558 | hypo_result, hypo_result_skeleton = get_hypo_rprops(hypo_result_mask, return_skeleton=True, filter=filter)
559 | io.imsave(os.path.join(export_path, image_filename_root + '_2.png'), img_as_uint(hypo_result_skeleton))
560 | visualize_regions(image.data.cpu().numpy()[0].transpose((1, 2, 0)), hypo_result,
561 | os.path.join(export_path, image_filename_root + '_3.png'))
562 |
563 |
564 | if __name__ == '__main__':
565 | hypo_mask_img = io.imread('/home/namazu/Data/hypocotyl/measurement_test/masks/140925 8-4 050.png')
566 | hypo_result = get_hypo_rprops(hypo_mask_img, filter=False, already_skeletonized=True)
567 | hypo_result.score(hypo_result)
568 |
--------------------------------------------------------------------------------