├── models
└── UNet.h5
├── Images
├── Input.png
├── unet.png
├── BottleNeck.png
├── Expansive_Path.PNG
├── Skip_Connection.PNG
├── Contracting_Path.png
├── Output_Feature_Map.PNG
├── Predicted_Mask_1.PNG
├── Predicted_Mask_2.PNG
├── Predicted_Mask_3.PNG
└── Unet_Architecture.PNG
├── UNet for Image Segmentation.pdf
├── __pycache__
├── Unet.cpython-37.pyc
└── UnetUtils.cpython-37.pyc
├── UNetDataGenerator.py
├── Unet.py
├── README.md
└── UnetUtils.py
/models/UNet.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/models/UNet.h5
--------------------------------------------------------------------------------
/Images/Input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Input.png
--------------------------------------------------------------------------------
/Images/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/unet.png
--------------------------------------------------------------------------------
/Images/BottleNeck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/BottleNeck.png
--------------------------------------------------------------------------------
/Images/Expansive_Path.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Expansive_Path.PNG
--------------------------------------------------------------------------------
/Images/Skip_Connection.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Skip_Connection.PNG
--------------------------------------------------------------------------------
/Images/Contracting_Path.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Contracting_Path.png
--------------------------------------------------------------------------------
/Images/Output_Feature_Map.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Output_Feature_Map.PNG
--------------------------------------------------------------------------------
/Images/Predicted_Mask_1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Predicted_Mask_1.PNG
--------------------------------------------------------------------------------
/Images/Predicted_Mask_2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Predicted_Mask_2.PNG
--------------------------------------------------------------------------------
/Images/Predicted_Mask_3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Predicted_Mask_3.PNG
--------------------------------------------------------------------------------
/Images/Unet_Architecture.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/Images/Unet_Architecture.PNG
--------------------------------------------------------------------------------
/UNet for Image Segmentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/UNet for Image Segmentation.pdf
--------------------------------------------------------------------------------
/__pycache__/Unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/__pycache__/Unet.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/UnetUtils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/HEAD/__pycache__/UnetUtils.cpython-37.pyc
--------------------------------------------------------------------------------
/UNetDataGenerator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import cv2
6 |
7 | import tensorflow as tf
8 |
9 | class NucleiDataGenerator(tf.keras.utils.Sequence):
10 |
11 | """
12 | The custom data generator class generates and feeds data to
13 | the model dynamically in batches during the training phase.
14 |
15 | This generator generates batched of data for the dataset available @
16 | Find the nuclei in divergent images to advance medical discovery -
17 | https://www.kaggle.com/c/data-science-bowl-2018
18 |
19 | **
20 | tf.keras.utils.Sequence is the root class for
21 | Custom Data Generators.
22 | **
23 |
24 | Args:
25 | image_ids: the ids of the image.
26 | img_path: the full path of the image directory.
27 | batch_size: no. of images to be included in a batch feed. Default is set to 8.
28 | image_size: size of the image. Default is set to 128 as per the data available.
29 |
30 | Ref: https://dzlab.github.io/dltips/en/keras/data-generator/
31 |
32 | """
33 | def __init__(self, image_ids, img_path, batch_size = 8, image_size = 128):
34 |
35 | self.ids = image_ids
36 | self.path = img_path
37 | self.batch_size = batch_size
38 | self.image_size = image_size
39 | self.on_epoch_end()
40 |
41 | def __load__(self, item):
42 |
43 | """
44 | loads the specified image.
45 |
46 | """
47 |
48 | # the name for parent of parent directory where the image is located and the name of the image are same.
49 | # an example directory breakup is shown below -
50 | # - data-science-bowl-2018/
51 | # - stage1_train/
52 | # - abc
53 | # - image
54 | # - abc
55 | # - mask
56 | full_image_path = os.path.join(self.path, item, "images", item) + ".png"
57 | mask_dir_path = os.path.join(self.path, item, "masks/")
58 | all_masks = os.listdir(mask_dir_path)
59 |
60 | # load the images
61 | image = cv2.imread(full_image_path, 1)
62 | image = cv2.resize(image, (self.image_size, self.image_size))
63 |
64 | masked_img = np.zeros((self.image_size, self.image_size, 1))
65 |
66 | # load and prepare the corresponding mask.
67 | for mask in all_masks:
68 | fullPath = mask_dir_path + mask
69 | _masked_img = cv2.imread(fullPath, -1)
70 | _masked_img = cv2.resize(_masked_img, (self.image_size, self.image_size))
71 | _masked_img = np.expand_dims(_masked_img, axis = -1)
72 | masked_img = np.maximum(masked_img, _masked_img)
73 |
74 | # mormalize the mask and the image.
75 | image = image/255.0
76 | masked_img = masked_img/255.0
77 |
78 | return image, masked_img
79 |
80 | def __getitem__(self, index):
81 |
82 | """
83 | Returns a single batch of data.
84 |
85 | Args:
86 | index: the batch index.
87 |
88 | """
89 |
90 | # edge case scenario where there are still some items left
91 | # after segregatings the images into batches of size batch_size.
92 | # the items left out will form one batch at the end.
93 | if(index + 1) * self.batch_size > len(self.ids):
94 | self.batch_size = len(self.ids) - index * self.batch_size
95 |
96 | # group the items into a batch.
97 | batch = self.ids[index * self.batch_size : (index + 1) * self.batch_size]
98 |
99 | image = []
100 | mask = []
101 |
102 | # load the items in the current batch
103 | for item in batch:
104 | img, masked_img = self.__load__(item)
105 | image.append(img)
106 | mask.append(masked_img)
107 |
108 | image = np.array(image)
109 | mask = np.array(mask)
110 |
111 | return image, mask
112 |
113 | def on_epoch_end(self):
114 |
115 | """
116 | optional method to run some logic at the end of each epoch: e.g. reshuffling
117 |
118 | """
119 |
120 | pass
121 |
122 | def __len__(self):
123 |
124 | """
125 | Returns the number of batches
126 | """
127 | return int(np.ceil(len(self.ids)/float(self.batch_size)))
128 |
--------------------------------------------------------------------------------
/Unet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.layers import Input, Conv2D
4 |
5 | from UnetUtils import UnetUtils
6 | UnetUtils = UnetUtils()
7 |
8 | class Unet():
9 |
10 | """
11 | Unet Model design.
12 |
13 | This module consumes the Unet utilities framework moule and designs the Unet network.
14 | It consists of a contracting path and an expansive path. Both these paths are joined
15 | by a bottleneck block.
16 |
17 | The different blocks involved in the design of the network can be referenced @
18 | U-Net: Convolutional Networks for Biomedical Image Segmentation
19 |
20 | Source:
21 | https://arxiv.org/pdf/1505.04597
22 | """
23 |
24 | def __init__(self, input_shape = (572, 572, 1), filters = [64, 128, 256, 512, 1024], padding = "valid"):
25 | """
26 |
27 | Initialize the Unet framework and the model parameters - input_shape,
28 | filters and padding type.
29 |
30 | Args:
31 | input_shape: The shape of the input to the network. A tuple comprising of (img_height, img_width, channels).
32 | Original paper implementation is (572, 572, 1).
33 | filters: a collection of filters denoting the number of components to be used at each blocks along the
34 | contracting and expansive paths. The original paper implementation for number of filters along the
35 | contracting and expansive paths are [64, 128, 256, 512, 1024].
36 | padding: the padding type to be used during the convolution step. The original paper used unpadded convolutions
37 | which is of type "valid".
38 |
39 | **Remarks: The default values are as per the implementation in the original paper @ https://arxiv.org/pdf/1505.04597
40 |
41 | """
42 | self.input_shape = input_shape
43 | self.filters = filters
44 | self.padding = padding
45 |
46 | def Build_UNetwork(self):
47 |
48 | """
49 | Builds the Unet Model network.
50 |
51 | Args:
52 | None
53 |
54 | Return:
55 | The Unet Model.
56 |
57 | """
58 |
59 |
60 | UnetInput = Input(self.input_shape)
61 |
62 | # the contracting path.
63 | # the last item in the filetrs collection points to the number of filters in the bottleneck block.
64 | conv1, pool1 = UnetUtils.contracting_block(input_layer = UnetInput, filters = self.filters[0], padding = self.padding)
65 | conv2, pool2 = UnetUtils.contracting_block(input_layer = pool1, filters = self.filters[1], padding = self.padding)
66 | conv3, pool3 = UnetUtils.contracting_block(input_layer = pool2, filters = self.filters[2], padding = self.padding)
67 | conv4, pool4 = UnetUtils.contracting_block(input_layer = pool3, filters = self.filters[3], padding = self.padding)
68 |
69 | # bottleneck block connecting the contracting and the expansive paths.
70 | bottleNeck = UnetUtils.bottleneck_block(pool4, filters = self.filters[4], padding = self.padding)
71 |
72 | # the expansive path.
73 | upConv1 = UnetUtils.expansive_block(bottleNeck, conv4, filters = self.filters[3], padding = self.padding)
74 | upConv2 = UnetUtils.expansive_block(upConv1, conv3, filters = self.filters[2], padding = self.padding)
75 | upConv3 = UnetUtils.expansive_block(upConv2, conv2, filters = self.filters[1], padding = self.padding)
76 | upConv4 = UnetUtils.expansive_block(upConv3, conv1, filters = self.filters[0], padding = self.padding)
77 |
78 | UnetOutput = Conv2D(1, (1, 1), padding = self.padding, activation = tf.math.sigmoid)(upConv4)
79 |
80 | model = Model(UnetInput, UnetOutput, name = "UNet")
81 |
82 | return model
83 |
84 | def CompileAndSummarizeModel(self, model, optimizer = "adam", loss = "binary_crossentropy"):
85 |
86 | """
87 | Compiles and displays the model summary of the Unet model.
88 |
89 | Args:
90 | model: The Unet model.
91 | optimizer: model optimizer. Default is the adam optimizer.
92 | loss: the loss function. Default is the binary cross entropy loss.
93 |
94 | Return:
95 | None
96 |
97 | """
98 | model.compile(optimizer = optimizer, loss = loss, metrics = ["acc"])
99 | model.summary()
100 |
101 | def plotModel(self, model, to_file = 'unet.png', show_shapes = True, dpi = 96):
102 |
103 | """
104 | Saves the Unet model to a file.
105 |
106 | Args:
107 | model: the Unet model.
108 | to_file: the file name to save the model. Default name - 'unet.png'.
109 | show_shapes: whether to display shape information. Default = True.
110 | dpi: dots per inch. Default value is 96.
111 |
112 | Return:
113 | None
114 |
115 | """
116 |
117 | tf.keras.utils.plot_model(model, to_file = to_file, show_shapes = show_shapes, dpi = dpi)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # U-Net---Biomedical-Image-Segmentation
2 |
3 | **Implementation of the paper titled - U-Net: Convolutional Networks for Biomedical Image Segmentation**
4 |
5 | ## Original Paper
6 | The original paper can be accessed @ https://arxiv.org/abs/1505.04597
7 |
8 | ## Why Unet?
9 | *UNet, a convolutional neural network dedicated for biomedical image segmentation, was first designed and applied in 2015. In general the usecases for a typical convolutional neural network focuses on image classification tasks, where
10 | the output to an image is a single class label, however in biomedical image visual tasks, it requires not only to distinguish whether there is a medical condition, but also to localize the area of infection i.e., a class label is supposed to be assigned to each pixel.*
11 |
12 | ## UNet Architecture
13 |
14 | The complete architecture of the UNet netowrk is as seen in the image below -
15 | 
16 | Image Reference - https://arxiv.org/pdf/1505.04597.pdf
17 |
18 | **The Unet netowrk model has 3 parts:**
19 |
20 | - The Contracting/Downsampling Path.
21 | - Bottleneck Block.
22 | - The Expansive/Upsampling Path.
23 |
24 | ### Contracting Path:
25 | It consists of two 3x3 unpadded convolutions each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling. After each downsampling operation, the number of feature channels are doubled. The following images show the input block and the contracting path -
26 |
27 | **Input Block -**
28 |
29 | 
30 |
31 | **Contracting Path -**
32 |
33 | 
34 |
35 | ### Bottleneck Block:
36 | The bottleneck block connects the contracting and the expansive paths. This block performs two unpadded convolutions each with 1024 filters and prepares for the expansive path. The image below shows the bottleneck block -
37 |
38 | 
39 |
40 | ### Expansive Path:
41 | Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) using transposed convolutions, a concatenation with the correspondingly feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. Transposed convolution is an upsampling technique to expand the size of images. The image below shows the expansive path of the network -
42 |
43 | 
44 |
45 | ### Skip Connections:
46 | The skip connections from the contracting path are concatenated with the corresponding feature maps in the expansive path. These skip connections provide higher resolution features to better localize and learn representations from the input image. They also help in recovering any spatial information that could have been lost during downsampling. The image below shows one skip connection between the contracting path and the expansive path -
47 |
48 | 
49 |
50 | ### Final Layer:
51 | At the final layer a 1x1 convolution is used to map each (64 component) feature vector to the desired number of classes.
52 |
53 | ***The entire network consists of a total of 23 convolotional layers.***
54 |
55 | ## Original Implementation
56 |
57 | The original UNet model's implementation as described in the [paper](https://arxiv.org/pdf/1505.04597.pdf) can be found @ [UNet - Biomedical_Segmentation](https://github.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/blob/main/UNet%20-%20Biomedical_Segmentation.ipynb)
58 |
59 | ## Application of UNet
60 |
61 | An application of the UNet model is implemented @ [UNet In Action](https://github.com/sauravmishra1710/U-Net---Biomedical-Image-Segmentation/blob/main/UNet%20In%20Action.ipynb). The objective of the task is to segment and identify the cell nuclei. The dataset is taken from a kaggle chellenge - [2018 Data Science Bowl - Find the nuclei in divergent images to advance medical discovery](https://www.kaggle.com/c/data-science-bowl-2018/).
62 |
63 | ## Results and Conclusion
64 |
65 | Comparing the microscopic image, original mask and the predicted mask, it looks like the model is correctly able to segment the cell nuclei and generate the masks. The predictions for the masks as generated by the trained UNet model on the cellular nuclei images are as seen in the images below -
66 |
67 | ### Prediction 1
68 |
69 | 
70 |
71 | ### Prediction 2
72 |
73 | 
74 |
75 | ### Prediction 3
76 |
77 | 
78 |
79 | Though UNet was originally designed for bio-medical images, this model can be applied to any conputer vision segmentation task.
80 |
81 | ## Further Reading
82 | 1. UNet++: A Nested U-Net Architecture for Medical Image Segmentation. The original paper is available at https://arxiv.org/abs/1807.10165.
83 | 2. UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation. The original paper is available at https://arxiv.org/abs/1912.05074v2.
84 | 3. Attention U-Net: Learning Where to Look for the Pancreas. The original paper is available at https://arxiv.org/abs/1804.03999.
85 | 4. TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation. The original paper is available at https://arxiv.org/abs/1801.05746.
86 | 5. U-Net and its variants for medical image segmentation: theory and applications. The original paper is available at https://arxiv.org/abs/2011.01118.
87 |
88 |
89 | ## References
90 | 1. Ronneberger, O., Fischer, P. and Brox, T. (2015) ‘U-Net: Convolutional Networks for Biomedical Image Segmentation’, CoRR, abs/1505.0. Available at: http://arxiv.org/abs/1505.04597.
91 | 2. Implementing original U-Net from scratch using PyTorch by Abhishek Thakur. Available at https://www.youtube.com/watch?v=u1loyDCoGbE
92 |
--------------------------------------------------------------------------------
/UnetUtils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.models import Model, Sequential
3 | from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, Cropping2D, Concatenate
4 |
5 | class UnetUtils():
6 |
7 | """
8 | Unet Model design utillities framework.
9 |
10 | This module provides a convenient way to create different layers/blocks
11 | which the UNet network is based upon. It consists of a contracting
12 | path and an expansive path. Both these paths are joined by a bottleneck block.
13 |
14 | The different blocks involved in the design of the network can be referenced @
15 | U-Net: Convolutional Networks for Biomedical Image Segmentation
16 |
17 | Source:
18 | https://arxiv.org/pdf/1505.04597
19 | """
20 |
21 | def __init__(self):
22 | pass
23 |
24 | def contracting_block(self, input_layer, filters, padding, kernel_size = 3):
25 |
26 | """
27 | UNet Contracting block
28 | Perform two unpadded convolutions with a specified number of filters and downsample
29 | through max-pooling.
30 |
31 | Args:
32 | input_layer: the input layer on which the current layers should work upon.
33 | filters (int): Number of filters in convolution.
34 | kernel_size (int/tuple): Index of block. Default is 3.
35 | padding ("valid" or "same"): Default is "valid" (no padding involved).
36 |
37 | Return:
38 | Tuple of convolved ``inputs`` after and before downsampling
39 | """
40 |
41 | # two 3x3 convolutions (unpadded convolutions), each followed by
42 | # a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
43 | # for downsampling.
44 | conv = Conv2D(filters = filters,
45 | kernel_size = kernel_size,
46 | activation = tf.nn.relu,
47 | padding = padding)(input_layer)
48 |
49 | conv = Conv2D(filters = filters,
50 | kernel_size = kernel_size,
51 | activation = tf.nn.relu,
52 | padding = padding)(conv)
53 |
54 | pool = MaxPooling2D(pool_size = 2,
55 | strides = 2)(conv)
56 |
57 | return conv, pool
58 |
59 | def bottleneck_block(self, input_layer, filters, padding, kernel_size = 3, strides = 1):
60 |
61 | """
62 | UNet bottleneck block
63 |
64 | Performs 2 unpadded convolutions with a specified number of filters.
65 |
66 | Args:
67 | input_layer: the input layer on which the current layers should work upon.
68 | filters (int): Number of filters in convolution.
69 | kernel_size (int/tuple): Index of block. Default is 3.
70 | padding ("valid" or "same"): Default is "valid" (no padding involved).
71 | strides: An integer or tuple/list of 2 integers, specifying the strides
72 | of the convolution along the height and width. Default is 1.
73 | Return:
74 | The convolved ``inputs``.
75 | """
76 |
77 | # two 3x3 convolutions (unpadded convolutions), each followed by
78 | # a rectified linear unit (ReLU)
79 | conv = Conv2D(filters = filters,
80 | kernel_size = kernel_size,
81 | padding = padding,
82 | strides = strides,
83 | activation = tf.nn.relu)(input_layer)
84 |
85 | conv = Conv2D(filters = filters,
86 | kernel_size = kernel_size,
87 | padding = padding,
88 | strides = strides,
89 | activation = tf.nn.relu)(conv)
90 |
91 | return conv
92 |
93 | def expansive_block(self, input_layer, skip_conn_layer, filters, padding, kernel_size = 3, strides = 1):
94 |
95 | """
96 | UNet expansive (upsample) block.
97 |
98 | Transpose convolution which doubles the spatial dimensions (height and width)
99 | of the incoming feature maps and creates the skip connections with the corresponding
100 | feature maps from the contracting (downsample) path. These skip connections bring the feature maps
101 | from earlier layers helping the network to generate better semantic feature maps.
102 |
103 | Perform two unpadded convolutions with a specified number of filters
104 | and upsamples the incomming feature map.
105 |
106 | Args:
107 | input_layer: the input layer on which the current layers should work upon.
108 | skip_connection: The feature map from the contracting (downsample) path from which the
109 | skip connection has to be created.
110 | filters (int): Number of filters in convolution.
111 | kernel_size (int/tuple): Index of block. Default is 3.
112 | padding ("valid" or "same"): Default is "valid" (no padding involved).
113 | strides: An integer or tuple/list of 2 integers, specifying the strides
114 | of the convolution along the height and width. Default is 1.
115 |
116 | Return:
117 | The upsampled feature map.
118 | """
119 |
120 | # up sample the feature map using transpose convolution operations.
121 | transConv = Conv2DTranspose(filters = filters,
122 | kernel_size = (2, 2),
123 | strides = 2,
124 | padding = padding)(input_layer)
125 |
126 | # crop the source feature map so that the skip connection can be established.
127 | # the original paper implemented unpadded convolutions. So cropping is necessary
128 | # due to the loss of border pixels in every convolution.
129 | # establish the skip connections.
130 | if padding == "valid":
131 | cropped = self.crop_tensor(skip_conn_layer, transConv)
132 | concat = Concatenate()([transConv, cropped])
133 | else:
134 | concat = Concatenate()([transConv, skip_conn_layer])
135 |
136 | # two 3x3 convolutions, each followed by a ReLU
137 | up_conv = Conv2D(filters = filters,
138 | kernel_size = kernel_size,
139 | padding = padding,
140 | activation = tf.nn.relu)(concat)
141 |
142 | up_conv = Conv2D(filters = filters,
143 | kernel_size = kernel_size,
144 | padding = padding,
145 | activation = tf.nn.relu)(up_conv)
146 |
147 | return up_conv
148 |
149 | def crop_tensor(self, source_tensor, target_tensor):
150 |
151 | """
152 | Center crops the source tensor to the size of the target tensor size.
153 | The tensor shape format is [batchsize, height, width, channels]
154 |
155 | Args:
156 | source_tensor: the tensor that is to be cropped.
157 | target_tensor: the tensor to whose size the
158 | source needs to be cropped to.
159 |
160 | Return:
161 | the cropped version of the source tensor.
162 |
163 | """
164 |
165 | target_tensor_size = target_tensor.shape[2]
166 | source_tensor_size = source_tensor.shape[2]
167 |
168 | # calculate the delta to ensure correct cropping.
169 | delta = source_tensor_size - target_tensor_size
170 | delta = delta // 2
171 |
172 | cropped_source = source_tensor[:, delta:source_tensor_size - delta, delta:source_tensor_size - delta, :]
173 |
174 | return cropped_source
--------------------------------------------------------------------------------