├── MANIFEST.in ├── README.md ├── example_images ├── 3VV.tiff ├── 4CH.tiff ├── ABDOMINAL.tiff ├── BRAIN-CB.tiff ├── BRAIN-TV.tiff ├── FEMUR.tiff ├── KIDNEYS.tiff ├── LIPS.tiff ├── LVOT.tiff ├── PROFILE.tiff ├── RVOT.tiff ├── SPINE-CORONAL.tiff └── SPINE-SAGITTAL.tiff ├── setup.py ├── sononet ├── SonoNet16.pth ├── SonoNet32.pth ├── SonoNet64.pth ├── __init__.py └── sononet.py └── test.py /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include sononet/SonoNet16.pth 2 | include sononet/SonoNet32.pth 3 | include sononet/SonoNet64.pth 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch SonoNet 2 | 3 | **Disclaimer** 4 | These files come without any warranty! 5 | In particular, there might be unforeseen differences to the original implementation. 6 | 7 | ### About this repository 8 | 9 | This is a PyTorch implementation of SonoNet: 10 | 11 | Baumgartner et al., "Real-Time Detection and Localisation of Fetal Standard Scan Planes in 2D Freehand Ultrasound", arXiv preprint:1612.05601 (2016) 12 | 13 | This repository is based on https://github.com/baumgach/SonoNet-weights which provides a theano+lasagne implementation. 14 | 15 | ### Files 16 | sononet/sononet.py: 17 | PyTorch implementation of the original models.py file. 18 | 19 | sononet/SonoNet16.pth, sononet/SonoNet32.pth, sononet/SonoNet64.pth: 20 | The original pretrained weights converted into PyTorch format. 21 | 22 | test.py: 23 | Modified version of the original example.py file. This file runs classification on the examples images. 24 | 25 | ### Dependencies 26 | NumPy, Pillow, Matplotlib, PyTorch. 27 | Tested with PyTorch 0.4.0 and 1.3.1. 28 | 29 | ### Installing as Python module 30 | 31 | After installing the dependencies, run 32 | ``` 33 | cd SonoNet_PyTorch 34 | pip install . 35 | ``` 36 | 37 | ### Usage 38 | After installing the dependencies, classify the example example images with: 39 | ``` 40 | python SonoNet_Pytorch/test.py 41 | ``` 42 | 43 | After installing as Python module (see above), import SonoNet with: 44 | ``` 45 | from sononet import SonoNet 46 | ``` 47 | -------------------------------------------------------------------------------- /example_images/3VV.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/3VV.tiff -------------------------------------------------------------------------------- /example_images/4CH.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/4CH.tiff -------------------------------------------------------------------------------- /example_images/ABDOMINAL.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/ABDOMINAL.tiff -------------------------------------------------------------------------------- /example_images/BRAIN-CB.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/BRAIN-CB.tiff -------------------------------------------------------------------------------- /example_images/BRAIN-TV.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/BRAIN-TV.tiff -------------------------------------------------------------------------------- /example_images/FEMUR.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/FEMUR.tiff -------------------------------------------------------------------------------- /example_images/KIDNEYS.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/KIDNEYS.tiff -------------------------------------------------------------------------------- /example_images/LIPS.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/LIPS.tiff -------------------------------------------------------------------------------- /example_images/LVOT.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/LVOT.tiff -------------------------------------------------------------------------------- /example_images/PROFILE.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/PROFILE.tiff -------------------------------------------------------------------------------- /example_images/RVOT.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/RVOT.tiff -------------------------------------------------------------------------------- /example_images/SPINE-CORONAL.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/SPINE-CORONAL.tiff -------------------------------------------------------------------------------- /example_images/SPINE-SAGITTAL.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/example_images/SPINE-SAGITTAL.tiff -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from distutils.core import setup 3 | 4 | setup( 5 | name='SonoNet_PyTorch', 6 | version='0.1', 7 | packages=['sononet',], 8 | include_package_data=True, 9 | long_description=open('README.md').read(), 10 | ) 11 | -------------------------------------------------------------------------------- /sononet/SonoNet16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/sononet/SonoNet16.pth -------------------------------------------------------------------------------- /sononet/SonoNet32.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/sononet/SonoNet32.pth -------------------------------------------------------------------------------- /sononet/SonoNet64.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdroste/SonoNet_PyTorch/58e9f636be6fbfd87fbd1a27520b872828afa218/sononet/SonoNet64.pth -------------------------------------------------------------------------------- /sononet/__init__.py: -------------------------------------------------------------------------------- 1 | from .sononet import SonoNet -------------------------------------------------------------------------------- /sononet/sononet.py: -------------------------------------------------------------------------------- 1 | """PyTorch implementation of SonoNet. 2 | 3 | Baumgartner et al., "Real-Time Detection and Localisation of Fetal Standard 4 | Scan Planes in 2D Freehand Ultrasound", arXiv preprint:1612.05601 (2016) 5 | 6 | This repository is based on https://github.com/baumgach/SonoNet-weights which 7 | provides a theano+lasagne implementation. 8 | """ 9 | import math 10 | import os 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class SonoNet(nn.Module): 18 | """Real-Time Detection of Freehand Fetal Ultrasound Standard Scan Planes 19 | 20 | PyTorch implementation of the original models.py file, plus functions to 21 | load and convert the lasagne weights to a PyTorch state_dict. 22 | 23 | Args: 24 | config (str): Selects the architecture. 25 | Options are 'SN16', 'SN32' or 'SN64' 26 | num_labels (int, optional): Length of output vector after adaption. 27 | Default is 14. Ignored if features_only=True 28 | weights (bool, 0 or string): Select weight initialization. 29 | True: Load weights from default *.pth weight file. 30 | False: No weights are initialized. 31 | 0: Standard random weight initialization. 32 | str: Pass your own weight file. 33 | Default is True. 34 | features_only (bool, optional): If True, only feature layers are 35 | initialized and the forward method returns the features. 36 | Default is False. 37 | 38 | Attributes: 39 | feature_channels (int): Number of feature channels. 40 | features (torch.nn.Sequential): Feature extraction CNN 41 | adaption (torch.nn.Sequential): Adaption layers for classification 42 | 43 | Examples:: 44 | >>> net = sononet.SonoNet('SN64').eval().cuda() 45 | >>> outputs = net(x) 46 | 47 | >>> encoder = 48 | sononet.SonoNet('SN64', features_only=True).eval().cuda() 49 | >>> features = encoder(x) 50 | 51 | Note: 52 | Inputs into the forward methods must be preprocessed as shown test.py 53 | """ 54 | 55 | feature_cfg_dict = { 56 | 'SN16': [16, 16, 'M', 32, 32, 'M', 64, 64, 64, 'M', 57 | 128, 128, 128, 'M', 128, 128, 128], 58 | 'SN32': [32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 59 | 256, 256, 256, 'M', 256, 256, 256], 60 | 'SN64': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 61 | 512, 512, 512, 'M', 512, 512, 512] 62 | } 63 | 64 | def __init__(self, config, num_labels=14, weights=True, 65 | features_only=False, in_channels=1): 66 | super().__init__() 67 | self.config = config 68 | self.feature_cfg = self.feature_cfg_dict[config] 69 | self.feature_channels = self.feature_cfg[-1] 70 | self.weights = weights 71 | self.features_only = features_only 72 | self.features = self._make_feature_layers(self.feature_cfg, in_channels) 73 | if not features_only: 74 | self.adaption_channels = self.feature_channels // 2 75 | self.num_labels = num_labels 76 | self.adaption = self._make_adaption_layer( 77 | self.feature_channels, self.adaption_channels, self.num_labels) 78 | self.set_weights(weights) 79 | 80 | def forward(self, x): 81 | x = self.features(x) 82 | if not self.features_only: 83 | x = self.adaption(x) 84 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) 85 | x = F.softmax(x, dim=1) 86 | return x 87 | 88 | @staticmethod 89 | def _make_adaption_layer(feature_channels, adaption_channels, num_labels): 90 | return nn.Sequential( 91 | nn.Conv2d(feature_channels, 92 | adaption_channels, 1, bias=False), 93 | nn.BatchNorm2d(adaption_channels), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(adaption_channels, num_labels, 1, bias=False), 96 | nn.BatchNorm2d(num_labels), 97 | ) 98 | 99 | def set_weights(self, weights): 100 | if weights is not None: 101 | if weights: 102 | if not isinstance(weights, str): 103 | weights = os.path.join( 104 | os.path.dirname(__file__), 105 | 'SonoNet{}.pth'.format(self.config[2:])) 106 | self.load_weights(weights) 107 | else: 108 | self.apply(self._initialize_weights) 109 | 110 | @staticmethod 111 | def _initialize_weights(m): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | if m.bias is not None: 116 | m.bias.data.zero_() 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | elif isinstance(m, nn.Linear): 121 | m.weight.data.normal_(0, 0.01) 122 | m.bias.data.zero_() 123 | 124 | @staticmethod 125 | def _conv_layer(in_channels, out_channels): 126 | layer = [nn.Conv2d(in_channels, out_channels, 127 | kernel_size=3, padding=1, bias=False), 128 | nn.BatchNorm2d(out_channels, eps=1e-4), 129 | nn.ReLU(inplace=True)] 130 | return nn.Sequential(*layer) 131 | 132 | @classmethod 133 | def _make_feature_layers(cls, feature_cfg, in_channels): 134 | layers = [] 135 | conv_layers = [] 136 | for v in feature_cfg: 137 | if v == 'M': 138 | conv_layers.append(nn.MaxPool2d(2)) 139 | layers.append(nn.Sequential(*conv_layers)) 140 | conv_layers = [] 141 | else: 142 | conv_layers.append(cls._conv_layer(in_channels, v)) 143 | in_channels = v 144 | layers.append(nn.Sequential(*conv_layers)) 145 | return nn.Sequential(*layers) 146 | 147 | @staticmethod 148 | def process_lasagne_weights(weights): 149 | order = [0, 2, 1, 3, 4] 150 | weights = [weights[5 * (idx // 5) + order[idx % 5]] 151 | for idx in range(len(weights))] 152 | weights[4::5] = [np.power(w, -2) - 1e-4 for w in weights[4::5]] 153 | return weights 154 | 155 | @classmethod 156 | def load_lasagne_weights(cls, filename, state): 157 | with np.load(filename) as f: 158 | weight_data = [f['arr_%d' % i] for i in range(len(f.files))] 159 | weight_data = cls.process_lasagne_weights(weight_data) 160 | offset = 0 161 | for idx, layer in enumerate(state): 162 | if 'num_batches_tracked' in layer: 163 | # TODO: Initialize to different value? 164 | offset -= 1 165 | continue 166 | # assert tuple(state[layer].shape) == weight_data[idx].shape 167 | if not tuple(state[layer].shape) == weight_data[idx + offset].shape: 168 | pass 169 | state[layer] = torch.from_numpy(weight_data[idx + offset].copy()) 170 | 171 | @staticmethod 172 | def save_state(state, filename): 173 | if (not os.path.isfile(filename) or 174 | input('Overwrite state file?\nHit [y] to continue: ') == 'y'): 175 | torch.save(state, filename) 176 | 177 | def load_weights(self, weights): 178 | _, extension = os.path.splitext(weights) 179 | if extension == '.npz': 180 | state = self.state_dict() 181 | state = self.load_lasagne_weights(weights, state) 182 | self.save_state( 183 | state, os.path.join(os.path.dirname(__file__), 'SonoNet{}.pth'.format(self.config[2:]))) 184 | elif extension == '.pth': 185 | state = torch.load(weights) 186 | else: 187 | raise ValueError('Unknown weight file extension {}' 188 | .format(extension)) 189 | # Check if input channels match 190 | # for key in self.state_dict(): 191 | # size = self.state_dict()[key].size() 192 | # if state[key].size() != size: 193 | # expand = state[key].expand(size) 194 | # state[key] = expand * state[key].norm() / expand.norm() 195 | self.load_state_dict(state, strict=True) 196 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | test.py: Modified version of the original example.py file. 3 | This file runs classification on the examples images. 4 | ''' 5 | 6 | import glob 7 | import matplotlib.pyplot as plt 8 | from matplotlib.pyplot import imread 9 | import numpy as np 10 | from PIL import Image 11 | import sononet 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | 16 | disclaimer = ''' 17 | This is a PyTorch implementation of SonoNet: 18 | 19 | Baumgartner et al., "Real-Time Detection and Localisation of Fetal Standard 20 | Scan Planes in 2D Freehand Ultrasound", arXiv preprint:1612.05601 (2016) 21 | 22 | This repository is based on 23 | https://github.com/baumgach/SonoNet-weights 24 | which provides a theano+lasagne implementation. 25 | ''' 26 | print(disclaimer) 27 | 28 | # Configuration 29 | network_name = 'SN64' # 'SN16', 'SN32' pr 'SN64' 30 | display_images = False # Whether or not to show the images during inference 31 | GPU_NR = 0 # Choose the device number of your GPU 32 | 33 | # If you provide the original lasagne parameter file, it will be converted to a 34 | # pytorch state_dict and saved as *.pth. 35 | # In this repository, the converted parameters are already provided. 36 | weights = True 37 | # weights = ('/local/ball4916/dphil/SonoNet/SonoNet-weights/SonoNet{}.npz' 38 | # .format(network_name[2:])) 39 | 40 | 41 | # Other parameters 42 | crop_range = [(115, 734), (81, 874)] # [(top, bottom), (left, right)] 43 | input_size = [224, 288] 44 | image_path = './example_images/*.tiff' 45 | label_names = ['3VV', 46 | '4CH', 47 | 'Abdominal', 48 | 'Background', 49 | 'Brain (Cb.)', 50 | 'Brain (Tv.)', 51 | 'Femur', 52 | 'Kidneys', 53 | 'Lips', 54 | 'LVOT', 55 | 'Profile', 56 | 'RVOT', 57 | 'Spine (cor.)', 58 | 'Spine (sag.) '] 59 | 60 | 61 | def imcrop(image, crop_range): 62 | """ Crop an image to a crop range """ 63 | return image[crop_range[0][0]:crop_range[0][1], 64 | crop_range[1][0]:crop_range[1][1], ...] 65 | 66 | 67 | def prepare_inputs(): 68 | input_list = [] 69 | for filename in glob.glob(image_path): 70 | 71 | # prepare images 72 | image = imread(filename) # read 73 | image = imcrop(image, crop_range) # crop 74 | image = np.array(Image.fromarray(image).resize(input_size, resample=Image.BICUBIC)) 75 | image = np.mean(image, axis=2) # convert to gray scale 76 | 77 | # convert to 4D tensor of type float32 78 | image_data = np.float32(np.reshape(image, 79 | (1, 1, image.shape[0], 80 | image.shape[1]))) 81 | 82 | # normalise images by substracting mean and dividing by standard dev. 83 | mean = image_data.mean() 84 | std = image_data.std() 85 | image_data = np.array(255.0 * np.divide(image_data - mean, std), 86 | dtype=np.float32) 87 | # Note that the 255.0 scale factor is arbitrary 88 | # it is necessary because the network was trained 89 | # like this, but the same results would have been 90 | # achieved without this factor for training. 91 | 92 | input_list.append(image_data) 93 | 94 | return input_list 95 | 96 | 97 | def main(): 98 | 99 | print('Loading network') 100 | net = sononet.SonoNet(network_name, weights=weights) 101 | net.eval() 102 | 103 | print('Moving to GPU:') 104 | torch.cuda.device(GPU_NR) 105 | print(torch.cuda.get_device_name(torch.cuda.current_device())) 106 | net.cuda() 107 | 108 | print("\nPredictions using {}:".format(network_name)) 109 | input_list = prepare_inputs() 110 | for image, file_name in zip(input_list, glob.glob(image_path)): 111 | 112 | x = Variable(torch.from_numpy(image).cuda()) 113 | outputs = net(x) 114 | confidence, prediction = torch.max(outputs.data, 1) 115 | 116 | # True labels are obtained from file name. 117 | true_label = file_name.split('/')[-1][0:-5] 118 | print(" - {} (conf: {:.2f}, true label: {})" 119 | .format(label_names[prediction[0]], 120 | confidence[0], true_label)) 121 | 122 | if display_images: 123 | plt.imshow(np.squeeze(image), cmap='gray') 124 | plt.show() 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | --------------------------------------------------------------------------------