├── p2ch07 ├── __init__.py ├── vis.py └── dsets.py ├── p2ch08 ├── __init__.py ├── prepcache.py ├── model.py ├── vis.py ├── dsets.py └── training.py ├── p2ch09 ├── __init__.py ├── prepcache.py ├── model.py ├── vis.py ├── dsets.py └── training.py ├── p2ch12 ├── __init__.py ├── prepcache.py ├── model.py ├── vis.py ├── dsets.py └── training.py ├── p2ch13 ├── __init__.py ├── model_seg.py ├── prepcache.py ├── model_cls.py ├── vis.py └── diagnose.py ├── util ├── __init__.py ├── logconf.py ├── disk.py ├── unet.py └── util.py ├── README.md ├── .gitignore ├── p1ch2 ├── 1_making_sure_things_work.ipynb └── 4_mnist.ipynb ├── p1ch4 ├── 5_image_dog.ipynb ├── 7_video_cockatoo.ipynb ├── 2_time_series_bikes.ipynb ├── 3_text_jane_austin.ipynb ├── 4_audio_chirp.ipynb └── 1_tabular_wine.ipynb ├── p1ch5 ├── 2_autograd.ipynb └── 3_optimizers.ipynb └── p1ch6 └── 3_nn_module_subclassing.ipynb /p2ch07/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /p2ch08/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /p2ch09/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /p2ch12/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /p2ch13/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-learning-with-pytorch 2 | Code related to the book Deep Learning with PyTorch 3 | -------------------------------------------------------------------------------- /util/logconf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | 4 | root_logger = logging.getLogger() 5 | root_logger.setLevel(logging.INFO) 6 | 7 | # Some libraries attempt to add their own root logger handlers. This is 8 | # annoying and so we get rid of them. 9 | # I commented this part. Is better to not modify the iterable object 10 | # of a for loop inside it, the code below should do the trick. 11 | # for handler in list(root_logger.handlers): 12 | # root_logger.removeHandler(handler) 13 | root_logger.handlers = [] 14 | 15 | logfmt_str = ("%(asctime)s %(levelname)-8s pid:%(process)d %(name)s" 16 | ":%(lineno)03d:%(funcName)s %(message)s") 17 | formatter = logging.Formatter(logfmt_str) 18 | 19 | # StreamHandler is located in the core logging package, sends logging 20 | # output to streams 21 | streamHandler = logging.StreamHandler() 22 | streamHandler.setFormatter(formatter) 23 | streamHandler.setLevel(logging.DEBUG) 24 | 25 | root_logger.addHandler(streamHandler) 26 | -------------------------------------------------------------------------------- /p2ch13/model_seg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn 4 | 5 | from util.logconf import logging 6 | from util.unet import UNet 7 | 8 | log = logging.getLogger(__name__) 9 | # log = log.setLevel(logging.WARN) 10 | # log = log.setLevel(logging.INFO) 11 | log.setLevel(logging.DEBUG) 12 | 13 | 14 | class UnetWrapper(nn.Module): 15 | def __init__(self, **kwargs): 16 | super().__init__() 17 | 18 | self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels']) 19 | self.unet = UNet(**kwargs) 20 | 21 | self.final = nn.Sigmoid() 22 | 23 | self._init_weights() 24 | 25 | def _init_weights(self): 26 | init_set = {nn.Linear, nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, 27 | nn.ConvTranspose3d} 28 | for m in self.modules(): 29 | if type(m) in init_set: 30 | nn.init.kaiming_normal_(m.weight.data, mode='fan_out', 31 | nonlinearity='relu', a=0) 32 | if m.bias is not None: 33 | _, fan_out = nn.init._calculate_fan_in_and_fan_out( 34 | m.weight.data) 35 | bound = 1 / math.sqrt(fan_out) 36 | nn.init.normal_(m.bias, -bound, bound) 37 | 38 | def forward(self, input_batch): 39 | bn_output = self.input_batchnorm(input_batch) 40 | un_output = self.unet(bn_output) 41 | fn_output = self.final(un_output) 42 | 43 | return fn_output 44 | -------------------------------------------------------------------------------- /p2ch13/prepcache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from util.logconf import logging 7 | from util.util import enumerateWithEstimate 8 | from .dsets import PrepcacheLunaDataset 9 | 10 | 11 | log = logging.getLogger(__name__) 12 | # log.setLevel(logging.WARN) 13 | log.setLevel(logging.INFO) 14 | # log.setLevel(logging.DEBUG) 15 | 16 | 17 | class LunaPrepCacheApp(): 18 | def __init__(self, sys_argv=None): 19 | if sys_argv is None: 20 | sys_argv = sys.argv[1:] 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | '--batch-size', 25 | help="Batch size to use for training", 26 | default=1024, 27 | type=int 28 | ) 29 | parser.add_argument( 30 | '--num-workers', 31 | help="Number of worker processes for background data loading", 32 | default=8, 33 | type=int 34 | ) 35 | self.cli_args = parser.parse_args(sys_argv) 36 | 37 | def main(self): 38 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 39 | 40 | self.prep_dl = DataLoader( 41 | PrepcacheLunaDataset(sortby_str='series_uid'), 42 | batch_size=self.cli_args.batch_size, 43 | num_workers=self.cli_args.num_workers 44 | ) 45 | 46 | batch_iter = enumerateWithEstimate( 47 | self.prep_dl, 48 | "Stuffing cache", 49 | start_ndx=self.prep_dl.num_workers 50 | ) 51 | 52 | for _, _ in batch_iter: 53 | pass 54 | 55 | if __name__ == '__main__': 56 | LunaPrepCacheApp().main() 57 | -------------------------------------------------------------------------------- /p2ch08/prepcache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import numpy as np 5 | 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.optim import SGD 9 | from torch.utils.data import DataLoader 10 | 11 | from util.util import enumerateWithEstimate 12 | from .dsets import LunaDataset 13 | from util.logconf import logging 14 | from .model import LunaModel 15 | 16 | log = logging.getLogger(__name__) 17 | log.setLevel(logging.INFO) 18 | 19 | 20 | class LunaPrepCacheApp(): 21 | @classmethod 22 | def __init__(self, sys_argv=None): 23 | if sys_argv is None: 24 | sys_argv = sys.argv[1:] 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--batch-size', 28 | help="Batch size to use for training", 29 | default=1024, 30 | type=int) 31 | parser.add_argument('--num-workers', 32 | help="Number of worker processes for background data loading", 33 | default=8, 34 | type=int) 35 | self.cli_args = parser.parse_args(sys_argv) 36 | 37 | def main(self): 38 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 39 | 40 | self.prep_dl = DataLoader( 41 | LunaDataset(sortby_str='series_uid'), 42 | batch_size=self.cli_args.batch_size, 43 | num_workers=self.cli_args.num_workers 44 | ) 45 | 46 | batch_iter = enumerateWithEstimate( 47 | self.prep_dl, 48 | "Stuffing cache", 49 | start_ndx=self.prep_dl.num_workers 50 | ) 51 | 52 | for _ in batch_iter: 53 | pass 54 | 55 | if __name__ == '__main__': 56 | sys.exit(LunaPrepCacheApp().main() or 0) 57 | -------------------------------------------------------------------------------- /p2ch09/prepcache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import numpy as np 5 | 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.optim import SGD 9 | from torch.utils.data import DataLoader 10 | 11 | from util.util import enumerateWithEstimate 12 | from .dsets import LunaDataset 13 | from util.logconf import logging 14 | from .model import LunaModel 15 | 16 | log = logging.getLogger(__name__) 17 | log.setLevel(logging.INFO) 18 | 19 | 20 | class LunaPrepCacheApp(): 21 | @classmethod 22 | def __init__(self, sys_argv=None): 23 | if sys_argv is None: 24 | sys_argv = sys.argv[1:] 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | '--batch-size', 29 | help="Batch size to use for training", 30 | default=1024, 31 | type=int 32 | ) 33 | parser.add_argument( 34 | '--num-workers', 35 | help="Number of worker processes for background data loading", 36 | default=8, 37 | type=int 38 | ) 39 | self.cli_args = parser.parse_args(sys_argv) 40 | 41 | def main(self): 42 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 43 | 44 | self.prep_dl = DataLoader( 45 | LunaDataset(sortby_str='series_uid'), 46 | batch_size=self.cli_args.batch_size, 47 | num_workers=self.cli_args.num_workers 48 | ) 49 | 50 | batch_iter = enumerateWithEstimate( 51 | self.prep_dl, 52 | "Stuffing cache", 53 | start_ndx=self.prep_dl.num_workers 54 | ) 55 | 56 | for _ in batch_iter: 57 | pass 58 | 59 | if __name__ == '__main__': 60 | sys.exit(LunaPrepCacheApp().main() or 0) 61 | -------------------------------------------------------------------------------- /p2ch12/prepcache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import numpy as np 5 | 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.optim import SGD 9 | from torch.utils.data import DataLoader 10 | 11 | from util.util import enumerateWithEstimate 12 | from .dsets import LunaDataset 13 | from util.logconf import logging 14 | from .model import LunaModel 15 | 16 | log = logging.getLogger(__name__) 17 | # log.setLevel(logging.WARN) 18 | log.setLevel(logging.INFO) 19 | # log.setLevel(logging.DEBUG) 20 | 21 | 22 | class LunaPrepCacheApp(): 23 | @classmethod 24 | def __init__(self, sys_argv=None): 25 | if sys_argv is None: 26 | sys_argv = sys.argv[1:] 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | '--batch-size', 31 | help="Batch size to use for training", 32 | default=1024, 33 | type=int 34 | ) 35 | parser.add_argument( 36 | '--num-workers', 37 | help="Number of worker processes for background data loading", 38 | default=8, 39 | type=int 40 | ) 41 | self.cli_args = parser.parse_args(sys_argv) 42 | 43 | def main(self): 44 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 45 | 46 | self.prep_dl = DataLoader( 47 | LunaDataset(sortby_str='series_uid'), 48 | batch_size=self.cli_args.batch_size, 49 | num_workers=self.cli_args.num_workers 50 | ) 51 | 52 | batch_iter = enumerateWithEstimate( 53 | self.prep_dl, 54 | "Stuffing cache", 55 | start_ndx=self.prep_dl.num_workers 56 | ) 57 | 58 | for _ in batch_iter: 59 | pass 60 | 61 | if __name__ == '__main__': 62 | LunaPrepCacheApp().main() 63 | -------------------------------------------------------------------------------- /p2ch09/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from util.logconf import logging 5 | 6 | log = logging.getLogger(__name__) 7 | log.setLevel(logging.INFO) 8 | 9 | class LunaModel(nn.Module): 10 | def __init__(self, layer_count=4, in_channels=1, conv_channels=8): 11 | super().__init__() 12 | 13 | layer_list = [] 14 | for _ in range(layer_count): 15 | layer_list += [ 16 | nn.Conv3d(in_channels, conv_channels, kernel_size=3, 17 | padding=1, bias=False), 18 | nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this 19 | nn.LeakyReLU(inplace=True), # will assume plan ReLU 20 | nn.Dropout3d(p=0.2), # eli: will assume that p1ch6 doesn't use this 21 | 22 | nn.Conv3d(conv_channels, conv_channels, kernel_size=3, 23 | padding=1, bias=False), 24 | nn.BatchNorm3d(conv_channels), 25 | nn.LeakyReLU(inplace=True), 26 | nn.Dropout3d(p=0.2), 27 | 28 | nn.MaxPool3d(2, 2), 29 | # tag::model_init[] 30 | ] 31 | 32 | in_channels = conv_channels 33 | conv_channels *= 2 34 | 35 | self.convAndPool_seq = nn.Sequential(*layer_list) 36 | self.fullyConnected_layer = nn.Linear(512, 1) 37 | self.final = nn.Hardtanh(min_val=0.0, max_val=1.0) 38 | 39 | def forward(self, input_batch): 40 | conv_output = self.convAndPool_seq(input_batch) 41 | conv_flat = conv_output.view(conv_output.size(0), -1) 42 | 43 | try: 44 | classifier_output = self.fullyConnected_layer(conv_flat) 45 | except: 46 | log.debug(conv_flat.size()) 47 | raise 48 | 49 | classifier_output = self.final(classifier_output) 50 | 51 | return classifier_output 52 | -------------------------------------------------------------------------------- /p2ch08/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from util.logconf import logging 5 | 6 | log = logging.getLogger(__name__) 7 | log.setLevel(logging.DEBUG) 8 | 9 | class LunaModel(nn.Module): 10 | def __init__(self, layer_count=4, in_channels=1, conv_channels=8): 11 | super().__init__() 12 | 13 | layer_list = [] 14 | for _ in range(layer_count): 15 | layer_list += [ 16 | nn.Conv3d(in_channels, conv_channels, kernel_size=3, 17 | padding=1, bias=False), 18 | nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this 19 | nn.LeakyReLU(inplace=True), # will assume plan ReLU 20 | nn.Dropout3d(p=0.2), # eli: will assume that p1ch6 doesn't use this 21 | 22 | nn.Conv3d(conv_channels, conv_channels, kernel_size=3, 23 | padding=1, bias=False), 24 | nn.BatchNorm3d(conv_channels), 25 | nn.LeakyReLU(inplace=True), 26 | nn.Dropout3d(p=0.2), 27 | 28 | nn.MaxPool3d(2, 2), 29 | # tag::model_init[] 30 | ] 31 | 32 | in_channels = conv_channels 33 | conv_channels *= 2 34 | 35 | self.convAndPool_seq = nn.Sequential(*layer_list) 36 | self.fullyConnected_layer = nn.Linear(512, 1) 37 | self.final = nn.Hardtanh(min_val=0.0, max_val=1.0) 38 | 39 | def forward(self, input_batch): 40 | conv_output = self.convAndPool_seq(input_batch) 41 | conv_flat = conv_output.view(conv_output.size(0), -1) 42 | 43 | try: 44 | classifier_output = self.fullyConnected_layer(conv_flat) 45 | except: 46 | log.debug(conv_flat.size()) 47 | raise 48 | 49 | classifier_output = self.final(classifier_output) 50 | 51 | return classifier_output 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # data folder 107 | data 108 | data-unversioned 109 | 110 | # book 111 | Deep_Learning_with_PyTorch_*.pdf 112 | 113 | # original code 114 | dlwpt-code* 115 | -------------------------------------------------------------------------------- /p1ch2/1_making_sure_things_work.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/plain": [ 20 | "'1.3.1'" 21 | ] 22 | }, 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "torch.version.__version__" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "tensor([[2., 2., 2.],\n", 41 | " [2., 2., 2.],\n", 42 | " [2., 2., 2.]])" 43 | ] 44 | }, 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "a = torch.ones(3, 3)\n", 52 | "b = torch.ones(3, 3)\n", 53 | "\n", 54 | "a + b" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Cuda is not available.\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "if torch.cuda.is_available():\n", 72 | " a = a.to('cuda')\n", 73 | " b = b.to('cuda')\n", 74 | " a + b\n", 75 | "else:\n", 76 | " print(\"Cuda is not available.\")" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "Python 3", 90 | "language": "python", 91 | "name": "python3" 92 | }, 93 | "language_info": { 94 | "codemirror_mode": { 95 | "name": "ipython", 96 | "version": 3 97 | }, 98 | "file_extension": ".py", 99 | "mimetype": "text/x-python", 100 | "name": "python", 101 | "nbconvert_exporter": "python", 102 | "pygments_lexer": "ipython3", 103 | "version": "3.6.7" 104 | } 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 2 108 | } 109 | -------------------------------------------------------------------------------- /p2ch12/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn 4 | 5 | from util.logconf import logging 6 | 7 | log = logging.getLogger(__name__) 8 | # log.setLevel(logging.WARN) 9 | # log.setLevel(logging.INFO) 10 | log.setLevel(logging.DEBUG) 11 | 12 | class LunaModel(nn.Module): 13 | def __init__(self, in_channels=1, conv_channels=8): 14 | super().__init__() 15 | 16 | self.tail_batchnorm = nn.BatchNorm3d(1) 17 | 18 | self.block1 = LunaBlock(in_channels, conv_channels) 19 | self.block2 = LunaBlock(conv_channels, conv_channels * 2) 20 | self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4) 21 | self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8) 22 | 23 | self.head_linear = nn.Linear(1152, 2) 24 | self.head_softmax = nn.Softmax(dim=1) 25 | 26 | self._init_weights() 27 | 28 | def _init_weights(self): 29 | for m in self.modules(): 30 | if type(m) in {nn.Linear, nn.Conv3d, nn.Conv2d, 31 | nn.ConvTranspose2d, nn.ConvTranspose3d}: 32 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', 33 | nonlinearity='relu') 34 | if m.bias is not None: 35 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out( 36 | m.weight.data) 37 | bound = 1 / math.sqrt(fan_out) 38 | nn.init.normal_(m.bias, -bound, bound) 39 | 40 | def forward(self, input_batch): 41 | bn_output = self.tail_batchnorm(input_batch) 42 | 43 | block_out = self.block1(bn_output) 44 | block_out = self.block2(block_out) 45 | block_out = self.block3(block_out) 46 | block_out = self.block4(block_out) 47 | 48 | conv_flat = block_out.view(block_out.size(0), -1) 49 | linear_output = self.head_linear(conv_flat) 50 | 51 | return linear_output, self.head_softmax(linear_output) 52 | 53 | 54 | class LunaBlock(nn.Module): 55 | def __init__(self, in_channels, conv_channels): 56 | super().__init__() 57 | 58 | self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, 59 | padding=1, bias=True) 60 | self.relu1 = nn.ReLU(inplace=True) 61 | self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, 62 | padding=1, bias=True) 63 | self.relu2 = nn.ReLU(inplace=True) 64 | 65 | self.maxpool = nn.MaxPool3d(2, 2) 66 | 67 | def forward(self, input_batch): 68 | block_out = self.conv1(input_batch) 69 | block_out = self.relu1(block_out) 70 | block_out = self.conv2(block_out) 71 | block_out = self.relu2(block_out) 72 | 73 | return self.maxpool(block_out 74 | -------------------------------------------------------------------------------- /p2ch09/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from p2ch07.dsets import Ct, LunaDataset 6 | 7 | clim = (-1000.0, 300) 8 | 9 | 10 | def findMalignantSamples(start_ndx=0, limit=100): 11 | ds = LunaDataset() 12 | 13 | malignantSample_list = [] 14 | for sample_tup in ds.noduleInfo_list: 15 | if sample_tup[0]: 16 | malignantSample_list.append(sample_tup) 17 | 18 | if len(malignantSample_list) >= limit: 19 | break 20 | 21 | return malignantSample_list 22 | 23 | 24 | def showNodule(series_uid, batch_ndx=None): 25 | ds = LunaDataset(series_uid=series_uid) 26 | malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]] 27 | 28 | if batch_ndx is None: 29 | if malignant_list: 30 | batch_ndx = malignant_list[0] 31 | else: 32 | print("Warining: no malignant samples found; using first " 33 | "non-malignant sample.") 34 | batch_ndx = 0 35 | 36 | ct = Ct(series_uid) 37 | ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx] 38 | ct_a = ct_t[0].numpy() 39 | 40 | fig = plt.figure(figsize=(15, 25)) 41 | 42 | group_list = [ 43 | [9, 11, 13], 44 | [15, 16, 17], 45 | [19, 21, 23] 46 | ] 47 | 48 | subplot = fig.add_subplot(len(group_list) + 2, 3, 1) 49 | subplot.set_title("index {}".format(int(center_irc.index))) 50 | plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray') 51 | 52 | subplot = fig.add_subplot(len(group_list) + 2, 3, 2) 53 | subplot.set_title("row {}".format(int(center_irc.row))) 54 | plt.imshow(ct.ary[:, int(center_irc.row)], clim=clim, cmap='gray') 55 | 56 | subplot = fig.add_subplot(len(group_list) + 2, 3, 3) 57 | subplot.set_title("col {}".format(int(center_irc.col))) 58 | plt.imshow(ct.ary[:, :, int(center_irc.col)], clim=clim, cmap='gray') 59 | 60 | subplot = fig.add_subplot(len(group_list) + 2, 3, 4) 61 | subplot.set_title("index {}".format(int(center_irc.index))) 62 | plt.imshow(ct_a[ct_a.shape[0] // 2], clim=clim, cmap='gray') 63 | 64 | subplot = fig.add_subplot(len(group_list) + 2, 3, 5) 65 | subplot.set_title("row {}".format(int(center_irc.row))) 66 | plt.imshow(ct_a[:, ct_a.shape[1] // 2], clim=clim, cmap='gray') 67 | 68 | subplot = fig.add_subplot(len(group_list) + 2, 3, 6) 69 | subplot.set_title("col {}".format(int(center_irc.col))) 70 | plt.imshow(ct_a[:, :, ct_a.shape[2] // 2], clim=clim, cmap='gray') 71 | 72 | for row, index_list in enumerate(group_list): 73 | for col, index in enumerate(index_list): 74 | subplot = fig.add_subplot(len(group_list) + 2, 3, 75 | row * 3 + col + 7) 76 | subplot.set_title("slice {}".format(index)) 77 | plt.imshow(ct_a[index], clim=clim, cmap='gray') 78 | 79 | print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list) 80 | -------------------------------------------------------------------------------- /p2ch07/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from p2ch07.dsets import Ct, LunaDataset 6 | 7 | clim = (0.0, 1.3) 8 | 9 | 10 | def findMalignantSamples(start_ndx=0, limit=100): 11 | ds = LunaDataset() 12 | 13 | malignantSample_list = [] 14 | for sample_tup in ds.noduleInfo_list: 15 | if sample_tup[0]: 16 | malignantSample_list.append(sample_tup) 17 | 18 | if len(malignantSample_list) >= limit: 19 | break 20 | 21 | return malignantSample_list 22 | 23 | 24 | def showNodule(series_uid, batch_ndx=None): 25 | ds = LunaDataset(series_uid=series_uid) 26 | malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]] 27 | 28 | if batch_ndx is None: 29 | if malignant_list: 30 | batch_ndx = malignant_list[0] 31 | else: 32 | print("Warining: no malignant samples found; using first " 33 | "non-malignant sample.") 34 | batch_ndx = 0 35 | 36 | ct = Ct(series_uid) 37 | ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx] 38 | ct_ary = ct_tensor[0].numpy() 39 | 40 | fig = plt.figure(figsize=(15, 25)) 41 | 42 | group_list = [ 43 | [9, 11, 13], 44 | [15, 16, 17], 45 | [19, 21, 23] 46 | ] 47 | 48 | subplot = fig.add_subplot(len(group_list) + 2, 3, 1) 49 | subplot.set_title("index {}".format(int(center_irc.index))) 50 | plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray') 51 | 52 | subplot = fig.add_subplot(len(group_list) + 2, 3, 2) 53 | subplot.set_title("row {}".format(int(center_irc.row))) 54 | plt.imshow(ct.ary[:, int(center_irc.row)], clim=clim, cmap='gray') 55 | 56 | subplot = fig.add_subplot(len(group_list) + 2, 3, 3) 57 | subplot.set_title("col {}".format(int(center_irc.col))) 58 | plt.imshow(ct.ary[:, :, int(center_irc.col)], clim=clim, cmap='gray') 59 | 60 | subplot = fig.add_subplot(len(group_list) + 2, 3, 4) 61 | subplot.set_title("index {}".format(int(center_irc.index))) 62 | plt.imshow(ct_ary[ct_ary.shape[0] // 2], clim=clim, cmap='gray') 63 | 64 | subplot = fig.add_subplot(len(group_list) + 2, 3, 5) 65 | subplot.set_title("row {}".format(int(center_irc.row))) 66 | plt.imshow(ct_ary[:, ct_ary.shape[1] // 2], clim=clim, cmap='gray') 67 | 68 | subplot = fig.add_subplot(len(group_list) + 2, 3, 6) 69 | subplot.set_title("col {}".format(int(center_irc.col))) 70 | plt.imshow(ct_ary[:, :, ct_ary.shape[2] // 2], clim=clim, cmap='gray') 71 | 72 | for row, index_list in enumerate(group_list): 73 | for col, index in enumerate(index_list): 74 | subplot = fig.add_subplot(len(group_list) + 2, 3, 75 | row * 3 + col + 7) 76 | subplot.set_title("slice {}".format(index)) 77 | plt.imshow(ct_ary[index], clim=clim, cmap='gray') 78 | 79 | print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list) 80 | -------------------------------------------------------------------------------- /p2ch08/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from p2ch07.dsets import Ct, LunaDataset 6 | 7 | clim = (0.0, 1.3) 8 | 9 | 10 | def findMalignantSamples(start_ndx=0, limit=100): 11 | ds = LunaDataset() 12 | 13 | malignantSample_list = [] 14 | for sample_tup in ds.noduleInfo_list: 15 | if sample_tup[0]: 16 | malignantSample_list.append(sample_tup) 17 | 18 | if len(malignantSample_list) >= limit: 19 | break 20 | 21 | return malignantSample_list 22 | 23 | 24 | def showNodule(series_uid, batch_ndx=None): 25 | ds = LunaDataset(series_uid=series_uid) 26 | malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]] 27 | 28 | if batch_ndx is None: 29 | if malignant_list: 30 | batch_ndx = malignant_list[0] 31 | else: 32 | print("Warining: no malignant samples found; using first " 33 | "non-malignant sample.") 34 | batch_ndx = 0 35 | 36 | ct = Ct(series_uid) 37 | ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx] 38 | ct_ary = ct_tensor[0].numpy() 39 | 40 | fig = plt.figure(figsize=(15, 25)) 41 | 42 | group_list = [ 43 | [9, 11, 13], 44 | [15, 16, 17], 45 | [19, 21, 23] 46 | ] 47 | 48 | subplot = fig.add_subplot(len(group_list) + 2, 3, 1) 49 | subplot.set_title("index {}".format(int(center_irc.index))) 50 | plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray') 51 | 52 | subplot = fig.add_subplot(len(group_list) + 2, 3, 2) 53 | subplot.set_title("row {}".format(int(center_irc.row))) 54 | plt.imshow(ct.ary[:, int(center_irc.row)], clim=clim, cmap='gray') 55 | 56 | subplot = fig.add_subplot(len(group_list) + 2, 3, 3) 57 | subplot.set_title("col {}".format(int(center_irc.col))) 58 | plt.imshow(ct.ary[:, :, int(center_irc.col)], clim=clim, cmap='gray') 59 | 60 | subplot = fig.add_subplot(len(group_list) + 2, 3, 4) 61 | subplot.set_title("index {}".format(int(center_irc.index))) 62 | plt.imshow(ct_ary[ct_ary.shape[0] // 2], clim=clim, cmap='gray') 63 | 64 | subplot = fig.add_subplot(len(group_list) + 2, 3, 5) 65 | subplot.set_title("row {}".format(int(center_irc.row))) 66 | plt.imshow(ct_ary[:, ct_ary.shape[1] // 2], clim=clim, cmap='gray') 67 | 68 | subplot = fig.add_subplot(len(group_list) + 2, 3, 6) 69 | subplot.set_title("col {}".format(int(center_irc.col))) 70 | plt.imshow(ct_ary[:, :, ct_ary.shape[2] // 2], clim=clim, cmap='gray') 71 | 72 | for row, index_list in enumerate(group_list): 73 | for col, index in enumerate(index_list): 74 | subplot = fig.add_subplot(len(group_list) + 2, 3, 75 | row * 3 + col + 7) 76 | subplot.set_title("slice {}".format(index)) 77 | plt.imshow(ct_ary[index], clim=clim, cmap='gray') 78 | 79 | print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list) 80 | -------------------------------------------------------------------------------- /p1ch4/5_image_dog.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "torch.set_printoptions(edgeitems=2, threshold=50)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "(720, 1280, 3)" 23 | ] 24 | }, 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "output_type": "execute_result" 28 | } 29 | ], 30 | "source": [ 31 | "import imageio\n", 32 | "\n", 33 | "img_arr = imageio.imread(\"../data/p1ch4/image-dog/bobby.jpg\")\n", 34 | "img_arr.shape" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "img = torch.from_numpy(img_arr)\n", 44 | "out = torch.transpose(img, 0, 2)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "batch_size = 100\n", 54 | "batch = torch.zeros(100, 3, 256, 256, dtype=torch.uint8)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 5, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "import os\n", 64 | "\n", 65 | "data_dir = \"../data/p1ch4/image-cats/\"\n", 66 | "filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name) == '.png']\n", 67 | "\n", 68 | "for i, filename in enumerate(filenames):\n", 69 | " img_arr = imageio.imread(filename)\n", 70 | " batch[i] = torch.transpose(torch.from_numpy(img_arr), 0, 2)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 6, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "batch = batch.float()\n", 80 | "batch /= 255.0" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 7, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "n_channels = batch.shape[1]\n", 90 | "\n", 91 | "for c in range(n_channels):\n", 92 | " mean = torch.mean(batch[:, c])\n", 93 | " std = torch.std(batch[:, c])\n", 94 | " batch[:, c] = (batch[:, c] - mean) / std" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.6.9" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | -------------------------------------------------------------------------------- /p2ch13/model_cls.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn 4 | 5 | from util.logconf import logging 6 | 7 | log = logging.getLogger(__name__) 8 | # log.setLevel(logging.WARN) 9 | # log.setLevel(logging.INFO) 10 | log.setLevel(logging.DEBUG) 11 | 12 | 13 | class LunaModel(nn.Module): 14 | def __init__(self, in_channels=1, conv_channels=8): 15 | super().__init__() 16 | 17 | self.tail_batchnorm = nn.BatchNorm3d(1) 18 | 19 | self.block1 = LunaBlock(in_channels, conv_channels) 20 | self.block2 = LunaBlock(conv_channels, conv_channels * 2) 21 | self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4) 22 | self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8) 23 | 24 | self.head_linear = nn.Linear(1152, 2) 25 | self.head_softmax = nn.Softmax(dim=1) 26 | 27 | self._init_weights() 28 | 29 | # see also https://github.com/pytorch/pytorch/issues/18182 30 | def _init_weights(self): 31 | for m in self.modules(): 32 | if type(m) in {nn.Linear, nn.Conv3d, nn.Conv2d, 33 | nn.ConvTranspose2d, nn.ConvTranspose3d}: 34 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', 35 | nonlinearity='relu') 36 | if m.bias is not None: 37 | _, fan_out = nn.init._calculate_fan_in_and_fan_out( 38 | m.weight.data) 39 | bound = 1 / math.sqrt(fan_out) 40 | nn.init.normal_(m.bias, -bound, bound) 41 | 42 | def forward(self, input_batch): 43 | bn_output = self.tail_batchnorm(input_batch) 44 | 45 | block_out = self.block1(bn_output) 46 | block_out = self.block2(block_out) 47 | block_out = self.block3(block_out) 48 | block_out = self.block4(block_out) 49 | 50 | conv_flat = block_out.view(block_out.size(0), -1) 51 | linear_output = self.head_linear(conv_flat) 52 | 53 | return linear_output, self.head_softmax(linear_output) 54 | 55 | 56 | class LunaBlock(nn.Module): 57 | def __init__(self, in_channels, conv_channels): 58 | super().__init__() 59 | 60 | self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, 61 | padding=1, bias=True) 62 | self.relu1 = nn.ReLU(inplace=True) 63 | self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, 64 | padding=1, bias=True) 65 | self.relu2 = nn.ReLU(inplace=True) 66 | 67 | self.maxpool = nn.MaxPool3d(2, 2) 68 | 69 | def forward(self, input_batch): 70 | block_out = self.conv1(input_batch) 71 | block_out = self.relu1(block_out) 72 | block_out = self.conv2(block_out) 73 | block_out = self.relu2(block_out) 74 | 75 | return self.maxpool(block_out) 76 | 77 | 78 | class AlternateLunaModel(LunaModel): 79 | def __init__(self, in_channels, conv_channels=64): 80 | super().__init__() 81 | 82 | self.block1 = LunaBlock(in_channels, conv_channels) 83 | self.block2 = LunaBlock(conv_channels, conv_channels // 2) 84 | self.block3 = LunaBlock(conv_channels // 2, conv_channels // 4) 85 | self.block4 = LunaBlock(conv_channels // 4, conv_channels // 8) 86 | 87 | self.head_linear = nn.Linear(144, 2) 88 | -------------------------------------------------------------------------------- /p2ch12/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('nbagg') 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from p2ch12.dsets import Ct, LunaDataset 8 | 9 | clim = (-1000.0, 300) 10 | 11 | 12 | def findMalignantSamples(start_ndx=0, limit=10): 13 | ds = LunaDataset(sortby_str='malignancy_size') 14 | 15 | malignantSample_list = [] 16 | for sample_tup in ds.noduleInfo_list: 17 | if sample_tup.isMalignant_bool: 18 | print(len(malignantSample_list), sample_tup) 19 | malignantSample_list.append(sample_tup) 20 | 21 | if len(malignantSample_list) >= limit: 22 | break 23 | 24 | return malignantSample_list 25 | 26 | 27 | def showNodule(series_uid, batch_ndx=None, **kwargs): 28 | ds = LunaDataset(series_uid=series_uid, **kwargs) 29 | malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) 30 | if x.isMalignant_bool]] 31 | 32 | if batch_ndx is None: 33 | if malignant_list: 34 | batch_ndx = malignant_list[0] 35 | else: 36 | print("Warining: no malignant samples found; using first " 37 | "non-malignant sample.") 38 | batch_ndx = 0 39 | 40 | ct = Ct(series_uid) 41 | ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx] 42 | ct_a = ct_t[0].numpy() 43 | 44 | fig = plt.figure(figsize=(30, 50)) 45 | 46 | group_list = [ 47 | [9, 11, 13], 48 | [15, 16, 17], 49 | [19, 21, 23] 50 | ] 51 | 52 | subplot = fig.add_subplot(len(group_list) + 2, 3, 1) 53 | subplot.set_title("index {}".format(int(center_irc.index)), fontsize=30) 54 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 55 | label.set_fontsize(20) 56 | plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray') 57 | 58 | subplot = fig.add_subplot(len(group_list) + 2, 3, 2) 59 | subplot.set_title("row {}".format(int(center_irc.row)), fontsize=30) 60 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 61 | label.set_fontsize(20) 62 | plt.imshow(ct.hu_a[:, int(center_irc.row)], clim=clim, cmap='gray') 63 | 64 | subplot = fig.add_subplot(len(group_list) + 2, 3, 3) 65 | subplot.set_title("col {}".format(int(center_irc.col)), fontsize=30) 66 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 67 | label.set_fontsize(20) 68 | plt.imshow(ct.hu_a[:, :, int(center_irc.col)], clim=clim, cmap='gray') 69 | 70 | subplot = fig.add_subplot(len(group_list) + 2, 3, 4) 71 | subplot.set_title("index {}".format(int(center_irc.index)), fontsize=30 72 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 73 | label.set_fontsize(20) 74 | plt.imshow(ct_a[ct_a.shape[0] // 2], clim=clim, cmap='gray') 75 | 76 | subplot = fig.add_subplot(len(group_list) + 2, 3, 5) 77 | subplot.set_title("row {}".format(int(center_irc.row)), fontsize=30) 78 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 79 | label.set_fontsize(20) 80 | plt.imshow(ct_a[:, ct_a.shape[1] // 2], clim=clim, cmap='gray') 81 | 82 | subplot = fig.add_subplot(len(group_list) + 2, 3, 6) 83 | subplot.set_title("col {}".format(int(center_irc.col)), fontsize=30) 84 | for label in (subplot.get_xticklabels() + subplot.get_yticklabels()): 85 | label.set_fontsize(20) 86 | plt.imshow(ct_a[:, :, ct_a.shape[2] // 2], clim=clim, cmap='gray') 87 | 88 | for row, index_list in enumerate(group_list): 89 | for col, index in enumerate(index_list): 90 | subplot = fig.add_subplot(len(group_list) + 2, 3, 91 | row * 3 + col + 7) 92 | subplot.set_title("slice {}".format(index), fontsize=30) 93 | for label in (subplot.get_xticklabels() + 94 | subplot.get_yticklabels()): 95 | label.set_fontsize(20) 96 | plt.imshow(ct_a[index], clim=clim, cmap='gray') 97 | 98 | print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list) 99 | -------------------------------------------------------------------------------- /util/disk.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | from diskcache import FanoutCache, Disk 4 | from diskcache.core import BytesType, MODE_BINARY, BytesIO 5 | 6 | from util.logconf import logging 7 | log = logging.getLogger(__name__) 8 | # log.setLevel(logging.WARN) 9 | log.setLevel(logging.INFO) 10 | # log.setLevel(logging.DEBUG) 11 | 12 | 13 | # diskcache.Disk objects are responsible for serializing and 14 | # deserializing data stored in the cache 15 | class GzipDisk(Disk): 16 | def store(self, value, read, key=None): 17 | """ 18 | Override from base class diskcache.Disk. 19 | 20 | Chunking is due to needing to work on pythons < 2.7.13: 21 | - Issue #27130: In the "zlib" module, fix handling of large 22 | buffers (typically 2 or 4 GiB). Previously, inputs were 23 | limited to 2 GiB, compression and decompression operations 24 | did not properly handle results of 2 or 4 GiB. 25 | 26 | :param value: value to convert 27 | :param bool read: True when value is file-like object 28 | :return: (size, mode, filename, value) tuple for Cache table 29 | """ 30 | # pylint: disable=unidiomatic-typecheck 31 | if type(value) is BytesType: 32 | if read: 33 | value = value.read() 34 | read = False 35 | 36 | str_io = BytesIO() 37 | gz_file = gzip.GzipFile(mode='wb', compresslevel=1, fileobj=str_io) 38 | 39 | for offset in range(0, len(value), 2 ** 30): 40 | # Chunking 41 | gz_file.write(value[offset:offset + 2 ** 30]) 42 | gz_file.close() 43 | 44 | value = str_io.getvalue() 45 | 46 | return super().store(value, read) 47 | 48 | def fetch(self, mode, filename, value, read): 49 | """ 50 | Override from base class diskcache.Disk. 51 | 52 | Chunking is due to needing to work on pythons < 2.7.13: 53 | - Issue #27130: In the "zlib" module, fix handling of large 54 | buffers (typically 2 or 4 GiB). Previously, inputs were 55 | limited to 2 GiB, compression and decompression operations 56 | did not properly handle results of 2 or 4 GiB. 57 | 58 | :param int mode: value mode raw, binary, text or pickle 59 | :param str filename: filename or corresponding value 60 | :param value: database value 61 | :param bool read: when True, return an open file handle 62 | :return: corresponding Python value 63 | """ 64 | value = super().fetch(mode, filename, value, read) 65 | 66 | if mode == MODE_BINARY: 67 | str_io = BytesIO(value) 68 | gz_file = gzip.GzipFile(mode='rb', fileobj=str_io) 69 | read_csio = BytesIO() 70 | 71 | while True: 72 | # Note: 2 ** 30 = 1 GB 73 | uncompressed_data = gz_file.read(2 ** 30) 74 | if uncompressed_data: 75 | read_csio.write(uncompressed_data) 76 | else: 77 | break 78 | 79 | value = read_csio.getvalue() 80 | 81 | return value 82 | 83 | 84 | def getCache(scope_str): 85 | # Built atop Cache is diskcache.FanoutCache which automatically 86 | # shards the underlying database. Sharding is the practice of 87 | # horizontally partitioning data. Here it is used to decrease 88 | # blocking writes. While readers and writers do not block each 89 | # other, writers block other writers. Therefore a shard for every 90 | # concurrent writer is suggested. This will depend on your scenario. 91 | # The default value is 8. 92 | # timeout sets a limit on how long to wait for database 93 | # transactions. 94 | # size_limit is used as the total size of the cache. The size limit 95 | # of individual cache shards is the total size divided by the number 96 | # of shards. 97 | return FanoutCache("data-unversioned/cache/" + scope_str, 98 | disk=GzipDisk, 99 | shards=64, 100 | timeout=1, 101 | size_limit=2e11) 102 | -------------------------------------------------------------------------------- /p2ch13/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | from p2ch13.dsets import Ct, LunaDataset 6 | 7 | # enables interactive figures in a live IPython notebook session 8 | matplotlib.use('nbagg') 9 | 10 | clim = (-1000.0, 300) 11 | 12 | 13 | def findMalignantSamples(limit=100): 14 | ds = LunaDataset(sortby_str='malignancy_size') 15 | 16 | malignantSample_list = [] 17 | for sample_tup in ds.noduleInfo_list: 18 | if sample_tup.isMalignant_bool: 19 | print(len(malignantSample_list), sample_tup) 20 | malignantSample_list.append(sample_tup) 21 | 22 | if len(malignantSample_list) >= limit: 23 | break 24 | 25 | return malignantSample_list 26 | 27 | 28 | def showNodule(series_uid, batch_ndx=None, **kwargs): 29 | ds = LunaDataset(series_uid=series_uid, **kwargs) 30 | malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) 31 | if x.isMalignant_bool] 32 | 33 | if batch_ndx is None: 34 | if malignant_list: 35 | batch_ndx = malignant_list[0] 36 | else: 37 | print("Warining: no malignant samples found; using first " 38 | "non-malignant sample.") 39 | batch_ndx = 0 40 | 41 | ct = Ct(series_uid) 42 | ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx] 43 | ct_a = ct_t[0].numpy() 44 | 45 | fig = plt.figure(figsize=(30, 50)) 46 | 47 | group_list = [ 48 | [9, 11, 13], 49 | [15, 16, 17], 50 | [19, 21, 23] 51 | ] 52 | 53 | subplot = fig.add_subplot(len(group_list) + 2, 3, 1) 54 | subplot.set_title("index {}".format(int(center_irc.index)), fontsize=30) 55 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 56 | label.set_fontsize(20) 57 | 58 | plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray') 59 | 60 | subplot = fig.add_subplot(len(group_list) + 2, 3, 2) 61 | subplot.set_title("row {}".format(int(center_irc.row)), fontsize=30) 62 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 63 | label.set_fontsize(20) 64 | 65 | plt.imshow(ct.hu_a[:, int(center_irc.row)], clim=clim, cmap='gray') 66 | plt.gca().invert_yaxis() 67 | 68 | subplot = fig.add_subplot(len(group_list) + 2, 3, 3) 69 | subplot.set_title("col {}".format(int(center_irc.col)), fontsize=30) 70 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 71 | label.set_fontsize(20) 72 | 73 | plt.imshow(ct.hu_a[:, :, int(center_irc.col)], clim=clim, cmap='gray') 74 | plt.gca().invert_yaxis() 75 | 76 | subplot = fig.add_subplot(len(group_list) + 2, 3, 4) 77 | subplot.set_title("index {}".format(int(center_irc.index)), fontsize=30) 78 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 79 | label.set_fontsize(20) 80 | 81 | plt.imshow(ct_a[ct_a.shape[0] // 2], clim=clim, cmap='gray') 82 | 83 | subplot = fig.add_subplot(len(group_list) + 2, 3, 5) 84 | subplot.set_title("row {}".format(int(center_irc.row)), fontsize=30) 85 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 86 | label.set_fontsize(20) 87 | 88 | plt.imshow(ct_a[:, ct_a.shape[1] // 2], clim=clim, cmap='gray') 89 | plt.gca().invert_yaxis() 90 | 91 | subplot = fig.add_subplot(len(group_list) + 2, 3, 6) 92 | subplot.set_title("col {}".format(int(center_irc.col)), fontsize=30) 93 | for label in subplot.get_xticklabels() + subplot.get_yticklabels(): 94 | label.set_fontsize(20) 95 | 96 | plt.imshow(ct_a[:, :, ct_a.shape[2] // 2], clim=clim, cmap='gray') 97 | plt.gca().invert_yaxis() 98 | 99 | for row, index_list in enumerate(group_list): 100 | for col, index in enumerate(index_list): 101 | subplot = fig.add_subplot(len(group_list) + 2, 3, 102 | row * 3 + col + 7) 103 | subplot.set_title("slice {}".format(index), fontsize=30) 104 | for label in (subplot.get_xticklabels() + 105 | subplot.get_yticklabels()): 106 | label.set_fontsize(20) 107 | plt.imshow(ct_a[index], clim=clim, cmap='gray') 108 | 109 | print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list) 110 | -------------------------------------------------------------------------------- /p1ch5/2_autograd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "torch.set_printoptions(edgeitems=2)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0])\n", 23 | "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4])\n", 24 | "t_un = 0.1 * t_u" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def model(t_u, w, b):\n", 34 | " return w * t_u + b" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def loss_fn(t_p, t_c):\n", 44 | " squared_diffs = (t_p - t_c) ** 2\n", 45 | " return squared_diffs.mean()" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "params = torch.tensor([1.0, 0.0], requires_grad=True)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 6, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "True" 66 | ] 67 | }, 68 | "execution_count": 6, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "params.grad is None" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "tensor([4517.2969, 82.6000])" 86 | ] 87 | }, 88 | "execution_count": 7, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "loss = loss_fn(model(t_u, *params), t_c)\n", 95 | "loss.backward()\n", 96 | "\n", 97 | "params.grad" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 8, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "if params.grad is not None:\n", 107 | " params.grad.zero_()" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 9, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "def training_loop(n_epochs, learning_rate, params, t_u, t_c):\n", 117 | " for epoch in range(1, n_epochs + 1):\n", 118 | "\n", 119 | " if params.grad is not None: # <1>\n", 120 | " params.grad.zero_()\n", 121 | " \n", 122 | " t_p = model(t_u, *params)\n", 123 | " loss = loss_fn(t_p, t_c)\n", 124 | " loss.backward()\n", 125 | " \n", 126 | " params = (params - learning_rate * params.grad).detach().requires_grad_()\n", 127 | " \n", 128 | " if epoch % 500 == 0:\n", 129 | " print(\"Epoch %d, Loss %f\" % (epoch, float(loss)))\n", 130 | " \n", 131 | " return params" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 10, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "Epoch 500, Loss 7.860116\n", 144 | "Epoch 1000, Loss 3.828538\n", 145 | "Epoch 1500, Loss 3.092191\n", 146 | "Epoch 2000, Loss 2.957697\n", 147 | "Epoch 2500, Loss 2.933134\n", 148 | "Epoch 3000, Loss 2.928648\n", 149 | "Epoch 3500, Loss 2.927830\n", 150 | "Epoch 4000, Loss 2.927679\n", 151 | "Epoch 4500, Loss 2.927652\n", 152 | "Epoch 5000, Loss 2.927647\n" 153 | ] 154 | }, 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "tensor([ 5.3671, -17.3012], requires_grad=True)" 159 | ] 160 | }, 161 | "execution_count": 10, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "training_loop(\n", 168 | " n_epochs=5000,\n", 169 | " learning_rate=1e-2,\n", 170 | " params=torch.tensor([1.0, 0.0], requires_grad=True), # <1>\n", 171 | " t_u=t_un, # <2>\n", 172 | " t_c=t_c)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.9" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /p1ch2/4_mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from torch import nn \n", 10 | "from torch import optim\n", 11 | "from torchvision import datasets, transforms\n", 12 | "import torch\n", 13 | "import torch.nn.functional as F" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "" 25 | ] 26 | }, 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "output_type": "execute_result" 30 | } 31 | ], 32 | "source": [ 33 | "torch.manual_seed(4242)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "train_loader = torch.utils.data.DataLoader(\n", 43 | " datasets.MNIST(\n", 44 | " \"../data/p1ch2/mnist\", train=True, download=True,\n", 45 | " transform=transforms.Compose([\n", 46 | " transforms.ToTensor(),\n", 47 | " transforms.Normalize((0.1307,), (0.3081,))\n", 48 | " ])\n", 49 | " ),\n", 50 | " batch_size=64,\n", 51 | " shuffle=True\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "class Net(nn.Module):\n", 62 | " def __init__(self):\n", 63 | " super().__init__()\n", 64 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 65 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 66 | " self.conv2_drop = nn.Dropout2d()\n", 67 | " self.fc1 = nn.Linear(320, 50)\n", 68 | " self.fc2 = nn.Linear(50, 10)\n", 69 | " \n", 70 | " def forward(self, x):\n", 71 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 72 | " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", 73 | " x = x.view(-1, 320)\n", 74 | " x = F.relu(self.fc1(x))\n", 75 | " x = F.dropout(x, training=self.training)\n", 76 | " x = self.fc2(x)\n", 77 | " return F.log_softmax(x, dim=1)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "model = Net()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 7, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Current loss 0.43217700719833374\n", 108 | "Current loss 0.24119044840335846\n", 109 | "Current loss 0.37073448300361633\n", 110 | "Current loss 0.37481623888015747\n", 111 | "Current loss 0.28295284509658813\n", 112 | "Current loss 0.16025568544864655\n", 113 | "Current loss 0.09627309441566467\n", 114 | "Current loss 0.1352904587984085\n", 115 | "Current loss 0.3498668670654297\n", 116 | "Current loss 0.23658804595470428\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "for epoch in range(10):\n", 122 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 123 | " optimizer.zero_grad()\n", 124 | " output = model(data)\n", 125 | " loss = F.nll_loss(output, target)\n", 126 | " loss.backward()\n", 127 | " optimizer.step()\n", 128 | " print(\"Current loss\", float(loss))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 8, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "torch.save(model.state_dict(), \"../data/p1ch2/mnist/mnist.pth\")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 9, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "" 149 | ] 150 | }, 151 | "execution_count": 9, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "pretrained_model = Net()\n", 158 | "pretrained_model.load_state_dict(torch.load(\"../data/p1ch2/mnist/mnist.pth\"))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.6.7" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /p1ch4/7_video_cockatoo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Video\n", 8 | "====" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "torch.set_printoptions(edgeitems=2, threshold=50)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "When it comes to the shape of tensors, video data can be seen as equivalent to volumetric data, with `depth` replaced by the `time` dimension. The result is again a 5D tensor with shape `N x C x T x H x W`.\n", 27 | "\n", 28 | "There are several formats for video, especially geared towards compression by exploiting redundancies in space and time. Luckily for us, `imageio` reads video data as well. Suppose we'd like to retain 100 consecutive frames in our 512 x 512 RBG video for classifying an action using a convolutional neural network. We first create a reader instance for the video, that will allow us to get information about the video and iterate over the frames in time.\n", 29 | "Let's see what the meta data for the video looks like:" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "{'plugin': 'ffmpeg',\n", 41 | " 'nframes': 280,\n", 42 | " 'ffmpeg_version': '4.1-tessus https://evermeet.cx/ffmpeg/ built with Apple LLVM version 10.0.0 (clang-1000.11.45.5)',\n", 43 | " 'codec': 'h264',\n", 44 | " 'pix_fmt': 'yuv444p',\n", 45 | " 'fps': 20.0,\n", 46 | " 'source_size': (1280, 720),\n", 47 | " 'size': (1280, 720),\n", 48 | " 'duration': 14.0}" 49 | ] 50 | }, 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "import imageio\n", 58 | "\n", 59 | "reader = imageio.get_reader(\"../data/p1ch4/video-cockatoo/cockatoo.mp4\")\n", 60 | "meta = reader.get_meta_data()\n", 61 | "meta['nframes'] = reader.count_frames()\n", 62 | "meta" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "We now have all the information to size the tensor that will store the video frames:" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "torch.Size([3, 280, 1280, 720])" 81 | ] 82 | }, 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "n_channels = 3\n", 90 | "n_frames = meta['nframes']\n", 91 | "video = torch.FloatTensor(n_channels, n_frames, *meta['size'])\n", 92 | "\n", 93 | "video.shape" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "Now we just iterate over the reader and set the values for all three channels into in the proper `i`-th time slice.\n", 101 | "This might take a few seconds to finish!" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "for i, frame_arr in enumerate(reader):\n", 111 | " frame = torch.from_numpy(frame_arr).float()\n", 112 | " video[:, i] = torch.transpose(frame, 0, 2)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "In the above, we iterate over individual frames and set each frame in the `C x T x H x W` video tensor, after transposing the channel. We can then obtain a batch by stacking multiple 4D tensors or pre-allocating a 5D tensor with a known batch size and filling it iteratively, clip by clip, assuming clips are trimmed to a fixed number of frames.\n", 120 | "\n", 121 | "Equating video data to volumetric data is not the only way to represent video for training purposes. This is a valid strategy if we deal with video bursts of fixed length. An alternative strategy is to resort to network architectures capable of processing long sequences and exploiting short and long-term relationships in time, just like for text or audio. We'll see this kind of architectures when we take on recurrent networks.\n", 122 | "\n", 123 | "This next approach accounts for time along the batch dimension. Hence, we'll build our dataset as a 4D tensor, stacking frame by frame in the batch:" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "torch.Size([280, 3, 1280, 720])" 135 | ] 136 | }, 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "time_video = torch.FloatTensor(n_frames, n_channels, *meta['size'])\n", 144 | "\n", 145 | "for i, frame in enumerate(reader):\n", 146 | " frame = torch.from_numpy(frame).float()\n", 147 | " time_video[i] = torch.transpose(frame, 0, 2)\n", 148 | " \n", 149 | "time_video.shape" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.6.9" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 2 181 | } 182 | -------------------------------------------------------------------------------- /p2ch08/dsets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import functools 4 | import glob 5 | import os 6 | import random 7 | 8 | import SimpleITK as sitk 9 | 10 | import numpy as np 11 | import torch 12 | import torch.cuda 13 | from torch.utils.data import Dataset 14 | 15 | from util.disk import getCache 16 | from util.util import XyzTuple, xyz2irc 17 | from util.logconf import logging 18 | 19 | log = logging.getLogger(__name__) 20 | log.setLevel(logging.INFO) 21 | 22 | raw_cache = getCache('part2ch8_raw') 23 | 24 | 25 | @functools.lru_cache(1) 26 | def getNoduleInfoList(requireDataOnDisk_bool=True): 27 | # We construct a set with all series_uids that are present on disk. 28 | # This will let us use the data, even if we haven't downloaded all 29 | # of the subsets yet. 30 | mhd_list = glob.glob("data-unversioned/part2/luna/subset*/*.mhd") 31 | dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} 32 | 33 | diameter_dict = {} 34 | with open("data/part2/luna/annotations.csv", 'r') as f: 35 | for row in list(csv.reader(f))[1:]: 36 | series_uid = row[0] 37 | annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) 38 | annotationDiameter_mm = float(row[4]) 39 | 40 | diameter_dict.setdefault(series_uid, []).append( 41 | (annotationCenter_xyz, annotationDiameter_mm)) 42 | 43 | noduleInfo_list = [] 44 | with open("data/part2/luna/candidates.csv", 'r') as f: 45 | for row in list(csv.reader(f))[1:]: 46 | series_uid = row[0] 47 | 48 | if series_uid not in dataPresentOnDisk_set and \ 49 | requireDataOnDisk_bool: 50 | continue 51 | 52 | isMalignant_bool = bool(int(row[4])) 53 | candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) 54 | 55 | candidateDiameter_mm = 0.0 56 | for annotationDiameter_mm, annotationDiameter_mm in \ 57 | diameter_dict.get(series_uid, []): 58 | for i in range(3): 59 | delta_mm = abs(candidateCenter_xyz[i] - 60 | annotationCenter_xyz[i]) 61 | if delta_mm > annotationDiameter_mm / 4: 62 | break 63 | else: 64 | candidateDiameter_mm = annotationDiameter_mm 65 | break 66 | 67 | noduleInfo_list.append((isMalignant_bool, 68 | candidateDiameter_mm, 69 | series_uid, 70 | candidateCenter_xyz)) 71 | noduleInfo_list.sort(reverse=True) 72 | 73 | return noduleInfo_list 74 | 75 | 76 | class Ct: 77 | def __init__(self, series_uid): 78 | mhd_path = glob.glob("data-unversioned/part2/luna/subset*/{}.mhd"\ 79 | .format(series_uid))[0] 80 | 81 | ct_mhd = sitk.ReadImage(mhd_path) 82 | ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) 83 | 84 | # CTs are natively expressed in 85 | # https://en.wikipedia.org/wiki/Hounsfield_scale HU are scaled 86 | # oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc 87 | # (water) being 0. 88 | # This converts HU to g/cc 89 | ct_ary += 1000 90 | ct_ary /= 1000 91 | 92 | # This gets rid of negative density stuff used to indicate 93 | # out-of-FOV 94 | ct_ary[ct_ary < 0] = 0 95 | 96 | # This nukes any weird hotspots and clamps bone down 97 | ct_ary[ct_ary > 2] = 2 98 | 99 | self.series_uid = series_uid 100 | self.ary = ct_ary 101 | 102 | self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) 103 | self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) 104 | self.direction_tup = tuple(int(round(x)) for x in 105 | ct_mhd.GetDirection()) 106 | 107 | def getRawNodule(self, center_xyz, width_irc): 108 | center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, 109 | self.direction_tup) 110 | 111 | slice_list = [] 112 | for axis, center_val in enumerate(center_irc): 113 | start_ndx = int(round(center_val - width_irc[axis] / 2)) 114 | end_ndx = int(start_ndx + width_irc[axis]) 115 | 116 | assert center_val >= 0 and center_val < self.ary.shape[axis], \ 117 | repr([self.series_uid, center_xyz, self.origin_xyz, 118 | self.vxSize_xyz, center_xyz, axis]) 119 | 120 | if start_ndx < 0: 121 | start_ndx = 0 122 | end_ndx = int(width_irc[axis]) 123 | 124 | if end_ndx > self.ary.shape[axis]: 125 | end_ndx = self.ary.shape[axis] 126 | start_ndx = int(self.ary.shape[axis] - width_irc[axis]) 127 | 128 | slice_list.append(slice(start_ndx, end_ndx)) 129 | 130 | ct_chunk = self.ary[tuple(slice_list)] 131 | 132 | return ct_chunk, center_irc 133 | 134 | 135 | @functools.lru_cache(1, typed=True) 136 | def getCt(series_uid): return Ct(series_uid) 137 | 138 | @raw_cache.memoize(typed=True) 139 | def getCtRawNodule(series_uid, center_xyz, width_irc): 140 | ct = getCt(series_uid) 141 | ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc) 142 | 143 | return ct_chunk, center_irc 144 | 145 | 146 | class LunaDataset(Dataset): 147 | def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None, 148 | sortby_str='random'): 149 | self.noduleInfo_list = copy.copy(getNoduleInfoList()) 150 | 151 | if series_uid: 152 | self.noduleInfo_list = [x for x in self.noduleInfo_list 153 | if x[2] == series_uid] 154 | 155 | # __init__ continued... 156 | if test_stride > 1: 157 | if isTestSet_bool: 158 | self.noduleInfo_list = self.noduleInfo_list[::test_stride] 159 | else: 160 | del self.noduleInfo_list[::test_stride] 161 | 162 | log.info("{!r}: {} {} samples".format( 163 | self, 164 | len(self.noduleInfo_list), 165 | "testing" if isTestSet_bool else "training" 166 | )) 167 | 168 | def __len__(self): return len(self.noduleInfo_list) 169 | 170 | def __getitem__(self, ndx): 171 | sample_ndx = ndx 172 | 173 | isMalignant_bool, _diameter_mm, series_uid, center_xyz = \ 174 | self.noduleInfo_list[sample_ndx] 175 | 176 | nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, 177 | (32, 32, 32)) 178 | 179 | nodule_tensor = torch.from_numpy(nodule_ary) 180 | nodule_tensor = nodule_tensor.unsqueeze(0) 181 | 182 | malignant_tensor = torch.tensor([isMalignant_bool], 183 | dtype=torch.float32) 184 | 185 | return nodule_tensor, malignant_tensor, series_uid, center_irc 186 | -------------------------------------------------------------------------------- /p1ch6/3_nn_module_subclassing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "from torch import optim\n", 14 | "from torch import nn\n", 15 | "\n", 16 | "torch.set_printoptions(edgeitems=2)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "Sequential(\n", 28 | " (0): Linear(in_features=1, out_features=11, bias=True)\n", 29 | " (1): Tanh()\n", 30 | " (2): Linear(in_features=11, out_features=1, bias=True)\n", 31 | ")" 32 | ] 33 | }, 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "output_type": "execute_result" 37 | } 38 | ], 39 | "source": [ 40 | "seq_model = nn.Sequential(\n", 41 | " nn.Linear(1, 11), # <1>\n", 42 | " nn.Tanh(),\n", 43 | " nn.Linear(11, 1)) # <2>\n", 44 | "\n", 45 | "seq_model" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "Sequential(\n", 57 | " (hidden_linear): Linear(in_features=1, out_features=12, bias=True)\n", 58 | " (hidden_activation): Tanh()\n", 59 | " (output_linear): Linear(in_features=12, out_features=1, bias=True)\n", 60 | ")" 61 | ] 62 | }, 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "from collections import OrderedDict\n", 70 | "\n", 71 | "namedseq_model = nn.Sequential(OrderedDict([\n", 72 | " ('hidden_linear', nn.Linear(1, 12)),\n", 73 | " ('hidden_activation', nn.Tanh()),\n", 74 | " ('output_linear', nn.Linear(12, 1))\n", 75 | "]))\n", 76 | "\n", 77 | "namedseq_model" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "SubclassModel(\n", 89 | " (hidden_linear): Linear(in_features=1, out_features=13, bias=True)\n", 90 | " (hidden_activation): Tanh()\n", 91 | " (output_linear): Linear(in_features=13, out_features=1, bias=True)\n", 92 | ")" 93 | ] 94 | }, 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "class SubclassModel(nn.Module):\n", 102 | " def __init__(self):\n", 103 | " super().__init__()\n", 104 | " \n", 105 | " self.hidden_linear = nn.Linear(1, 13)\n", 106 | " self.hidden_activation = nn.Tanh()\n", 107 | " self.output_linear = nn.Linear(13, 1)\n", 108 | " \n", 109 | " def forward(self, input):\n", 110 | " hidden_t = self.hidden_linear(input)\n", 111 | " activated_t = self.hidden_activation(hidden_t)\n", 112 | " output_t = self.output_linear(activated_t)\n", 113 | " \n", 114 | " return output_t\n", 115 | " \n", 116 | "subclass_model = SubclassModel()\n", 117 | "subclass_model" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "seq\n", 130 | "0.weight torch.Size([11, 1]) 11\n", 131 | "0.bias torch.Size([11]) 11\n", 132 | "2.weight torch.Size([1, 11]) 11\n", 133 | "2.bias torch.Size([1]) 1\n", 134 | "\n", 135 | "namedseq\n", 136 | "hidden_linear.weight torch.Size([12, 1]) 12\n", 137 | "hidden_linear.bias torch.Size([12]) 12\n", 138 | "output_linear.weight torch.Size([1, 12]) 12\n", 139 | "output_linear.bias torch.Size([1]) 1\n", 140 | "\n", 141 | "subclass\n", 142 | "hidden_linear.weight torch.Size([13, 1]) 13\n", 143 | "hidden_linear.bias torch.Size([13]) 13\n", 144 | "output_linear.weight torch.Size([1, 13]) 13\n", 145 | "output_linear.bias torch.Size([1]) 1\n", 146 | "\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "for type_str, model in [('seq', seq_model), ('namedseq', namedseq_model), ('subclass', subclass_model)]:\n", 152 | " print(type_str)\n", 153 | " \n", 154 | " for name_str, param in model.named_parameters():\n", 155 | " print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel()))\n", 156 | " \n", 157 | " print()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 6, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "SubclassFunctionalModel(\n", 169 | " (hidden_linear): Linear(in_features=1, out_features=14, bias=True)\n", 170 | " (output_linear): Linear(in_features=14, out_features=1, bias=True)\n", 171 | ")" 172 | ] 173 | }, 174 | "execution_count": 6, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "class SubclassFunctionalModel(nn.Module):\n", 181 | " def __init__(self):\n", 182 | " super().__init__()\n", 183 | " \n", 184 | " self.hidden_linear = nn.Linear(1, 14)\n", 185 | " # <1> \n", 186 | " self.output_linear = nn.Linear(14, 1)\n", 187 | " \n", 188 | " def forward(self, input):\n", 189 | " hidden_t = self.hidden_linear(input)\n", 190 | " activated_t = torch.tanh(hidden_t) # <2>\n", 191 | " output_t = self.output_linear(activated_t)\n", 192 | " \n", 193 | " return output_t\n", 194 | " \n", 195 | "func_model = SubclassFunctionalModel()\n", 196 | "func_model " 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.6.9" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 2 228 | } 229 | -------------------------------------------------------------------------------- /p2ch07/dsets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import functools 4 | import glob 5 | import os 6 | import random 7 | 8 | import SimpleITK as sitk 9 | 10 | import numpy as np 11 | import torch 12 | import torch.cuda 13 | from torch.utils.data import Dataset 14 | 15 | from util.disk import getCache 16 | from util.util import XyzTuple, xyz2irc 17 | from util.logconf import logging 18 | 19 | log = logging.getLogger(__name__) 20 | log.setLevel(logging.INFO) 21 | 22 | raw_cache = getCache('part2ch8_raw') 23 | 24 | 25 | @functools.lru_cache(1) 26 | def getNoduleInfoList(requireDataOnDisk_bool=True): 27 | # We construct a set with all series_uids that are present on disk. 28 | # This will let us use the data, even if we haven't downloaded all 29 | # of the subsets yet. 30 | mhd_list = glob.glob("data-unversioned/part2/luna/subset*/*.mhd") 31 | dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} 32 | 33 | diameter_dict = {} 34 | with open("data/part2/luna/annotations.csv", 'r') as f: 35 | for row in list(csv.reader(f))[1:]: 36 | series_uid = row[0] 37 | annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) 38 | annotationDiameter_mm = float(row[4]) 39 | 40 | diameter_dict.setdefault(series_uid, []).append( 41 | (annotationCenter_xyz, annotationDiameter_mm)) 42 | 43 | noduleInfo_list = [] 44 | with open("data/part2/luna/candidates.csv", 'r') as f: 45 | for row in list(csv.reader(f))[1:]: 46 | series_uid = row[0] 47 | 48 | if series_uid not in dataPresentOnDisk_set and \ 49 | requireDataOnDisk_bool: 50 | continue 51 | 52 | isMalignant_bool = bool(int(row[4])) 53 | candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) 54 | 55 | candidateDiameter_mm = 0.0 56 | for annotationDiameter_mm, annotationDiameter_mm in \ 57 | diameter_dict.get(series_uid, []): 58 | for i in range(3): 59 | delta_mm = abs(candidateCenter_xyz[i] - 60 | annotationCenter_xyz[i]) 61 | if delta_mm > annotationDiameter_mm / 4: 62 | break 63 | else: 64 | candidateDiameter_mm = annotationDiameter_mm 65 | break 66 | 67 | noduleInfo_list.append((isMalignant_bool, 68 | candidateDiameter_mm, 69 | series_uid, 70 | candidateCenter_xyz)) 71 | noduleInfo_list.sort(reverse=True) 72 | 73 | return noduleInfo_list 74 | 75 | 76 | class Ct: 77 | def __init__(self, series_uid): 78 | mhd_path = glob.glob("data-unversioned/part2/luna/subset*/{}.mhd"\ 79 | .format(series_uid))[0] 80 | 81 | ct_mhd = sitk.ReadImage(mhd_path) 82 | ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) 83 | 84 | # CTs are natively expressed in 85 | # https://en.wikipedia.org/wiki/Hounsfield_scale HU are scaled 86 | # oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc 87 | # (water) being 0. 88 | # This converts HU to g/cc 89 | ct_ary += 1000 90 | ct_ary /= 1000 91 | 92 | # This gets rid of negative density stuff used to indicate 93 | # out-of-FOV 94 | ct_ary[ct_ary < 0] = 0 95 | 96 | # This nukes any weird hotspots and clamps bone down 97 | ct_ary[ct_ary > 2] = 2 98 | 99 | self.series_uid = series_uid 100 | self.ary = ct_ary 101 | 102 | self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) 103 | self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) 104 | self.direction_tup = tuple(int(round(x)) for x in 105 | ct_mhd.GetDirection()) 106 | 107 | def getRawNodule(self, center_xyz, width_irc): 108 | center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, 109 | self.direction_tup) 110 | 111 | slice_list = [] 112 | for axis, center_val in enumerate(center_irc): 113 | start_ndx = int(round(center_val - width_irc[axis] / 2)) 114 | end_idx = int(start_ndx + width_irc[axis]) 115 | 116 | assert center_val >= 0 and center_val < self.ary.shape[axis], \ 117 | repr([self.series_uid, center_xyz, self.origin_xyz, 118 | self.vxSize_xyz, center_xyz, axis]) 119 | 120 | if start_ndx < 0: 121 | start_ndx = 0 122 | end_ndx = int(width_irc[axis]) 123 | 124 | if end_idx > self.ary.shape[axis]: 125 | end_idx = self.ary.shape[axis] 126 | start_ndx = int(self.ary.shape[axis] - width_irc[axis]) 127 | 128 | slice_list.append(slice(start_ndx, end_idx)) 129 | 130 | ct_chunk = self.ary[tuple(slice_list)] 131 | 132 | return ct_chunk, center_irc 133 | 134 | 135 | @functools.lru_cache(1, typed=True) 136 | def getCt(series_uid): return Ct(series_uid) 137 | 138 | @raw_cache.memoize(typed=True) 139 | def getCtRawNodule(series_uid, center_xyz, width_irc): 140 | ct = getCt(series_uid) 141 | ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc) 142 | 143 | return ct_chunk, center_irc 144 | 145 | 146 | class LunaDataset(Dataset): 147 | def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None): 148 | self.noduleInfo_list = copy.copy(getNoduleInfoList()) 149 | 150 | if series_uid: 151 | self.noduleInfo_list = [x for x in self.noduleInfo_list 152 | if x[2] == series_uid] 153 | 154 | # __init__ continued... 155 | if test_stride > 1: 156 | if isTestSet_bool: 157 | self.noduleInfo_list = self.noduleInfo_list[::test_stride] 158 | else: 159 | del self.noduleInfo_list[::test_stride] 160 | 161 | log.info("{!r}: {} {} samples".format( 162 | self, 163 | len(self.noduleInfo_list), 164 | "testing" if isTestSet_bool else "training" 165 | )) 166 | 167 | def __len__(self): return len(self.noduleInfo_list) 168 | 169 | def __getitem__(self, ndx): 170 | sample_ndx = ndx 171 | 172 | isMalignant_bool, diameter_mm, series_uid, center_xyz = \ 173 | self.noduleInfo_list[sample_ndx] 174 | 175 | nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, 176 | (32, 32, 32)) 177 | 178 | nodule_tensor = torch.from_numpy(nodule_ary) 179 | nodule_tensor = nodule_tensor.unsqueeze(0) 180 | 181 | malignant_tensor = torch.tensor([isMalignant_bool], 182 | dtype=torch.float32) 183 | 184 | return nodule_tensor, malignant_tensor, series_uid, center_irc 185 | -------------------------------------------------------------------------------- /p2ch09/dsets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import functools 4 | import glob 5 | import os 6 | import random 7 | 8 | from collections import namedtuple 9 | 10 | import SimpleITK as sitk 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda 15 | from torch.utils.data import Dataset 16 | 17 | from util.disk import getCache 18 | from util.util import XyzTuple, xyz2irc 19 | from util.logconf import logging 20 | 21 | log = logging.getLogger(__name__) 22 | log.setLevel(logging.INFO) 23 | 24 | raw_cache = getCache('part2ch9_raw') 25 | 26 | NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool', 27 | 'diameter_mm', 'series_uid', 'center_xyz') 28 | 29 | @functools.lru_cache(1) 30 | def getNoduleInfoList(requireDataOnDisk_bool=True): 31 | # We construct a set with all series_uids that are present on disk. 32 | # This will let us use the data, even if we haven't downloaded all 33 | # of the subsets yet. 34 | mhd_list = glob.glob("data-unversioned/part2/luna/subset*/*.mhd") 35 | dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} 36 | 37 | diameter_dict = {} 38 | with open("data/part2/luna/annotations.csv", 'r') as f: 39 | for row in list(csv.reader(f))[1:]: 40 | series_uid = row[0] 41 | annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) 42 | annotationDiameter_mm = float(row[4]) 43 | 44 | diameter_dict.setdefault(series_uid, []).append( 45 | (annotationCenter_xyz, annotationDiameter_mm)) 46 | 47 | noduleInfo_list = [] 48 | with open("data/part2/luna/candidates.csv", 'r') as f: 49 | for row in list(csv.reader(f))[1:]: 50 | series_uid = row[0] 51 | 52 | if series_uid not in dataPresentOnDisk_set and \ 53 | requireDataOnDisk_bool: 54 | continue 55 | 56 | isMalignant_bool = bool(int(row[4])) 57 | candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) 58 | 59 | candidateDiameter_mm = 0.0 60 | for annotationDiameter_mm, annotationDiameter_mm in \ 61 | diameter_dict.get(series_uid, []): 62 | for i in range(3): 63 | delta_mm = abs(candidateCenter_xyz[i] - 64 | annotationCenter_xyz[i]) 65 | if delta_mm > annotationDiameter_mm / 4: 66 | break 67 | else: 68 | candidateDiameter_mm = annotationDiameter_mm 69 | break 70 | 71 | noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, 72 | candidateDiameter_mm, 73 | series_uid, 74 | candidateCenter_xyz)) 75 | noduleInfo_list.sort(reverse=True) 76 | 77 | return noduleInfo_list 78 | 79 | 80 | class Ct: 81 | def __init__(self, series_uid): 82 | mhd_path = glob.glob("data-unversioned/part2/luna/subset*/{}.mhd"\ 83 | .format(series_uid))[0] 84 | 85 | ct_mhd = sitk.ReadImage(mhd_path) 86 | ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) 87 | 88 | # CTs are natively expressed in 89 | # https://en.wikipedia.org/wiki/Hounsfield_scale HU are scaled 90 | # oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc 91 | # (water) being 0. 92 | # This gets rid of negative density stuff used to indicate 93 | # out-of-FOV 94 | ct_a[ct_a < -1000] = -1000 95 | 96 | # This nukes any weird hotspots and clamps bone down 97 | ct_a[ct_a > 1000] = 1000 98 | 99 | self.series_uid = series_uid 100 | self.hu_a = ct_a 101 | 102 | self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) 103 | self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) 104 | self.direction_tup = tuple(int(round(x)) for x in 105 | ct_mhd.GetDirection()) 106 | 107 | def getRawNodule(self, center_xyz, width_irc): 108 | center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, 109 | self.direction_tup) 110 | 111 | slice_list = [] 112 | for axis, center_val in enumerate(center_irc): 113 | start_ndx = int(round(center_val - width_irc[axis] / 2)) 114 | end_ndx = int(start_ndx + width_irc[axis]) 115 | 116 | assert center_val >= 0 and center_val < self.hu_a.shape[axis], \ 117 | repr([self.series_uid, center_xyz, self.origin_xyz, 118 | self.vxSize_xyz, center_xyz, axis]) 119 | 120 | if start_ndx < 0: 121 | start_ndx = 0 122 | end_ndx = int(width_irc[axis]) 123 | 124 | if end_ndx > self.hu_a.shape[axis]: 125 | end_ndx = self.hu_a.shape[axis] 126 | start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) 127 | 128 | slice_list.append(slice(start_ndx, end_ndx)) 129 | 130 | ct_chunk = self.hu_a[tuple(slice_list)] 131 | 132 | return ct_chunk, center_irc 133 | 134 | 135 | @functools.lru_cache(1, typed=True) 136 | def getCt(series_uid): return Ct(series_uid) 137 | 138 | @raw_cache.memoize(typed=True) 139 | def getCtRawNodule(series_uid, center_xyz, width_irc): 140 | ct = getCt(series_uid) 141 | ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc) 142 | 143 | return ct_chunk, center_irc 144 | 145 | 146 | class LunaDataset(Dataset): 147 | def __init__(self, val_stride=0, isValSet_bool=None, series_uid=None): 148 | self.noduleInfo_list = copy.copy(getNoduleInfoList()) 149 | 150 | if series_uid: 151 | self.noduleInfo_list = [x for x in self.noduleInfo_list 152 | if x[2] == series_uid] 153 | 154 | if val_stride > 1: 155 | if isValSet_bool: 156 | self.noduleInfo_list = self.noduleInfo_list[::val_stride] 157 | else: 158 | del self.noduleInfo_list[::val_stride] 159 | 160 | log.info("{!r}: {} {} samples".format( 161 | self, 162 | len(self.noduleInfo_list), 163 | "validation" if isValSet_bool else "training" 164 | )) 165 | 166 | def __len__(self): return len(self.noduleInfo_list) 167 | 168 | def __getitem__(self, ndx): 169 | nodule_tup = self.noduleInfo_list[ndx] 170 | width_irc = (24, 48, 48) 171 | 172 | nodule_a, center_irc = getCtRawNodule( 173 | nodule_tup.series_uid, 174 | nodule_tup.center_xyz, 175 | width_irc 176 | ) 177 | 178 | nodule_t = torch.from_numpy(nodule_a) 179 | nodule_t = nodule_t.to(torch.float32) 180 | nodule_t = nodule_t.unsqueeze(0) 181 | 182 | cls_t = torch.tensor([ 183 | not nodule_tup.isMalignant_bool, 184 | nodule_tup.isMalignant_bool 185 | ]) 186 | 187 | return nodule_t, cls_t, nodule_tup.series_uid, center_irc 188 | -------------------------------------------------------------------------------- /util/unet.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/jvanvugt/pytorch-unet 2 | # https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Joris 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # Adapted from https://discuss.pytorch.org/t/unet-implementation/426 27 | 28 | import torch 29 | from torch import nn 30 | import torch.nn.functional as F 31 | 32 | 33 | class UNet(nn.Module): 34 | def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, 35 | padding=False, batch_norm=False, up_mode='upconv'): 36 | """ 37 | Implementation of 38 | U-Net: Convolutional Networks for Biomedical Image Segmentation 39 | (Ronneberger et al., 2015) 40 | 41 | Using the default arguments will yield the exact version used 42 | in the original paper 43 | 44 | Args: 45 | in_channels (int): number of input channels 46 | n_classes (int): number of output channels 47 | depth (int): depth of the network 48 | wf (int): number of filters in the first layer is 2 ** wf 49 | padding (bool): if True, apply padding such that the input shape 50 | is the same as the output. 51 | This may introduce artifacts 52 | batch_norm (bool): Use BatchNorm after layers with an 53 | activation function 54 | up_mode (str): one of 'upconv' or 'upsample'. 55 | 'upconv' will use transposed convolutions for 56 | learned upsampling. 57 | 'upsample' will use bilinear upsampling. 58 | """ 59 | super().__init__() 60 | assert up_mode in ('upconv', 'upsample') 61 | self.padding = padding 62 | self.depth = depth 63 | prev_channels = in_channels 64 | # nn.ModuleList is like a regular Python list, the difference 65 | # is that pytorch is "aware" of the existence of the nn.Module's 66 | # inside the nn.ModuleList, which is not the case for Python 67 | # lists. 68 | # nn.ModuleList, differently from nn.Sequential, doesn't have a 69 | # forward() method and is not needed that the output size of a 70 | # block matches the input size of the following block. 71 | self.down_path = nn.ModuleList() 72 | for i in range(depth): 73 | self.down_path.append(UNetConvBlock(prev_channels, 2 ** (wf + i), 74 | padding, batch_norm)) 75 | prev_channels = 2 ** (wf + i) 76 | 77 | self.up_path = nn.ModuleList() 78 | for i in reversed(range(depth - 1)): 79 | self.up_path.append(UNetUpBlock(prev_channels, 2 ** (wf + i), 80 | up_mode, padding, batch_norm)) 81 | prev_channels = 2 ** (wf + 1) 82 | 83 | self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1) 84 | 85 | def forward(self, x): 86 | blocks = [] 87 | for i, down in enumerate(self.down_path): 88 | x = down(x) 89 | if i != len(self.down_path) - 1: 90 | blocks.append(x) 91 | x = F.avg_pool2d(x, 2) 92 | 93 | for i, up in enumerate(self.up_path): 94 | # blocks[-i - 1] is concatenation along the channel 95 | # dimension to the "up" intermediate output 96 | x = up(x, blocks[-i - 1]) 97 | 98 | return self.last(x) 99 | 100 | 101 | class UNetConvBlock(nn.Module): 102 | def __init__(self, in_size, out_size, padding, batch_norm): 103 | super().__init__() 104 | block = [] 105 | 106 | block.append(nn.Conv2d(in_size, out_size, kernel_size=3, 107 | padding=int(padding))) 108 | block.append(nn.ReLU()) 109 | 110 | if batch_norm: 111 | block.append(nn.BatchNorm2d(out_size)) 112 | 113 | block.append(nn.Conv2d(out_size, out_size, kernel_size=3, 114 | padding=int(padding))) 115 | block.append(nn.ReLU()) 116 | 117 | if batch_norm: 118 | block.append(nn.BatchNorm2d(out_size)) 119 | 120 | self.block = nn.Sequential(*block) 121 | 122 | def forward(self, x): 123 | out = self.block(x) 124 | return out 125 | 126 | 127 | class UNetUpBlock(nn.Module): 128 | def __init__(self, in_size, out_size, up_mode, padding, batch_norm): 129 | super().__init__() 130 | 131 | if up_mode == 'upconv': 132 | # See for example: https://medium.com/apache-mxnet/transposed-convolutions-explained-with-ms-excel-52d13030c7e8 133 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, 134 | stride=2) 135 | elif up_mode == 'upsample': 136 | self.up = nn.Sequential( 137 | nn.Upsample(mode='bilinear', scale_factor=2), 138 | nn.Conv2d(in_size, out_size, kernel_size=1) 139 | ) 140 | 141 | self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm) 142 | 143 | def center_crop(self, layer, target_size): 144 | # Note: target_size is defined as (height, width) 145 | 146 | # Note: .size() is equivalent to .shape. In this case, it return 147 | # batch_size, channels, height, width. 148 | _, _, layer_height, layer_width = layer.size() 149 | diff_y = (layer_height - target_size[0]) // 2 150 | diff_x = (layer_width - target_size[1]) // 2 151 | 152 | return layer[:, :, diff_y:diff_y + target_size[0], 153 | diff_x:diff_x + target_size[1]] 154 | 155 | def forward(self, x, bridge): 156 | up = self.up(x) 157 | crop1 = self.center_crop(bridge, up.shape[2:]) 158 | # Concatenation along the channel dimension. This "skip 159 | # connection" combines the output of the "down" fine detail 160 | # layer and "up" wide receptive field layers. 161 | out = torch.cat([up, crop1], 1) 162 | out = self.conv_block(out) 163 | 164 | return out 165 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import datetime 4 | import gc 5 | import time 6 | 7 | import numpy as np 8 | 9 | from util.logconf import logging 10 | 11 | log = logging.getLogger(__name__) 12 | # log.setLevel(logging.WARN) 13 | # log.setLevel(logging.INFO) 14 | log.setLevel(logging.DEBUG) 15 | 16 | # Irc stay for index, row and column. Xyz are the patient 17 | # coordinates. x represent the right to left direction, y the 18 | # anterior to posterior direction and z the inferior to superior 19 | # direction. 20 | # Usually the row and column dimensions have voxel sizes that are 21 | # the same, and the index dimension has a larger value. 22 | # Commonly, CTs are 512 rows by 512 columns, with the index 23 | # dimension ranging from around 100 total slices up to perhaps 250 24 | # slices. 25 | IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col']) 26 | XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z']) 27 | 28 | # The following function apply a scaling factor to produce images with 29 | # realistic proportions. 30 | def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_tup): 31 | if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1): 32 | direction_ary = np.ones((3,)) 33 | elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1): 34 | direction_ary = np.array((-1, -1, 1)) 35 | else: 36 | raise Exception("Unsupported direction_tup: {}".format(direction_tup)) 37 | 38 | coord_cri = (np.array(coord_xyz) - np.array(origin_xyz)) / \ 39 | np.array(vxSize_xyz) 40 | coord_cri *= direction_ary 41 | 42 | return IrcTuple(*list(reversed(coord_cri.tolist()))) 43 | 44 | 45 | def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup): 46 | coord_cri = np.array(list(reversed(coord_irc))) 47 | 48 | if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1): 49 | direction_ary = np.ones((3,)) 50 | elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1): 51 | direction_ary = np.array((-1, -1, 1)) 52 | else: 53 | raise Exception("Unsupported direction_tup: {}".format(direction_tup)) 54 | 55 | coord_xyz = coord_cri * direction_ary * np.array(vxSize_xyz) \ 56 | + np.array(origin_xyz) 57 | 58 | return XyzTuple(*coord_xyz.tolist()) 59 | 60 | 61 | def importstr(module_str, from_=None): 62 | """ 63 | >>> importstr('os') 64 | 65 | >>> importstr('math', 'fabs') 66 | 67 | """ 68 | if from_ is None and ':' in module_str: 69 | # In this case rsplit is equivalment to split 70 | module_str, from_ = module_str.rsplit(':') 71 | 72 | module = __import__(module_str) 73 | for sub_str in module_str.split('.')[1:]: 74 | # getattr(module, sub_str) is equivalent to module.sub_str 75 | module = getattr(module, sub_str) 76 | 77 | if from_: 78 | try: 79 | return getattr(module, from_) 80 | except: 81 | raise ImportError('{}.{}'.format(module_str, from_)) 82 | 83 | return module 84 | 85 | 86 | def prhist(ary, prefix_str=None, **kwargs): 87 | if prefix_str is None: 88 | prefix_str = '' 89 | else: 90 | prefix_str += ' ' 91 | 92 | count_ary, bins_ary = np.histogram(ary, **kwargs) 93 | for i in range(count_ary.shape[0]): 94 | print("{}{:-8.2f}".format(prefix_str, bins_ary[i]), 95 | "{:-10}".format(count_ary[i])) 96 | 97 | print("{}{:-8.2f}".format(prefix_str, bins_ary[-1])) 98 | 99 | 100 | def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, 101 | backoff=2, iter_len=None): 102 | """ 103 | In terms of behavior, `enumerateWithEstimate` is almost identical 104 | to the standard `enumerate` (the differences are things like how 105 | our function returns a generator, while `enumerate` returns a 106 | specialized ``). 107 | 108 | However, the side effects (logging, specifically) are what make the 109 | function interesting. 110 | 111 | :param iter: `iter` is the iterable that will be passed into 112 | `enumerate`. Required. 113 | 114 | :param desc_str: This is a human-readable string that describes 115 | what the loop is doing. The value is arbitrary, but should be 116 | kept reasonably short. Things like `"epoch 4 training"` or 117 | `"deleting temp files"` or similar would all make sense. 118 | 119 | :param start_ndx: This parameter defines how many iterations of the 120 | loop should be skipped before timing actually starts. Skipping 121 | a few iterations can be useful if there are startup costs like 122 | caching that are only paid early on, resulting in a skewed 123 | average when those early iterations dominate the average time 124 | per iteration. 125 | 126 | NOTE: Using `start_ndx` to skip some iterations makes the time 127 | spent performing those iterations not be included in the 128 | displayed duration. Please account for this if you use the 129 | displayed duration for anything formal. 130 | 131 | This parameter defaults to `0`. 132 | 133 | :param print_ndx: determines which loop interation that the timing 134 | logging will start on. The intent is that we don't start 135 | logging until we've given the loop a few iterations to let the 136 | average time-per-iteration a chance to stablize a bit. We 137 | require that `print_ndx` not be less than `start_ndx` times 138 | `backoff`, since `start_ndx` greater than `0` implies that the 139 | early N iterations are unstable from a timing perspective. 140 | 141 | `print_ndx` defaults to `4`. 142 | 143 | :param backoff: This is used to how many iterations to skip before 144 | logging again. Frequent logging is less interesting later on, 145 | so by default we double the gap between logging messages each 146 | time after the first. 147 | 148 | `backoff` defaults to `2`. 149 | 150 | :param iter_len: Since we need to know the number of items to 151 | estimate when the loop will finish, that can be provided by 152 | passing in a value for `iter_len`. If a value isn't provided, 153 | then it will be set by using the value of `len(iter)`. 154 | 155 | :return: 156 | """ 157 | if iter_len is None: 158 | iter_len = len(iter) 159 | 160 | assert backoff >= 2 161 | while print_ndx < start_ndx * backoff: 162 | print_ndx *= backoff 163 | 164 | log.warning("{} ----/{}, starting".format( 165 | desc_str, 166 | iter_len, 167 | )) 168 | start_ts = time.time() 169 | for (current_ndx, item) in enumerate(iter): 170 | yield (current_ndx, item) 171 | if current_ndx == print_ndx: 172 | duration_sec = ((time.time() - start_ts) / 173 | (current_ndx - start_ndx + 1) 174 | * (iter_len - start_ndx)) 175 | 176 | done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec) 177 | done_td = datetime.timedelta(seconds=duration_sec) 178 | 179 | log.info("{} {:-4}/{}, done at {}, {}".format( 180 | desc_str, 181 | current_ndx, 182 | iter_len, 183 | str(done_dt).rsplit('.', 1)[0], 184 | str(done_td).rsplit('.', 1)[0], 185 | )) 186 | 187 | print_ndx *= backoff 188 | 189 | if current_ndx + 1 == start_ndx: 190 | start_ts = time.time() 191 | 192 | log.warning("{} ----/{}, done at {}".format( 193 | desc_str, 194 | iter_len, 195 | str(datetime.datetime.now()).rsplit('.', 1)[0], 196 | )) 197 | -------------------------------------------------------------------------------- /p1ch4/2_time_series_bikes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "torch.set_printoptions(edgeitems=2, threshold=50)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "tensor([[1.0000e+00, 1.0000e+00, ..., 1.3000e+01, 1.6000e+01],\n", 23 | " [2.0000e+00, 1.0000e+00, ..., 3.2000e+01, 4.0000e+01],\n", 24 | " ...,\n", 25 | " [1.7378e+04, 3.1000e+01, ..., 4.8000e+01, 6.1000e+01],\n", 26 | " [1.7379e+04, 3.1000e+01, ..., 3.7000e+01, 4.9000e+01]])" 27 | ] 28 | }, 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "output_type": "execute_result" 32 | } 33 | ], 34 | "source": [ 35 | "bikes_numpy = np.loadtxt(\"../data/p1ch4/bike-sharing-dataset/hour-fixed.csv\",\n", 36 | " dtype=np.float32,\n", 37 | " delimiter=',',\n", 38 | " skiprows=1,\n", 39 | " # extract the day\n", 40 | " converters={1: lambda x: float(x[8:10])}) # <1>\n", 41 | "bikes = torch.from_numpy(bikes_numpy)\n", 42 | "bikes" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "(torch.Size([17520, 17]), (17, 1))" 54 | ] 55 | }, 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "bikes.shape, bikes.stride()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "(torch.Size([730, 24, 17]), (408, 17, 1))" 74 | ] 75 | }, 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "daily_bikes = bikes.view(-1, 24, bikes.shape[1])\n", 83 | "daily_bikes.shape, daily_bikes.stride()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "(torch.Size([730, 17, 24]), (408, 1, 17))" 95 | ] 96 | }, 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "daily_bikes = daily_bikes.transpose(1, 2)\n", 104 | "daily_bikes.shape, daily_bikes.stride()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "tensor([1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2])" 116 | ] 117 | }, 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "first_day = bikes[:24].long()\n", 125 | "weather_onehot = torch.zeros(first_day.shape[0], 4)\n", 126 | "first_day[:, 9]" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "tensor([[1., 0., 0., 0.],\n", 138 | " [1., 0., 0., 0.],\n", 139 | " ...,\n", 140 | " [0., 1., 0., 0.],\n", 141 | " [0., 1., 0., 0.]])" 142 | ] 143 | }, 144 | "execution_count": 7, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "weather_onehot.scatter_(\n", 151 | " dim=1, \n", 152 | " index=first_day[:, 9].unsqueeze(1) - 1, # <1>\n", 153 | " value=1.0)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 8, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "tensor([[ 1.0000, 1.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 6.0000,\n", 165 | " 0.0000, 1.0000, 0.2400, 0.2879, 0.8100, 0.0000, 3.0000, 13.0000,\n", 166 | " 16.0000, 1.0000, 0.0000, 0.0000, 0.0000]])" 167 | ] 168 | }, 169 | "execution_count": 8, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "torch.cat((bikes[:24], weather_onehot), 1)[:1]" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 9, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "torch.Size([730, 4, 24])" 187 | ] 188 | }, 189 | "execution_count": 9, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "daily_weather_onehot = torch.zeros(daily_bikes.shape[0], 4, daily_bikes.shape[2])\n", 196 | "daily_weather_onehot.shape" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 10, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "text/plain": [ 207 | "torch.Size([730, 4, 24])" 208 | ] 209 | }, 210 | "execution_count": 10, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "daily_weather_onehot.scatter_(1, daily_bikes[:, 9, :].long().unsqueeze(1) - 1, 1.0)\n", 217 | "daily_weather_onehot.shape" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 11, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "daily_bikes = torch.cat((daily_bikes, daily_weather_onehot), dim=1)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 12, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "daily_bikes[:, 9, :] = (daily_bikes[:, 9, :] - 1.0) / 3.0" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 13, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "temp = daily_bikes[: 10, :]\n", 245 | "temp_min = torch.min(temp)\n", 246 | "temp_max = torch.max(temp)\n", 247 | "daily_bikes[:, 10, :] = (daily_bikes[:, 10, :] - temp_min) / (temp_max - temp_min)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 14, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "temp = daily_bikes[:, 10, :]\n", 257 | "daily_bikes[:, 10, :] = (daily_bikes[:, 10, :] - torch.mean(temp)) / torch.std(temp)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.6.9" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 2 289 | } 290 | -------------------------------------------------------------------------------- /p2ch08/training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import sys 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch.optim import SGD 10 | from torch.utils.data import DataLoader 11 | 12 | from util.util import enumerateWithEstimate 13 | from util.logconf import logging 14 | from .dsets import LunaDataset 15 | from .model import LunaModel 16 | 17 | log = logging.getLogger(__name__) 18 | log.setLevel(logging.INFO) 19 | 20 | # Used for computeBatchLoss and logMetrics to index into 21 | # metrics_tensor/metrics_ary 22 | METRICS_LABEL_NDX = 0 23 | METRICS_PRED_NDX = 1 24 | METRICS_LOSS_NDX = 2 25 | 26 | class LunaTrainingApp(): 27 | def __init__(self, sys_argv=None): 28 | if sys_argv is None: 29 | sys_argv = sys.argv[1:] 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | '--batch-size', 34 | help="Batch size to use for training", 35 | default=32, 36 | type=int 37 | ) 38 | parser.add_argument( 39 | '--num-workers', 40 | help="Number of worker processes for background data loading", 41 | default=8, 42 | type=int 43 | ) 44 | parser.add_argument( 45 | '--epochs', 46 | help="Number of epochs to train for", 47 | default=1, 48 | type=int 49 | ) 50 | 51 | self.cli_args = parser.parse_args(sys_argv) 52 | self.time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 53 | 54 | def main(self): 55 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 56 | 57 | self.use_cuda = torch.cuda.is_available() 58 | self.device = torch.device("cuda" if self.use_cuda else "cpu") 59 | 60 | self.model = LunaModel() 61 | if self.use_cuda: 62 | if torch.cuda.device_count() > 1: 63 | self.model = nn.DataParallel(self.model) 64 | 65 | self.model = self.model.to(self.device) 66 | self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9) 67 | 68 | train_dl = DataLoader( 69 | LunaDataset(test_stride=10, isTestSet_bool=False), 70 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() 71 | if self.use_cuda else 1), 72 | num_workers=self.cli_args.num_workers, 73 | pin_memory=self.use_cuda 74 | ) 75 | 76 | test_dl = DataLoader( 77 | LunaDataset(test_stride=10, isTestSet_bool=True), 78 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() 79 | if self.use_cuda else 1), 80 | num_workers=self.cli_args.num_workers, 81 | pin_memory=self.use_cuda 82 | ) 83 | 84 | for epoch_ndx in range(1, self.cli_args.epochs + 1): 85 | log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( 86 | epoch_ndx, 87 | self.cli_args.epochs, 88 | len(train_dl), 89 | len(test_dl), 90 | self.cli_args.batch_size, 91 | (torch.cuda.device_count() if self.use_cuda else 1) 92 | )) 93 | 94 | # Trainig loop, very similar to below 95 | self.model.train() 96 | trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1) 97 | batch_iter = enumerateWithEstimate( 98 | train_dl, 99 | "E{} Traning".format(epoch_ndx), 100 | start_ndx=train_dl.num_workers 101 | ) 102 | 103 | for batch_ndx, batch_tup in batch_iter: 104 | self.optimizer.zero_grad() 105 | loss_var = self.computeBatchLoss(batch_ndx, batch_tup, 106 | train_dl.batch_size, 107 | trainingMetrics_tensor) 108 | loss_var.backward() 109 | self.optimizer.step() 110 | del loss_var 111 | 112 | # Testing loop, very similar to above, but simplified 113 | with torch.no_grad(): 114 | self.model.eval() 115 | testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1) 116 | batch_iter = enumerateWithEstimate( 117 | test_dl, 118 | "E{} Testing".format(epoch_ndx), 119 | start_ndx=test_dl.num_workers 120 | ) 121 | for batch_ndx, batch_tup in batch_iter: 122 | self.computeBatchLoss(batch_ndx, batch_tup, 123 | test_dl.batch_size, 124 | testingMetrics_tensor) 125 | 126 | self.logMetrics(epoch_ndx, trainingMetrics_tensor, 127 | testingMetrics_tensor) 128 | 129 | def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, 130 | metrics_tensor): 131 | input_tensor, label_tensor, _series_list, _center_list = batch_tup 132 | input_devtensor = input_tensor.to(self.device) 133 | label_devtensor = label_tensor.to(self.device) 134 | 135 | prediction_devtensor = self.model(input_devtensor) 136 | loss_devsensor = nn.MSELoss(reduction='none')(prediction_devtensor, 137 | label_devtensor) 138 | 139 | start_ndx = batch_ndx * batch_size 140 | end_ndx = start_ndx + label_tensor.size(0) 141 | 142 | metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor 143 | metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = \ 144 | prediction_devtensor.to('cpu') 145 | metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = \ 146 | loss_devsensor 147 | 148 | return loss_devsensor.mean() 149 | 150 | def logMetrics(self, epoch_ndx, trainingMetrics_tensor, 151 | testingMetrics_tensor, classificationThreshold_float=0.5): 152 | log.info("E{} {}".format( 153 | epoch_ndx, 154 | type(self).__name__ 155 | )) 156 | 157 | for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), 158 | ('tst', testingMetrics_tensor)]: 159 | metrics_ary = metrics_tensor.detach().numpy()[:, :, 0] 160 | assert np.isfinite(metrics_ary).all() 161 | 162 | benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= \ 163 | classificationThreshold_float 164 | benPred_mask = metrics_ary[METRICS_PRED_NDX] <= \ 165 | classificationThreshold_float 166 | 167 | malLabel_mask = ~benLabel_mask 168 | malPred_mask = ~benPred_mask 169 | 170 | benLabel_count = benLabel_mask.sum() 171 | malLabel_count = malLabel_mask.sum() 172 | 173 | benCorrect_count = (benLabel_mask & benPred_mask).sum() 174 | malCorrect_count = (malLabel_mask & malPred_mask).sum() 175 | 176 | metrics_dict = {} 177 | 178 | metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean() 179 | metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, 180 | benLabel_mask].mean() 181 | metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, 182 | malLabel_mask].mean() 183 | 184 | metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) \ 185 | / metrics_ary.shape[1] * 100 186 | metrics_dict['correct/ben'] = benCorrect_count / benLabel_count * 100 187 | metrics_dict['correct/mal'] = malCorrect_count / malLabel_count * 100 188 | 189 | log.info(("E{} {:8} {loss/all:.4f} loss, {correct/all:-5.1f}% " 190 | "correct").format(epoch_ndx, mode_str, **metrics_dict)) 191 | log.info(("E{} {:8} {loss/ben:.4f} loss, {correct/ben:-5.1f}% " 192 | "correct").format(epoch_ndx, mode_str + '_ben', **metrics_dict)) 193 | log.info(("E{} {:8} {loss/mal:.4f} loss, {correct/mal:-5.1f}% " 194 | "correct").format(epoch_ndx, mode_str + 'mal', **metrics_dict)) 195 | 196 | if __name__ == '__main__': 197 | sys.exit(LunaTrainingApp().main() or 0) 198 | -------------------------------------------------------------------------------- /p1ch4/3_text_jane_austin.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "torch.set_printoptions(edgeitems=2, threshold=50)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "with open(\"../data/p1ch4/jane-austin/1342-0.txt\", encoding='utf8') as f:\n", 21 | " text = f.read()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "'“Impossible, Mr. Bennet, impossible, when I am not acquainted with him'" 33 | ] 34 | }, 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "lines = text.split('\\n')\n", 42 | "line = lines[200]\n", 43 | "line" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "torch.Size([70, 128])" 55 | ] 56 | }, 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "letter_tensor = torch.zeros(len(line), 128) # <1>\n", 64 | "letter_tensor.shape" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "for i, letter in enumerate(line.lower().strip()):\n", 74 | " letter_index = ord(letter) if ord(letter) < 128 else 0 # <1>\n", 75 | " letter_tensor[i][letter_index] = 1" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 6, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "('“Impossible, Mr. Bennet, impossible, when I am not acquainted with him',\n", 87 | " ['impossible',\n", 88 | " 'mr',\n", 89 | " 'bennet',\n", 90 | " 'impossible',\n", 91 | " 'when',\n", 92 | " 'i',\n", 93 | " 'am',\n", 94 | " 'not',\n", 95 | " 'acquainted',\n", 96 | " 'with',\n", 97 | " 'him'])" 98 | ] 99 | }, 100 | "execution_count": 6, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "def clean_words(input_str):\n", 107 | " punctuation = '.,;:\"!?”“_-'\n", 108 | " word_list = input_str.lower().replace('\\n', ' ').split()\n", 109 | " word_list = [word.strip(punctuation) for word in word_list]\n", 110 | " return word_list\n", 111 | "\n", 112 | "words_in_line = clean_words(line)\n", 113 | "line, words_in_line" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "(7261, 3394)" 125 | ] 126 | }, 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "word_list = sorted(set(clean_words(text)))\n", 134 | "word2index_dict = {word: i for (i, word) in enumerate(word_list)}\n", 135 | "\n", 136 | "len(word2index_dict), word2index_dict['impossible']" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 8, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | " 0 3394 impossible\n", 149 | " 1 4305 mr\n", 150 | " 2 813 bennet\n", 151 | " 3 3394 impossible\n", 152 | " 4 7078 when\n", 153 | " 5 3315 i\n", 154 | " 6 415 am\n", 155 | " 7 4436 not\n", 156 | " 8 239 acquainted\n", 157 | " 9 7148 with\n", 158 | "10 3215 him\n", 159 | "torch.Size([11, 7261])\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "word_tensor = torch.zeros(len(words_in_line), len(word2index_dict))\n", 165 | "\n", 166 | "for i, word in enumerate(words_in_line):\n", 167 | " word_index = word2index_dict[word]\n", 168 | " word_tensor[i][word_index] = 1\n", 169 | " print(\"{:2} {:4} {}\".format(i, word_index, word))\n", 170 | " \n", 171 | "print(word_tensor.shape)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "data": { 181 | "text/plain": [ 182 | "torch.Size([11, 1, 7261])" 183 | ] 184 | }, 185 | "execution_count": 9, 186 | "metadata": {}, 187 | "output_type": "execute_result" 188 | } 189 | ], 190 | "source": [ 191 | "word_tensor = word_tensor.unsqueeze(1)\n", 192 | "word_tensor.shape" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 10, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "[('\\n', 10),\n", 204 | " (' ', 32),\n", 205 | " ('!', 33),\n", 206 | " ('#', 35),\n", 207 | " ('$', 36),\n", 208 | " ('%', 37),\n", 209 | " (\"'\", 39),\n", 210 | " ('(', 40),\n", 211 | " (')', 41),\n", 212 | " ('*', 42),\n", 213 | " (',', 44),\n", 214 | " ('-', 45),\n", 215 | " ('.', 46),\n", 216 | " ('/', 47),\n", 217 | " ('0', 48),\n", 218 | " ('1', 49),\n", 219 | " ('2', 50),\n", 220 | " ('3', 51),\n", 221 | " ('4', 52),\n", 222 | " ('5', 53),\n", 223 | " ('6', 54),\n", 224 | " ('7', 55),\n", 225 | " ('8', 56),\n", 226 | " ('9', 57),\n", 227 | " (':', 58),\n", 228 | " (';', 59),\n", 229 | " ('?', 63),\n", 230 | " ('@', 64),\n", 231 | " ('A', 65),\n", 232 | " ('B', 66),\n", 233 | " ('C', 67),\n", 234 | " ('D', 68),\n", 235 | " ('E', 69),\n", 236 | " ('F', 70),\n", 237 | " ('G', 71),\n", 238 | " ('H', 72),\n", 239 | " ('I', 73),\n", 240 | " ('J', 74),\n", 241 | " ('K', 75),\n", 242 | " ('L', 76),\n", 243 | " ('M', 77),\n", 244 | " ('N', 78),\n", 245 | " ('O', 79),\n", 246 | " ('P', 80),\n", 247 | " ('Q', 81),\n", 248 | " ('R', 82),\n", 249 | " ('S', 83),\n", 250 | " ('T', 84),\n", 251 | " ('U', 85),\n", 252 | " ('V', 86),\n", 253 | " ('W', 87),\n", 254 | " ('X', 88),\n", 255 | " ('Y', 89),\n", 256 | " ('Z', 90),\n", 257 | " ('[', 91),\n", 258 | " (']', 93),\n", 259 | " ('_', 95),\n", 260 | " ('a', 97),\n", 261 | " ('b', 98),\n", 262 | " ('c', 99),\n", 263 | " ('d', 100),\n", 264 | " ('e', 101),\n", 265 | " ('f', 102),\n", 266 | " ('g', 103),\n", 267 | " ('h', 104),\n", 268 | " ('i', 105),\n", 269 | " ('j', 106),\n", 270 | " ('k', 107),\n", 271 | " ('l', 108),\n", 272 | " ('m', 109),\n", 273 | " ('n', 110),\n", 274 | " ('o', 111),\n", 275 | " ('p', 112),\n", 276 | " ('q', 113),\n", 277 | " ('r', 114),\n", 278 | " ('s', 115),\n", 279 | " ('t', 116),\n", 280 | " ('u', 117),\n", 281 | " ('v', 118),\n", 282 | " ('w', 119),\n", 283 | " ('x', 120),\n", 284 | " ('y', 121),\n", 285 | " ('z', 122),\n", 286 | " ('“', 8220),\n", 287 | " ('”', 8221),\n", 288 | " ('\\ufeff', 65279)]" 289 | ] 290 | }, 291 | "execution_count": 10, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "[(c, ord(c)) for c in sorted(set(text))]" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 11, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "108" 309 | ] 310 | }, 311 | "execution_count": 11, 312 | "metadata": {}, 313 | "output_type": "execute_result" 314 | } 315 | ], 316 | "source": [ 317 | "ord('l')" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [] 326 | } 327 | ], 328 | "metadata": { 329 | "kernelspec": { 330 | "display_name": "Python 3", 331 | "language": "python", 332 | "name": "python3" 333 | }, 334 | "language_info": { 335 | "codemirror_mode": { 336 | "name": "ipython", 337 | "version": 3 338 | }, 339 | "file_extension": ".py", 340 | "mimetype": "text/x-python", 341 | "name": "python", 342 | "nbconvert_exporter": "python", 343 | "pygments_lexer": "ipython3", 344 | "version": "3.6.9" 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 2 349 | } 350 | -------------------------------------------------------------------------------- /p1ch4/4_audio_chirp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Audio\n", 8 | "====" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "torch.set_printoptions(edgeitems=2, threshold=50)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "Sound can be seen as fluctuations of pressure of a medium, air for instance, at a certain location in time. There are other representations that we'll get into in a minute, but we can think about this as the _raw_, time-domain representation. In order for the human ear to appreciate sound, pressure must fluctuate with a frequency between 20 and 20000 oscillations per second (measured in Hertz, Hz). More oscillations per second will lead to a higher perceived pitch.\n", 27 | "\n", 28 | "By recording pressure fluctuations in time using a microphone and converting every pressure level at each time point into a number (e.g. a 16-bit integer), we can now represent sound as a vector of numbers. This is known as Pulse Code Modulation (PCM), where a continuous signal is both sampled in time and quantized in amplitude. If we want to make sure we hear the highest possible pitch in a recording, we'll have to record our samples at slightly more than twice the maximum audible frequency, i.e. just over 40000 times per second. It is not by chance that audio CD's have a sampling frequency of 44100 Hz. This means that a one-hour stereo (i.e. 2 channels) CD track where samples are recorded at 16-bit precision will amount to `2 * 16 * 44100 * 3600 = 5080320000 bit = 605.6 MB` if stored without compression.\n", 29 | "\n", 30 | "There are a plethora of audio formats, WAV, AIFF, MP3, AAC being the most popular, where raw audio signals are typically encoded in compressed form by leveraging on both correlation between successive samples in the time series, between the two stereo channels as well as elimination of barely audible frequencies. This can result in dramatic reduction of storage requirements (a one-hour audio file in AAC format takes less than 60 MB). In addition, audio players can decode these formats on the fly on dedicated hardware, consuming a tiny amount of power.\n", 31 | "\n", 32 | "In our data scientist role we may have to feed audio samples to our network and classify them, or generate captions, for instance. In that case, we won't work with compressed data, rather we'll have to find a way to load an audio file in some format and lay it out as an uncompressed time series in a tensor. Let's do that now.\n", 33 | "\n", 34 | "We can download a fair number of environmental sounds at the ESC-50 repository (https://github.com/karoldvl/ESC-50) in the `audio` directory. Let's get `1-100038-A-14.wav` for instance, containing the sound of a bird chirping.\n", 35 | "\n", 36 | "In order to load the sound we resort to SciPy, specifically `scipy.io.wavfile.read`, which has the nice property to return data as a NumPy array:" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "(44100, array([ -388, -3387, -4634, ..., 2289, 1327, 90], dtype=int16))" 48 | ] 49 | }, 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "from scipy.io import wavfile\n", 57 | "\n", 58 | "freq, waveform_arr = wavfile.read(\"../data/p1ch4/audio-chirp/1-100038-A-14.wav\")\n", 59 | "freq, waveform_arr" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "The `read` function returns two outputs, namely the sampling frequency and the waveform as a 16-bit integer 1D array. It's a single 1D array, which tells us that it's a mono recording - we'd have two waveforms (two channels) if the sound were stereo.\n", 67 | "\n", 68 | "We can convert the array to a tensor and we're good to go. We might also want to convert the waveform tensor to a float tensor since we're at it." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "torch.Size([220500])" 80 | ] 81 | }, 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "waveform = torch.from_numpy(waveform_arr).float()\n", 89 | "waveform.shape" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "In a typical dataset, we'll have more than one waveform, and possibly over more than one channel. Depending on the kind of network employed for carrying out a task, for instance a sound classification task, we would be required to lay out the tensor in one of two ways.\n", 97 | "\n", 98 | "For architectures based on filtering the 1D signal with cascades of learned filter banks, such as convolutional networks, we would need to lay out the tensor as `N x C x L`, where `N` is the number of sounds in a dataset, `C` the number of channels and `L` the number of samples in time.\n", 99 | "\n", 100 | "Conversely, for architectures that incorporate the notion of temporal sequences, just as recurrent networks we mentioned for text, data needs to be laid out as `L x N x C` - sequence length comes first. Intuitively, this is because the latter architectures take one set of `C` values at a time - the signal is not considered as a whole, but as an individual input changing in time.\n", 101 | "\n", 102 | "Although the most straightforward, this is only one of the ways to represent audio so that it is digestible by a neural network. Anther way is turning the audio signal into a _spectrogram_.\n", 103 | "\n", 104 | "Instead of representing oscillations explicitly in time, we can characterize what at frequencies those oscillations occur for short time intervals. So, for instance, if we pluck the fifth string of our (hopefully tuned) guitar and we focus on 0.1 seconds of that recording, we will see that the waveform oscillates at 440 cycles per second, plus smaller spurious oscillations at different frequencies that make up the timbre of the sound. If we move on to subsequent 0.1 second intervals, we now see that the frequency content doesn't change, but the intensity does, as the sound of our string fades. If we now decide to pluck another string, we will observe new frequencies fading in time.\n", 105 | "\n", 106 | "We could indeed build a plot having time in the X-axis, frequencies heard at that time in the Y-axis and encode intensity of those frequencies as a value at that X and Y. Or color. Ok, that starts to look like an image, right?\n", 107 | "\n", 108 | "That's correct, spectrograms are a representation of the intensity at each frequency at each point in time. It turns out that one can train convolutional neural networks built for analyzing images (we'll see about those in a couple of chapters) on sound represented as a spectrogram.\n", 109 | "\n", 110 | "Let's see how we can turn the sound we loaded earlier into a spectrogram. To do that, we need to resort to a method for converting a signal in the time domain into its frequency content. This is known as the Fourier transform, and the algorithm that allows us to compute it efficiently is the Fast Fourier Trasform (FFT), which is one of the most widespread algorithms out there. If we do that consecutively for short bursts of sound in time, we can build out spectrogram column by column.\n", 111 | "\n", 112 | "This is the general idea and we won't go into too many details here. Luckily for us SciPy has a function that gets us a shiny spectrogram given an input waveform. We import the `signal` module from SciPy,\n", 113 | "then provide the `spectrogram` function with the waveform and the sampling frequency that we got previously.\n", 114 | "The return values are all NumPy arrays, namely frequency `f_arr` (values along the Y axis), time `t_arr` (values along the X axis) and the actual spectrogra `sp_arr` as a 2D array. Turning the latter into a PyTorch tensor is trivial:" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "torch.Size([129, 984])" 126 | ] 127 | }, 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "from scipy import signal\n", 135 | "\n", 136 | "f_arr, t_arr, sp_arr = signal.spectrogram(waveform_arr, freq)\n", 137 | "\n", 138 | "sp_mono = torch.from_numpy(sp_arr)\n", 139 | "sp_mono.shape" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "Dimensions are `F x T`, where `F` is frequency and `T` is time.\n", 147 | "\n", 148 | "As we mentioned earlier, stereo sound has two channels, which will lead to a two-channel spectrogram. Suppose we have two spectrograms, one for each channel. We can convert the two channels separately:" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "(torch.Size([129, 984]), torch.Size([129, 984]))" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "sp_left = sp_right = sp_arr\n", 169 | "sp_left_tensor = torch.from_numpy(sp_left)\n", 170 | "sp_right_tensor = torch.from_numpy(sp_right)\n", 171 | "sp_left_tensor.shape, sp_right_tensor.shape" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "and stack the two tensors along the first dimension to obtain a two channels image of size `C x F x T`, where `C` is the number channels:" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 6, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "torch.Size([2, 129, 984])" 190 | ] 191 | }, 192 | "execution_count": 6, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "sp_tensor = torch.stack((sp_left_tensor, sp_right_tensor), dim=0)\n", 199 | "sp_tensor.shape" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "If we want to build a dataset to use as input for a network, we will stack multiple spectrograms representing multiple sounds in a dataset along the first dimension, leading to a `N x C x F x T` tensor.\n", 207 | "\n", 208 | "Such tensor is indistinguishable from what we would build for a dataset set of images, where `F` is represents rows and `T` columns of an image. Indeed, we would tackle a sound classification problem on spectrograms with the exact same networks." 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.6.9" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /p2ch12/dsets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import functools 4 | import glob 5 | import math 6 | import os 7 | import random 8 | 9 | from collections import namedtuple 10 | 11 | import SimpleITK as sitk 12 | import numpy as np 13 | 14 | import torch 15 | import torch.cuda 16 | import torch.nn.functional as F 17 | from torch.utils.data import Dataset 18 | 19 | from util.disk import getCache 20 | from util.util import XyzTuple, xyz2irc 21 | from util.logconf import logging 22 | 23 | log = logging.getLogger(__name__) 24 | # log.setLevel(logging.WARN) 25 | # log.setLevel(logging.INFO) 26 | log.setLevel(logging.DEBUG) 27 | 28 | raw_cache = getCache('part2ch12_raw') 29 | 30 | NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool', 31 | 'diameter_mm', 'series_uid', 'center_xyz') 32 | 33 | @functools.lru_cache(1) 34 | def getNoduleInfoList(requireDataOnDisk_bool=True): 35 | # We construct a set with all series_uids that are present on disk. 36 | # This will let us use the data, even if we haven't downloaded all 37 | # of the subsets yet. 38 | mhd_list = glob.glob("data-unversioned/part2/luna/subset*/*.mhd") 39 | dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} 40 | 41 | diameter_dict = {} 42 | with open("data/part2/luna/annotations.csv", 'r') as f: 43 | for row in list(csv.reader(f))[1:]: 44 | series_uid = row[0] 45 | annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) 46 | annotationDiameter_mm = float(row[4]) 47 | 48 | diameter_dict.setdefault(series_uid, []).append( 49 | (annotationCenter_xyz, annotationDiameter_mm)) 50 | 51 | noduleInfo_list = [] 52 | with open("data/part2/luna/candidates.csv", 'r') as f: 53 | for row in list(csv.reader(f))[1:]: 54 | series_uid = row[0] 55 | 56 | if series_uid not in dataPresentOnDisk_set and \ 57 | requireDataOnDisk_bool: 58 | continue 59 | 60 | isMalignant_bool = bool(int(row[4])) 61 | candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) 62 | 63 | candidateDiameter_mm = 0.0 64 | for annotationDiameter_mm, annotationDiameter_mm in \ 65 | diameter_dict.get(series_uid, []): 66 | for i in range(3): 67 | delta_mm = abs(candidateCenter_xyz[i] - 68 | annotationCenter_xyz[i]) 69 | if delta_mm > annotationDiameter_mm / 4: 70 | break 71 | else: 72 | candidateDiameter_mm = annotationDiameter_mm 73 | break 74 | 75 | noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, 76 | candidateDiameter_mm, 77 | series_uid, 78 | candidateCenter_xyz)) 79 | noduleInfo_list.sort(reverse=True) 80 | 81 | return noduleInfo_list 82 | 83 | 84 | class Ct: 85 | def __init__(self, series_uid): 86 | mhd_path = glob.glob("data-unversioned/part2/luna/subset*/{}.mhd"\ 87 | .format(series_uid))[0] 88 | 89 | ct_mhd = sitk.ReadImage(mhd_path) 90 | ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) 91 | 92 | # CTs are natively expressed in 93 | # https://en.wikipedia.org/wiki/Hounsfield_scale HU are scaled 94 | # oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc 95 | # (water) being 0. 96 | # This gets rid of negative density stuff used to indicate 97 | # out-of-FOV 98 | ct_a[ct_a < -1000] = -1000 99 | 100 | # This nukes any weird hotspots and clamps bone down 101 | ct_a[ct_a > 1000] = 1000 102 | 103 | self.series_uid = series_uid 104 | self.hu_a = ct_a 105 | 106 | self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) 107 | self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) 108 | self.direction_tup = tuple(int(round(x)) for x in 109 | ct_mhd.GetDirection()) 110 | 111 | def getRawNodule(self, center_xyz, width_irc): 112 | center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, 113 | self.direction_tup) 114 | 115 | slice_list = [] 116 | for axis, center_val in enumerate(center_irc): 117 | start_ndx = int(round(center_val - width_irc[axis] / 2)) 118 | end_ndx = int(start_ndx + width_irc[axis]) 119 | 120 | assert center_val >= 0 and center_val < self.hu_a.shape[axis], \ 121 | repr([self.series_uid, center_xyz, self.origin_xyz, 122 | self.vxSize_xyz, center_irc, axis]) 123 | 124 | if start_ndx < 0: 125 | start_ndx = 0 126 | end_ndx = int(width_irc[axis]) 127 | 128 | if end_ndx > self.hu_a.shape[axis]: 129 | end_ndx = self.hu_a.shape[axis] 130 | start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) 131 | 132 | slice_list.append(slice(start_ndx, end_ndx)) 133 | 134 | ct_chunk = self.hu_a[tuple(slice_list)] 135 | 136 | return ct_chunk, center_irc 137 | 138 | 139 | @functools.lru_cache(1, typed=True) 140 | def getCt(series_uid): return Ct(series_uid) 141 | 142 | 143 | @raw_cache.memoize(typed=True) 144 | def getCtRawNodule(series_uid, center_xyz, width_irc): 145 | ct = getCt(series_uid) 146 | ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc) 147 | 148 | return ct_chunk, center_irc 149 | 150 | 151 | def getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_irc, 152 | use_cache=True): 153 | if use_cache: 154 | ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, 155 | width_irc) 156 | else: 157 | ct = getCt(series_uid) 158 | ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc) 159 | 160 | ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32) 161 | 162 | trasform_t = torch.eye(4).to(torch.float64) 163 | 164 | for i in range(3): 165 | if 'flip' in augmentation_dict: 166 | if random.random() > 0.5: 167 | trasform_t[i, i] *= -1 168 | 169 | if 'offset' in augmentation_dict: 170 | offset_float = augmentation_dict['offset'] 171 | random_float = (random.random() * 2 - 1) 172 | trasform_t[3, i] = offset_float * random_float 173 | 174 | if 'scale' in augmentation_dict: 175 | scale_float = augmentation_dict['scale'] 176 | random_float = (random.random() * 2 - 1) 177 | trasform_t[i, i] *= 1.0 + scale_float * random_float 178 | 179 | if 'rotate' in augmentation_dict: 180 | angle_red = random.random() * math.pi * 2 181 | s = math.sin(angle_red) 182 | c = math.cos(angle_red) 183 | 184 | rotation_t = torch.tensor([ 185 | [c, -s, 0, 0], 186 | [s, c, 0, 0], 187 | [0, 0, 1, 0], 188 | [0, 0, 0, 1], 189 | ], dtype=torch.float64) 190 | 191 | trasform_t @= rotation_t 192 | 193 | affine_t = F.affine_grid( 194 | trasform_t[:3].unsqueeze(0).to(torch.float32), 195 | ct_t.size() 196 | ) 197 | augmented_chunk = F.grid_sample( 198 | ct_t, 199 | affine_t, 200 | padding_mode='border' 201 | ).to('cpu') 202 | 203 | if 'noise' in augmentation_dict: 204 | noise_t = torch.randn_like(augmented_chunk) 205 | noise_t *= augmented_chunk['noise'] 206 | 207 | augmented_chunk += noise_t 208 | 209 | return augmented_chunk[0], center_irc 210 | 211 | 212 | class LunaDataset(Dataset): 213 | def __init__(self, val_stride=0, isValSet_bool=None, series_uid=None, 214 | sortby_str='random', ratio_int=0, augmentation_dict=None, 215 | noduleInfo_list=None): 216 | self.ratio_int = ratio_int 217 | self.augmentation_dict = augmentation_dict 218 | 219 | if noduleInfo_list: 220 | self.noduleInfo_list = copy.copy(noduleInfo_list) 221 | self.use_cache = False 222 | else: 223 | self.noduleInfo_list = copy.copy(getNoduleInfoList()) 224 | self.use_cache = True 225 | 226 | if series_uid: 227 | self.noduleInfo_list = [x for x in self.noduleInfo_list 228 | if x[2] == series_uid] 229 | 230 | if isValSet_bool: 231 | assert val_stride > 0, val_stride 232 | self.noduleInfo_list = self.noduleInfo_list[::val_stride] 233 | assert self.noduleInfo_list 234 | elif val_stride > 0: 235 | del self.noduleInfo_list[::val_stride] 236 | assert self.noduleInfo_list 237 | 238 | if sortby_str == 'random': 239 | random.shuffle(self.noduleInfo_list) 240 | elif sortby_str == 'series_uid': 241 | self.noduleInfo_list.sort(key=lambda x: (x.series_uid, 242 | x.center_xyz)) 243 | elif sortby_str = 'malignancy_size': 244 | pass 245 | else: 246 | raise Exception("Unknown sort: " + repr(sortby_str)) 247 | 248 | self.benign_list = [nt for nt in self.noduleInfo_list in not 249 | nt.isMalignant_bool] 250 | self.malignant_list = [nt for nt in self.noduleInfo_list if 251 | not.isMalignant_bool] 252 | 253 | log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format( 254 | self, 255 | len(self.noduleInfo_list), 256 | "validation" if isValSet_bool else "training", 257 | len(self.benign_list), 258 | len(self.malignant_list), 259 | '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced' 260 | )) 261 | 262 | def shuffleSamples(self): 263 | if self.ratio_int: 264 | random.shuffle(self.benign_list) 265 | random.shuffle(self.malignant_list) 266 | 267 | def __len__(self): 268 | if self.ratio_int: 269 | return 200000 270 | else: 271 | return len(self.noduleInfo_list) // 20 272 | 273 | def __getitem__(self, ndx): 274 | if self.ratio_int: 275 | malignant_ndx = ndx // (self.ratio_int + 1) 276 | 277 | if ndx % (self.ratio_int + 1): 278 | benign_ndx = ndx - 1 - malignant_ndx 279 | benign_ndx %= len(self.benign_list) 280 | nodule_tup = self.benign_list[benign_ndx] 281 | else: 282 | malignant_ndx %= len(self.malignant_list) 283 | nodule_tup = self.malignant_list[malignant_ndx] 284 | else: 285 | nodule_tup = self.noduleInfo_list[ndx] 286 | 287 | if self.augmentation_dict: 288 | nodule_t, center_irc = getCtAugmentedNodule( 289 | self.augmentation_dict, 290 | nodule_tup.series_uid, 291 | nodule_tup.center_xyz, 292 | width_irc, 293 | self.use_cache 294 | ) 295 | elif self.use_cache: 296 | nodule_a, center_irc = getCtRawNodule( 297 | nodule_tup.series_uid, 298 | nodule_tup.center_xyz, 299 | width_irc 300 | ) 301 | nodule_t = torch.from_numpy(nodule_a).to(torch.float32) 302 | nodule_t = nodule_t.unsqueeze(0) 303 | else: 304 | ct = getCt(nodule_tup.series_uid) 305 | nodule_a, center_irc = ct.getCtRawNodule( 306 | nodule_tup.center_xyz, 307 | width_irc, 308 | ) 309 | nodule_t = torch.from_numpy(nodule_a).to(torch.float32) 310 | nodule_t = nodule_t.unsqueeze(0) 311 | 312 | malignant_t = torch.tensor([ 313 | not nodule_tup.isMalignant_bool, 314 | nodule_tup.isMalignant_bool 315 | ], dtype=torch.long) 316 | 317 | 318 | return nodule_t, malignant_t, nodule_tup.series_uid, \ 319 | torch.tensor(center_irc) 320 | -------------------------------------------------------------------------------- /p2ch09/training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import sys 4 | 5 | import numpy as np 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | import torch 10 | from torch import nn 11 | from torch.optim import SGD 12 | from torch.utils.data import DataLoader 13 | 14 | from util.util import enumerateWithEstimate 15 | from util.logconf import logging 16 | from .dsets import LunaDataset 17 | from .model import LunaModel 18 | 19 | log = logging.getLogger(__name__) 20 | log.setLevel(logging.INFO) 21 | 22 | # Used for computeBatchLoss and logMetrics to index into 23 | # metrics_tensor/metrics_ary 24 | METRICS_LABEL_NDX = 0 25 | METRICS_PRED_NDX = 1 26 | METRICS_LOSS_NDX = 2 27 | 28 | class LunaTrainingApp(): 29 | def __init__(self, sys_argv=None): 30 | if sys_argv is None: 31 | sys_argv = sys.argv[1:] 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | '--batch-size', 36 | help="Batch size to use for training", 37 | default=32, 38 | type=int 39 | ) 40 | parser.add_argument( 41 | '--num-workers', 42 | help="Number of worker processes for background data loading", 43 | default=8, 44 | type=int 45 | ) 46 | parser.add_argument( 47 | '--epochs', 48 | help="Number of epochs to train for", 49 | default=1, 50 | type=int 51 | ) 52 | parser.add_argument( 53 | '--balanced', 54 | help="Balance the training data to half benign, half malignant.", 55 | action='store_true', 56 | default=False 57 | ) 58 | parser.add_argument( 59 | '--tb-prefix', 60 | help="Data prefix to use for Tensorboard run. Defaults to chapter.", 61 | default='p2ch09' 62 | ) 63 | parser.add_argument( 64 | 'comment', 65 | help="Comment suffix for Tensorboard run.", 66 | nargs='?', 67 | default='none' 68 | ) 69 | 70 | self.cli_args = parser.parse_args(sys_argv) 71 | self.time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 72 | 73 | def main(self): 74 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 75 | 76 | self.use_cuda = torch.cuda.is_available() 77 | self.device = torch.device("cuda" if self.use_cuda else "cpu") 78 | self.totalTrainingSamples_count = 0 79 | 80 | self.model = LunaModel() 81 | if self.use_cuda: 82 | if torch.cuda.device_count() > 1: 83 | self.model = nn.DataParallel(self.model) 84 | 85 | self.model = self.model.to(self.device) 86 | 87 | self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9) 88 | 89 | train_dl = DataLoader( 90 | LunaDataset(test_stride=10, isTestSet_bool=False, 91 | ratio_int=int(self.cli_args.balanced)), 92 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() 93 | if self.use_cuda else 1), 94 | num_workers=self.cli_args.num_workers, 95 | pin_memory=self.use_cuda 96 | ) 97 | 98 | test_dl = DataLoader( 99 | LunaDataset(test_stride=10, isTestSet_bool=True), 100 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() 101 | if self.use_cuda else 1), 102 | num_workers=self.cli_args.num_workers, 103 | pin_memory=self.use_cuda 104 | ) 105 | 106 | for epoch_ndx in range(1, self.cli_args.epochs + 1): 107 | log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( 108 | epoch_ndx, 109 | self.cli_args.epochs, 110 | len(train_dl), 111 | len(test_dl), 112 | self.cli_args.batch_size, 113 | (torch.cuda.device_count() if self.use_cuda else 1) 114 | )) 115 | 116 | # Trainig loop, very similar to below 117 | self.model.train() 118 | trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1) 119 | batch_iter = enumerateWithEstimate( 120 | train_dl, 121 | "E{} Traning".format(epoch_ndx), 122 | start_ndx=train_dl.num_workers 123 | ) 124 | 125 | for batch_ndx, batch_tup in batch_iter: 126 | self.optimizer.zero_grad() 127 | loss_var = self.computeBatchLoss(batch_ndx, batch_tup, 128 | train_dl.batch_size, 129 | trainingMetrics_tensor) 130 | loss_var.backward() 131 | self.optimizer.step() 132 | del loss_var 133 | 134 | # Testing loop, very similar to above, but simplified 135 | with torch.no_grad(): 136 | self.model.eval() 137 | testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1) 138 | batch_iter = enumerateWithEstimate( 139 | test_dl, 140 | "E{} Testing".format(epoch_ndx), 141 | start_ndx=test_dl.num_workers 142 | ) 143 | for batch_ndx, batch_tup in batch_iter: 144 | self.computeBatchLoss(batch_ndx, batch_tup, 145 | test_dl.batch_size, 146 | testingMetrics_tensor) 147 | 148 | self.logMetrics(epoch_ndx, trainingMetrics_tensor, 149 | testingMetrics_tensor) 150 | 151 | if hasattr(self, 'trn_writer'): 152 | self.trn_writer.close() 153 | self.tst_writer.close() 154 | 155 | def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, 156 | metrics_tensor): 157 | input_tensor, label_tensor, _series_list, _center_list = batch_tup 158 | input_devtensor = input_tensor.to(self.device) 159 | label_devtensor = label_tensor.to(self.device) 160 | 161 | prediction_devtensor = self.model(input_devtensor) 162 | loss_devsensor = nn.MSELoss(reduction='none')(prediction_devtensor, 163 | label_devtensor) 164 | 165 | start_ndx = batch_ndx * batch_size 166 | end_ndx = start_ndx + label_tensor.size(0) 167 | 168 | metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor 169 | metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = \ 170 | prediction_devtensor.to('cpu') 171 | metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = \ 172 | loss_devsensor 173 | 174 | return loss_devsensor.mean() 175 | 176 | def logMetrics(self, epoch_ndx, trainingMetrics_tensor, 177 | testingMetrics_tensor, classificationThreshold_float=0.5): 178 | log.info("E{} {}".format(epoch_ndx, type(self).__name__)) 179 | 180 | if epoch_ndx == 2: 181 | log_dir = os.path.join('runs', self.cli_args.tb_prefix, 182 | self.time_str) 183 | 184 | self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_' + 185 | self.cli_args.comment) 186 | self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_' + 187 | self.cli_args.comment) 188 | 189 | self.totalTrainingSamples_count += trainingMetrics_tensor.size(1) 190 | 191 | for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), 192 | ('tst', testingMetrics_tensor)]: 193 | metrics_ary = metrics_tensor.cpu().detach().numpy()[:, :, 0] 194 | assert np.isfinite(metrics_ary).all() 195 | 196 | benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= \ 197 | classificationThreshold_float 198 | benPred_mask = metrics_ary[METRICS_PRED_NDX] <= \ 199 | classificationThreshold_float 200 | 201 | malLabel_mask = ~benLabel_mask 202 | malPred_mask = ~benPred_mask 203 | 204 | benLabel_count = benLabel_mask.sum() 205 | malLabel_count = malLabel_mask.sum() 206 | 207 | trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum() 208 | truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum() 209 | 210 | falsePos_count = benLabel_count - benCorrect_count 211 | falseNeg_count = malLabel_count - malCorrect_count 212 | 213 | metrics_dict = {} 214 | 215 | metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean() 216 | metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, 217 | benLabel_mask].mean() 218 | metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, 219 | malLabel_mask].mean() 220 | 221 | metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) \ 222 | / metrics_ary.shape[1] * 100 223 | metrics_dict['correct/ben'] = benCorrect_count / benLabel_count * 100 224 | metrics_dict['correct/mal'] = malCorrect_count / malLabel_count * 100 225 | 226 | precision = metrics_dict['pr/precision'] = truePos_count / \ 227 | (truePos_count + falsePos_count) 228 | recall = metrics_dict['pr/recall'] = truePos_count / \ 229 | (truePos_count + falseNeg_count) 230 | 231 | metrics_dict['pr/f1_score'] = 2 * (precision * recall) / \ 232 | (precision + recall) 233 | 234 | log.info(("E{} {:8} {loss/all:.4f} loss, {correct/all:-5.1f}% " 235 | "correct").format(epoch_ndx, mode_str, **metrics_dict)) 236 | log.info(("E{} {:8} {loss/ben:.4f} loss, {correct/ben:-5.1f}% " 237 | "correct ({benCorrect_count:} of {benLabel_count:})")\ 238 | .format(epoch_ndx, 239 | mode_str + '_ben', 240 | benCorrect_count=benCorrect_count, 241 | benLabel_count=benLabel_count, 242 | **metrics_dict)) 243 | log.info(("E{} {:8} {loss/mal:.4f} loss, {correct/mal:-5.1f}% " 244 | "correct ({malCorrect_count:} of {malLabel_count:})")\ 245 | .format(epoch_ndx, 246 | mode_str + '_mal', 247 | malCorrect_count=malCorrect_count, 248 | malLabel_count=malLabel_count, 249 | **metrics_dict)) 250 | if epoch_ndx > 1: 251 | writer = getattr(self, mode_str + '_writer') 252 | 253 | for key, value in metrics_dict.items(): 254 | writer.add_scalar(key, value, 255 | self.totalTrainingSamples_count) 256 | 257 | writer.add_pr_curve( 258 | 'pr', 259 | metrics_ary[METRICS_LABEL_NDX], 260 | metrics_ary[METRICS_PRED_NDX], 261 | self.totalTrainingSamples_count 262 | ) 263 | 264 | benHist_mask = benLabel_mask & \ 265 | (metrics_ary[METRICS_PRED_NDX] > 0.01) 266 | malHist_mask = malLabel_mask & \ 267 | (metrics_ary[METRICS_PRED_NDX] < 0.99) 268 | bins = [x / 50.0 for x in range(51)] 269 | 270 | writer.add_histogram( 271 | 'is_ben', 272 | metrics_ary[METRICS_PRED_NDX, benHist_mask], 273 | self.totalTrainingSamples_count, 274 | bins=bins 275 | ) 276 | writer.add_histogram( 277 | 'is_mal', 278 | metrics_ary[METRICS_PRED_NDX, malHist_mask], 279 | self.totalTrainingSamples_count, 280 | bins=bins 281 | ) 282 | 283 | if __name__ == '__main__': 284 | sys.exit(LunaTrainingApp().main() or 0) 285 | -------------------------------------------------------------------------------- /p1ch4/1_tabular_wine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "torch.set_printoptions(edgeitems=2)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "array([[ 7. , 0.27, 0.36, ..., 0.45, 8.8 , 6. ],\n", 23 | " [ 6.3 , 0.3 , 0.34, ..., 0.49, 9.5 , 6. ],\n", 24 | " [ 8.1 , 0.28, 0.4 , ..., 0.44, 10.1 , 6. ],\n", 25 | " ...,\n", 26 | " [ 6.5 , 0.24, 0.19, ..., 0.46, 9.4 , 6. ],\n", 27 | " [ 5.5 , 0.29, 0.3 , ..., 0.38, 12.8 , 7. ],\n", 28 | " [ 6. , 0.21, 0.38, ..., 0.32, 11.8 , 6. ]], dtype=float32)" 29 | ] 30 | }, 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | } 35 | ], 36 | "source": [ 37 | "import csv\n", 38 | "wine_path = \"../data/p1ch4/tabular-wine/winequality-white.csv\"\n", 39 | "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\", skiprows=1)\n", 40 | "wineq_numpy" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "((4898, 12),\n", 52 | " ['fixed acidity',\n", 53 | " 'volatile acidity',\n", 54 | " 'citric acid',\n", 55 | " 'residual sugar',\n", 56 | " 'chlorides',\n", 57 | " 'free sulfur dioxide',\n", 58 | " 'total sulfur dioxide',\n", 59 | " 'density',\n", 60 | " 'pH',\n", 61 | " 'sulphates',\n", 62 | " 'alcohol',\n", 63 | " 'quality'])" 64 | ] 65 | }, 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "col_list = next(csv.reader(open(wine_path), delimiter=';'))\n", 73 | "\n", 74 | "wineq_numpy.shape, col_list" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "(torch.Size([4898, 12]), 'torch.FloatTensor')" 86 | ] 87 | }, 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "wineq = torch.from_numpy(wineq_numpy)\n", 95 | "\n", 96 | "wineq.shape, wineq.type()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "(tensor([[ 7.0000, 0.2700, ..., 0.4500, 8.8000],\n", 108 | " [ 6.3000, 0.3000, ..., 0.4900, 9.5000],\n", 109 | " ...,\n", 110 | " [ 5.5000, 0.2900, ..., 0.3800, 12.8000],\n", 111 | " [ 6.0000, 0.2100, ..., 0.3200, 11.8000]]), torch.Size([4898, 11]))" 112 | ] 113 | }, 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "data = wineq[:, :-1] # <1>\n", 121 | "data, data.shape" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 6, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "(tensor([6., 6., ..., 7., 6.]), torch.Size([4898]))" 133 | ] 134 | }, 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "target = wineq[:, -1] # <2>\n", 142 | "target, target.shape" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "tensor([6, 6, ..., 7, 6])" 154 | ] 155 | }, 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "target = wineq[:, -1].long()\n", 163 | "target" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "tensor([[0., 0., ..., 0., 0.],\n", 175 | " [0., 0., ..., 0., 0.],\n", 176 | " ...,\n", 177 | " [0., 0., ..., 0., 0.],\n", 178 | " [0., 0., ..., 0., 0.]])" 179 | ] 180 | }, 181 | "execution_count": 8, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "target_onehot = torch.zeros(target.shape[0], 10)\n", 188 | "\n", 189 | "target_onehot.scatter_(1, target.unsqueeze(1), 1.0)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 9, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "tensor([[6],\n", 201 | " [6],\n", 202 | " ...,\n", 203 | " [7],\n", 204 | " [6]])" 205 | ] 206 | }, 207 | "execution_count": 9, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "target_unsqueezed = target.unsqueeze(1)\n", 214 | "target_unsqueezed" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 10, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "tensor([6.8548e+00, 2.7824e-01, 3.3419e-01, 6.3914e+00, 4.5772e-02, 3.5308e+01,\n", 226 | " 1.3836e+02, 9.9403e-01, 3.1883e+00, 4.8985e-01, 1.0514e+01])" 227 | ] 228 | }, 229 | "execution_count": 10, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "data_mean = torch.mean(data, dim=0)\n", 236 | "data_mean" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 11, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "tensor([7.1211e-01, 1.0160e-02, 1.4646e-02, 2.5726e+01, 4.7733e-04, 2.8924e+02,\n", 248 | " 1.8061e+03, 8.9455e-06, 2.2801e-02, 1.3025e-02, 1.5144e+00])" 249 | ] 250 | }, 251 | "execution_count": 11, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "data_var = torch.var(data, dim=0)\n", 258 | "data_var" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 12, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "data": { 268 | "text/plain": [ 269 | "tensor([[ 1.7209e-01, -8.1764e-02, ..., -3.4914e-01, -1.3930e+00],\n", 270 | " [-6.5743e-01, 2.1587e-01, ..., 1.3467e-03, -8.2418e-01],\n", 271 | " ...,\n", 272 | " [-1.6054e+00, 1.1666e-01, ..., -9.6250e-01, 1.8574e+00],\n", 273 | " [-1.0129e+00, -6.7703e-01, ..., -1.4882e+00, 1.0448e+00]])" 274 | ] 275 | }, 276 | "execution_count": 12, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "data_normalized = (data - data_mean) / torch.sqrt(data_var)\n", 283 | "data_normalized" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 13, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "(torch.Size([4898]), torch.bool, tensor(20))" 295 | ] 296 | }, 297 | "execution_count": 13, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "bad_indexes = torch.le(target, 3)\n", 304 | "bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 14, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "torch.Size([20, 11])" 316 | ] 317 | }, 318 | "execution_count": 14, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "bad_data = data[bad_indexes]\n", 325 | "bad_data.shape" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 15, 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | " 0 fixed acidity 7.60 6.89 6.73\n", 338 | " 1 volatile acidity 0.33 0.28 0.27\n", 339 | " 2 citric acid 0.34 0.34 0.33\n", 340 | " 3 residual sugar 6.39 6.71 5.26\n", 341 | " 4 chlorides 0.05 0.05 0.04\n", 342 | " 5 free sulfur dioxide 53.33 35.42 34.55\n", 343 | " 6 total sulfur dioxide 170.60 141.83 125.25\n", 344 | " 7 density 0.99 0.99 0.99\n", 345 | " 8 pH 3.19 3.18 3.22\n", 346 | " 9 sulphates 0.47 0.49 0.50\n", 347 | "10 alcohol 10.34 10.26 11.42\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "bad_data = data[torch.le(target, 3)]\n", 353 | "mid_data = data[torch.gt(target, 3) & torch.lt(target, 7)] # <1>\n", 354 | "good_data = data[torch.ge(target, 7)]\n", 355 | "\n", 356 | "bad_mean = torch.mean(bad_data, dim=0)\n", 357 | "mid_mean = torch.mean(mid_data, dim=0)\n", 358 | "good_mean = torch.mean(good_data, dim=0)\n", 359 | "\n", 360 | "for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):\n", 361 | " print(\"{:2} {:20} {:6.2f} {:6.2f} {:6.2f}\".format(i, *args))" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 16, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "data": { 371 | "text/plain": [ 372 | "(torch.Size([4898]), torch.bool, tensor(2727))" 373 | ] 374 | }, 375 | "execution_count": 16, 376 | "metadata": {}, 377 | "output_type": "execute_result" 378 | } 379 | ], 380 | "source": [ 381 | "total_sulfur_threshold = 141.83\n", 382 | "total_sulfur_data = data[:, 6]\n", 383 | "predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)\n", 384 | "\n", 385 | "predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 17, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "data": { 395 | "text/plain": [ 396 | "(torch.Size([4898]), torch.bool, tensor(3258))" 397 | ] 398 | }, 399 | "execution_count": 17, 400 | "metadata": {}, 401 | "output_type": "execute_result" 402 | } 403 | ], 404 | "source": [ 405 | "actual_indexes = torch.gt(target, 5)\n", 406 | "\n", 407 | "actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 18, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "data": { 417 | "text/plain": [ 418 | "(2018, 0.74000733406674, 0.6193984039287906)" 419 | ] 420 | }, 421 | "execution_count": 18, 422 | "metadata": {}, 423 | "output_type": "execute_result" 424 | } 425 | ], 426 | "source": [ 427 | "n_matches = torch.sum(actual_indexes & predicted_indexes).item()\n", 428 | "n_predicted = torch.sum(predicted_indexes).item()\n", 429 | "n_actual = torch.sum(actual_indexes).item()\n", 430 | "\n", 431 | "n_matches, n_matches / n_predicted, n_matches / n_actual" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [] 440 | } 441 | ], 442 | "metadata": { 443 | "kernelspec": { 444 | "display_name": "Python 3", 445 | "language": "python", 446 | "name": "python3" 447 | }, 448 | "language_info": { 449 | "codemirror_mode": { 450 | "name": "ipython", 451 | "version": 3 452 | }, 453 | "file_extension": ".py", 454 | "mimetype": "text/x-python", 455 | "name": "python", 456 | "nbconvert_exporter": "python", 457 | "pygments_lexer": "ipython3", 458 | "version": "3.6.9" 459 | } 460 | }, 461 | "nbformat": 4, 462 | "nbformat_minor": 2 463 | } 464 | -------------------------------------------------------------------------------- /p1ch5/3_optimizers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "torch.set_printoptions(edgeitems=2)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0])\n", 23 | "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4])\n", 24 | "t_un = 0.1 * t_u" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def model(t_u, w, b):\n", 34 | " return w * t_u + b" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def loss_fn(t_p, t_c):\n", 44 | " squared_diffs = (t_p - t_c) ** 2\n", 45 | " return squared_diffs.mean()" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "['ASGD',\n", 57 | " 'Adadelta',\n", 58 | " 'Adagrad',\n", 59 | " 'Adam',\n", 60 | " 'AdamW',\n", 61 | " 'Adamax',\n", 62 | " 'LBFGS',\n", 63 | " 'Optimizer',\n", 64 | " 'RMSprop',\n", 65 | " 'Rprop',\n", 66 | " 'SGD',\n", 67 | " 'SparseAdam',\n", 68 | " '__builtins__',\n", 69 | " '__cached__',\n", 70 | " '__doc__',\n", 71 | " '__file__',\n", 72 | " '__loader__',\n", 73 | " '__name__',\n", 74 | " '__package__',\n", 75 | " '__path__',\n", 76 | " '__spec__',\n", 77 | " 'lr_scheduler']" 78 | ] 79 | }, 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "import torch.optim as optim\n", 87 | "\n", 88 | "dir(optim)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "params = torch.tensor([1.0, 0.0], requires_grad=True)\n", 98 | "learning_rate = 1e-5\n", 99 | "optimizer = optim.SGD([params], lr=learning_rate)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "tensor([ 9.5483e-01, -8.2600e-04], requires_grad=True)" 111 | ] 112 | }, 113 | "execution_count": 7, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "t_p = model(t_u, *params)\n", 120 | "loss = loss_fn(t_p, t_c)\n", 121 | "loss.backward()\n", 122 | "\n", 123 | "optimizer.step()\n", 124 | "\n", 125 | "params" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 8, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "tensor([1.7761, 0.1064], requires_grad=True)" 137 | ] 138 | }, 139 | "execution_count": 8, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "params = torch.tensor([1.0, 0.0], requires_grad=True)\n", 146 | "learning_rate = 1e-2\n", 147 | "optimizer = optim.SGD([params], lr=learning_rate)\n", 148 | "\n", 149 | "t_p = model(t_un, *params)\n", 150 | "loss = loss_fn(t_p, t_c)\n", 151 | "\n", 152 | "optimizer.zero_grad() # <1>\n", 153 | "loss.backward()\n", 154 | "optimizer.step()\n", 155 | "\n", 156 | "params" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 9, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def training_loop(n_epochs, optimizer, params, t_u, t_c):\n", 166 | " for epoch in range(1, n_epochs + 1):\n", 167 | " t_p = model(t_u, *params)\n", 168 | " loss = loss_fn(t_p, t_c)\n", 169 | " \n", 170 | " optimizer.zero_grad()\n", 171 | " loss.backward()\n", 172 | " optimizer.step()\n", 173 | " \n", 174 | " if epoch % 500 == 0:\n", 175 | " print(\"Epoch %d, Loss %f\" % (epoch, float(loss)))\n", 176 | " \n", 177 | " return params" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 10, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Epoch 500, Loss 7.860116\n", 190 | "Epoch 1000, Loss 3.828538\n", 191 | "Epoch 1500, Loss 3.092191\n", 192 | "Epoch 2000, Loss 2.957697\n", 193 | "Epoch 2500, Loss 2.933134\n", 194 | "Epoch 3000, Loss 2.928648\n", 195 | "Epoch 3500, Loss 2.927830\n", 196 | "Epoch 4000, Loss 2.927679\n", 197 | "Epoch 4500, Loss 2.927652\n", 198 | "Epoch 5000, Loss 2.927647\n" 199 | ] 200 | }, 201 | { 202 | "data": { 203 | "text/plain": [ 204 | "tensor([ 5.3671, -17.3012], requires_grad=True)" 205 | ] 206 | }, 207 | "execution_count": 10, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "params = torch.tensor([1.0, 0.0], requires_grad=True)\n", 214 | "learning_rate = 1e-2\n", 215 | "optimizer = optim.SGD([params], lr=learning_rate) # <1>\n", 216 | "\n", 217 | "training_loop(\n", 218 | " n_epochs=5000,\n", 219 | " optimizer=optimizer,\n", 220 | " params=params, # <1>\n", 221 | " t_u=t_un,\n", 222 | " t_c=t_c)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 11, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "Epoch 500, Loss 2.962301\n", 235 | "Epoch 1000, Loss 2.927645\n", 236 | "Epoch 1500, Loss 2.927646\n", 237 | "Epoch 2000, Loss 2.927646\n" 238 | ] 239 | }, 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "tensor([ 5.3677, -17.3048], requires_grad=True)" 244 | ] 245 | }, 246 | "execution_count": 11, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "params = torch.tensor([1.0, 0.0], requires_grad=True)\n", 253 | "learning_rate = 1e-1\n", 254 | "optimizer = optim.Adam([params], lr=learning_rate) # <1>\n", 255 | "\n", 256 | "training_loop(\n", 257 | " n_epochs=2000,\n", 258 | " optimizer=optimizer,\n", 259 | " params=params, # <2>\n", 260 | " t_u=t_un,\n", 261 | " t_c=t_c)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 12, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "data": { 271 | "text/plain": [ 272 | "(tensor([ 4, 10, 8, 2, 3, 1, 0, 9, 7]), tensor([6, 5]))" 273 | ] 274 | }, 275 | "execution_count": 12, 276 | "metadata": {}, 277 | "output_type": "execute_result" 278 | } 279 | ], 280 | "source": [ 281 | "n_samples = t_u.shape[0]\n", 282 | "n_val = int(0.2 * n_samples)\n", 283 | "\n", 284 | "shuffled_indices = torch.randperm(n_samples)\n", 285 | "\n", 286 | "train_indices = shuffled_indices[:-n_val]\n", 287 | "val_indices = shuffled_indices[-n_val:]\n", 288 | "\n", 289 | "train_indices, val_indices # <1>" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 13, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "train_t_u = t_u[train_indices]\n", 299 | "train_t_c = t_c[train_indices]\n", 300 | "\n", 301 | "val_t_u = t_u[val_indices]\n", 302 | "val_t_c = t_c[val_indices]\n", 303 | "\n", 304 | "train_t_un = 0.1 * train_t_u\n", 305 | "val_t_un = 0.1 * val_t_u" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 14, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n", 315 | " for epoch in range(1, n_epochs + 1):\n", 316 | " train_t_p = model(train_t_u, *params) # <1>\n", 317 | " train_loss = loss_fn(train_t_p, train_t_c)\n", 318 | " \n", 319 | " val_t_p = model(val_t_u, *params) # <1>\n", 320 | " val_loss = loss_fn(val_t_p, val_t_c)\n", 321 | " \n", 322 | " optimizer.zero_grad()\n", 323 | " train_loss.backward() # <2>\n", 324 | " optimizer.step()\n", 325 | " \n", 326 | " if epoch <= 3 or epoch % 500 == 0:\n", 327 | " print(\"Epoch {}, Training loss {}, Validation loss\".format(\n", 328 | " epoch, float(train_loss), float(val_loss)))\n", 329 | " \n", 330 | " return params" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 15, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "Epoch 1, Training loss 97.13150787353516, Validation loss\n", 343 | "Epoch 2, Training loss 39.57359313964844, Validation loss\n", 344 | "Epoch 3, Training loss 32.7770881652832, Validation loss\n", 345 | "Epoch 500, Training loss 8.47635269165039, Validation loss\n", 346 | "Epoch 1000, Training loss 3.964862823486328, Validation loss\n", 347 | "Epoch 1500, Training loss 3.106287956237793, Validation loss\n", 348 | "Epoch 2000, Training loss 2.942892074584961, Validation loss\n", 349 | "Epoch 2500, Training loss 2.9117960929870605, Validation loss\n", 350 | "Epoch 3000, Training loss 2.905879497528076, Validation loss\n" 351 | ] 352 | }, 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "tensor([ 5.4996, -18.1371], requires_grad=True)" 357 | ] 358 | }, 359 | "execution_count": 15, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "params = torch.tensor([1.0, 0.0], requires_grad=True)\n", 366 | "learning_rate = 1e-2\n", 367 | "optimizer = optim.SGD([params], lr=learning_rate)\n", 368 | "\n", 369 | "training_loop(\n", 370 | " n_epochs=3000,\n", 371 | " optimizer=optimizer,\n", 372 | " params=params,\n", 373 | " train_t_u=train_t_un, # <1>\n", 374 | " val_t_u=val_t_un, # <1>\n", 375 | " train_t_c=train_t_c, \n", 376 | " val_t_c=val_t_c)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 16, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n", 386 | " for epoch in range(1, n_epochs + 1):\n", 387 | " train_t_p = model(train_t_u, *params)\n", 388 | " train_loss = loss_fn(train_t_p, train_t_c)\n", 389 | " \n", 390 | " with torch.no_grad(): # <1>\n", 391 | " val_t_p = model(val_t_u, *params)\n", 392 | " val_loss = loss_fn(val_t_p, val_t_c)\n", 393 | " assert val_loss.requires_grad == False # <2>\n", 394 | " \n", 395 | " optimizer.zero_grad()\n", 396 | " train_loss.backward() # <2>\n", 397 | " optimizer.step()" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 17, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "def calc_forward(t_u, t_c, is_train):\n", 407 | " with torch.set_grad_enabled(is_train):\n", 408 | " t_p = model(t_u, *params)\n", 409 | " loss = loss_fn(t_p, t_c)\n", 410 | " return loss" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [] 419 | } 420 | ], 421 | "metadata": { 422 | "kernelspec": { 423 | "display_name": "Python 3", 424 | "language": "python", 425 | "name": "python3" 426 | }, 427 | "language_info": { 428 | "codemirror_mode": { 429 | "name": "ipython", 430 | "version": 3 431 | }, 432 | "file_extension": ".py", 433 | "mimetype": "text/x-python", 434 | "name": "python", 435 | "nbconvert_exporter": "python", 436 | "pygments_lexer": "ipython3", 437 | "version": "3.6.9" 438 | } 439 | }, 440 | "nbformat": 4, 441 | "nbformat_minor": 2 442 | } 443 | -------------------------------------------------------------------------------- /p2ch13/diagnose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import scipy.ndimage.measurements as measure 8 | import scipy.ndimage.morphology as morph 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | from util.util import enumerateWithEstimate 17 | from util.logconf import logging 18 | from util.util import irc2xyz 19 | 20 | from .dsets import LunaDataset 21 | from .dsets import Luna2dSegmentationDataset 22 | from .dsets import getCt 23 | from .dsets import getNoduleInfoList 24 | from .dsets import NoduleInfoTuple 25 | from .model_cls import LunaModel 26 | from .model_seg import UnetWrapper 27 | 28 | log = logging.getLogger(__name__) 29 | # log.setLevel(logging.WARN) 30 | # log.setLevel(logging.INFO) 31 | log.setLevel(logging.DEBUG) 32 | 33 | 34 | class LunaDiagnoseApp: 35 | def __init__(self, sys_argv=None): 36 | if sys_argv is None: 37 | log.debug(sys.argv) 38 | sys_argv = sys.argv[1:] 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | '--batch-size', 43 | help="batch size to use for training", 44 | default=4, 45 | type=int 46 | ) 47 | parser.add_argument( 48 | '--num-workers', 49 | help="number of worker processes for background data loading", 50 | default=8, 51 | type=int 52 | ) 53 | parser.add_argument( 54 | '--series-uid', 55 | help="limit inference to this Series UID only", 56 | default=None, 57 | type=str 58 | ) 59 | parser.add_argument( 60 | '--include-train', 61 | help=("include data that was in the training set (default: " 62 | "validation only)"), 63 | action='store_true', 64 | default=False, 65 | ) 66 | parser.add_argument( 67 | '--segmentation-path', 68 | help="path to the saved segmentation model", 69 | nargs='?', 70 | default=None, 71 | ) 72 | parser.add_argument( 73 | '--classification-path', 74 | help="path to the saved classification model", 75 | nargs='?', 76 | default=None, 77 | ) 78 | parser.add_argument( 79 | '--tb-prefix', 80 | help=("data prefix to use for TensorBoard run. Defaults to" 81 | "chapter"), 82 | default='p2ch13', 83 | ) 84 | 85 | self.cli_args = parser.parse_args(sys_argv) 86 | 87 | self.use_cuda = torch.cuda.is_available() 88 | self.device = torch.device('cuda' if self.use_cuda else 'cpu') 89 | 90 | if not self.cli_args.segmentation_path: 91 | self.cli_args.segmentation_path = self.initModelPath('seg') 92 | 93 | if not self.cli_args.classification_path: 94 | self.cli_args.classification_path = self.initModelPath('cls') 95 | 96 | self.seg_model, self.cls_model = self.initModels() 97 | 98 | def initModelPath(self, type_str): 99 | local_path = os.path.join( 100 | 'data-unversioned', 101 | 'part2', 102 | 'models', 103 | self.cli_args.tb_prefix, 104 | type_str + '_{}_{}.{}.state'.format('*', '*', 'best') 105 | ) 106 | 107 | file_list = glob.glob(local_path) 108 | if not file_list: 109 | pretrained_path = os.path.join( 110 | 'data', 111 | 'part2', 112 | 'models', 113 | type_str + '_{}_{}.{}.state'.format('*', '*', '*') 114 | ) 115 | file_list = glob.glob(pretrained_path) 116 | else: 117 | pretrained_path = None 118 | 119 | file_list.sort() 120 | 121 | try: 122 | 123 | return file_list[-1] 124 | except IndexError: 125 | log.debug([local_path, pretrained_path, file_list]) 126 | raise 127 | 128 | def initModels(self): 129 | log.debug(self.cli_args.segmentation_path) 130 | seg_dict = torch.load(self.cli_args.segmentation_path) 131 | 132 | seg_model = UnetWrapper( 133 | in_channels=8, 134 | n_classes=1, 135 | depth=4, 136 | wf=3, 137 | padding=True, 138 | batch_norm=True, 139 | up_mode='upconv' 140 | ) 141 | seg_model.load_state_dict(seg_dict['model_state']) 142 | seg_model.eval() 143 | 144 | log.debug(self.cli_args.classification_path) 145 | cls_dict = torch.load(self.cli_args.classification_path) 146 | 147 | cls_model = LunaModel() 148 | cls_model.load_state_dict(cls_dict['model_state']) 149 | cls_model.eval() 150 | 151 | if self.use_cuda: 152 | if torch.cuda.device_count() > 1: 153 | seg_model = nn.DataParallel(seg_model) 154 | cls_model = nn.DataParallel(cls_model) 155 | 156 | seg_model = seg_model.to(self.device) 157 | cls_model = cls_model.to(self.device) 158 | 159 | return seg_model, cls_model 160 | 161 | def initSegmentationDl(self, series_uid): 162 | seg_ds = Luna2dSegmentationDataset( 163 | contextSlices_count=3, 164 | series_uid=series_uid, 165 | fullCt_bool=True 166 | ) 167 | seg_dl = DataLoader( 168 | seg_ds, 169 | batch_size=self.cli_args.batch_size * \ 170 | (torch.cuda.device_count() if self.use_cuda else 1), 171 | num_workers=self.cli_args.num_workers, 172 | pin_memory=self.use_cuda 173 | ) 174 | 175 | return seg_ds 176 | 177 | def initClassificationDl(self, noduleInfo_list): 178 | cls_ds = LunaDataset( 179 | series_uid=series_uid, 180 | noduleInfo_list=noduleInfo_list 181 | ) 182 | cls_dl = DataLoader( 183 | cls_ds, 184 | batch_size=self.cli_args.batch_size * \ 185 | (torch.cuda.device_count() if self.use_cuda else 1), 186 | num_workers=self.cli_args.num_workers, 187 | pin_memory=self.use_cuda 188 | ) 189 | 190 | return cls_dl 191 | 192 | def main(self): 193 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 194 | 195 | val_ds = LunaDataset( 196 | val_stride=10, 197 | isValSet_bool=True 198 | ) 199 | 200 | val_set = set( 201 | noduleInfo_tup.series_uid 202 | for noduleInfo_tup in val_ds.noduleInfo_list 203 | ) 204 | malignant_set = set( 205 | noduleInfo_tup.series_uid 206 | for noduleInfo_tup in getNoduleInfoList() 207 | if noduleInfo_tup.isMalignant_bool 208 | ) 209 | 210 | if self.cli_args.series_uid: 211 | series_set = set(self.cli_args.series_uid.split(',')) 212 | else: 213 | series_set = set( 214 | noduleInfo_tup.series_uid 215 | for noduleInfo_tup in getNoduleInfoList() 216 | ) 217 | 218 | train_list = sorted(series_set - val_set) if \ 219 | self.cli_args.include_train else [] 220 | val_list = sorted(series_set & val_set) 221 | 222 | noduleInfo_list = [] 223 | series_iter = enumerateWithEstimate( 224 | val_list + train_list, 225 | 'Series' 226 | ) 227 | for _, series_uid, in series_iter: 228 | ct, _, _, clean_a = self.segmentCt(series_uid) 229 | 230 | noduleInfo_list += self.clusterSegmentationOutput( 231 | series_uid, 232 | ct, 233 | clean_a 234 | ) 235 | 236 | cls_dl = self.initClassificationDl(noduleInfo_list) 237 | 238 | series2diagnosis_dict = {} 239 | batch_iter = enumerateWithEstimate( 240 | cls_dl, 241 | "Cls all", 242 | start_ndx=cls_dl.num_workers 243 | ) 244 | for batch_ndx, batch_tup in batch_iter: 245 | input_t, _, series_list, center_list = batch_tup 246 | 247 | input_g = input_t.to(self.device) 248 | with torch.no_grad(): 249 | _, probability_g = self.cls_model(input_g) 250 | 251 | classification_list = zip( 252 | series_list, 253 | center_list, 254 | probability_g[:, 1].to('cpu') 255 | ) 256 | for cls_tup in classification_list: 257 | series_uid, center_irc, probability_t = cls_tup 258 | probability_float = probability_t.item() 259 | 260 | this_tup = (probability_float, tuple(center_irc)) 261 | current_tup = series2diagnosis_dict.get(series_uid, 262 | this_tup) 263 | try: 264 | assert np.all(np.isfinite(tuple(center_irc))) 265 | if this_tup > current_tup: 266 | log.debug([series_uid, this_tup]) 267 | # This part is to cover the eventuality that 268 | # the same series_uid is repeted multiple 269 | # times 270 | series2diagnosis_dict[series_uid] = \ 271 | max(this_tup, current_tup) 272 | except: 273 | log.debug([(type(x), x) for x in this_tup] + 274 | [(type(x), x) for x in current_tup]) 275 | raise 276 | 277 | log.info('Training set:') 278 | self.logResults('Training', train_list, series2diagnosis_dict, 279 | malignant_set) 280 | 281 | log.info('Validation set:') 282 | self.logResults('Validation', val_list, series2diagnosis_dict, 283 | malignant_set) 284 | 285 | def segmentCt(self, series_uid): 286 | with torch.no_grad(): 287 | ct = getCt(series_uid) 288 | 289 | output_a = np.zeros_like(ct.hu_a, dtype=np.float32) 290 | 291 | seg_dl = self.initSegmentationDl(series_uid) 292 | for batch_tup in seg_dl: 293 | input_t = batch_tup[0] 294 | ndx_list = batch_tup[6] 295 | 296 | input_g = input_t.to(self.device) 297 | prediction_g = self.seg_model(input_g) 298 | for i, sample_ndx in enumerate(ndx_list): 299 | output_a[sample_ndx] = prediction_g[i].cpu().numpy() 300 | 301 | mask_a = output_a > 0.5 302 | clean_a = morph.binary_erosion(mask_a, iterations=1) 303 | clean_a = morph.binary_dilation(clean_a, iterations=2) 304 | 305 | return ct, output_a, mask_a, clean_a 306 | 307 | def clusterSegmentationOutput(self, series_uid, ct, clean_a): 308 | # Assign a different label to each group sconnected to 309 | # the others 310 | noduleLabel_a, nodule_count = measure.label(clean_a) 311 | centerIrc_list = measure.center_of_mass( 312 | ct.hu_a + 1001, 313 | labels=noduleLabel_a, 314 | # This part is probably redundant 315 | index=list(range(1, nodule_count + 1)) 316 | ) 317 | 318 | noduleInfo_list = [] 319 | for i, center_irc in enumerate(centerIrc_list): 320 | center_xyz = irc2xyz( 321 | center_irc, 322 | ct.origin_xyz, 323 | ct.vxSize_xyz, 324 | ct.direction_tup 325 | ) 326 | assert np.all(np.isfinite(center_irc)), \ 327 | repr(['irc', center_irc, i, nodule_count]) 328 | assert np.all(np.isfinite(center_xyz), \ 329 | repr(['xyz', center_xyz])) 330 | noduleInfo_tup = \ 331 | NoduleInfoTuple(False, 0.0, series_uid, center_xyz) 332 | noduleInfo_list.append(noduleInfo_tup) 333 | 334 | return noduleInfo_list 335 | 336 | def logResults(self, mode_str, filtered_list, series2diagnosis_dict, 337 | malignant_set): 338 | count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} 339 | for series_uid in filtered_list: 340 | probability_float, center_irc = series2diagnosis_dict(series_uid, 341 | (0.0, None)) 342 | if center_irc is not None: 343 | center_irc = tuple(int(x.item()) for x in center_irc) 344 | 345 | malignant_bool = series_uid in malignant_set 346 | prediction_bool = probability_float > 0.5 347 | correct_bool = malignant_bool == prediction_bool 348 | 349 | if malignant_bool and prediction_bool: 350 | count_dict['tp'] += 1 351 | elif not malignant_bool and not prediction_bool: 352 | count_dict['tn'] += 1 353 | elif not malignant_bool and prediction_bool: 354 | count_dict['fp'] += 1 355 | else: 356 | count_dict['fn'] += 1 357 | 358 | log.info(("{} {} " 359 | "Mal:{!r:5} " 360 | "Pred:{!r:5} " 361 | "Correct?:{!r:5} " 362 | "Value:{!r:5} {}").format(mode_str, 363 | series_uid, 364 | malignant_bool, 365 | prediction_bool, 366 | correct_bool, 367 | probability_float, 368 | center_irc)) 369 | 370 | total_count = sum(count_dict.values()) 371 | percent_dict = {k: v / (total_count or 1) * 100 for k, v in 372 | count_dict.items()} 373 | 374 | precision = percent_dict['p'] = \ 375 | count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1) 376 | recall = percent_dict['r'] = \ 377 | count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1) 378 | percent_dict['f1'] = \ 379 | 2 * (precision + recall) / ((precision + recall) or 1) 380 | 381 | log.info(mode_str + ("tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, " 382 | "fn:{fn:.1f}%").format(**percent_dict)) 383 | log.info(mode_str + ("precision:{p:.3f}, recall:{r.3f}, " 384 | "F1:{f1.3f}").format(**percent_dict)) 385 | 386 | 387 | if __name__ == '__main__': 388 | sys.exit(LunaDiagnoseApp().main() or 0) 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | -------------------------------------------------------------------------------- /p2ch12/training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | 8 | from tensorboardX import SummaryWriter 9 | 10 | import torch 11 | from torch import nn 12 | from torch.optim import SGD 13 | from torch.utils.data import DataLoader 14 | 15 | from util.util import enumerateWithEstimate 16 | from util.logconf import logging 17 | from .dsets import LunaDataset 18 | from .model import LunaModel 19 | 20 | log = logging.getLogger(__name__) 21 | # log.setLevel(logging.WARN) 22 | log.setLevel(logging.INFO) 23 | # log.setLevel(logging.DEBUG) 24 | 25 | # Used for computeBatchLoss and logMetrics to index into 26 | # metrics_g/metrics_ary 27 | METRICS_LABEL_NDX = 0 28 | METRICS_PRED_NDX = 1 29 | METRICS_LOSS_NDX = 2 30 | METRICS_SIZE = 3 31 | 32 | class LunaTrainingApp(): 33 | def __init__(self, sys_argv=None): 34 | if sys_argv is None: 35 | sys_argv = sys.argv[1:] 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | '--batch-size', 40 | help="Batch size to use for training", 41 | default=32, 42 | type=int 43 | ) 44 | parser.add_argument( 45 | '--num-workers', 46 | help="Number of worker processes for background data loading", 47 | default=8, 48 | type=int 49 | ) 50 | parser.add_argument( 51 | '--epochs', 52 | help="Number of epochs to train for", 53 | default=1, 54 | type=int 55 | ) 56 | parser.add_argument( 57 | '--balanced', 58 | help="Balance the training data to half benign, half malignant.", 59 | action='store_true', 60 | default=False 61 | ) 62 | parser.add_argument( 63 | '--augmented', 64 | help="Augment the training data.", 65 | action='store_true', 66 | default=False 67 | ) 68 | parser.add_argument( 69 | '--augment-flip', 70 | help=("Augment the training data by randomly flipping the data " 71 | "left-right, up-down and front-back."), 72 | action='store_true', 73 | default=False 74 | ) 75 | parser.add_argument( 76 | '--augment-offset', 77 | help=("Augment the training data by randomly offsetting the data " 78 | "slightly along the X and Y axes."), 79 | action='store_true', 80 | default=False 81 | ) 82 | parser.add_argument( 83 | '--augment-scale', 84 | help=("Augment the training data by randomly increasing or " 85 | "decreasing the size of the nodule."), 86 | action='store_true', 87 | default=False 88 | ) 89 | parser.add_argument( 90 | '--augment-rotate', 91 | help=("Augment the training data by randomly rotating the data " 92 | "around the head-foot axis."), 93 | action='store_true', 94 | default=False 95 | ) 96 | parser.add_argument( 97 | '--augment-noise', 98 | help=("Augment the training data by randomly adding noise to the " 99 | "data."), 100 | action='store_true', 101 | default=False 102 | ) 103 | parser.add_argument( 104 | '--tb-prefix', 105 | help=("Data prefix to use for Tensorboard run. Defaults to " 106 | "chapter."), 107 | default='p2ch12' 108 | ) 109 | parser.add_argument( 110 | 'comment', 111 | help="Comment suffix for Tensorboard run.", 112 | nargs='?', 113 | default='none' 114 | ) 115 | 116 | self.cli_args = parser.parse_args(sys_argv) 117 | self.time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 118 | 119 | self.trn_writer = None 120 | self.val_writer = None 121 | self.totalTrainingSamples_count = 0 122 | 123 | self.augmentation_dict = {} 124 | if self.cli_args.augmented or self.cli_args.augment_flip: 125 | self.augmentation_dict['flip'] = True 126 | if self.cli_args.augmented or self.cli_args.augment_offset: 127 | self.augmentation_dict['offset'] = 0.1 128 | if self.cli_args.augmented or self.cli_args.augment_scale: 129 | self.augmentation_dict['scale'] = 0.2 130 | if self.cli_args.augmented or self.cli_args.augment_rotate: 131 | self.augmentation_dict['rotate'] = True 132 | if self.cli_args.augmented or self.cli_args.augment_noise: 133 | self.augmentation_dict['noise'] = 25.0 134 | 135 | self.use_cuda = torch.cuda.is_available() 136 | self.device = torch.device("cuda" if self.use_cuda else "cpu") 137 | 138 | self.model = self.initModel() 139 | self.optimizer self.initOptimizer() 140 | 141 | def initModel(self): 142 | model = LunaModel() 143 | 144 | if self.use_cuda: 145 | log.info("Using CUDA with {} devices."\ 146 | .format(torch.cuda.device_count())) 147 | if torch.cuda.device_count() > 1: 148 | model = nn.DataParallel(model) 149 | model = model.to(self.device) 150 | 151 | return model 152 | 153 | def initOptimizer(self): 154 | return SGD(self.model.parameters(), lr=0.001, momentum=0.99) 155 | 156 | def initTrainDl(self): 157 | train_ds = LunaDataset( 158 | val_stride=10, 159 | isValSet_bool=False, 160 | ratio_int=int(self.cli_args.balanced), 161 | augmentation_dict=self.augmentation_dict 162 | ) 163 | 164 | train_dl = DataLoader( 165 | train_ds, 166 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if 167 | self.use_cuda else 1), 168 | num_workers=self.cli_args.num_workers, 169 | pin_memory=self.use_cuda 170 | ) 171 | 172 | return train_dl 173 | 174 | def initValDl(self): 175 | val_ds = LunaDataset(val_stride=10, isValSet_bool=True) 176 | 177 | val_dl = DataLoader( 178 | val_ds, 179 | batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if 180 | self.use_cuda else 1), 181 | num_workers=self.cli_args.num_workers, 182 | pin_memory=self.use_cuda 183 | ) 184 | 185 | return val_dl 186 | 187 | def initTensorboardWriters(self): 188 | if self.trn_writer is None: 189 | log_dir = os.path.join('runs', self.cli_args.tb_prefix, 190 | self.time_str) 191 | 192 | self.trn_writer = SummaryWriter(log_dir=log_dir + '-trn_cls-' + 193 | self.cli_args.comment) 194 | self.val_writer = SummaryWriter(log_dir=log_dir + '-val_cls-' + 195 | self.cli_args.comment) 196 | 197 | 198 | def main(self): 199 | log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) 200 | 201 | train_dl = self.initTrainDl() 202 | val_dl = self.initValDl() 203 | 204 | for epoch_ndx in range(1, self.cli_args.epochs + 1): 205 | log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( 206 | epoch_ndx, 207 | self.cli_args.epochs, 208 | len(train_dl), 209 | len(val_dl), 210 | self.cli_args.batch_size, 211 | (torch.cuda.device_count() if self.use_cuda else 1) 212 | )) 213 | 214 | 215 | trnMetrics_t = self.doTraining(epoch_ndx, train_dl) 216 | self.logMetrics(epoch_ndx, 'trn', trnMetrics_t) 217 | 218 | valMetrics_t = self.doValidation(epoch_ndx, val_dl) 219 | self.logMetrics(epoch_ndx, 'val', valMetrics_t) 220 | 221 | if hasattr(self, 'trn_writer'): 222 | self.trn_writer.close() 223 | self.val_writer.close() 224 | 225 | def doTraining(self, epoch_ndx, train_dl): 226 | self.model.train() 227 | train_dl.dataset.shuffleSamples() 228 | trnMetrics_g = torch.zeros(METRICS_SIZE, 229 | len(train_dl.dataset)).to(self.device) 230 | 231 | batch_iter = enumerateWithEstimate( 232 | train_dl, 233 | "E{} Traning".format(epoch_ndx), 234 | start_ndx=train_dl.num_workers 235 | ) 236 | 237 | for batch_ndx, batch_tup in batch_iter: 238 | self.optimizer.zero_grad() 239 | loss_var = self.computeBatchLoss(batch_ndx, batch_tup, 240 | train_dl.batch_size, 241 | trnMetrics_g) 242 | loss_var.backward() 243 | self.optimizer.step() 244 | del loss_var 245 | 246 | self.totalTrainingSamples_count += len(train_dl.dataset) 247 | 248 | return trnMetrics_g.to('cpu') 249 | 250 | def doValidation(self, epoch_ndx, val_dl): 251 | with torch.no_grad(): 252 | self.model.eval() 253 | valMetrics_g = torch.zeros(METRICS_SIZE, 254 | len(val_dl.dataset)).to(self.device) 255 | batch_iter = enumerateWithEstimate( 256 | val_dl, 257 | "E{} Validation".format(epoch_ndx), 258 | start_ndx=val_dl.num_workers 259 | ) 260 | for batch_ndx, batch_tup in batch_iter: 261 | self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, 262 | valMetrics_g) 263 | 264 | return valMetrics_g.to('cpu') 265 | 266 | 267 | 268 | def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g): 269 | input_t, label_t, _series_list, _center_list = batch_tup 270 | 271 | input_g = input_t.to(self.device, non_blocking=True) 272 | label_g = label_t.to(self.device, non_blocking=True) 273 | 274 | logits_g, probability_g = self.model(input_g) 275 | 276 | loss_func = nn.CrossEntropyLoss(reduction='none') 277 | loss_g = loss_func(logits_g, label_g[:, 1]) 278 | 279 | start_ndx = batch_ndx * batch_size 280 | end_ndx = start_ndx + label_t.size(0) 281 | 282 | metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:, 1] 283 | metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:, 1] 284 | metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g 285 | 286 | return loss_g.mean() 287 | 288 | def logMetrics(self, epoch_ndx, mode_str, metrics_t): 289 | self.initTensorboardWriters() 290 | log.info("E{} {}".format(epoch_ndx, type(self).__name__)) 291 | 292 | benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5 293 | benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5 294 | 295 | malLabel_mask = ~benLabel_mask 296 | malPred_mask = ~benPred_mask 297 | 298 | ben_count = int(benLabel_mask.sum()) 299 | mal_count = int(malLabel_mask.sum()) 300 | 301 | trueNeg_count = ben_correct = (benLabel_mask & benPred_mask).sum() 302 | truePos_count = mal_correct = (malLabel_mask & malPred_mask).sum() 303 | 304 | falsePos_count = ben_count - ben_correct 305 | falseNeg_count = mal_count - mal_correct 306 | 307 | metrics_dict = {} 308 | 309 | metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean() 310 | metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, 311 | benLabel_mask].mean() 312 | metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, 313 | malLabel_mask].mean() 314 | 315 | metrics_dict['correct/all'] = (mal_correct + ben_correct) \ 316 | / metrics_ary.shape[1] * 100 317 | metrics_dict['correct/ben'] = ben_correct / ben_count * 100 318 | metrics_dict['correct/mal'] = mal_correct / mal_count * 100 319 | 320 | precision = metrics_dict['pr/precision'] = truePos_count / \ 321 | np.float64(truePos_count + falsePos_count) 322 | recall = metrics_dict['pr/recall'] = truePos_count / \ 323 | np.float64(truePos_count + falseNeg_count) 324 | 325 | metrics_dict['pr/f1_score'] = 2 * (precision * recall) / \ 326 | (precision + recall) 327 | 328 | log.info(("E{} {:8} {loss/all:.4f} loss, " 329 | "{correct/all:-5.1f}% correct, " 330 | "{pr/precision:.4f} recall, " 331 | "{pr/f1_score:.4f} f1 score").format(epoch_ndx, 332 | mode_str, 333 | **metrics_dict)) 334 | log.info(("E{} {:8} {loss/ben:.4f} loss, {correct/ben:-5.1f}% " 335 | "correct ({ben_correct:} of {ben_count:})")\ 336 | .format(epoch_ndx, 337 | mode_str + '_ben', 338 | ben_correct=ben_correct, 339 | ben_count=ben_count, 340 | **metrics_dict)) 341 | log.info(("E{} {:8} {loss/mal:.4f} loss, {correct/mal:-5.1f}% " 342 | "correct ({mal_correct:} of {mal_count:})")\ 343 | .format(epoch_ndx, 344 | mode_str + '_mal', 345 | mal_correct=mal_correct, 346 | mal_count=mal_count, 347 | **metrics_dict)) 348 | 349 | writer = getattr(self, mode_str + '_writer') 350 | 351 | for key, value in metrics_dict.items(): 352 | writer.add_scalar(key, value, self.totalTrainingSamples_count) 353 | 354 | writer.add_pr_curve( 355 | 'pr', 356 | metrics_ary[METRICS_LABEL_NDX], 357 | metrics_ary[METRICS_PRED_NDX], 358 | self.totalTrainingSamples_count 359 | ) 360 | 361 | bins = [x / 50.0 for x in range(51)] 362 | benHist_mask = benLabel_mask & \ 363 | (metrics_ary[METRICS_PRED_NDX] > 0.01) 364 | malHist_mask = malLabel_mask & \ 365 | (metrics_ary[METRICS_PRED_NDX] < 0.99) 366 | 367 | if benHist_mask.any(): 368 | writer.add_histogram( 369 | 'is_ben', 370 | metrics_ary[METRICS_PRED_NDX, benHist_mask], 371 | self.totalTrainingSamples_count, 372 | bins=bins 373 | ) 374 | if malHist_mask.any(): 375 | writer.add_histogram( 376 | 'is_mal', 377 | metrics_ary[METRICS_PRED_NDX, malHist_mask], 378 | self.totalTrainingSamples_count, 379 | bins=bins 380 | ) 381 | 382 | 383 | if __name__ == '__main__': 384 | LunaTrainingApp().main() 385 | --------------------------------------------------------------------------------