├── vignettes ├── .gitignore └── carvana.Rmd ├── LICENSE ├── .Rbuildignore ├── NAMESPACE ├── .gitignore ├── unet.Rproj ├── README.md ├── DESCRIPTION ├── man └── unet.Rd ├── LICENSE.md └── R └── model.R /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2019 2 | COPYRIGHT HOLDER: Daniel Falbel 3 | -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^unet\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^LICENSE\.md$ 4 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(unet) 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | data-raw/* 6 | logs_r/* 7 | unet.R 8 | inst/doc 9 | -------------------------------------------------------------------------------- /unet.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # unet 2 | 3 | 4 | 5 | 6 | An Keras/R implementation of [U-Net](https://arxiv.org/abs/1505.04597). 7 | 8 | ![U-Net Architecture](https://user-images.githubusercontent.com/4706822/63275620-3c987800-c278-11e9-9d92-66d1264eb05c.png) 9 | 10 | ## Installation 11 | 12 | Install U-Net from GitHub with: 13 | 14 | ``` r 15 | remotes::install_github("r-tensorflow/unet") 16 | ``` 17 | 18 | ## Example 19 | 20 | A full example can be found [here](https://github.com/dfalbel/unet/blob/master/vignettes/carvana.Rmd). 21 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: unet 2 | Title: U-Net: Convolutional Networks for Biomedical Image Segmentation 3 | Version: 0.1 4 | Authors@R: c( 5 | person(given = "Daniel", family = "Falbel", role = c("aut", "cre"), email = "daniel@rstudio.com"), 6 | person(given = "Karol", family = "Żak", role = c("aut"), comment = "Original Python implementation") 7 | ) 8 | Description: Implementation of U-Net: Convolutional Networks for Biomedical Image 9 | Segmentation (Ronneberger et al., 2015 ) using Keras. 10 | License: MIT + file LICENSE 11 | Encoding: UTF-8 12 | LazyData: true 13 | Imports: 14 | keras 15 | RoxygenNote: 6.1.1 16 | Suggests: 17 | knitr, 18 | rmarkdown 19 | VignetteBuilder: knitr 20 | -------------------------------------------------------------------------------- /man/unet.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model.R 3 | \name{unet} 4 | \alias{unet} 5 | \title{U-Net: Convolutional Networks for Biomedical Image Segmentation} 6 | \usage{ 7 | unet(input_shape, num_classes = 1, dropout = 0.5, filters = 64, 8 | num_layers = 4, output_activation = "sigmoid") 9 | } 10 | \arguments{ 11 | \item{input_shape}{Dimensionality of the input (integer) not including the 12 | samples axis. Must be lenght 3 numeric vector.} 13 | 14 | \item{num_classes}{Number of classes.} 15 | 16 | \item{dropout}{Dropout rate applied.} 17 | 18 | \item{filters}{Number of filters of the first convolution.} 19 | 20 | \item{num_layers}{Number of layers in the encoder.} 21 | 22 | \item{output_activation}{Activation in the output layer.} 23 | } 24 | \description{ 25 | U-Net: Convolutional Networks for Biomedical Image Segmentation 26 | } 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2019 Daniel Falbel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /R/model.R: -------------------------------------------------------------------------------- 1 | conv2d_block <- function(inputs, use_batch_norm = TRUE, dropout = 0.3, 2 | filters = 16, kernel_size = c(3, 3), activation = "relu", 3 | kernel_initializer = "he_normal", padding = "same") { 4 | 5 | x <- keras::layer_conv_2d( 6 | inputs, 7 | filters = filters, 8 | kernel_size = kernel_size, 9 | activation = activation, 10 | kernel_initializer = kernel_initializer, 11 | padding = padding 12 | ) 13 | 14 | if (use_batch_norm) { 15 | x <- keras::layer_batch_normalization(x) 16 | } 17 | 18 | if (dropout > 0) { 19 | x <- keras::layer_dropout(x, rate = dropout) 20 | } 21 | 22 | x <- keras::layer_conv_2d( 23 | x, 24 | filters = filters, 25 | kernel_size = kernel_size, 26 | activation = activation, 27 | kernel_initializer = kernel_initializer, 28 | padding = padding 29 | ) 30 | 31 | if (use_batch_norm) { 32 | x <- keras::layer_batch_normalization(x) 33 | } 34 | 35 | x 36 | } 37 | 38 | #' U-Net: Convolutional Networks for Biomedical Image Segmentation 39 | #' 40 | #' @param input_shape Dimensionality of the input (integer) not including the 41 | #' samples axis. Must be length 3 numeric vector. 42 | #' @param num_classes Number of classes. 43 | #' @param dropout Dropout rate applied between downsampling and upsampling phases. 44 | #' @param filters Number of filters of the first convolution. 45 | #' @param num_layers Number of downsizing blocks in the encoder. 46 | #' @param output_activation Activation in the output layer. 47 | #' 48 | #' @export 49 | unet <- function(input_shape, num_classes = 1, dropout = 0.5, filters = 64, 50 | num_layers = 4, output_activation = "sigmoid") { 51 | 52 | 53 | input <- keras::layer_input(shape = input_shape) 54 | 55 | x <- input 56 | down_layers <- list() 57 | 58 | for (i in seq_len(num_layers)) { 59 | 60 | x <- conv2d_block( 61 | inputs = x, 62 | filters = filters, 63 | use_batch_norm = FALSE, 64 | dropout = 0, 65 | padding = "same" 66 | ) 67 | 68 | down_layers[[i]] <- x 69 | 70 | x <- keras::layer_max_pooling_2d(x, pool_size = c(2,2), strides = c(2,2)) 71 | 72 | filters <- filters * 2 73 | 74 | } 75 | 76 | if (dropout > 0) { 77 | x <- keras::layer_dropout(x, rate = dropout) 78 | } 79 | 80 | x <- conv2d_block( 81 | inputs = x, 82 | filters = filters, 83 | use_batch_norm = FALSE, 84 | dropout = 0.0, 85 | padding = 'same' 86 | ) 87 | 88 | for (conv in rev(down_layers)) { 89 | 90 | filters <- filters / 2L 91 | 92 | x <- keras::layer_conv_2d_transpose( 93 | x, 94 | filters = filters, 95 | kernel_size = c(2,2), 96 | padding = "same", 97 | strides = c(2,2) 98 | ) 99 | 100 | x <- keras::layer_concatenate(list(conv, x)) 101 | x <- conv2d_block( 102 | inputs = x, 103 | filters = filters, 104 | use_batch_norm = FALSE, 105 | dropout = 0.0, 106 | padding = 'same' 107 | ) 108 | 109 | } 110 | 111 | output <- keras::layer_conv_2d( 112 | x, 113 | filters = num_classes, 114 | kernel_size = c(1,1), 115 | activation = output_activation 116 | ) 117 | 118 | model <- keras::keras_model(input, output) 119 | 120 | model 121 | } 122 | 123 | -------------------------------------------------------------------------------- /vignettes/carvana.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Carvana Image Masking Challenge" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Carvana Image Masking Challenge} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>", 14 | eval = FALSE 15 | ) 16 | ``` 17 | 18 | ```{r setup} 19 | knitr::opts_chunk$set(warning = TRUE, message = TRUE) 20 | library(unet) 21 | library(keras) 22 | library(tfdatasets) 23 | library(tidyverse) 24 | library(rsample) 25 | library(reticulate) 26 | ``` 27 | 28 | In this example we will use the `unet` package to create a U-Net model that 29 | could be used to remove the background from images in the Carvana dataset. 30 | 31 | [U-Net](https://arxiv.org/abs/1505.04597) is a kind of convolutional neural network that was first developed for biomedical image segmentation but it showed good results in many other fields. 32 | 33 | ![U-Net architecture](https://user-images.githubusercontent.com/4706822/63275620-3c987800-c278-11e9-9d92-66d1264eb05c.png) 34 | 35 | 36 | 37 | The dataset we are going to use appeared first in the [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) 38 | on Kaggle. You can see more information on the [competition page](https://www.kaggle.com/c/carvana-image-masking-challenge/overview). 39 | 40 | Before running the script, download [the data](https://www.kaggle.com/c/carvana-image-masking-challenge/data). You will only 41 | need `train.zip` and `train_mask.zip` files. 42 | 43 | The `train.zip` file contains images of the cars taken by Carvana and `train_mask.zip` 44 | contains the respective masks. 45 | 46 | Here are some examples of what you can find in the dataset. On the left we can find the 47 | original image and on the right we find the mask. 48 | 49 | ```{r echo=FALSE} 50 | images <- tibble( 51 | img = list.files(here::here("data-raw/train"), full.names = TRUE), 52 | mask = list.files(here::here("data-raw/train_masks"), full.names = TRUE) 53 | ) %>% 54 | sample_n(2) %>% 55 | map(. %>% magick::image_read() %>% magick::image_resize("128x128")) 56 | 57 | out <- magick::image_append(c( 58 | magick::image_append(images$img, stack = TRUE), 59 | magick::image_append(images$mask, stack = TRUE) 60 | ) 61 | ) 62 | 63 | plot(out) 64 | ``` 65 | 66 | Now let's start building building our model. We will use `tfdatasets` to build our 67 | data loading and pre-processing pipeline. 68 | 69 | First we will define which images we are going to use for training and which images 70 | we will use for validation. I am assuming we extracted both folders into the `data-raw` directory. 71 | 72 | ```{r} 73 | data <- tibble( 74 | img = list.files(here::here("data-raw/train"), full.names = TRUE), 75 | mask = list.files(here::here("data-raw/train_masks"), full.names = TRUE) 76 | ) 77 | 78 | data <- initial_split(data, prop = 0.8) 79 | ``` 80 | 81 | Ok, now let's define a pipeline to read the files and decode them as images. In 82 | this case the images are `.jpeg` files and the masks are `.gif` files. 83 | 84 | ```{r} 85 | training_dataset <- training(data) %>% 86 | tensor_slices_dataset() %>% 87 | dataset_map(~.x %>% list_modify( 88 | img = tf$image$decode_jpeg(tf$io$read_file(.x$img)), 89 | mask = tf$image$decode_gif(tf$io$read_file(.x$mask))[1,,,][,,1,drop=FALSE] 90 | )) 91 | ``` 92 | 93 | The `[` calls wouldn't be necessary if `tf$image$decode_gif` returned a 3D Tensor like `tf$image$decode_jpeg` does. And if it could read just one color channel as we are only interested if it's black and white. 94 | 95 | If you are running this code interactively you can easily see the output of this 96 | chunk with: 97 | 98 | ```{r} 99 | example <- training_dataset %>% as_iterator() %>% iter_next() 100 | ``` 101 | 102 | The above loaded the images into into a `uint8` Tensor. Which is great for reading 103 | as it uses less memory. However for modelling we prefer having `float32` Tensors, 104 | and that the values are in the [0,1] range. That's what we will fix now: 105 | 106 | ```{r} 107 | training_dataset <- training_dataset %>% 108 | dataset_map(~.x %>% list_modify( 109 | img = tf$image$convert_image_dtype(.x$img, dtype = tf$float32), 110 | mask = tf$image$convert_image_dtype(.x$mask, dtype = tf$float32) 111 | )) 112 | ``` 113 | 114 | The images from our dataset are pretty high definition (1280x1918) but we will resize them to reduce the computing cost of the model. We are going to resize them to 128x128. This size is completely arbitrary. 115 | 116 | ```{r} 117 | training_dataset <- training_dataset %>% 118 | dataset_map(~.x %>% list_modify( 119 | img = tf$image$resize(.x$img, size = shape(128, 128)), 120 | mask = tf$image$resize(.x$mask, size = shape(128, 128)) 121 | )) 122 | ``` 123 | 124 | We can plot the resulting images: 125 | 126 | ```{r} 127 | example <- training_dataset %>% as_iterator() %>% iter_next() 128 | example$img %>% as.array() %>% as.raster() %>% plot() 129 | ``` 130 | 131 | It's usual when fitting U-Net to use some kind of data augmentation strategy. 132 | In this example we are going to apply some random brightness, saturation and 133 | contrast in each image. Let's encapsulate this into an R function: 134 | 135 | ```{r} 136 | random_bsh <- function(img) { 137 | img %>% 138 | tf$image$random_brightness(max_delta = 0.3) %>% 139 | tf$image$random_contrast(lower = 0.5, upper = 0.7) %>% 140 | tf$image$random_saturation(lower = 0.5, upper = 0.7) %>% 141 | tf$clip_by_value(0, 1) # clip the values into [0,1] range. 142 | } 143 | ``` 144 | 145 | We can now map this function over the images: 146 | 147 | ```{r} 148 | training_dataset <- training_dataset %>% 149 | dataset_map(~.x %>% list_modify( 150 | img = random_bsh(.x$img) 151 | )) 152 | ``` 153 | 154 | Again, we can plot the resulting image: 155 | 156 | ```{r} 157 | example <- training_dataset %>% as_iterator() %>% iter_next() 158 | example$img %>% as.array() %>% as.raster() %>% plot() 159 | ``` 160 | 161 | Of course, we could create a function with the above code and reuse it to create 162 | the validation dataset, and that's what we are going to do. 163 | 164 | ```{r} 165 | create_dataset <- function(data, train, batch_size = 32L) { 166 | 167 | dataset <- data %>% 168 | tensor_slices_dataset() %>% 169 | dataset_map(~.x %>% list_modify( 170 | img = tf$image$decode_jpeg(tf$io$read_file(.x$img)), 171 | mask = tf$image$decode_gif(tf$io$read_file(.x$mask))[1,,,][,,1,drop=FALSE] 172 | )) %>% 173 | dataset_map(~.x %>% list_modify( 174 | img = tf$image$convert_image_dtype(.x$img, dtype = tf$float32), 175 | mask = tf$image$convert_image_dtype(.x$mask, dtype = tf$float32) 176 | )) %>% 177 | dataset_map(~.x %>% list_modify( 178 | img = tf$image$resize(.x$img, size = shape(128, 128)), 179 | mask = tf$image$resize(.x$mask, size = shape(128, 128)) 180 | )) 181 | 182 | if (train) { 183 | dataset <- dataset %>% 184 | dataset_map(~.x %>% list_modify( 185 | img = random_bsh(.x$img) 186 | )) 187 | } 188 | 189 | if (train) { 190 | dataset <- dataset %>% 191 | dataset_shuffle(buffer_size = batch_size*128) 192 | } 193 | 194 | dataset <- dataset %>% 195 | dataset_batch(batch_size) 196 | 197 | 198 | 199 | dataset %>% 200 | dataset_map(unname) # Keras needs an unnamed output. 201 | } 202 | ``` 203 | 204 | 205 | Note that we added 3 steps in the `create_dataset` function: 206 | 207 | 1. `dataset_batch` to batch the dataset before. 208 | 2. `dataset_shuffle` to shuffle the dataset 209 | 3. `dataset_map(unname)` since Keras needs unnamed input. 210 | 211 | Now we can create our training and validation datasets: 212 | 213 | ```{r} 214 | training_dataset <- create_dataset(training(data), train = TRUE) 215 | validation_dataset <- create_dataset(testing(data), train = FALSE) 216 | ``` 217 | 218 | Great! We have prepared our data pipeline. Now we need to build the model. 219 | 220 | Luckily, building the model is the easiest part if you use `unet`. 221 | 222 | ```{r} 223 | model <- unet(input_shape = c(128, 128, 3)) 224 | ``` 225 | 226 | That's all. The model is built. You can see the summary if you want with: 227 | 228 | ```{r} 229 | summary(model) 230 | ``` 231 | 232 | Finally, let's compile and fit our model. The competition uses a different metric 233 | called Dice that can be implemented like this: 234 | 235 | ```{r} 236 | dice <- custom_metric("dice", function(y_true, y_pred, smooth = 1.0) { 237 | y_true_f <- k_flatten(y_true) 238 | y_pred_f <- k_flatten(y_pred) 239 | intersection <- k_sum(y_true_f * y_pred_f) 240 | (2 * intersection + smooth) / (k_sum(y_true_f) + k_sum(y_pred_f) + smooth) 241 | }) 242 | ``` 243 | 244 | We can now compile our model: 245 | 246 | ```{r} 247 | model %>% compile( 248 | optimizer = optimizer_rmsprop(lr = 1e-5), 249 | loss = "binary_crossentropy", 250 | metrics = list(dice, metric_binary_accuracy) 251 | ) 252 | ``` 253 | 254 | We could use a different loss - tuned to make Dice higher, but let's just use the 255 | binary crossentropy. 256 | 257 | ```{r} 258 | model %>% fit( 259 | training_dataset, 260 | epochs = 5, 261 | validation_data = validation_dataset 262 | ) 263 | ``` 264 | 265 | Fitting this model takes ~1500s per epoch on my MacBook Pro CPU. With a good GPU 266 | you can make it in around ~120s/epoch. 267 | 268 | That's it. Now you have trained a U-Net using `unet`. 269 | 270 | We can now make predictions for the validation data and see what the results looks 271 | like. Let's take the first batch of images in the validation data. 272 | 273 | ```{r} 274 | batch <- validation_dataset %>% as_iterator() %>% iter_next() 275 | predictions <- predict(model, batch) 276 | ``` 277 | 278 | In the image below you can see the original mask, the original picture and 279 | the predicted mask. 280 | 281 | ```{r, echo = FALSE} 282 | images <- tibble( 283 | image = batch[[1]] %>% array_branch(1), 284 | predicted_mask = predictions[,,,1] %>% array_branch(1), 285 | mask = batch[[2]][,,,1] %>% array_branch(1) 286 | ) %>% 287 | sample_n(2) %>% 288 | map_depth(2, function(x) { 289 | as.raster(x) %>% magick::image_read() 290 | }) %>% 291 | map(~do.call(c, .x)) 292 | 293 | 294 | out <- magick::image_append(c( 295 | magick::image_append(images$mask, stack = TRUE), 296 | magick::image_append(images$image, stack = TRUE), 297 | magick::image_append(images$predicted_mask, stack = TRUE) 298 | ) 299 | ) 300 | 301 | plot(out) 302 | ``` 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | --------------------------------------------------------------------------------