├── models
├── __init__.py
├── layers.py
├── cnn_res.py
└── net_sphere.py
├── searchspace
├── __init__.py
├── res_search_space.py
└── search_space_utils.py
├── configuration
├── __init__.py
└── config.py
├── generalnetwork
├── __init__.py
└── dpn3d.py
├── _config.yml
├── data
├── model.npy
├── list3.2.csv
├── dataanalysis.py
├── kappatest.py
├── dimcls.py
├── data_enhancement.py
├── dataconverter.py
├── nodclsgbt.py
├── humanperformance.py
├── pthumanperformance.py
└── extclsshpinfo.py
├── imgs
└── architecture.png
├── requirements.txt
├── .idea
├── dictionaries
│ └── 10485.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── misc.xml
├── NAS-Lung.iml
├── remote-mappings.xml
└── deployment.xml
├── run_training.sh
├── LICENSE
├── search_main.py
├── .gitignore
├── dataloader.py
├── README.md
├── test.py
├── random_forest.py
├── data_utils.py
├── main.py
├── mainsp.py
└── transforms.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/searchspace/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configuration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/generalnetwork/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
--------------------------------------------------------------------------------
/data/model.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/data/model.npy
--------------------------------------------------------------------------------
/data/list3.2.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/data/list3.2.csv
--------------------------------------------------------------------------------
/imgs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/imgs/architecture.png
--------------------------------------------------------------------------------
/configuration/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class config:
4 | def __init__(self, channel_range):
--------------------------------------------------------------------------------
/data/dataanalysis.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | a = pd.read_csv('../data/annotationdetclsconvfnl_v3.csv')
4 | path = 'F:\\医学数据集\\LUNA\\rowfile\\subset5'
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.1
2 | pandas>=0.25.0
3 | matplotlib>=3.1.0
4 | scikit-learn>=0.22.1
5 | torch>=0.4.1.post2
6 | dill>=0.3.0
7 | scipy>=1.4.1
8 | seaborn>=0.10.0
9 | Pillow>=7.0.0
--------------------------------------------------------------------------------
/.idea/dictionaries/10485.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | cbam
5 | learnable
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/run_training.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 |
5 | maxeps=9
6 |
7 |
8 | for (( i=0; i<=$maxeps; i+=1))
9 | do
10 | echo "process $i epoch"
11 | CUDA_VISIBLE_DEVICES=0 python main.py --batch_size 8 --num_epochs 400 --fold $i
12 | done
13 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/NAS-Lung.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/data/kappatest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | pp = [0]*4#[111]*4#[0]*4
4 | nn = [0]*4#[106]*4#[0]*4
5 | pn = [0]*4#[12]*4#[0]*4
6 | npp = [0]*4#[47]*4#[0]*4
7 |
8 | for fd in [1,2,3,5]:#xrange(1,4,1):
9 | dctprd = np.load('/media/data1/wentao/tianchi/luna16/CSVFILES/dctptlabel'+str(fd)+'.npy').item()
10 | for d in range(4):
11 | modprd = np.load('modprd'+str(d+1)+'fd'+str(fd)+'.npy').item()
12 | for k, v in dctprd.iteritems():
13 | if v[d] != -1:
14 | assert v[d] in [0, 1]
15 | if v[d] == 1 and modprd[k] == 1: pp[d] += 1
16 | if v[d] == 0 and modprd[k] == 0: nn[d] += 1
17 | if v[d] == 1 and modprd[k] == 0: pn[d] += 1
18 | if v[d] == 0 and modprd[k] == 1: npp[d] += 1
19 | print(pp, nn, pn, npp)
20 | for d in range(4):
21 | n = pp[d] + nn[d] + pn[d] + npp[d]
22 | p0 = (pp[d] + nn[d]) / float(n)
23 | pe = (pp[d] + pn[d]) * (pp[d] + npp[d])
24 | pe += (nn[d] + pn[d]) * (nn[d] + npp[d])
25 | pe /= float(n * n)
26 | print((p0-pe)/(1-pe))
27 |
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Fei
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/generalnetwork/dpn3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 |
7 | class Bottleneck(nn.Module):
8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
9 | super(Bottleneck, self).__init__()
10 | self.out_planes = out_planes
11 | self.dense_depth = dense_depth
12 | self.last_planes = last_planes
13 | self.in_planes = in_planes
14 |
15 | self.conv1 = nn.Conv3d(last_planes, in_planes, kernel_size=1, bias=False)
16 | self.bn1 = nn.BatchNorm3d(in_planes)
17 | self.conv2 = nn.Conv3d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
18 | self.bn2 = nn.BatchNorm3d(in_planes)
19 | self.conv3 = nn.Conv3d(in_planes, out_planes + dense_depth, kernel_size=1, bias=False)
20 | self.bn3 = nn.BatchNorm3d(out_planes + dense_depth)
21 |
22 | self.shortcut = nn.Sequential()
23 | if first_layer:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv3d(last_planes, out_planes + dense_depth, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm3d(out_planes + dense_depth)
27 | )
28 |
29 | def forward(self, x):
30 | # print 'bottleneck_0', x.size(), self.last_planes, self.in_planes, 1
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | # print 'bottleneck_1', out.size(), self.in_planes, self.in_planes, 3
33 | out = F.relu(self.bn2(self.conv2(out)))
34 | # print 'bottleneck_2', out.size(), self.in_planes, self.out_planes+self.dense_depth, 1
35 | out = self.bn3(self.conv3(out))
36 | # print 'bottleneck_3', out.size()
37 | x = self.shortcut(x)
38 | d = self.out_planes
39 | # print 'bottleneck_4', x.size(), self.last_planes, self.out_planes+self.dense_depth, d
40 | out = torch.cat([x[:, :d, :, :] + out[:, :d, :, :], x[:, d:, :, :], out[:, d:, :, :]], 1)
41 | # print 'bottleneck_5', out.size()
42 | out = F.relu(out)
43 | return out
44 |
45 |
--------------------------------------------------------------------------------
/search_main.py:
--------------------------------------------------------------------------------
1 | import searchspace.res_search_space as res_search_space
2 | import torch.nn as nn
3 | import logging
4 | import torch
5 | import argparse
6 |
7 | # set args
8 | parser = argparse.ArgumentParser(description='searching')
9 | # parser.add_argument('--sub', type=int, default=5, help="sub data set")
10 | parser.add_argument('--fold', type=int, default=5, help="fold")
11 | parser.add_argument('--gpu_id', type=str, default='0', help="gpu_id")
12 | parser.add_argument('--lr', type=float, default=0.0002, help="lr")
13 | parser.add_argument('--epoch', type=int, default=20, help="epoch")
14 | parser.add_argument('--num_workers', type=int, default=20, help="num_workers")
15 | parser.add_argument('--train_data_path', type=str, default='/data/xxx/LUNA/cls/crop_v3', help="train_data_path")
16 | parser.add_argument('--test_data_path', type=str, default='/data/xxx/LUNA/rowfile/subset', help="test_data_path")
17 | parser.add_argument('--batch_size', type=int, default=8, help="batch_size")
18 | parser.add_argument('--max_depth', type=int, default=9, help="max_depth")
19 | parser.add_argument('--min_depth', type=int, default=3, help="min_depth")
20 | parser.add_argument('--save_module_path', type=str, default='Module')
21 | parser.add_argument('--log_file', type=str, default='log_search')
22 |
23 | if __name__ == '__main__':
24 | args = parser.parse_args()
25 | fold = args.fold
26 | channel_range = [4, 8, 16, 32, 64, 128]
27 | batch_size = args.batch_size
28 | max_depth = args.max_depth
29 | min_depth = args.min_depth
30 | criterion = nn.CrossEntropyLoss()
31 | gpu_id = args.gpu_id
32 |
33 | input_shape = [1, 1, 32, 32, 32]
34 | logging.basicConfig(filename=args.log_file, level=logging.INFO)
35 | use_gpu = torch.cuda.is_available()
36 | train_data_path = args.train_data_path
37 | test_data_path = args.test_data_path
38 | lr = args.lr
39 | save_module_path = args.save_module_path
40 | num_workers = args.num_workers
41 | epoch = args.epoch
42 | # sub = args.sub
43 | # search model
44 | res_search = res_search_space.ResSearchSpace(channel_range, max_depth, min_depth, train_data_path, test_data_path, fold,
45 | batch_size, logging, input_shape, use_gpu, gpu_id, criterion, lr,
46 | save_module_path, num_works=num_workers, epoch=epoch)
47 | res_search.main_method()
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/data/dimcls.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | CROPSIZE = 36
5 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
6 | srslst = pdframe['seriesuid'].tolist()[1:]
7 | crdxlst = pdframe['coordX'].tolist()[1:]
8 | crdylst = pdframe['coordY'].tolist()[1:]
9 | crdzlst = pdframe['coordZ'].tolist()[1:]
10 | dimlst = pdframe['diameter_mm'].tolist()[1:]
11 | mlglst = pdframe['malignant'].tolist()[1:]
12 |
13 | newlst = []
14 | import csv
15 | fid = open('annotationdetclsconvfnl_v3.csv', 'w')
16 | writer = csv.writer(fid)
17 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
18 | for i in range(len(srslst)):
19 | writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
20 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
21 | fid.close()
22 |
23 | # train use gbt
24 | subset1path = '/media/data1/wentao/tianchi/luna16/subset1/'
25 | testfnamelst = []
26 | for fname in os.listdir(subset1path):
27 | if fname.endswith('.mhd'):
28 | testfnamelst.append(fname[:-4])
29 | ntest = 0
30 | for idx in range(len(newlst)):
31 | fname = newlst[idx][0]
32 | if fname.split('-')[0] in testfnamelst: ntest +=1
33 | print('ntest', ntest, 'ntrain', len(newlst)-ntest)
34 |
35 | traindata = np.zeros((len(newlst)-ntest,))
36 | trainlabel = np.zeros((len(newlst)-ntest,))
37 | testdata = np.zeros((ntest,))
38 | testlabel = np.zeros((ntest,))
39 |
40 | trainidx = testidx = 0
41 | for idx in range(len(newlst)):
42 | fname = newlst[idx][0]
43 | if fname.split('-')[0] in testfnamelst:
44 | testdata[testidx] = newlst[idx][-2]
45 | testlabel[testidx] = newlst[idx][-1]
46 | testidx += 1
47 | else:
48 | traindata[trainidx] = newlst[idx][-2]
49 | trainlabel[trainidx] = newlst[idx][-1]
50 | trainidx += 1
51 |
52 | tracclst = []
53 | teacclst = []
54 | thlst = np.sort(traindata).tolist()
55 | besttr = bestte = 0
56 | for th in thlst:
57 | tracc = np.mean(trainlabel == (traindata > th))
58 | teacc = np.mean(testlabel == (testdata > th))
59 | if tracc > besttr:
60 | besttr = tracc
61 | bestte = teacc
62 | tracclst.append(tracc)
63 | teacclst.append(teacc)
64 | import matplotlib
65 | matplotlib.use('Agg')
66 | import matplotlib.pyplot as plt
67 | plt.plot(thlst, tracclst, label='train acc')
68 | plt.plot(thlst, teacclst, label='test acc')
69 | plt.xlabel('Threshold for diameter (mm)')
70 | plt.ylabel('Diagnosis (malignant vs. benign) accuracy (%)')
71 | plt.title('Diagnosis accuracy using diameter feature on fold 1')
72 | plt.legend()
73 | plt.savefig('accwrtdim.png')
74 | print(max(teacclst))
75 |
76 | print(besttr, bestte)
--------------------------------------------------------------------------------
/data/data_enhancement.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 |
5 |
6 | def horizontal_flip(img):
7 | x, y, d = img.shape
8 | for i in range(y // 2):
9 | intermediate = img[:, i, :]
10 | img[:, i, :] = img[:, y - 1 - i, :]
11 | img[:, y - 1 - i, :] = intermediate
12 | return img
13 |
14 |
15 | def vertical_flip(img):
16 | x, y, d = img.shape
17 | for i in range(x // 2):
18 | intermediate = img[i, :, :]
19 | img[i, :, :] = img[x - 1 - i, :, :]
20 | img[x - 1 - i, :, :] = intermediate
21 | return img
22 |
23 |
24 | def deep_flip(img):
25 | x, y, d = img.shape
26 | for i in range(d // 2):
27 | intermediate = img[:, :, i]
28 | img[:, :, i] = img[:, :, d - 1 - i]
29 | img[:, :, d - 1 - i] = intermediate
30 | return img
31 |
32 |
33 | path = "F:\\医学数据集\\LUNA\\cls\\crop_v3\\"
34 | dataframe = pd.read_csv("F:/PycharmProjects/DeepLung/data/annotationdetclsconvfnl_v3.csv", encoding='utf-8')
35 | data_list = dataframe.to_numpy()[1:]
36 | enhancement_data = dataframe.to_numpy()
37 | test_data = np.empty(shape=(0, 32, 32, 32))
38 | train_data = np.empty(shape=(0, 32, 32, 32))
39 | teidlst = []
40 | for fname in os.listdir('F:\\医学数据集\\LUNA\\rowfile\\subset5' + '\\'):
41 | if fname.endswith('.mhd'):
42 | teidlst.append(fname[:-4])
43 | for data in data_list:
44 | if data[0].split('-')[0] not in teidlst:
45 | a = np.load(path + data[0] + '.npy')
46 | horizontal_data = horizontal_flip(a)
47 | horizontal_name = path + data[0] + 'horizontal'
48 | horizontal_information = np.copy(data)
49 | horizontal_information[0] = data[0] + 'horizontal'
50 | enhancement_data = np.append(enhancement_data, [horizontal_information], 0)
51 | # np.save(horizontal_name, horizontal_data)
52 | vertical_data = vertical_flip(a)
53 | vertical_name = path + data[0] + 'vertical'
54 | vertical_information = np.copy(data)
55 | vertical_information[0] = data[0] + 'vertical'
56 | enhancement_data = np.append(enhancement_data, [vertical_information], 0)
57 | # np.save(vertical_name, vertical_data)
58 | deep_data = deep_flip(a)
59 | deep_name = path + data[0] + 'deep'
60 | deep_information = np.copy(data)
61 | deep_information[0] = data[0] + 'deep'
62 | enhancement_data = np.append(enhancement_data, [deep_information], 0)
63 | # np.save(deep_name, deep_data)
64 | frame = pd.DataFrame(enhancement_data, index=None,
65 | columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
66 | frame.to_csv('H:\\a.csv', index=None, columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
67 |
--------------------------------------------------------------------------------
/data/dataconverter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | import pandas as pd
4 | import SimpleITK as sitk
5 | import os
6 | import os.path
7 | def load_itk_image(filename):
8 | with open(filename) as f:
9 | contents = f.readlines()
10 | line = [k for k in contents if k.startswith('TransformMatrix')][0]
11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
12 | transformM = np.round(transformM)
13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])):
14 | isflip = True
15 | else:
16 | isflip = False
17 | itkimage = sitk.ReadImage(filename)
18 | numpyImage = sitk.GetArrayFromImage(itkimage)
19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
21 | return numpyImage, numpyOrigin, numpySpacing,isflip
22 | def worldToVoxelCoord(worldCoord, origin, spacing):
23 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
24 | voxelCoord = stretchedVoxelCoord / spacing
25 | return voxelCoord
26 | # read groundtruth from original data space
27 | # remove data of 0 value
28 | pdframe = pd.read_csv('annotationdetclsgt.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
29 | srslst = pdframe['seriesuid'].tolist()[1:]
30 | crdxlst = pdframe['coordX'].tolist()[1:]
31 | crdylst = pdframe['coordY'].tolist()[1:]
32 | crdzlst = pdframe['coordZ'].tolist()[1:]
33 | dimlst = pdframe['diameter_mm'].tolist()[1:]
34 | mlglst = pdframe['malignant'].tolist()[1:]
35 | dct = {}
36 | for idx in range(len(srslst)):
37 | # if mlglst[idx] == '0':
38 | # continue
39 | assert mlglst[idx] in ['1', '0']
40 | vlu = [float(crdxlst[idx]), float(crdylst[idx]), float(crdzlst[idx]), float(dimlst[idx]), int(mlglst[idx])]
41 | if srslst[idx] in dct:
42 | dct[srslst[idx]].append(vlu)
43 | else:
44 | dct[srslst[idx]] = [vlu]
45 | # convert it to the preprocessed space
46 | newlst = []
47 | rawpath = '/media/data1/wentao/tianchi/luna16/lunaall/'
48 | preprocesspath = '/media/data1/wentao/tianchi/luna16/preprocess/lunaall/'
49 | resolution = np.array([1,1,1])
50 | def process(pid):
51 | # print pid
52 | Mask,origin,spacing,isflip = load_itk_image(os.path.join(rawpath, pid+'.mhd'))
53 | spacing = np.load(os.path.join(preprocesspath, pid+'_spacing.npy'))
54 | extendbox = np.load(os.path.join(preprocesspath, pid+'_extendbox.npy'))
55 | origin = np.load(os.path.join(preprocesspath, pid+'_origin.npy'))
56 | if isflip:
57 | Mask = np.load(os.path.join(preprocesspath, pid+'_mask.npy'))
58 | retlst = []
59 | for vlu in dct[pid]:
60 | pos = worldToVoxelCoord(vlu[:3][::-1], origin=origin, spacing=spacing)
61 | if isflip:
62 | pos[1:] = Mask.shape[1:3] - pos[1:]
63 | label = np.concatenate([pos, [vlu[3]/spacing[1]]])
64 | label2 = np.expand_dims(np.copy(label), 1)
65 | # print label2.shape
66 | label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1)
67 | label2[3] = label2[3]*spacing[1]/resolution[1]
68 | label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1)
69 | label2 = label2[:4].T
70 | retlst.append([pid, label2[0,0], label2[0,1], label2[0,2], label2[0,3], vlu[-1]])
71 | return retlst
72 | from multiprocessing import Pool
73 | p = Pool(30)
74 | newlst = p.map(process, dct.keys())
75 | p.close()
76 | print(len(dct.keys()), len(newlst))
77 | # for pid in dct.keys():
78 | # print pid
79 | # Mask,origin,spacing,isflip = load_itk_image(os.path.join(rawpath, pid+'.mhd'))
80 | # spacing = np.load(os.path.join(preprocesspath, pid+'_spacing.npy'))
81 | # extendbox = np.load(os.path.join(preprocesspath, pid+'_extendbox.npy'))
82 | # origin = np.load(os.path.join(preprocesspath, pid+'_origin.npy'))
83 | # if isflip:
84 | # Mask = np.load(os.path.join(preprocesspath, pid+'_mask.npy'))
85 | # for vlu in dct[pid]:
86 | # pos = worldToVoxelCoord(vlu[:3][::-1], origin=origin, spacing=spacing)
87 | # if isflip:
88 | # pos[1:] = Mask.shape[1:3] - pos[1:]
89 | # label = np.concatenate([pos, [vlu[3]/spacing[1]]])
90 | # label2 = np.expand_dims(np.copy(label), 1)
91 | # # print label2.shape
92 | # label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1)
93 | # label2[3] = label2[3]*spacing[1]/resolution[1]
94 | # label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1)
95 | # label2 = label2[:4].T
96 | # newlst.append([pid, label2[0,0], label2[0,1], label2[0,2], label2[0,3], vlu[-1]])
97 | # save it to the csv
98 | savecsv = 'annotationdetclsconv_v3.csv'
99 | fid = open(savecsv, 'w')
100 | writer = csv.writer(fid)
101 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
102 | for idx in xrange(len(newlst)):
103 | for subidx in xrange(len(newlst[idx])):
104 | writer.writerow(newlst[idx][subidx])
105 | fid.close()
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .net_sphere import *
7 |
8 |
9 | class ResCBAMLayer(nn.Module):
10 | """
11 | CBAM+Res model
12 | """
13 |
14 | def __init__(self, in_planes, feature_size):
15 | super(ResCBAMLayer, self).__init__()
16 | self.in_planes = in_planes
17 | self.feature_size = feature_size
18 | self.ch_AvgPool = nn.AvgPool3d(feature_size, feature_size)
19 | self.ch_MaxPool = nn.MaxPool3d(feature_size, feature_size)
20 | self.ch_Linear1 = nn.Linear(in_planes, in_planes // 4, bias=False)
21 | self.ch_Linear2 = nn.Linear(in_planes // 4, in_planes, bias=False)
22 | self.ch_Softmax = nn.Softmax(1)
23 | self.sp_Conv = nn.Conv3d(2, 1, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.sp_Softmax = nn.Softmax(1)
25 |
26 | def forward(self, x):
27 | x_ch_avg_pool = self.ch_AvgPool(x).view(x.size(0), -1)
28 | x_ch_max_pool = self.ch_MaxPool(x).view(x.size(0), -1)
29 | # x_ch_avg_linear = self.ch_Linear2(self.ch_Linear1(x_ch_avg_pool))
30 | a = self.ch_Linear1(x_ch_avg_pool)
31 | x_ch_avg_linear = self.ch_Linear2(a)
32 |
33 | x_ch_max_linear = self.ch_Linear2(self.ch_Linear1(x_ch_max_pool))
34 | ch_out = (self.ch_Softmax(x_ch_avg_linear + x_ch_max_linear).view(x.size(0), self.in_planes, 1, 1, 1)) * x
35 | x_sp_max_pool = torch.max(ch_out, 1, keepdim=True)[0]
36 | x_sp_avg_pool = torch.sum(ch_out, 1, keepdim=True) / self.in_planes
37 | sp_conv1 = torch.cat([x_sp_max_pool, x_sp_avg_pool], dim=1)
38 | sp_out = self.sp_Conv(sp_conv1)
39 | sp_out = self.sp_Softmax(sp_out.view(x.size(0), -1)).view(x.size(0), 1, x.size(2), x.size(3), x.size(4))
40 | out = sp_out * x + x
41 | return out
42 |
43 |
44 | def make_conv3d(in_channels: int, out_channels: int, kernel_size: typing.Union[int, tuple], stride: int,
45 | padding: int, dilation=1, groups=1,
46 | bias=True) -> nn.Module:
47 | """
48 | produce a Conv3D with Batch Normalization and ReLU
49 |
50 | :param in_channels: num of in in channels
51 | :param out_channels: num of out channels
52 | :param kernel_size: size of kernel int or tuple
53 | :param stride: num of stride
54 | :param padding: num of padding
55 | :param bias: bias
56 | :param groups: groups
57 | :param dilation: dilation
58 | :return: conv3d module
59 | """
60 | module = nn.Sequential(
61 |
62 | nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
63 | groups=groups,
64 | bias=bias),
65 | nn.BatchNorm3d(out_channels),
66 | nn.ReLU())
67 | return module
68 |
69 |
70 | def conv3d_same_size(in_channels, out_channels, kernel_size, stride=1,
71 | dilation=1, groups=1,
72 | bias=True):
73 | """
74 | keep the w,h of inputs same as the outputs
75 |
76 | :param in_channels: num of in in channels
77 | :param out_channels: num of out channels
78 | :param kernel_size: size of kernel int or tuple
79 | :param stride: num of stride
80 | :param dilation: Spacing between kernel elements
81 | :param groups: Number of blocked connections from input channels to output channels.
82 | :param bias: If True, adds a learnable bias to the output
83 | :return: conv3d
84 | """
85 | padding = kernel_size // 2
86 | return make_conv3d(in_channels, out_channels, kernel_size, stride,
87 | padding, dilation, groups,
88 | bias)
89 |
90 |
91 | def conv3d_pooling(in_channels, kernel_size, stride=1,
92 | dilation=1, groups=1,
93 | bias=False):
94 | """
95 | pooling with convolution
96 |
97 | :param in_channels:
98 | :param kernel_size:
99 | :param stride:
100 | :param dilation:
101 | :param groups:
102 | :param bias:
103 | :return: pooling-convolution
104 | """
105 |
106 | padding = kernel_size // 2
107 | return make_conv3d(in_channels, in_channels, kernel_size, stride,
108 | padding, dilation, groups,
109 | bias)
110 |
111 |
112 | class ResidualBlock(nn.Module):
113 | """
114 | a simple residual block
115 | """
116 |
117 | def __init__(self, in_channels, out_channels):
118 | super(ResidualBlock, self).__init__()
119 | self.my_conv1 = make_conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
120 | self.my_conv2 = make_conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
121 | self.conv3 = make_conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
122 |
123 | def forward(self, inputs):
124 | out1 = self.conv3(inputs)
125 | out = self.my_conv1(inputs)
126 | out = self.my_conv2(out)
127 | out = out + out1
128 | return out
129 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import errno
6 | import numpy as np
7 | import sys
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 | import torch
13 | import torch.utils.data as data
14 | from torch.autograd import Variable
15 | # from .utils import download_url, check_integrity
16 |
17 | # npypath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/'
18 | class lunanod(data.Dataset):
19 | """`CIFAR10 `_ Dataset.
20 |
21 | Args:
22 | root (string): Root directory of dataset where directory
23 | ``cifar-10-batches-py`` exists.
24 | train (bool, optional): If True, creates dataset from training set, otherwise
25 | creates from test set.
26 | transform (callable, optional): A function/transform that takes in an PIL image
27 | and returns a transformed version. E.g, ``transforms.RandomCrop``
28 | target_transform (callable, optional): A function/transform that takes in the
29 | target and transforms it.
30 | download (bool, optional): If true, downloads the dataset from the internet and
31 | puts it in root directory. If dataset is already downloaded, it is not
32 | downloaded again.
33 |
34 | """
35 | def __init__(self, npypath, fnamelst, labellst, featlst, train=True,
36 | transform=None, target_transform=None,
37 | download=False):
38 | self.transform = transform
39 | self.target_transform = target_transform
40 | self.train = train # training set or test set
41 | # now load the picked numpy arrays
42 | if self.train:
43 | self.train_data = []
44 | self.train_labels = []
45 | self.train_feat = featlst
46 | for label, fentry in zip(labellst, fnamelst):
47 | file = os.path.join(npypath, fentry)
48 | self.train_data.append(np.load(file))
49 | self.train_labels.append(label)
50 | self.train_data = np.concatenate(self.train_data)
51 | print(len(fnamelst))
52 | print(self.train_data.shape)
53 | self.train_data = self.train_data.reshape((len(fnamelst), 32, 32, 32))
54 | # self.train_labels = np.asarray(self.train_labels)
55 | # self.train_data = self.train_data.transpose((0, 2, 3, 4, 1)) # convert to HWZC
56 | self.train_len = len(fnamelst)
57 | else:
58 | self.test_data = []
59 | self.test_labels = []
60 | self.test_feat = featlst
61 | for label, fentry in zip(labellst, fnamelst):
62 | # if fentry.shape[0] != 32 or fentry.shape[1] != 32 or fentry.shape[2] != 32:
63 | # print(fentry.shape, type(fentry), type(fentry)!='str')
64 | if not isinstance(fentry,str):
65 | self.test_data.append(fentry)
66 | self.test_labels.append(label)
67 | # print('1')
68 | else:
69 | file = os.path.join(npypath, fentry)
70 | self.test_data.append(np.load(file))
71 | self.test_labels.append(label)
72 | self.test_data = np.concatenate(self.test_data)
73 | # print(self.test_data.shape)
74 | self.test_data = self.test_data.reshape((len(fnamelst), 32, 32, 32))
75 | # self.test_labels = np.asarray(self.test_labels)
76 | # self.test_data = self.test_data.transpose((0, 2, 3, 4, 1)) # convert to HWZC
77 | self.test_len = len(fnamelst)
78 | print(self.test_data.shape, len(self.test_labels), len(self.test_feat))
79 |
80 | def __getitem__(self, index):
81 | """
82 | Args:
83 | index (int): Index
84 |
85 | Returns:
86 | tuple: (image, target) where target is index of the target class.
87 | """
88 | if self.train:
89 | img, target, feat = self.train_data[index], self.train_labels[index], self.train_feat[index]
90 | else:
91 | img, target, feat = self.test_data[index], self.test_labels[index], self.test_feat[index]
92 | # img = torch.from_numpy(img)
93 | # img = img.cuda(async = True)
94 |
95 | # doing this so that it is consistent with all other datasets
96 | # to return a PIL Image
97 | # print('1', img.shape, type(img))
98 | # img = Image.fromarray(img)
99 | # print('2', img.size)
100 |
101 | if self.transform is not None:
102 | img = self.transform(img)
103 |
104 | if self.target_transform is not None:
105 | target = self.target_transform(target)
106 | # print(img.shape, target.shape, feat.shape)
107 | # print(target)
108 |
109 | return img, target, feat
110 |
111 | def __len__(self):
112 | if self.train:
113 | return self.train_len
114 | else:
115 | return self.test_len
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NAS-Lung
2 |
3 | > **3D Neural Architecture Search (NAS) for Pulmonary Nodules Classification**
4 | >
5 | > Hanliang Jiang, Fuhao Shen, Fei Gao\*, Weidong Han. *Learning Efficient, Explainable and Discriminative Representations for Pulmonary Nodules Classification*. Pattern Recognition, 113: 107825, 2021.
6 | >
7 | > ```latex
8 | > @article{Jiang2021naslung,
9 | > author = {Hanliang Jiang and Fuhao Shen and Fei Gao and Weidong Han},
10 | > title = {Learning efficient, explainable and discriminative representations for pulmonary nodules classification},
11 | > journal = {Pattern Recognition},
12 | > volume = {113},
13 | > pages = {107825},
14 | > year = {2021},
15 | > issn = {0031-3203},
16 | > doi = {https://doi.org/10.1016/j.patcog.2021.107825},
17 | > }
18 | > ```
19 | >
20 | > [[Paper@PR]](https://www.sciencedirect.com/science/article/pii/S0031320321000121) [[Paper@arxiv]](https://arxiv.org/abs/2101.07429) [[Code@Github]](https://github.com/fei-hdu/NAS-Lung/issues/3)
21 |
22 |
23 |
24 | ## Architecture
25 |
26 | 
27 |
28 | ## Results
29 |
30 | ### NASLung
31 |
32 | | model | Accu. | Sens. | Spec. | F1 Score | para.(M) |
33 | | ------------------- | ----- | ----- | ----- | -------- | -------- |
34 | | Multi-crop CNN | 87.14 | - | - | - | - |
35 | | Nodule-level 2D CNN | 87.30 | 88.50 | 86.00 | 87.23 | - |
36 | | Vanilla 3D CNN | 87.40 | 89.40 | 85.20 | 87.25 | - |
37 | | DeepLung | 90.44 | 81.42 | - | - | 141.57 |
38 | | AE-DPN | 90.24 | 92.04 | 88.94 | 90.45 | 678.69 |
39 | | | | | | | |
40 | | **NASLung (ours)** | 90.77 | 85.37 | 95.04 | 89.29 | 16.84 |
41 |
42 | ### Searched 3D Networks
43 |
44 | | Model | Accu. | Sens. | Spec. | F1 Score | para. |
45 | | -------- | ----- | ----- | ----- | -------- | ----- |
46 | | Model-1 | 88.83 | 87.20 | 90.12 | 87.50 | 0.14 |
47 | | Model-2 | 88.42 | 84.38 | 91.46 | 86.67 | 2.61 |
48 | | Model-3 | 88.17 | 84.44 | 91.60 | 86.50 | 3.90 |
49 | | Model-4 | 88.13 | 83.20 | 92.28 | 86.30 | 2.54 |
50 | | Model-5 | 87.97 | 83.72 | 91.31 | 86.22 | 0.43 |
51 | | Model-6 | 87.77 | 83.67 | 91.00 | 86.03 | 0.22 |
52 | | Model-7 | 87.76 | 84.14 | 89.79 | 85.98 | 0.86 |
53 | | Model-8 | 88.00 | 82.43 | 92.69 | 85.97 | 4.02 |
54 | | Model-9 | 88.04 | 78.01 | 96.09 | 85.36 | 4.06 |
55 | | Model-10 | 87.22 | 82.70 | 90.92 | 85.32 | 0.24 |
56 |
57 | ## Prerequisites
58 |
59 | - Linux or similar environment
60 | - Python 3.7
61 | - Pytorch 0.4.1
62 | - NVIDIA GPU + CUDA CuDNN
63 |
64 | ## Getting Started
65 |
66 | ### Installation
67 |
68 | - Clone this repo:
69 |
70 | ```shell script
71 | git clone https://github.com/fei-hdu/NAS-Lung
72 | cd NAS-Lung
73 | ```
74 |
75 | - Install PyTorch 0.4+ and torchvision from [Pytorch](http://pytorch.org) and other dependencies (e.g., visdom and dominate). You can install all the dependencies by
76 |
77 | ```shell script
78 | pip install -r requirments.txt
79 | ```
80 |
81 | - Download Dataset [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI)
82 |
83 | ### Neural Architecture Search
84 |
85 | ```shell script
86 | python search_main.py --train_data_path {train_data_path} --test_data_path {test_data_path} --save_module_path {save_module_path}
87 | ```
88 |
89 | ### Train/Test
90 |
91 | - Train a model
92 |
93 | ```shell script
94 | sh run_training.sh
95 | ```
96 |
97 | - Test a model
98 |
99 | ```shell script
100 | python test.py --test_data_path {test_data_path} --preprocess_path {preprocess_path} --model_path {model_path}
101 | ```
102 |
103 | ### DataSet
104 |
105 | - [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI)
106 |
107 | ### Model Result
108 |
109 | - our final result can be download:[Google Drive](https://drive.google.com/drive/folders/1vUFi5tEfMcDcKqMbxuN3Tt44QwLcDZnA?usp=sharing)
110 |
111 | ### Training/Test Tips
112 |
113 | - Best practice for training and testing your models.
114 | - Feel free to ask any questions about **_coding_**. **Fuhao Shen, `1048532267sfh@gmail.com`**
115 |
116 | ## Acknowledgement
117 |
118 | - Our work/code is inspired by [Partial Order Pruning: for Best Speed/Accuracy Trade-off in Neural Architecture Search, CVPR 2019](https://github.com/lixincn2015/Partial-Order-Pruning).
119 |
120 | ## Selected References
121 |
122 | - S. Armato III, G. et al., Data from **LIDC-IDRI**, The Cancer Imaging . [LIDC-IDRI](http://doi.org/10.7937/K9/TCIA.2015.LO9QL9SX).
123 | - X. Li, Y. Zhou, Z. Pan, J. Feng, **Partial order pruning**: For best speed/accuracy trade-off in neural architecture search (2019) 9145–9153.
124 | - S. Woo, J. Park, J.-Y. Lee, I. So Kweon, **CBAM**: Convolutional block attention module, in: Proceedings of the European Conference on Computer Vision (ECCV), 2018, pp. 3–19.
125 | - W. Liu, Y. Wen, Z. Yu, M. Li, B. Raj, L. Song, **Sphereface**: Deep hypersphere embedding for face recognition, in: The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017.
126 | - T. Elsken, J. H. Metzen, F. Hutter, **Neural architecture search**: A survey, Journal of Machine Learning Research 20 (55) (2019) 1–21.
127 | - W. Zhu, C. Liu, W. Fan, X. Xie, **Deeplung**: Deep 3d dual path nets for automated pulmonary nodule detection and classification, in: 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), IEEE, 2018, pp. 673–681.
128 |
--------------------------------------------------------------------------------
/models/cnn_res.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .net_sphere import *
7 |
8 |
9 | debug = False
10 |
11 |
12 | class ResCBAMLayer(nn.Module):
13 | def __init__(self, in_planes, feature_size):
14 | super(ResCBAMLayer, self).__init__()
15 | self.in_planes = in_planes
16 | self.feature_size = feature_size
17 | self.ch_AvgPool = nn.AvgPool3d(feature_size, feature_size)
18 | self.ch_MaxPool = nn.MaxPool3d(feature_size, feature_size)
19 | self.ch_Linear1 = nn.Linear(in_planes, in_planes // 4, bias=False)
20 | self.ch_Linear2 = nn.Linear(in_planes // 4, in_planes, bias=False)
21 | self.ch_Softmax = nn.Softmax(1)
22 | self.sp_Conv = nn.Conv3d(2, 1, kernel_size=3, stride=1, padding=1, bias=False)
23 | self.sp_Softmax = nn.Softmax(1)
24 | self.sp_sigmoid = nn.Sigmoid()
25 | def forward(self, x):
26 | x_ch_avg_pool = self.ch_AvgPool(x).view(x.size(0), -1)
27 | x_ch_max_pool = self.ch_MaxPool(x).view(x.size(0), -1)
28 | # x_ch_avg_linear = self.ch_Linear2(self.ch_Linear1(x_ch_avg_pool))
29 | a = self.ch_Linear1(x_ch_avg_pool)
30 | x_ch_avg_linear = self.ch_Linear2(a)
31 |
32 | x_ch_max_linear = self.ch_Linear2(self.ch_Linear1(x_ch_max_pool))
33 | ch_out = (self.ch_Softmax(x_ch_avg_linear + x_ch_max_linear).view(x.size(0), self.in_planes, 1, 1, 1)) * x
34 | x_sp_max_pool = torch.max(ch_out, 1, keepdim=True)[0]
35 | x_sp_avg_pool = torch.sum(ch_out, 1, keepdim=True) / self.in_planes
36 | sp_conv1 = torch.cat([x_sp_max_pool, x_sp_avg_pool], dim=1)
37 | sp_out = self.sp_Conv(sp_conv1)
38 | sp_out = self.sp_sigmoid(sp_out.view(x.size(0), -1)).view(x.size(0), 1, x.size(2), x.size(3), x.size(4))
39 | out = sp_out * x + x
40 | return out
41 |
42 |
43 | def make_conv3d(in_channels: int, out_channels: int, kernel_size: typing.Union[int, tuple], stride: int,
44 | padding: int, dilation=1, groups=1,
45 | bias=True) -> nn.Module:
46 | """
47 | produce a Conv3D with Batch Normalization and ReLU
48 |
49 | :param in_channels: num of in in
50 | :param out_channels: num of out channels
51 | :param kernel_size: size of kernel int or tuple
52 | :param stride: num of stride
53 | :param padding: num of padding
54 | :param bias: bias
55 | :param groups: groups
56 | :param dilation: dilation
57 | :return: my conv3d module
58 | """
59 | module = nn.Sequential(
60 |
61 | nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
62 | groups=groups,
63 | bias=bias),
64 | nn.BatchNorm3d(out_channels),
65 | nn.ReLU())
66 | return module
67 |
68 |
69 | def conv3d_same_size(in_channels, out_channels, kernel_size, stride=1,
70 | dilation=1, groups=1,
71 | bias=True):
72 | padding = kernel_size // 2
73 | return make_conv3d(in_channels, out_channels, kernel_size, stride,
74 | padding, dilation, groups,
75 | bias)
76 |
77 |
78 | def conv3d_pooling(in_channels, kernel_size, stride=1,
79 | dilation=1, groups=1,
80 | bias=False):
81 | padding = kernel_size // 2
82 | return make_conv3d(in_channels, in_channels, kernel_size, stride,
83 | padding, dilation, groups,
84 | bias)
85 |
86 |
87 | class ResidualBlock(nn.Module):
88 | """
89 | a simple residual block
90 | """
91 |
92 | def __init__(self, in_channels, out_channels):
93 | super(ResidualBlock, self).__init__()
94 | self.my_conv1 = make_conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
95 | self.my_conv2 = make_conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
96 | self.conv3 = make_conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
97 |
98 | def forward(self, inputs):
99 | out1 = self.conv3(inputs)
100 | out = self.my_conv1(inputs)
101 | out = self.my_conv2(out)
102 | out = out + out1
103 | return out
104 |
105 |
106 | class ConvRes(nn.Module):
107 | def __init__(self, config):
108 | super(ConvRes, self).__init__()
109 | self.conv1 = conv3d_same_size(in_channels=1, out_channels=4, kernel_size=3)
110 | self.conv2 = conv3d_same_size(in_channels=4, out_channels=4, kernel_size=3)
111 | self.config = config
112 | self.last_channel = 4
113 | self.first_cbam = ResCBAMLayer(4, 32)
114 | layers = []
115 | i = 0
116 | for stage in config:
117 | i = i+1
118 | layers.append(conv3d_pooling(self.last_channel, kernel_size=3, stride=2))
119 | for channel in stage:
120 | layers.append(ResidualBlock(self.last_channel, channel))
121 | self.last_channel = channel
122 | layers.append(ResCBAMLayer(self.last_channel, 32//(2**i)))
123 | self.layers = nn.Sequential(*layers)
124 | self.avg_pooling = nn.AvgPool3d(kernel_size=4, stride=4)
125 | self.fc = AngleLinear(in_features=self.last_channel, out_features=2)
126 |
127 | def forward(self, inputs):
128 | if debug:
129 | print(inputs.size())
130 | out = self.conv1(inputs)
131 | if debug:
132 | print(out.size())
133 | out = self.conv2(out)
134 | if debug:
135 | print(out.size())
136 | out = self.first_cbam(out)
137 | out = self.layers(out)
138 | if debug:
139 | print(out.size())
140 | out = self.avg_pooling(out)
141 | out = out.view(out.size(0), -1)
142 | if debug:
143 | print(out.size())
144 | out = self.fc(out)
145 | return out
146 |
147 |
148 | def test():
149 | global debug
150 | debug = True
151 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]])
152 | inputs = torch.randn((1, 1, 32, 32, 32))
153 | output = net(inputs)
154 | print(net.config)
155 | print(output)
156 |
--------------------------------------------------------------------------------
/data/nodclsgbt.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import mahotas
4 | from mahotas.features.lbp import lbp
5 |
6 | CROPSIZE = 32 # 24#30#36
7 | print(CROPSIZE)
8 | pdframe = pd.read_csv('annotationdetclsconvfnl_v3.csv',
9 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
10 | srslst = pdframe['seriesuid'].tolist()[1:]
11 | crdxlst = pdframe['coordX'].tolist()[1:]
12 | crdylst = pdframe['coordY'].tolist()[1:]
13 | crdzlst = pdframe['coordZ'].tolist()[1:]
14 | dimlst = pdframe['diameter_mm'].tolist()[1:]
15 | mlglst = pdframe['malignant'].tolist()[1:]
16 |
17 | newlst = []
18 | import csv
19 |
20 | fid = open('annotationdetclsconvfnl_v3.csv', 'r')
21 | # writer = csv.writer(fid)
22 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
23 | for i in range(len(srslst)):
24 | # writer.writerow([srslst[i] + '-' + str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
25 | newlst.append([srslst[i] + '-' + str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
26 | fid.close()
27 |
28 | preprocesspath = '/media/jehovah/Work/data/LUNA/propocess/all/'
29 | savepath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/'
30 | import os
31 | import os.path
32 |
33 | if not os.path.exists(savepath): os.mkdir(savepath)
34 | for idx in range(len(newlst)):
35 | fname = newlst[idx][0]
36 | # if fname != '1.3.6.1.4.1.14519.5.2.1.6279.6001.119209873306155771318545953948-581': continue
37 | pid = fname.split('-')[0]
38 | crdx = int(float(newlst[idx][1]))
39 | crdy = int(float(newlst[idx][2]))
40 | crdz = int(float(newlst[idx][3]))
41 | dim = int(float(newlst[idx][4]))
42 | data = np.load(os.path.join(preprocesspath, pid + '_clean.npy'))
43 | bgx = max(0, crdx - CROPSIZE / 2)
44 | bgy = max(0, crdy - CROPSIZE / 2)
45 | bgz = max(0, crdz - CROPSIZE / 2)
46 | cropdata = np.ones((CROPSIZE, CROPSIZE, CROPSIZE)) * 170
47 | cropdatatmp = np.array(data[0, bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE])
48 | cropdata[CROPSIZE / 2 - cropdatatmp.shape[0] / 2:CROPSIZE / 2 - cropdatatmp.shape[0] / 2 + cropdatatmp.shape[0], \
49 | CROPSIZE / 2 - cropdatatmp.shape[1] / 2:CROPSIZE / 2 - cropdatatmp.shape[1] / 2 + cropdatatmp.shape[1], \
50 | CROPSIZE / 2 - cropdatatmp.shape[2] / 2:CROPSIZE / 2 - cropdatatmp.shape[2] / 2 + cropdatatmp.shape[2]] = np.array(
51 | 2 - cropdatatmp)
52 | assert cropdata.shape[0] == CROPSIZE and cropdata.shape[1] == CROPSIZE and cropdata.shape[2] == CROPSIZE
53 | np.save(os.path.join(savepath, fname + '.npy'), cropdata)
54 |
55 | # train use gbt
56 | subset1path = '/media/jehovah/Work/data/LUNA/rowfile/subset1/'
57 | testfnamelst = []
58 | for fname in os.listdir(subset1path):
59 | if fname.endswith('.mhd'):
60 | testfnamelst.append(fname[:-4])
61 | ntest = 0
62 | for idx in range(len(newlst)):
63 | fname = newlst[idx][0]
64 | if fname.split('-')[0] in testfnamelst: ntest += 1
65 | print('ntest', ntest, 'ntrain', len(newlst) - ntest)
66 |
67 | traindata = np.zeros((len(newlst) - ntest, CROPSIZE * CROPSIZE * CROPSIZE))
68 | trainlabel = np.zeros((len(newlst) - ntest,))
69 | testdata = np.zeros((ntest, CROPSIZE * CROPSIZE * CROPSIZE))
70 | testlabel = np.zeros((ntest,))
71 |
72 | trainidx = testidx = 0
73 | for idx in range(len(newlst)):
74 | fname = newlst[idx][0]
75 | # print fname
76 | data = np.load(os.path.join(savepath, fname + '.npy'))
77 | # print data.shape
78 | bgx = data.shape[0] / 2 - CROPSIZE / 2
79 | bgy = data.shape[1] / 2 - CROPSIZE / 2
80 | bgz = data.shape[2] / 2 - CROPSIZE / 2
81 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE])
82 | if fname.split('-')[0] in testfnamelst:
83 | testdata[testidx, :] = np.reshape(data, (-1,)) / 255
84 | # testdata[testidx, -4] = newlst[idx][1]
85 | # testdata[testidx, -3] = newlst[idx][2]
86 | # testdata[testidx, -2] = newlst[idx][3]
87 | # testdata[testidx, -1] = newlst[idx][4]
88 | testlabel[testidx] = newlst[idx][-1]
89 | testidx += 1
90 | else:
91 | traindata[trainidx, :] = np.reshape(data, (-1,)) / 255
92 | # traindata[trainidx, -4] = newlst[idx][1]
93 | # traindata[trainidx, -3] = newlst[idx][2]
94 | # traindata[trainidx, -2] = newlst[idx][3]
95 | # traindata[trainidx, -1] = newlst[idx][4]
96 | trainlabel[trainidx] = newlst[idx][-1]
97 | trainidx += 1
98 | maxtraindata1 = max(traindata[:, -1])
99 | # traindata[:, -1] = np.array(traindata[:, -1] / maxtraindata1)
100 | # maxtraindata2 = max(traindata[:, -2])
101 | # traindata[:, -2] = np.array(traindata[:, -2] / maxtraindata2)
102 | # maxtraindata3 = max(traindata[:, -3])
103 | # traindata[:, -3] = np.array(traindata[:, -3] / maxtraindata3)
104 | # maxtraindata4 = max(traindata[:, -4])
105 | # traindata[:, -4] = np.array(traindata[:, -4] / maxtraindata4)
106 | # testdata[:, -1] = np.array(testdata[:, -1] / maxtraindata1)
107 | # testdata[:, -2] = np.array(testdata[:, -2] / maxtraindata2)
108 | # testdata[:, -3] = np.array(testdata[:, -3] / maxtraindata3)
109 | # testdata[:, -4] = np.array(testdata[:, -4] / maxtraindata4)
110 | from sklearn.ensemble import GradientBoostingClassifier as gbt
111 |
112 |
113 | def gbtfunc(dep):
114 | m = gbt(max_depth=dep, random_state=0)
115 | m.fit(traindata, trainlabel)
116 | predtrain = m.predict(traindata)
117 | predtest = m.predict_proba(testdata)
118 | # print predtest.shape, predtest[1,:]
119 | return np.sum(predtrain == trainlabel) / float(traindata.shape[0]), \
120 | np.mean((predtest[:, 1] > 0.5).astype(int) == testlabel), predtest # / float(testdata.shape[0]),
121 |
122 |
123 | # trainacc, testacc, predtest = gbtfunc(3)
124 | # print trainacc, testacc
125 | # np.save('pixradiustest.npy', predtest[:,1])
126 | from multiprocessing import Pool
127 |
128 | p = Pool(30)
129 | acclst = p.map(gbtfunc, range(1, 9, 1)) # 3,4,1))#5,1))#1,9,1))
130 | for acc in acclst:
131 | print("{0:.4f}".format(acc[0]), "{0:.4f}".format(acc[1]))
132 | p.close()
133 | # for dep in xrange(1,9,1):
134 | # m = gbt(max_depth=dep)
135 | # m.fit(traindata, trainlabel)
136 | # print dep, 'trainacc', np.sum(m.predict(traindata) == trainlabel) / float(traindata.shape[0])
137 | # print dep, 'testacc', np.sum(m.predict(testdata) == testlabel) / float(testdata.shape[0])
138 |
--------------------------------------------------------------------------------
/searchspace/res_search_space.py:
--------------------------------------------------------------------------------
1 | """the width of a block to be no narrower than its preceding block"""
2 |
3 | import numpy as np
4 | from .search_space_utils import *
5 | import random
6 | from data_utils import *
7 | from models.cnn_res import ConvRes
8 |
9 |
10 | class ResSearchSpace:
11 | """
12 | search class
13 | """
14 | def __init__(self, channel_range, max_depth, min_depth, trained_data_path, test_data_path, fold, batch_size,
15 | logging, input_shape, use_gpu, gpu_id, criterion, lr, save_module_path, num_works, epoch):
16 | self.lr = lr
17 | self.epoch = epoch
18 | self.num_works = num_works
19 | self.fold = fold
20 | # self.sub = sub
21 | self.save_module_path = save_module_path
22 | self.max_depth = max_depth
23 | self.min_depth = min_depth
24 | self.criterion = criterion
25 | self.logging = logging
26 | self.input_shape = input_shape
27 | self.use_gpu = use_gpu
28 | self.gpu_id = gpu_id
29 | self.max_depth = max_depth
30 | self.channel_range = channel_range
31 | self.trained_module_acc_lat = np.empty((0, 3))
32 | # self.trained_module_acc_lat = [module_config,acc,lat]
33 | self.pruned_module = []
34 | # initialize all architecture set
35 | self.untrained_module = get_all_search_space(min_len=min_depth, max_len=max_depth, channel_range=channel_range)
36 | # load data set
37 | self.train_loader, self.test_loader = load_data(trained_data_path, test_data_path, self.fold, batch_size,
38 | self.num_works)
39 | self.trained_yw = []
40 | self.save_module_path = save_module_path
41 | self.trained_yw_and_module = []
42 | if not os.path.isdir(self.save_module_path):
43 | os.mkdir(self.save_module_path)
44 |
45 | def random_generate(self):
46 | """
47 | get untrained model config
48 |
49 | :return: model config
50 | """
51 | count = self.untrained_module.__len__()
52 | index = random.randint(0, count)
53 | return self.untrained_module[index]
54 |
55 | def main_method(self):
56 | """
57 | search main method
58 | """
59 | B = []
60 | stable_time = 0
61 | repeat_time = 0
62 | while True:
63 | # get untrained model
64 | config = self.random_generate()
65 | # config = [[512, 512, 512,512,512], [512, 512, 512,512,512], [512, 512, 512, 512,512]]
66 | self.untrained_module.remove(config)
67 | # train model
68 | net = ConvRes(config)
69 | net_lat = get_module_lat(net, input_shape=self.input_shape)
70 | net = net_to_cuda(net, use_gpu=self.use_gpu, gpu_ids=self.gpu_id)
71 | optimizer = optim.Adam(net.parameters(), lr=self.lr, betas=(0.5, 0.999))
72 | acc = get_acc(net, self.use_gpu, self.train_loader, self.test_loader, optimizer, self.criterion,
73 | self.logging, self.lr, config, self.epoch)
74 | print(f'module:{config}\nacc:{acc} lat:{net_lat}')
75 | self.logging.info(f'config:{config}\nacc:{acc} lat:{net_lat}')
76 | del net
77 | self.trained_module_acc_lat = np.append(self.trained_module_acc_lat, [[config, acc, net_lat]], axis=0)
78 | # prune model
79 | for module in self.trained_module_acc_lat:
80 | yw = get_yw(self.trained_module_acc_lat, module)
81 | if len(yw) != 0:
82 | module_config = module[0]
83 | yw_config = yw[0]
84 | yw_lat = yw[2]
85 | if [yw_config, module_config] not in self.trained_yw_and_module:
86 | print(f'yw:{yw_config}\nmodule:{module_config}')
87 | self.logging.info(f'yw:{yw_config}\nmodule:{module_config}')
88 | narrower_module = get_narrower_module(self.channel_range, module_config)
89 | print('found narrower_module:' + str(narrower_module.__len__()))
90 | self.logging.info(
91 | 'found narrower_module:' + str(narrower_module.__len__())
92 | )
93 | shallower_module = get_shallower_module(self.min_depth, [module_config], shallower_module=[])
94 | print('found shallower_module:' + str(shallower_module.__len__()))
95 | self.logging.info(
96 | 'found shallower_module:' + str(shallower_module.__len__())
97 | )
98 | pruned_narrower_module = 0
99 | for i in narrower_module:
100 | if i in self.untrained_module:
101 | lat = get_latency(ConvRes(i), input_size=self.input_shape)
102 | if lat > yw_lat:
103 | self.pruned_module.append(i)
104 | self.untrained_module.remove(i)
105 | pruned_narrower_module = pruned_narrower_module + 1
106 | print('pruned_narrower_module:' + str(pruned_narrower_module))
107 | self.logging.info(
108 | 'pruned_narrower_module:' + str(pruned_narrower_module)
109 | )
110 | pruned_shallower_module = 0
111 | for i in shallower_module:
112 | if i in self.untrained_module:
113 | lat = get_latency(ConvRes(i), input_size=self.input_shape)
114 | if lat > yw_lat:
115 | pruned_shallower_module = pruned_shallower_module + 1
116 | self.pruned_module.append(i)
117 | self.untrained_module.remove(i)
118 | print('pruned_shallower_module:' + str(pruned_shallower_module))
119 | self.logging.info(
120 | 'pruned_shallower_module:' + str(pruned_shallower_module)
121 | )
122 | self.trained_yw_and_module.append([yw_config, module_config])
123 | else:
124 | print(f'{module[0]}yw not found')
125 |
126 | B1 = get_excellent_module(self.trained_module_acc_lat)
127 | if repeat_time % 20 == 0:
128 | np.save(self.save_module_path + '/' + str(repeat_time) + 'trained', self.trained_module_acc_lat)
129 | np.save(self.save_module_path + '/' + str(repeat_time) + 'excellent', B1)
130 | if np.array_equal(B1, B):
131 | stable_time = stable_time + 1
132 | B = copy.deepcopy(B1)
133 | else:
134 | stable_time = 0
135 | B = copy.deepcopy(B1)
136 | if stable_time > 30:
137 | break
138 | repeat_time = repeat_time + 1
139 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import transforms as transforms
5 | import pandas as pd
6 | from dataloader import lunanod
7 | from torch.autograd import Variable
8 | from itertools import combinations, permutations
9 | import logging
10 | import pandas as pd
11 | import argparse
12 |
13 |
14 | def load_data(test_data_path, preprocess_path, fold, batch_size, num_workers):
15 | test_data_path = '/data/xxx/LUNA/rowfile/subset'
16 | crop_size = 32
17 | black_list = []
18 |
19 | preprocess_path = '/data/xxx/LUNA/cls/crop_v3'
20 | pix_value, npix = 0, 0
21 | for file_name in os.listdir(preprocess_path):
22 | if file_name.endswith('.npy'):
23 | if file_name[:-4] in black_list:
24 | continue
25 | data = np.load(os.path.join(preprocess_path, file_name))
26 | pix_value += np.sum(data)
27 | npix += np.prod(data.shape)
28 | pix_mean = pix_value / float(npix)
29 | pix_value = 0
30 | for file_name in os.listdir(preprocess_path):
31 | if file_name.endswith('.npy'):
32 | if file_name[:-4] in black_list: continue
33 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean
34 | pix_value += np.sum(data * data)
35 | pix_std = np.sqrt(pix_value / float(npix))
36 | print(pix_mean, pix_std)
37 | transform_train = transforms.Compose([
38 | # transforms.RandomScale(range(28, 38)),
39 | transforms.RandomCrop(32, padding=4),
40 | transforms.RandomHorizontalFlip(),
41 | transforms.RandomYFlip(),
42 | transforms.RandomZFlip(),
43 | transforms.ZeroOut(4),
44 | transforms.ToTensor(),
45 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func
46 | ])
47 |
48 | transform_test = transforms.Compose([
49 | transforms.ToTensor(),
50 | transforms.Normalize((pix_mean), (pix_std)),
51 | ])
52 |
53 | # load data list
54 | test_file_name_list = []
55 | test_label_list = []
56 | test_feat_list = []
57 |
58 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
59 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
60 |
61 | all_list = data_frame['seriesuid'].tolist()[1:]
62 | label_list = data_frame['malignant'].tolist()[1:]
63 | crdx_list = data_frame['coordX'].tolist()[1:]
64 | crdy_list = data_frame['coordY'].tolist()[1:]
65 | crdz_list = data_frame['coordZ'].tolist()[1:]
66 | dim_list = data_frame['diameter_mm'].tolist()[1:]
67 | # test id
68 | test_id_list = []
69 | for file_name in os.listdir(test_data_path + str(fold) + '/'):
70 |
71 | if file_name.endswith('.mhd'):
72 | test_id_list.append(file_name[:-4])
73 | mxx = mxy = mxz = mxd = 0
74 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list):
75 | mxx = max(abs(float(x)), mxx)
76 | mxy = max(abs(float(y)), mxy)
77 | mxz = max(abs(float(z)), mxz)
78 | mxd = max(abs(float(d)), mxd)
79 | if srsid in black_list:
80 | continue
81 | # crop raw pixel as feature
82 | data = np.load(os.path.join(preprocess_path, srsid + '.npy'))
83 | bgx = int(data.shape[0] / 2 - crop_size / 2)
84 | bgy = int(data.shape[1] / 2 - crop_size / 2)
85 | bgz = int(data.shape[2] / 2 - crop_size / 2)
86 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size])
87 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2]
88 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
89 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float)
90 | feat[mask] = 1
91 | if srsid.split('-')[0] in test_id_list:
92 | test_file_name_list.append(srsid + '.npy')
93 | test_label_list.append(int(label))
94 | test_feat_list.append(feat)
95 | for idx in range(len(test_feat_list)):
96 | test_feat_list[idx][-1] /= mxd
97 |
98 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False,
99 | download=True,
100 | transform=transform_test)
101 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
102 | return test_loader
103 |
104 |
105 | def load_module(path):
106 | checkpoint = torch.load(path)
107 | net = checkpoint['net']
108 | net.cuda()
109 | return net
110 |
111 |
112 | def get_targets(test_loader):
113 | target_list = np.empty(shape=0)
114 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader):
115 | target_list = np.append(target_list, targets)
116 | target_list = target_list.astype(int)
117 | return target_list
118 |
119 |
120 | def get_permutations(model_list, count, top_count):
121 | result = []
122 | for i in permutations(model_list, count):
123 | result.append(list(i))
124 | if result.__len__() >= top_count:
125 | return result
126 | return result
127 |
128 |
129 | def test_module(module, test_loader):
130 | module.eval()
131 | total = 0
132 | correct = 0
133 | TP = 0
134 | TN = 0
135 | FN = 0
136 | FP = 0
137 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader):
138 | inputs, targets = inputs.cuda(), targets.cuda()
139 |
140 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
141 | outputs = module(inputs)
142 | prediction = 0
143 | if not isinstance(outputs, tuple):
144 | _, prediction = torch.max(outputs.data, 1)
145 | # print(prediction.shape)
146 | # print('1')
147 | else:
148 | _, prediction = torch.max(outputs[0].data, 1)
149 | # print('2')
150 | TP += ((prediction == 1) & (targets.data == 1)).cpu().sum()
151 | TN += ((prediction == 0) & (targets.data == 0)).cpu().sum()
152 | FN += ((prediction == 0) & (targets.data == 1)).cpu().sum()
153 | FP += ((prediction == 1) & (targets.data == 0)).cpu().sum()
154 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item())
155 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item())
156 | acc = (TP.data.item()+TN.data.item()) / (TP.data.item()+TN.data.item()+FN.data.item()+FP.data.item())
157 | print(f'acc:{acc}')
158 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr))
159 |
160 |
161 | def get_predicted(result_array):
162 | positive_array = result_array == 1
163 | negative_array = result_array == 0
164 | positive_count = np.sum(positive_array, axis=0)
165 | negative_count = np.sum(negative_array, axis=0)
166 | predicted = positive_count > negative_count
167 | return predicted.astype(int)
168 |
169 |
170 | parser = argparse.ArgumentParser(description='test')
171 | parser.add_argument('--model_path', type=str,
172 | default='/data/fuhao/PartialOrderPrunning/[4,4,[4, 8, 16, 16, 16], [32, 128], [128]]/checkpoint-5/ckpt.t7',
173 | help='ckpt.t7')
174 | parser.add_argument('--fold', type=int, default=5, help='1-5')
175 | parser.add_argument('--batch_size', type=int, default=8)
176 | parser.add_argument('--num_workers', type=int, default=24)
177 | parser.add_argument('--test_data_path', type=str, default='/data/xxx/LUNA/rowfile/subset')
178 | parser.add_argument('--preprocess_path', type=str, default='/data/xxx/LUNA/cls/crop_v3')
179 | args = parser.parse_args()
180 |
181 | if __name__ == '__main__':
182 | fold = args.fold
183 | batch_size = args.batch_size
184 | num_workers = args.num_workers
185 | test_data_path = args.test_data_path
186 | preprocess_path = args.preprocess_path
187 | net = load_module(args.model_path)
188 | test_data_loader = load_data(test_data_path, preprocess_path, fold, batch_size, num_workers)
189 | test_module(net, test_data_loader)
190 |
--------------------------------------------------------------------------------
/searchspace/search_space_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import time
4 | import torch
5 | import copy
6 | import itertools
7 |
8 |
9 | def get_all_search_space(min_len, max_len, channel_range):
10 | """
11 | get all configs of model
12 |
13 | :param min_len: min of the depth of model
14 | :param max_len: max of the depth of model
15 | :param channel_range: list, the range of channel
16 | :return: all search space
17 | """
18 | all_search_space = []
19 | # get all model config with max length
20 | max_array = get_search_space(max_len, channel_range)
21 | max_array = np.array(max_array)
22 | for i in range(min_len, max_len+1):
23 | new_array = max_array[:, :i]
24 | repeat_list = new_array.tolist()
25 | # remove repeated list from lists
26 | new_list = remove_repeated_element(repeat_list)
27 | for list in new_list:
28 | for first_split in range(1, i -1):
29 | for second_split in range(first_split + 1, i):
30 | # split list
31 | all_search_space.append(
32 | [list[:first_split], list[first_split:second_split], list[second_split:]])
33 | return all_search_space
34 |
35 |
36 | def get_limited_search_space(min_len, max_len, channel_range):
37 | """
38 | get all limited configs of model,
39 | the depth of a stage between [model_depth//4,model_depth//2]
40 |
41 | :param min_len: min of the depth of model
42 | :param max_len: max of the depth of model
43 | :param channel_range: list, the range of channel
44 | :return: all search space
45 | """
46 | max_len = max_len+1
47 | all_search_space = []
48 | # get all model config with max length
49 | max_array = get_search_space(max_len, channel_range)
50 | max_array = np.array(max_array)
51 | for i in range(min_len, max_len):
52 | new_array = max_array[:, :i]
53 | repeat_list = new_array.tolist()
54 | # remove repeated list from lists
55 | new_list = remove_repeated_element(repeat_list)
56 | for list in new_list:
57 | # limit [model_depth//4,model_depth//2]
58 | for first_split in range(i // 4, i - i // 2 + 1):
59 | for second_split in range(first_split + i // 4, i - i // 4 + 1):
60 | all_search_space.append(
61 | [list[:first_split], list[first_split:second_split], list[second_split:]])
62 | return all_search_space
63 |
64 |
65 | def get_search_space(max_len, channel_range, search_space=[], now=0):
66 | """
67 | Recursive.
68 | Get all configuration combinations
69 |
70 | :param max_len: max of the depth of model
71 | :param channel_range: list, the range of channel
72 | :param search_space: search space
73 | :param now: depth of model
74 | :return:
75 | """
76 | result = []
77 | if now == 0:
78 | for i in channel_range:
79 | result.append([i])
80 | else:
81 | for i in search_space:
82 | larger_channel = get_larger_channel(channel_range, i[-1])
83 | for m in larger_channel:
84 | tmp = i.copy()
85 | tmp.append(m)
86 | result.append(tmp)
87 | now = now + 1
88 | if now < max_len:
89 | return get_search_space(max_len, channel_range, search_space=result, now=now)
90 | else:
91 | return result
92 |
93 |
94 | def get_larger_channel(channel_range, channel_num):
95 | """
96 | get channels which is larger than inputs
97 |
98 | :param channel_range: list,channel range
99 | :param channel_num: input channel
100 | :return: list,channels which is larger than inputs
101 | """
102 | result = filter(lambda x: x >= channel_num, channel_range)
103 | return list(result)
104 |
105 |
106 | def get_smaller_channel(channel, channel_range):
107 | """
108 | get channels which is smaller than inputs
109 |
110 | :param channel:input channel
111 | :param channel_range:list,channel range
112 | :return:list,channels which is larger than inputs
113 | """
114 |
115 | return list(filter(lambda x: x < channel, channel_range))
116 |
117 |
118 | def get_shallower_module(min_len, module_config, shallower_module=[]):
119 | """
120 | get module config which is shallower than module_config
121 |
122 | :param min_len: min depth of model
123 | :param module_config: input module config
124 | :param shallower_module:
125 | :return: list,module config which is shallower than module_config
126 | """
127 | new_module_config = []
128 | for config in module_config:
129 | for m in range(len(config)):
130 | if type(config[m]) is not int:
131 | if len(config[m]) > 1:
132 | for n in range(len(config[m])):
133 | tmp = copy.deepcopy(config)
134 | del tmp[m][n]
135 | new_module_config.append(tmp)
136 | new_module_config = remove_repeated_element(new_module_config)
137 | shallower_module.extend(new_module_config)
138 | # sum(len(x) for x in a):get the count of all element
139 | if shallower_module != []:
140 | count = sum(len(x) for x in shallower_module[-1])
141 | if count > min_len:
142 | return get_shallower_module(min_len, new_module_config, shallower_module)
143 | else:
144 | return shallower_module
145 | else:
146 | return []
147 |
148 |
149 | def remove_repeated_element(repeated_list):
150 | """
151 | Remove duplicate elements
152 |
153 | :param repeated_list: input list
154 | :return: List without duplicate elements
155 | """
156 | repeated_list.sort()
157 | new_list = [repeated_list[k] for k in range(len(repeated_list)) if
158 | k == 0 or repeated_list[k] != repeated_list[k - 1]]
159 | return new_list
160 |
161 |
162 | def get_element_count(the_list):
163 | """
164 | get depth of model
165 |
166 | :param the_list: input model config
167 | :return: depth of model
168 | """
169 | count = sum(len(x) for x in the_list)
170 | return count
171 |
172 |
173 | def flat_list(the_list):
174 | """
175 | flatten list
176 |
177 | :param the_list:
178 | :return: flatten list
179 | """
180 | return [item for sublist in the_list for item in sublist]
181 |
182 |
183 | def get_narrower_module(channel_range, module_config):
184 | """
185 | get module config which is narrower than module_config
186 |
187 | :param channel_range: channel range
188 | :param module_config: input model config
189 | :return: list,module config which is narrower than module_config
190 | """
191 | len_list = []
192 | for i in module_config:
193 | len_list.append(len(i))
194 | count = get_element_count(module_config)
195 | config_list = get_search_space(count, channel_range)
196 | config_array = np.array(config_list)
197 | module_config_array = np.array(flat_list(module_config))
198 | equal_module_config_array = config_array <= module_config_array
199 | equal_module_config_array = np.prod(equal_module_config_array, 1)
200 | index = np.where(equal_module_config_array == 1)
201 | narrower_config = config_array[index[0]]
202 | narrower_config = narrower_config.tolist()
203 | result = []
204 | for i in narrower_config:
205 | result.append([i[:len_list[0]], i[len_list[0]:len_list[1] + len_list[0]], i[len_list[1] + len_list[0]:]])
206 | return remove_repeated_element(result)
207 |
208 |
209 | def get_latency(module, input_size):
210 | """
211 | get the latency of module
212 |
213 | :param module:
214 | :param input_size:
215 | :return: latency
216 | """
217 | module_input = torch.randn(input_size)
218 | start = time.time()
219 | output = module(module_input)
220 | end = time.time()
221 | return end - start
222 |
223 |
224 | def get_excellent_module(trained_module):
225 | """
226 | get model with less latency and higher acc
227 |
228 | :param trained_module: trained module list
229 | :return: excellent module
230 | """
231 | excellent_module = np.empty(shape=(0, 3))
232 | acc_and_lat = trained_module[:, 1:]
233 | for module in trained_module:
234 | tmp = copy.deepcopy(acc_and_lat)
235 | tmp[:, 0] = acc_and_lat[:, 0] <= module[1]
236 | tmp[:, 1] = acc_and_lat[:, 1] >= module[2]
237 | tmp = np.sum(tmp, axis=1)
238 | if 0 not in tmp:
239 | excellent_module = np.append(excellent_module, [module], axis=0)
240 | return excellent_module
241 |
242 |
243 |
244 |
--------------------------------------------------------------------------------
/random_forest.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import transforms as transforms
5 | import pandas as pd
6 | from dataloader import lunanod
7 | from torch.autograd import Variable
8 | from itertools import combinations, permutations
9 | import logging
10 | import pandas as pd
11 |
12 |
13 | def load_data(fold, batch_size, num_workers):
14 | test_data_path = '/data/xxx/LUNA/rowfile/subset'
15 | crop_size = 32
16 | black_list = []
17 |
18 | preprocess_path = '/data/xxx/LUNA/cls/crop_v3'
19 | pix_value, npix = 0, 0
20 | for file_name in os.listdir(preprocess_path):
21 | if file_name.endswith('.npy'):
22 | if file_name[:-4] in black_list:
23 | continue
24 | data = np.load(os.path.join(preprocess_path, file_name))
25 | pix_value += np.sum(data)
26 | npix += np.prod(data.shape)
27 | pix_mean = pix_value / float(npix)
28 | pix_value = 0
29 | for file_name in os.listdir(preprocess_path):
30 | if file_name.endswith('.npy'):
31 | if file_name[:-4] in black_list: continue
32 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean
33 | pix_value += np.sum(data * data)
34 | pix_std = np.sqrt(pix_value / float(npix))
35 | print(pix_mean, pix_std)
36 | transform_train = transforms.Compose([
37 | # transforms.RandomScale(range(28, 38)),
38 | transforms.RandomCrop(32, padding=4),
39 | transforms.RandomHorizontalFlip(),
40 | transforms.RandomYFlip(),
41 | transforms.RandomZFlip(),
42 | transforms.ZeroOut(4),
43 | transforms.ToTensor(),
44 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func
45 | ])
46 |
47 | transform_test = transforms.Compose([
48 | transforms.ToTensor(),
49 | transforms.Normalize((pix_mean), (pix_std)),
50 | ])
51 |
52 | # load data list
53 | test_file_name_list = []
54 | test_label_list = []
55 | test_feat_list = []
56 |
57 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
58 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
59 |
60 | all_list = data_frame['seriesuid'].tolist()[1:]
61 | label_list = data_frame['malignant'].tolist()[1:]
62 | crdx_list = data_frame['coordX'].tolist()[1:]
63 | crdy_list = data_frame['coordY'].tolist()[1:]
64 | crdz_list = data_frame['coordZ'].tolist()[1:]
65 | dim_list = data_frame['diameter_mm'].tolist()[1:]
66 | # test id
67 | test_id_list = []
68 | for file_name in os.listdir(test_data_path + str(fold) + '/'):
69 |
70 | if file_name.endswith('.mhd'):
71 | test_id_list.append(file_name[:-4])
72 | mxx = mxy = mxz = mxd = 0
73 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list):
74 | mxx = max(abs(float(x)), mxx)
75 | mxy = max(abs(float(y)), mxy)
76 | mxz = max(abs(float(z)), mxz)
77 | mxd = max(abs(float(d)), mxd)
78 | if srsid in black_list:
79 | continue
80 | # crop raw pixel as feature
81 | data = np.load(os.path.join(preprocess_path, srsid + '.npy'))
82 | bgx = int(data.shape[0] / 2 - crop_size / 2)
83 | bgy = int(data.shape[1] / 2 - crop_size / 2)
84 | bgz = int(data.shape[2] / 2 - crop_size / 2)
85 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size])
86 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2]
87 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
88 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float)
89 | feat[mask] = 1
90 | if srsid.split('-')[0] in test_id_list:
91 | test_file_name_list.append(srsid + '.npy')
92 | test_label_list.append(int(label))
93 | test_feat_list.append(feat)
94 | for idx in range(len(test_feat_list)):
95 | test_feat_list[idx][-1] /= mxd
96 |
97 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False,
98 | download=True,
99 | transform=transform_test)
100 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
101 | return test_loader
102 |
103 |
104 | def load_module(module_config, set_num):
105 | path = f'/data/fuhao/PartialOrderPrunning/{module_config}/checkpoint-{set_num}/ckpt.t7'
106 | checkpoint = torch.load(path)
107 | net = checkpoint['net']
108 | net.cuda()
109 | return net
110 |
111 |
112 | def get_targets(test_loader):
113 | target_list = np.empty(shape=0)
114 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader):
115 | target_list = np.append(target_list, targets)
116 | target_list = target_list.astype(int)
117 | return target_list
118 |
119 |
120 | def get_permutations(model_list, count, top_count):
121 | result = []
122 | for i in permutations(model_list, count):
123 | result.append(list(i))
124 | if result.__len__() >= top_count:
125 | return result
126 | return result
127 |
128 |
129 | def test_module(module_config, set_num, test_loader):
130 | module = load_module(module_config, set_num)
131 | module.eval()
132 | result = np.empty(shape=0)
133 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader):
134 | inputs, targets = inputs.cuda(), targets.cuda()
135 |
136 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
137 | outputs = module(inputs)
138 | if not isinstance(outputs, tuple):
139 | _, predicted = torch.max(outputs.data, 1)
140 | else:
141 | _, predicted = torch.max(outputs[0].data, 1)
142 | result = np.append(result, predicted)
143 | return result
144 |
145 |
146 | def get_predicted(result_array):
147 | positive_array = result_array == 1
148 | negative_array = result_array == 0
149 | positive_count = np.sum(positive_array, axis=0)
150 | negative_count = np.sum(negative_array, axis=0)
151 | predicted = positive_count > negative_count
152 | return predicted.astype(int)
153 |
154 |
155 | if __name__ == '__main__':
156 | run_result = np.empty(shape=(0, 20))
157 | top_count = 20
158 | module_list = np.load('data/model.npy')
159 | module_list = list(filter(lambda x: '[32,64,[' in x, module_list))
160 | logging.basicConfig(filename='modelfusion_huge_log', level=logging.INFO)
161 | save_excel = 'modelfusion_huge'
162 | for i in range(3, 20):
163 | if i % 2 == 1:
164 | permutations_result = get_permutations(module_list[:i + 4], i, top_count)
165 | num = 0
166 | for modules in permutations_result:
167 | num += 1
168 | logging.info(f'model_count={i}')
169 | print(f'model_count={i}')
170 | logging.info(f'num:{num}')
171 | print(f'num:{num}')
172 | logging.info(modules)
173 | print(modules)
174 | line = []
175 | for fold in range(6):
176 | test_loader = load_data(fold, 8, 20)
177 | targets = get_targets(test_loader)
178 | length = targets.shape[0]
179 | all_result = np.empty(shape=(0, length))
180 | for module_config in modules:
181 | result = test_module(module_config, fold, test_loader)
182 | all_result = np.append(all_result, [result], axis=0)
183 | predicted = get_predicted(all_result)
184 | TP = np.sum((predicted == 1) & (targets == 1))
185 | TN = np.sum((predicted == 0) & (targets == 0))
186 | FN = np.sum((predicted == 0) & (targets == 1))
187 | FP = np.sum((predicted == 1) & (targets == 0))
188 | tpr = 100. * TP / (TP + FN)
189 | fpr = 100. * FP / (FP + TN)
190 | acc = 100. * np.sum(predicted == targets) / length
191 | line.append(acc)
192 | line.append(tpr)
193 | line.append(fpr)
194 | logging.info(f'set={fold}')
195 | print(f'set={fold}')
196 | logging.info(f'acc={acc}')
197 | print(f'acc={acc}')
198 | logging.info(f'tpr={tpr} fpr={fpr}')
199 | print(f'tpr={tpr} fpr={fpr}')
200 | run_result = np.append(run_result, np.array(line))
201 | np.save('run_result_huge', run_result)
202 | df = pd.DataFrame(data=run_result,
203 | columns=['module_count', 'module_config',
204 | 'fold-0-acc', 'fold-0-tpr', 'fold-0-fpr',
205 | 'fold-1-acc', 'fold-1-tpr', 'fold-1-fpr',
206 | 'fold-2-acc', 'fold-2-tpr', 'fold-2-fpr',
207 | 'fold-3-acc', 'fold-3-tpr', 'fold-3-fpr',
208 | 'fold-4-acc', 'fold-4-tpr', 'fold-4-fpr',
209 | 'fold-5-acc', 'fold-5-tpr', 'fold-5-fpr'],
210 | index=None)
211 | df.to_excel(save_excel)
212 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import pandas as pd
9 | import transforms as transforms
10 | from dataloader import lunanod
11 | import os
12 | import argparse
13 | import time
14 | from models.cnn_res import *
15 | # from utils import progress_bar
16 | from torch.autograd import Variable
17 | import logging
18 | import numpy as np
19 | import copy
20 |
21 |
22 | def load_data(trained_data_path, test_data_path, fold, batch_size, num_workers):
23 | crop_size = 32
24 | black_list = []
25 |
26 | preprocess_path = trained_data_path
27 | pix_value, npix = 0, 0
28 | for file_name in os.listdir(preprocess_path):
29 | if file_name.endswith('.npy'):
30 | if file_name[:-4] in black_list:
31 | continue
32 | data = np.load(os.path.join(preprocess_path, file_name))
33 | pix_value += np.sum(data)
34 | npix += np.prod(data.shape)
35 | pix_mean = pix_value / float(npix)
36 | pix_value = 0
37 | for file_name in os.listdir(preprocess_path):
38 | if file_name.endswith('.npy'):
39 | if file_name[:-4] in black_list: continue
40 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean
41 | pix_value += np.sum(data * data)
42 | pix_std = np.sqrt(pix_value / float(npix))
43 | print(pix_mean, pix_std)
44 | transform_train = transforms.Compose([
45 | # transforms.RandomScale(range(28, 38)),
46 | transforms.RandomCrop(32, padding=4),
47 | transforms.RandomHorizontalFlip(),
48 | transforms.RandomYFlip(),
49 | transforms.RandomZFlip(),
50 | transforms.ZeroOut(4),
51 | transforms.ToTensor(),
52 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func
53 | ])
54 |
55 | transform_test = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Normalize((pix_mean), (pix_std)),
58 | ])
59 |
60 | # load data list
61 | train_file_name_list = []
62 | train_label_list = []
63 | train_feat_list = []
64 | test_file_name_list = []
65 | test_label_list = []
66 | test_feat_list = []
67 |
68 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
69 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
70 |
71 | all_list = data_frame['seriesuid'].tolist()[1:]
72 | label_list = data_frame['malignant'].tolist()[1:]
73 | crdx_list = data_frame['coordX'].tolist()[1:]
74 | crdy_list = data_frame['coordY'].tolist()[1:]
75 | crdz_list = data_frame['coordZ'].tolist()[1:]
76 | dim_list = data_frame['diameter_mm'].tolist()[1:]
77 | # test id
78 | test_id_list = []
79 | for file_name in os.listdir(test_data_path + str(fold) + '/'):
80 |
81 | if file_name.endswith('.mhd'):
82 | test_id_list.append(file_name[:-4])
83 | mxx = mxy = mxz = mxd = 0
84 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list):
85 | mxx = max(abs(float(x)), mxx)
86 | mxy = max(abs(float(y)), mxy)
87 | mxz = max(abs(float(z)), mxz)
88 | mxd = max(abs(float(d)), mxd)
89 | if srsid in black_list:
90 | continue
91 | # crop raw pixel as feature
92 | data = np.load(os.path.join(preprocess_path, srsid + '.npy'))
93 | bgx = int(data.shape[0] / 2 - crop_size / 2)
94 | bgy = int(data.shape[1] / 2 - crop_size / 2)
95 | bgz = int(data.shape[2] / 2 - crop_size / 2)
96 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size])
97 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2]
98 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
99 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float)
100 | feat[mask] = 1
101 | if srsid.split('-')[0] in test_id_list:
102 | test_file_name_list.append(srsid + '.npy')
103 | test_label_list.append(int(label))
104 | test_feat_list.append(feat)
105 | else:
106 | train_file_name_list.append(srsid + '.npy')
107 | train_label_list.append(int(label))
108 | train_feat_list.append(feat)
109 | for idx in range(len(train_feat_list)):
110 | train_feat_list[idx][-1] /= mxd
111 | for idx in range(len(test_feat_list)):
112 | test_feat_list[idx][-1] /= mxd
113 | train_set = lunanod(preprocess_path, train_file_name_list, train_label_list, train_feat_list, train=True,
114 | download=True,
115 | transform=transform_train)
116 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
117 |
118 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False,
119 | download=True,
120 | transform=transform_test)
121 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
122 | return train_loader, test_loader
123 |
124 |
125 | def train_module(net, use_cuda, train_loader, optimizer, criterion, log, lr, config, epoch):
126 | net.train()
127 |
128 | for i in range(epoch):
129 | correct = 0
130 | total = 0
131 | for batch_idx, (inputs, targets, feat) in enumerate(train_loader):
132 | if use_cuda:
133 | inputs, targets = inputs.cuda(), targets.cuda()
134 |
135 | optimizer.zero_grad()
136 | inputs, targets = Variable(inputs), Variable(targets)
137 | outputs = net(inputs)
138 | loss = criterion(outputs, targets)
139 |
140 | loss.backward()
141 | optimizer.step()
142 | _, predicted = torch.max(outputs.data, 1)
143 | total += targets.size(0)
144 | correct += predicted.eq(targets.data).cpu().sum()
145 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
146 | print(
147 | 'ep ' + str(i) + str(config) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
148 |
149 | log.info(
150 | 'ep ' + str(i) + str(config) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
151 | return net
152 |
153 |
154 | def my_test_module(net, use_cuda, test_loader, criterion, log):
155 | epoch_start_time = time.time()
156 | # global best_acc
157 | # global best_acc_gbt
158 | net.eval()
159 | test_loss = 0
160 | correct = 0
161 | total = 0
162 | TP = FP = FN = TN = 0
163 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader):
164 | if use_cuda:
165 | inputs, targets = inputs.cuda(), targets.cuda()
166 |
167 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
168 | outputs = net(inputs)
169 |
170 | loss = criterion(outputs, targets)
171 | test_loss += loss.data.item()
172 | _, predicted = torch.max(outputs.data, 1)
173 | total += targets.size(0)
174 | correct += predicted.eq(targets.data).cpu().sum()
175 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum()
176 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum()
177 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum()
178 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum()
179 |
180 | acc = 100. * correct.data.item() / total
181 |
182 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item())
183 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item())
184 |
185 | print('teacc ' + str(acc))
186 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr))
187 | print('Time Taken: %d sec' % (time.time() - epoch_start_time))
188 | log.info(
189 | 'teacc ' + str(acc))
190 | log.info(
191 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr))
192 | return acc
193 |
194 |
195 | def get_acc(net, use_cuda, train_loader, test_loader, optimizer, criterion, log, lr, config, epoch):
196 | net = train_module(net, use_cuda, train_loader, optimizer, criterion, log, lr, config, epoch)
197 | acc = my_test_module(net, use_cuda, test_loader, criterion, log)
198 | return acc
199 |
200 |
201 | def net_to_cuda(net, use_gpu, gpu_ids):
202 | if use_gpu:
203 | net.cuda()
204 | if gpu_ids == 'all':
205 | device_ids = range(torch.cuda.device_count())
206 | else:
207 | device_ids = list(map(int, list(filter(str.isdigit, gpu_ids))))
208 |
209 | print('gpu use' + str(device_ids))
210 | net = torch.nn.DataParallel(net, device_ids=device_ids)
211 | return net
212 |
213 |
214 | def get_module_lat(module, input_shape):
215 | x = torch.randn(input_shape)
216 | start = time.time()
217 | y = module(x)
218 | print(y)
219 | end = time.time()
220 | return end - start
221 |
222 |
223 | def get_yw(modules, module):
224 | module_acc = module[1]
225 | tmp_modules = copy.deepcopy(modules)
226 | # original_module_index = np.where(tmp_modules[:, 0] == [module[0]])[0]
227 | # tmp_modules = np.delete(tmp_modules, original_module_index, 0)
228 | tmp_modules = tmp_modules.tolist()
229 | tmp_modules.remove(module.tolist())
230 | tmp_modules = np.array(tmp_modules)
231 | if tmp_modules.size > 0:
232 | modules_acc = tmp_modules[:, 1]
233 | tmp = np.where(modules_acc >= module_acc)[0]
234 | better_modules = tmp_modules[tmp]
235 | if better_modules.size > 0:
236 | min_lat_index = np.argmin(better_modules[:, 2])
237 | return better_modules[min_lat_index].tolist()
238 | return []
239 |
240 |
241 | # a = np.array([[[[1, 2, 3], [12, 3], [1, 2]], 0.2, 0.3], [[[1, 2, 3], [12, 3], [1, 2]], 0.3, 0.3]])
242 | # # a = np.array([[[[1, 2, 3], [12, 3], [1, 2]], 0.2, 0.3]])
243 | # get_yw(a, a[1])
244 | # # for i in a:
245 | # # print(i)
246 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 |
10 | import transforms as transforms
11 | from dataloader import lunanod
12 | import os
13 | import argparse
14 | import time
15 | from models.cnn_res import *
16 | # from utils import progress_bar
17 | from torch.autograd import Variable
18 | import logging
19 | import numpy as np
20 | import ast
21 |
22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
23 | parser.add_argument('--lr', default=0.0002, type=float, help='learning rate')
24 | parser.add_argument('--batch_size', default=1, type=int, help='batch size ')
25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
26 | parser.add_argument('--savemodel', type=str, default='', help='resume from checkpoint model')
27 | parser.add_argument("--gpuids", type=str, default='all', help='use which gpu')
28 |
29 | parser.add_argument('--num_epochs', type=int, default=700)
30 | parser.add_argument('--num_epochs_decay', type=int, default=70)
31 |
32 | parser.add_argument('--num_workers', type=int, default=24)
33 |
34 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
35 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
36 | parser.add_argument('--lamb', type=float, default=1, help="lambda for loss2")
37 | parser.add_argument('--fold', type=int, default=5, help="fold")
38 |
39 | args = parser.parse_args()
40 |
41 | CROPSIZE = 32
42 | gbtdepth = 1
43 | fold = args.fold
44 | blklst = [] # ['1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-388', \
45 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-389', \
46 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.132817748896065918417924920957-660']
47 | logging.basicConfig(filename='log-' + str(fold), level=logging.INFO)
48 |
49 | use_cuda = torch.cuda.is_available()
50 | best_acc = 0 # best test accuracy
51 | best_acc_gbt = 0
52 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
53 | # Cal mean std
54 | # preprocesspath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/'
55 | preprocesspath = '/data/xxx/LUNA/cls/crop_v3/'
56 | # preprocesspath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/'
57 | pixvlu, npix = 0, 0
58 | for fname in os.listdir(preprocesspath):
59 | # print(fname)
60 | if fname.endswith('.npy'):
61 | if fname[:-4] in blklst: continue
62 | data = np.load(os.path.join(preprocesspath, fname))
63 | pixvlu += np.sum(data)
64 | # print("data.shape = " + str(data.shape))
65 | npix += np.prod(data.shape)
66 | pixmean = pixvlu / float(npix)
67 | pixvlu = 0
68 | for fname in os.listdir(preprocesspath):
69 | if fname.endswith('.npy'):
70 | if fname[:-4] in blklst: continue
71 | data = np.load(os.path.join(preprocesspath, fname)) - pixmean
72 | pixvlu += np.sum(data * data)
73 | pixstd = np.sqrt(pixvlu / float(npix))
74 | # pixstd /= 255
75 | print(pixmean, pixstd)
76 | logging.info('mean ' + str(pixmean) + ' std ' + str(pixstd))
77 | # Datatransforms
78 | logging.info('==> Preparing data..') # Random Crop, Zero out, x z flip, scale,
79 | transform_train = transforms.Compose([
80 | # transforms.RandomScale(range(28, 38)),
81 | transforms.RandomCrop(32, padding=4),
82 | transforms.RandomHorizontalFlip(),
83 | transforms.RandomYFlip(),
84 | transforms.RandomZFlip(),
85 | transforms.ZeroOut(4),
86 | transforms.ToTensor(),
87 | transforms.Normalize((pixmean), (pixstd)), # need to cal mean and std, revise norm func
88 | ])
89 |
90 | transform_test = transforms.Compose([
91 | transforms.ToTensor(),
92 | transforms.Normalize((pixmean), (pixstd)),
93 | ])
94 |
95 | # load data list
96 | trfnamelst = []
97 | trlabellst = []
98 | trfeatlst = []
99 | tefnamelst = []
100 | telabellst = []
101 | tefeatlst = []
102 | import pandas as pd
103 |
104 | dataframe = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
105 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
106 |
107 | alllst = dataframe['seriesuid'].tolist()[1:]
108 | labellst = dataframe['malignant'].tolist()[1:]
109 | crdxlst = dataframe['coordX'].tolist()[1:]
110 | crdylst = dataframe['coordY'].tolist()[1:]
111 | crdzlst = dataframe['coordZ'].tolist()[1:]
112 | dimlst = dataframe['diameter_mm'].tolist()[1:]
113 | # test id
114 | teidlst = []
115 | for fname in os.listdir('/data/xxx/LUNA/rowfile/subset' + str(fold) + '/'):
116 | # for fname in os.listdir('/media/jehovah/Work/data/LUNA/rowfile/subset' + str(fold) + '/'):
117 |
118 | if fname.endswith('.mhd'):
119 | teidlst.append(fname[:-4])
120 | mxx = mxy = mxz = mxd = 0
121 | for srsid, label, x, y, z, d in zip(alllst, labellst, crdxlst, crdylst, crdzlst, dimlst):
122 | mxx = max(abs(float(x)), mxx)
123 | mxy = max(abs(float(y)), mxy)
124 | mxz = max(abs(float(z)), mxz)
125 | mxd = max(abs(float(d)), mxd)
126 | if srsid in blklst: continue
127 | # crop raw pixel as feature
128 | data = np.load(os.path.join(preprocesspath, srsid + '.npy'))
129 | bgx = int(data.shape[0] / 2 - CROPSIZE / 2)
130 | bgy = int(data.shape[1] / 2 - CROPSIZE / 2)
131 | bgz = int(data.shape[2] / 2 - CROPSIZE / 2)
132 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE])
133 | # feat = np.hstack((np.reshape(data, (-1,)) / 255, float(d)))
134 | y, x, z = np.ogrid[-CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2]
135 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
136 | feat = np.zeros((CROPSIZE, CROPSIZE, CROPSIZE), dtype=float)
137 | feat[mask] = 1
138 | # print(feat.shape)
139 | if srsid.split('-')[0] in teidlst:
140 | tefnamelst.append(srsid + '.npy')
141 | telabellst.append(int(label))
142 | tefeatlst.append(feat)
143 | else:
144 | trfnamelst.append(srsid + '.npy')
145 | trlabellst.append(int(label))
146 | trfeatlst.append(feat)
147 | for idx in range(len(trfeatlst)):
148 | # trfeatlst[idx][0] /= mxx
149 | # trfeatlst[idx][1] /= mxy
150 | # trfeatlst[idx][2] /= mxz
151 | trfeatlst[idx][-1] /= mxd
152 | for idx in range(len(tefeatlst)):
153 | # tefeatlst[idx][0] /= mxx
154 | # tefeatlst[idx][1] /= mxy
155 | # tefeatlst[idx][2] /= mxz
156 | tefeatlst[idx][-1] /= mxd
157 | trainset = lunanod(preprocesspath, trfnamelst, trlabellst, trfeatlst, train=True, download=True,
158 | transform=transform_train)
159 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=20)
160 |
161 | testset = lunanod(preprocesspath, tefnamelst, telabellst, tefeatlst, train=False, download=True,
162 | transform=transform_test)
163 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=20)
164 | savemodelpath = './checkpoint-' + str(fold) + '/'
165 | # Model
166 | print(args.resume)
167 | if args.resume:
168 | print('==> Resuming from checkpoint..')
169 | print(args.savemodel)
170 | if args.savemodel == '':
171 | logging.info('==> Resuming from checkpoint..')
172 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
173 | checkpoint = torch.load(savemodelpath + 'ckpt.t7')
174 |
175 | else:
176 | logging.info('==> Resuming from checkpoint..')
177 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
178 | checkpoint = torch.load(args.savemodel)
179 | net = checkpoint['net']
180 | best_acc = checkpoint['acc']
181 | start_epoch = checkpoint['epoch']
182 | print(savemodelpath + " load success")
183 | print(start_epoch)
184 | else:
185 | logging.info('==> Building model..')
186 | logging.info('args.savemodel : ' + args.savemodel)
187 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]])
188 | if args.savemodel != "":
189 | # args.savemodel = '/home/xxx/DeepLung-master/nodcls/checkpoint-5/ckpt.t7'
190 | checkpoint = torch.load(args.savemodel)
191 | finenet = checkpoint
192 | Low_rankmodel_dic = net.state_dict()
193 | finenet = {k: v for k, v in finenet.items() if k in Low_rankmodel_dic}
194 | Low_rankmodel_dic.update(finenet)
195 | net.load_state_dict(Low_rankmodel_dic)
196 | print("net_loaded")
197 |
198 | lr = args.lr
199 |
200 |
201 | def get_lr(epoch):
202 | global lr
203 | if (epoch + 1) > (args.num_epochs - args.num_epochs_decay):
204 | lr -= (lr / float(args.num_epochs_decay))
205 | for param_group in optimizer.param_groups:
206 | param_group['lr'] = lr
207 | print('Decay learning rate to lr: {}.'.format(lr))
208 |
209 |
210 | if use_cuda:
211 | net.cuda()
212 | if args.gpuids == 'all':
213 | device_ids = range(torch.cuda.device_count())
214 | else:
215 | device_ids = map(int, list(filter(str.isdigit, args.gpuids)))
216 |
217 | print('gpu use' + str(device_ids))
218 | net = torch.nn.DataParallel(net, device_ids=device_ids)
219 | cudnn.benchmark = False # True
220 |
221 | criterion = nn.CrossEntropyLoss()
222 | optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
223 |
224 |
225 | # L2Loss = torch.nn.MSELoss()
226 |
227 | # Training
228 | def train(epoch):
229 | logging.info('\nEpoch: ' + str(epoch))
230 | net.train()
231 | get_lr(epoch)
232 | train_loss = 0
233 | correct = 0
234 | total = 0
235 |
236 | for batch_idx, (inputs, targets, feat) in enumerate(trainloader):
237 | if use_cuda:
238 | inputs, targets = inputs.cuda(), targets.cuda()
239 |
240 | optimizer.zero_grad()
241 | inputs, targets = Variable(inputs), Variable(targets)
242 | outputs = net(inputs)
243 | loss = criterion(outputs, targets)
244 |
245 | loss.backward()
246 | optimizer.step()
247 | train_loss += loss.data.item()
248 | _, predicted = torch.max(outputs.data, 1)
249 | total += targets.size(0)
250 | correct += predicted.eq(targets.data).cpu().sum()
251 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
252 |
253 | print('ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
254 |
255 | logging.info(
256 | 'ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
257 |
258 |
259 | def test(epoch):
260 | epoch_start_time = time.time()
261 | global best_acc
262 | global best_acc_gbt
263 | net.eval()
264 | test_loss = 0
265 | correct = 0
266 | total = 0
267 | TP = FP = FN = TN = 0
268 | for batch_idx, (inputs, targets, feat) in enumerate(testloader):
269 | if use_cuda:
270 | inputs, targets = inputs.cuda(), targets.cuda()
271 |
272 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
273 | outputs = net(inputs)
274 |
275 | loss = criterion(outputs, targets)
276 | test_loss += loss.data.item()
277 | _, predicted = torch.max(outputs.data, 1)
278 | total += targets.size(0)
279 | correct += predicted.eq(targets.data).cpu().sum()
280 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum()
281 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum()
282 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum()
283 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum()
284 |
285 | # Save checkpoint.
286 | acc = 100. * correct.data.item() / total
287 | if acc > best_acc:
288 | logging.info('Saving..')
289 | state = {
290 | 'net': net.module if use_cuda else net,
291 | 'acc': acc,
292 | 'epoch': epoch,
293 | }
294 | if not os.path.isdir(savemodelpath):
295 | os.mkdir(savemodelpath)
296 | torch.save(state, savemodelpath + 'ckpt.t7')
297 | best_acc = acc
298 | logging.info('Saving..')
299 | state = {
300 | 'net': net.module if use_cuda else net,
301 | 'acc': acc,
302 | 'epoch': epoch,
303 | }
304 | if not os.path.isdir(savemodelpath):
305 | os.mkdir(savemodelpath)
306 | if epoch % 50 == 0:
307 | torch.save(state, savemodelpath + 'ckpt' + str(epoch) + '.t7')
308 | # best_acc = acc
309 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item())
310 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item())
311 |
312 | print('teacc ' + str(acc) + ' bestacc ' + str(best_acc))
313 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr))
314 | print('Time Taken: %d sec' % (time.time() - epoch_start_time))
315 | logging.info(
316 | 'teacc ' + str(acc) + ' bestacc ' + str(best_acc))
317 | logging.info(
318 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr))
319 |
320 |
321 | if __name__ == '__main__':
322 | for epoch in range(start_epoch + 1, start_epoch + args.num_epochs + 1): # 200):
323 | train(epoch)
324 | test(epoch)
325 |
--------------------------------------------------------------------------------
/models/net_sphere.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://github.com/clcarwin/sphereface_pytorch/blob/master/net_sphere.py
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | import torch.nn.functional as F
9 | from torch.nn import Parameter
10 | import dill
11 | import math
12 |
13 |
14 | def myphi(x, m):
15 | x = x * m
16 | return 1 - x ** 2 / math.factorial(2) + x ** 4 / math.factorial(4) - x ** 6 / math.factorial(6) + \
17 | x ** 8 / math.factorial(8) - x ** 9 / math.factorial(9)
18 |
19 |
20 | import math
21 | import torch
22 | from torch import nn
23 | # from scipy.special import binom
24 | import scipy.special as special
25 |
26 | class LSoftmaxLinear(nn.Module):
27 |
28 | def __init__(self, input_features, output_features, margin=4, device='cuda'):
29 | super().__init__()
30 | self.input_dim = input_features # number of input feature i.e. output of the last fc layer
31 | self.output_dim = output_features # number of output = class numbers
32 | self.margin = margin # m
33 | self.beta = 100
34 | self.beta_min = 0
35 | self.scale = 0.99
36 |
37 | # self.device = device # gpu or cpu
38 | use_cuda = not False and torch.cuda.is_available()
39 | self.device = torch.device("cuda" if use_cuda else "cpu")
40 |
41 | # Initialize L-Softmax parameters
42 | self.weight = nn.Parameter(torch.FloatTensor(input_features, output_features))
43 | self.divisor = math.pi / self.margin # pi/m
44 | self.C_m_2n = torch.Tensor(special.binom(margin, range(0, margin + 1, 2))).to(device) # C_m{2n}
45 | self.cos_powers = torch.Tensor(range(self.margin, -1, -2)).to(device) # m - 2n
46 | self.sin2_powers = torch.Tensor(range(len(self.cos_powers))).to(device) # n
47 | self.signs = torch.ones(margin // 2 + 1).to(device) # 1, -1, 1, -1, ...
48 | self.signs[1::2] = -1
49 |
50 | def calculate_cos_m_theta(self, cos_theta):
51 | sin2_theta = 1 - cos_theta ** 2
52 | cos_terms = cos_theta.unsqueeze(1) ** self.cos_powers.unsqueeze(0) # cos^{m - 2n}
53 | sin2_terms = (sin2_theta.unsqueeze(1) # sin2^{n}
54 | ** self.sin2_powers.unsqueeze(0))
55 |
56 | cos_m_theta = (self.signs.unsqueeze(0) * # -1^{n} * C_m{2n} * cos^{m - 2n} * sin2^{n}
57 | self.C_m_2n.unsqueeze(0) *
58 | cos_terms *
59 | sin2_terms).sum(1) # summation of all terms
60 |
61 | return cos_m_theta
62 |
63 | def reset_parameters(self):
64 | nn.init.kaiming_normal_(self.weight.data.t())
65 |
66 | def find_k(self, cos):
67 | # to account for acos numerical errors
68 | eps = 1e-7
69 | cos = torch.clamp(cos, -1 + eps, 1 - eps)
70 | acos = cos.acos()
71 | k = (acos / self.divisor).floor().detach()
72 | return k
73 |
74 | def forward(self, input, target=None):
75 | a = 0
76 | if self.training:
77 | assert target is not None
78 | x, w = input, self.weight
79 | beta = max(self.beta, self.beta_min)
80 | logit = x.mm(w)
81 | indexes = range(logit.size(0))
82 | logit_target = logit[indexes, target]
83 |
84 | # cos(theta) = w * x / ||w||*||x||
85 | w_target_norm = w[:, target].norm(p=2, dim=0)
86 | x_norm = x.norm(p=2, dim=1)
87 | cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10)
88 |
89 | # equation 7
90 | cos_m_theta_target = self.calculate_cos_m_theta(cos_theta_target)
91 |
92 | # find k in equation 6
93 | k = self.find_k(cos_theta_target)
94 |
95 | # f_y_i
96 | logit_target_updated = (w_target_norm *
97 | x_norm *
98 | (((-1) ** k * cos_m_theta_target) - 2 * k))
99 | logit_target_updated_beta = (logit_target_updated + beta * logit[indexes, target]) / (1 + beta)
100 |
101 | logit[indexes, target] = logit_target_updated_beta
102 | self.beta *= self.scale
103 | return logit
104 | else:
105 | assert target is None
106 | return input.mm(self.weight)
107 |
108 |
109 | class AngleLinear(nn.Module):
110 | def __init__(self, in_features, out_features, m=4, phiflag=True):
111 | super(AngleLinear, self).__init__()
112 | self.in_features = in_features
113 | self.out_features = out_features
114 | self.weight = Parameter(torch.Tensor(in_features, out_features))
115 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
116 | self.phiflag = phiflag
117 | self.m = m
118 | # self.mlambda = [
119 | # lambda x: x ** 0,
120 | # lambda x: x ** 1,
121 | # lambda x: 2 * x ** 2 - 1,
122 | # lambda x: 4 * x ** 3 - 3 * x,
123 | # lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
124 | # lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
125 | # ]
126 |
127 | def forward(self, input):
128 | x = input # size=(B,F) F is feature len (128*512)
129 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
130 | # w = 512*227
131 | ww = w.renorm(2, 1, 1e-5).mul(1e5)
132 | xlen = x.pow(2).sum(1).pow(0.5) # size=B
133 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
134 |
135 | cos_theta = x.mm(ww) # size=(B,Classnum)
136 | cos_theta = cos_theta / xlen.view(-1, 1) / wlen.view(1, -1)
137 | cos_theta = cos_theta.clamp(-1, 1)
138 |
139 | if self.phiflag:
140 | cos_m_theta = 8 * cos_theta ** 4 - 8 * cos_theta ** 2 + 1
141 | theta = Variable(cos_theta.data.acos())
142 | k = (self.m * theta / 3.14159265).floor()
143 | n_one = k * 0.0 - 1
144 | phi_theta = (n_one ** k) * cos_m_theta - 2 * k
145 | else:
146 | theta = cos_theta.acos()
147 | phi_theta = myphi(theta, self.m)
148 | phi_theta = phi_theta.clamp(-1 * self.m, 1)
149 |
150 | cos_theta = cos_theta * xlen.view(-1, 1)
151 | phi_theta = phi_theta * xlen.view(-1, 1)
152 | output = (cos_theta, phi_theta)
153 | return output # size=(B,Classnum,2)
154 |
155 |
156 | class AngleLoss(nn.Module):
157 | def __init__(self, gamma=0):
158 | super(AngleLoss, self).__init__()
159 | self.gamma = gamma
160 | self.it = 1
161 | self.LambdaMin = 5.0
162 | self.LambdaMax = 1500.0
163 | self.lamb = 1500.0
164 |
165 | def forward(self, input, target):
166 | cos_theta, phi_theta = input
167 | target = target.view(-1, 1) # size=(B,1)
168 | index = cos_theta.data * 0.0 # size=(B, Classnum)
169 | # index = index.scatter(1, target.data.view(-1, 1).long(), 1)
170 | index = index.byte()
171 | index = Variable(index)
172 | # index = Variable(torch.randn(1,2)).byte()
173 |
174 | self.lamb = max(self.LambdaMin, self.LambdaMax / (1 + 0.1 * self.it))
175 | output = cos_theta * 1.0 # size=(B,Classnum)
176 | output1 = output.clone()
177 | # output1[index1] = output[index] - cos_theta[index] * (1.0 + 0) / (1 + self.lamb)
178 | # output1[index1] = output[index] + phi_theta[index] * (1.0 + 0) / (1 + self.lamb)
179 | output[index] = output1[index]- cos_theta[index] * (1.0 + 0) / (1 + self.lamb)+ phi_theta[index] * (1.0 + 0) / (1 + self.lamb)
180 | logpt = F.log_softmax(output)
181 | logpt = logpt.gather(1, target.long())
182 | logpt = logpt.view(-1)
183 | pt = Variable(logpt.data.exp())
184 |
185 | loss = -1 * (1 - pt) ** self.gamma * logpt
186 | loss = loss.mean()
187 | # loss = torch.sum(cos_theta)+ torch.sum(phi_theta)
188 | return loss
189 |
190 |
191 | # class STN(nn.Module):
192 | # def __init__(self ):
193 | # super(STN, self).__init__()
194 | # self.localization = nn.Sequential(
195 | # nn.Conv2d(3, 8, kernel_size=7),
196 | # nn.MaxPool2d(2, stride=2),
197 | # nn.ReLU(True),
198 | # nn.Conv2d(8, 10, kernel_size=5),
199 | # nn.MaxPool2d(2, stride=2),
200 | # nn.ReLU(True)
201 | # )
202 | # self.fc_loc = nn.Sequential(
203 | # nn.Linear(10*24*20, 32),
204 | # nn.ReLU(True),
205 | # nn.Linear(32, 3 * 2)
206 | # )
207 | #
208 | # # Initialize the weights/bias with identity transformation
209 | # # self.fc_loc[2].weight.data.zero_()
210 | # # self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
211 | #
212 | # def forward(self, x):
213 | # xs = self.localization(x)
214 | # xs = xs.view(-1, 10*24*20)
215 | # theta = self.fc_loc(xs)
216 | # theta = theta.view(-1, 2, 3)
217 | #
218 | # grid = F.affine_grid(theta, x.size())
219 | # x = F.grid_sample(x, grid)
220 | #
221 | # return x
222 |
223 |
224 | class sphere20a(nn.Module):
225 | def __init__(self, classnum=10574, feature=False):
226 | # classnum = dataloader.dataset.class_num = 227
227 | super(sphere20a, self).__init__()
228 | self.classnum = classnum
229 | self.feature = feature
230 | # input = B*3*112*96
231 | self.conv1_1 = nn.Conv2d(3, 64, 3, 2, 1) # =>B*64*56*48
232 | self.relu1_1 = nn.PReLU(64)
233 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
234 | self.relu1_2 = nn.PReLU(64)
235 | self.conv1_3 = nn.Conv2d(64, 64, 3, 1, 1)
236 | self.relu1_3 = nn.PReLU(64)
237 |
238 | self.conv2_1 = nn.Conv2d(64, 128, 3, 2, 1) # =>B*128*28*24
239 | self.relu2_1 = nn.PReLU(128)
240 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
241 | self.relu2_2 = nn.PReLU(128)
242 | self.conv2_3 = nn.Conv2d(128, 128, 3, 1, 1)
243 | self.relu2_3 = nn.PReLU(128)
244 |
245 | self.conv2_4 = nn.Conv2d(128, 128, 3, 1, 1) # =>B*128*28*24
246 | self.relu2_4 = nn.PReLU(128)
247 | self.conv2_5 = nn.Conv2d(128, 128, 3, 1, 1)
248 | self.relu2_5 = nn.PReLU(128)
249 |
250 | self.conv3_1 = nn.Conv2d(128, 256, 3, 2, 1) # =>B*256*14*12
251 | self.relu3_1 = nn.PReLU(256)
252 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
253 | self.relu3_2 = nn.PReLU(256)
254 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
255 | self.relu3_3 = nn.PReLU(256)
256 |
257 | self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12
258 | self.relu3_4 = nn.PReLU(256)
259 | self.conv3_5 = nn.Conv2d(256, 256, 3, 1, 1)
260 | self.relu3_5 = nn.PReLU(256)
261 |
262 | self.conv3_6 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12
263 | self.relu3_6 = nn.PReLU(256)
264 | self.conv3_7 = nn.Conv2d(256, 256, 3, 1, 1)
265 | self.relu3_7 = nn.PReLU(256)
266 |
267 | self.conv3_8 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12
268 | self.relu3_8 = nn.PReLU(256)
269 | self.conv3_9 = nn.Conv2d(256, 256, 3, 1, 1)
270 | self.relu3_9 = nn.PReLU(256)
271 |
272 | self.conv4_1 = nn.Conv2d(256, 512, 3, 2, 1) # =>B*512*7*6
273 | self.relu4_1 = nn.PReLU(512)
274 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
275 | self.relu4_2 = nn.PReLU(512)
276 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
277 | self.relu4_3 = nn.PReLU(512)
278 |
279 | self.fc5 = nn.Linear(512 * 7 * 6, 512)
280 | self.fc6 = AngleLinear(512, self.classnum)
281 | # self.stn = STN()
282 |
283 | def forward(self, x, target=None):
284 | # x = self.stn(x)
285 | x = self.relu1_1(self.conv1_1(x))
286 | x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x))))
287 |
288 | x = self.relu2_1(self.conv2_1(x))
289 | x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x))))
290 | x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x))))
291 |
292 | x = self.relu3_1(self.conv3_1(x))
293 | x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x))))
294 | x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x))))
295 | x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x))))
296 | x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x))))
297 |
298 | x = self.relu4_1(self.conv4_1(x))
299 | x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x))))
300 |
301 | x = x.view(x.size(0), -1)
302 | x = self.fc5(x) # 128*512
303 | if self.feature: # self.feature=False
304 | return x
305 |
306 | x = self.fc6(x)
307 |
308 | return x
309 |
310 |
311 | class testsp(nn.Module):
312 | def __init__(self):
313 | super(testsp, self).__init__()
314 | self.linear = AngleLinear(100, 2)
315 |
316 | def forward(self, x):
317 | out = self.linear(x)
318 | return out
319 |
320 | def test():
321 | net = testsp()
322 | x = Variable(torch.randn(1,100))
323 | tar = Variable(torch.Tensor(1))
324 | out = net(x)
325 | cre = AngleLoss()
326 | loss = cre(out, tar)
327 | loss.backward()
328 |
--------------------------------------------------------------------------------
/mainsp.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 | import models.net_sphere as sp_net
10 | import transforms as transforms
11 | from dataloader import lunanod
12 | import os
13 | import argparse
14 | import time
15 | from models.cnn_res import *
16 | # from utils import progress_bar
17 | from torch.autograd import Variable
18 | import logging
19 | import numpy as np
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
22 | parser.add_argument('--lr', default=0.0002, type=float, help='learning rate')
23 | parser.add_argument('--batch_size', default=1, type=int, help='batch size ')
24 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
25 | parser.add_argument('--savemodel', type=str, default='', help='resume from checkpoint model')
26 | parser.add_argument("--gpuids", type=str, default='all', help='use which gpu')
27 |
28 | parser.add_argument('--num_epochs', type=int, default=700)
29 | parser.add_argument('--num_epochs_decay', type=int, default=70)
30 |
31 | parser.add_argument('--num_workers', type=int, default=24)
32 |
33 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
34 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
35 | parser.add_argument('--lamb', type=float, default=1, help="lambda for loss2")
36 | parser.add_argument('--fold', type=int, default=5, help="fold")
37 |
38 | args = parser.parse_args()
39 |
40 | CROPSIZE = 32
41 | gbtdepth = 1
42 | fold = args.fold
43 | blklst = [] # ['1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-388', \
44 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-389', \
45 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.132817748896065918417924920957-660']
46 | logging.basicConfig(filename='log-' + str(fold), level=logging.INFO)
47 |
48 | use_cuda = torch.cuda.is_available()
49 | best_acc = 0 # best test accuracy
50 | best_acc_gbt = 0
51 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
52 | # Cal mean std
53 | # preprocesspath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/'
54 | preprocesspath = '/data/xxx/LUNA/cls/crop_v3/'
55 | # preprocesspath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/'
56 | pixvlu, npix = 0, 0
57 | for fname in os.listdir(preprocesspath):
58 | # print(fname)
59 | if fname.endswith('.npy'):
60 | if fname[:-4] in blklst: continue
61 | data = np.load(os.path.join(preprocesspath, fname))
62 | pixvlu += np.sum(data)
63 | # print("data.shape = " + str(data.shape))
64 | npix += np.prod(data.shape)
65 | pixmean = pixvlu / float(npix)
66 | pixvlu = 0
67 | for fname in os.listdir(preprocesspath):
68 | if fname.endswith('.npy'):
69 | if fname[:-4] in blklst: continue
70 | data = np.load(os.path.join(preprocesspath, fname)) - pixmean
71 | pixvlu += np.sum(data * data)
72 | pixstd = np.sqrt(pixvlu / float(npix))
73 | # pixstd /= 255
74 | print(pixmean, pixstd)
75 | logging.info('mean ' + str(pixmean) + ' std ' + str(pixstd))
76 | # Datatransforms
77 | logging.info('==> Preparing data..') # Random Crop, Zero out, x z flip, scale,
78 | transform_train = transforms.Compose([
79 | # transforms.RandomScale(range(28, 38)),
80 | transforms.RandomCrop(32, padding=4),
81 | transforms.RandomHorizontalFlip(),
82 | transforms.RandomYFlip(),
83 | transforms.RandomZFlip(),
84 | transforms.ZeroOut(4),
85 | transforms.ToTensor(),
86 | transforms.Normalize((pixmean), (pixstd)), # need to cal mean and std, revise norm func
87 | ])
88 |
89 | transform_test = transforms.Compose([
90 | transforms.ToTensor(),
91 | transforms.Normalize((pixmean), (pixstd)),
92 | ])
93 |
94 | # load data list
95 | trfnamelst = []
96 | trlabellst = []
97 | trfeatlst = []
98 | tefnamelst = []
99 | telabellst = []
100 | tefeatlst = []
101 | import pandas as pd
102 |
103 | dataframe = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
104 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
105 |
106 | alllst = dataframe['seriesuid'].tolist()[1:]
107 | labellst = dataframe['malignant'].tolist()[1:]
108 | crdxlst = dataframe['coordX'].tolist()[1:]
109 | crdylst = dataframe['coordY'].tolist()[1:]
110 | crdzlst = dataframe['coordZ'].tolist()[1:]
111 | dimlst = dataframe['diameter_mm'].tolist()[1:]
112 | # test id
113 | teidlst = []
114 | for fname in os.listdir('/data/xxx/LUNA/rowfile/subset' + str(fold) + '/'):
115 | # for fname in os.listdir('/media/jehovah/Work/data/LUNA/rowfile/subset' + str(fold) + '/'):
116 |
117 | if fname.endswith('.mhd'):
118 | teidlst.append(fname[:-4])
119 | mxx = mxy = mxz = mxd = 0
120 | for srsid, label, x, y, z, d in zip(alllst, labellst, crdxlst, crdylst, crdzlst, dimlst):
121 | mxx = max(abs(float(x)), mxx)
122 | mxy = max(abs(float(y)), mxy)
123 | mxz = max(abs(float(z)), mxz)
124 | mxd = max(abs(float(d)), mxd)
125 | if srsid in blklst: continue
126 | # crop raw pixel as feature
127 | data = np.load(os.path.join(preprocesspath, srsid + '.npy'))
128 | bgx = int(data.shape[0] / 2 - CROPSIZE / 2)
129 | bgy = int(data.shape[1] / 2 - CROPSIZE / 2)
130 | bgz = int(data.shape[2] / 2 - CROPSIZE / 2)
131 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE])
132 | # feat = np.hstack((np.reshape(data, (-1,)) / 255, float(d)))
133 | y, x, z = np.ogrid[-CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2]
134 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
135 | feat = np.zeros((CROPSIZE, CROPSIZE, CROPSIZE), dtype=float)
136 | feat[mask] = 1
137 | # print(feat.shape)
138 | if srsid.split('-')[0] in teidlst:
139 | tefnamelst.append(srsid + '.npy')
140 | telabellst.append(int(label))
141 | tefeatlst.append(feat)
142 | else:
143 | trfnamelst.append(srsid + '.npy')
144 | trlabellst.append(int(label))
145 | trfeatlst.append(feat)
146 | for idx in range(len(trfeatlst)):
147 | # trfeatlst[idx][0] /= mxx
148 | # trfeatlst[idx][1] /= mxy
149 | # trfeatlst[idx][2] /= mxz
150 | trfeatlst[idx][-1] /= mxd
151 | for idx in range(len(tefeatlst)):
152 | # tefeatlst[idx][0] /= mxx
153 | # tefeatlst[idx][1] /= mxy
154 | # tefeatlst[idx][2] /= mxz
155 | tefeatlst[idx][-1] /= mxd
156 | trainset = lunanod(preprocesspath, trfnamelst, trlabellst, trfeatlst, train=True, download=True,
157 | transform=transform_train)
158 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=20)
159 |
160 | testset = lunanod(preprocesspath, tefnamelst, telabellst, tefeatlst, train=False, download=True,
161 | transform=transform_test)
162 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=20)
163 | savemodelpath = './checkpoint-' + str(fold) + '/'
164 | train_val = np.empty(shape=0)
165 | test_val = np.empty(shape=(0, 3))
166 | # Model
167 | print(args.resume)
168 | if args.resume:
169 |
170 | print('==> Resuming from checkpoint..')
171 | print(args.savemodel)
172 | if args.savemodel == '':
173 | logging.info('==> Resuming from checkpoint..')
174 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
175 | checkpoint = torch.load(savemodelpath + 'ckpt.t7')
176 |
177 | else:
178 | logging.info('==> Resuming from checkpoint..')
179 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
180 | checkpoint = torch.load(args.savemodel)
181 | net = checkpoint['net']
182 | best_acc = checkpoint['acc']
183 | start_epoch = checkpoint['epoch']
184 | print(savemodelpath + " load success")
185 | print(start_epoch)
186 | else:
187 | logging.info('==> Building model..')
188 | logging.info('args.savemodel : ' + args.savemodel)
189 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]])
190 | if args.savemodel != "":
191 | # args.savemodel = '/home/xxx/DeepLung-master/nodcls/checkpoint-5/ckpt.t7'
192 | checkpoint = torch.load(args.savemodel)
193 | finenet = checkpoint
194 | Low_rankmodel_dic = net.state_dict()
195 | finenet = {k: v for k, v in finenet.items() if k in Low_rankmodel_dic}
196 | Low_rankmodel_dic.update(finenet)
197 | net.load_state_dict(Low_rankmodel_dic)
198 | print("net_loaded")
199 |
200 | lr = args.lr
201 |
202 |
203 | def get_lr(epoch):
204 | global lr
205 | if (epoch + 1) > (args.num_epochs - args.num_epochs_decay):
206 | lr -= (lr / float(args.num_epochs_decay))
207 | for param_group in optimizer.param_groups:
208 | param_group['lr'] = lr
209 | print('Decay learning rate to lr: {}.'.format(lr))
210 |
211 |
212 | if use_cuda:
213 | net.cuda()
214 | if args.gpuids == 'all':
215 | device_ids = range(torch.cuda.device_count())
216 | else:
217 | device_ids = map(int, list(filter(str.isdigit, args.gpuids)))
218 |
219 | print('gpu use' + str(device_ids))
220 | net = torch.nn.DataParallel(net, device_ids=device_ids)
221 | cudnn.benchmark = False # True
222 |
223 | criterion = sp_net.AngleLoss()
224 | optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
225 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
226 |
227 | # L2Loss = torch.nn.MSELoss()
228 |
229 | # Training
230 | def train(epoch):
231 | logging.info('\nEpoch: ' + str(epoch))
232 | net.train()
233 | get_lr(epoch)
234 | train_loss = 0
235 | correct = 0
236 | total = 0
237 |
238 | for batch_idx, (inputs, targets, feat) in enumerate(trainloader):
239 | if use_cuda:
240 | inputs, targets = inputs.cuda(), targets.cuda()
241 |
242 | optimizer.zero_grad()
243 | inputs, targets = Variable(inputs), Variable(targets)
244 | outputs = net(inputs)
245 | loss = criterion(outputs, targets)
246 |
247 | loss.backward()
248 | optimizer.step()
249 | train_loss += loss.data.item()
250 | _, predicted = torch.max(outputs[0].data, 1)
251 | total += targets.size(0)
252 | correct += predicted.eq(targets.data).cpu().sum()
253 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
254 |
255 | print('ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
256 | logging.info(
257 | 'ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))
258 | np.append(train_val, correct.data.item() / float(total))
259 |
260 |
261 | def test(epoch):
262 | epoch_start_time = time.time()
263 | global best_acc
264 | global best_acc_gbt
265 | net.eval()
266 | test_loss = 0
267 | correct = 0
268 | total = 0
269 | TP = FP = FN = TN = 0
270 | for batch_idx, (inputs, targets, feat) in enumerate(testloader):
271 | if use_cuda:
272 | inputs, targets = inputs.cuda(), targets.cuda()
273 |
274 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
275 | outputs = net(inputs)
276 |
277 | loss = criterion(outputs, targets)
278 | test_loss += loss.data.item()
279 | _, predicted = torch.max(outputs[0].data, 1)
280 | total += targets.size(0)
281 | correct += predicted.eq(targets.data).cpu().sum()
282 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum()
283 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum()
284 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum()
285 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum()
286 |
287 | # Save checkpoint.
288 | acc = 100. * correct.data.item() / total
289 | if acc > best_acc:
290 | logging.info('Saving..')
291 | state = {
292 | 'net': net.module if use_cuda else net,
293 | 'acc': acc,
294 | 'epoch': epoch,
295 | }
296 | if not os.path.isdir(savemodelpath):
297 | os.mkdir(savemodelpath)
298 | torch.save(state, savemodelpath + 'ckpt.t7')
299 | best_acc = acc
300 | logging.info('Saving..')
301 | state = {
302 | 'net': net.module if use_cuda else net,
303 | 'acc': acc,
304 | 'epoch': epoch,
305 | }
306 | if not os.path.isdir(savemodelpath):
307 | os.mkdir(savemodelpath)
308 | if epoch % 50 == 0:
309 | torch.save(state, savemodelpath + 'ckpt' + str(epoch) + '.t7')
310 | # best_acc = acc
311 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item())
312 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item())
313 |
314 | print('teacc ' + str(acc) + ' bestacc ' + str(best_acc))
315 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr))
316 | print('Time Taken: %d sec' % (time.time() - epoch_start_time))
317 | logging.info(
318 | 'teacc ' + str(acc) + ' bestacc ' + str(best_acc))
319 | logging.info(
320 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr))
321 | np.append(test_val, [[acc, tpr, fpr]], axis=0)
322 |
323 |
324 | if __name__ == '__main__':
325 | for epoch in range(start_epoch + 1, start_epoch + args.num_epochs + 1): # 200):
326 | train(epoch)
327 | test(epoch)
328 | np.save(savemodelpath + "train_acc", train_val)
329 | np.save(savemodelpath + "test_acc", test_val)
330 |
--------------------------------------------------------------------------------
/data/humanperformance.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import SimpleITK as sitk
3 | import os
4 | import os.path
5 | import numpy as np
6 | fold = 1
7 | def load_itk_image(filename):
8 | with open(filename) as f:
9 | contents = f.readlines()
10 | line = [k for k in contents if k.startswith('TransformMatrix')][0]
11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
12 | transformM = np.round(transformM)
13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])):
14 | isflip = True
15 | else:
16 | isflip = False
17 | itkimage = sitk.ReadImage(filename)
18 | numpyImage = sitk.GetArrayFromImage(itkimage)
19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
21 | return numpyImage, numpyOrigin, numpySpacing,isflip
22 | def worldToVoxelCoord(worldCoord, origin, spacing):
23 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
24 | voxelCoord = stretchedVoxelCoord / spacing
25 | return voxelCoord
26 | # read map file
27 | mapfname = 'LIDC-IDRI-mappingLUNA16'
28 | sidmap = {}
29 | fid = open(mapfname, 'r')
30 | line = fid.readline()
31 | line = fid.readline()
32 | while line:
33 | pidlist = line.split(' ')
34 | # print pidlist
35 | pid = pidlist[0]
36 | stdid = pidlist[1]
37 | srsid = pidlist[2]
38 | if srsid not in sidmap:
39 | sidmap[srsid] = [pid, stdid]
40 | else:
41 | assert sidmap[srsid][0] == pid
42 | assert sidmap[srsid][1] == stdid
43 | line = fid.readline()
44 | fid.close()
45 | # read luna16 annotation
46 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']
47 | lunaantframe = pd.read_csv('annotations.csv', names=colname)
48 | srslist = lunaantframe.seriesuid.tolist()[1:]
49 | cdxlist = lunaantframe.coordX.tolist()[1:]
50 | cdylist = lunaantframe.coordY.tolist()[1:]
51 | cdzlist = lunaantframe.coordZ.tolist()[1:]
52 | dimlist = lunaantframe.diameter_mm.tolist()[1:]
53 | lunaantdict = {}
54 | for idx in xrange(len(srslist)):
55 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])]
56 | if srslist[idx] in lunaantdict:
57 | lunaantdict[srslist[idx]].append(vlu)
58 | else:
59 | lunaantdict[srslist[idx]] = [vlu]
60 | # # convert luna16 annotation to LIDC-IDRI annotation space
61 | # from multiprocessing import Pool
62 | # lunantdictlidc = {}
63 | # for fold in xrange(10):
64 | # mhdpath = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)
65 | # print 'fold', fold
66 | # def getvoxcrd(fname):
67 | # sliceim,origin,spacing,isflip = load_itk_image(os.path.join(mhdpath, fname))
68 | # lunantdictlidc[fname[:-4]] = []
69 | # voxcrdlist = []
70 | # for lunaant in lunaantdict[fname[:-4]]:
71 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
72 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
73 | # voxcrdlist.append(voxcrd)
74 | # return voxcrdlist
75 | # p = Pool(30)
76 | # fnamelist = []
77 | # for fname in os.listdir(mhdpath):
78 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
79 | # fnamelist.append(fname)
80 | # voxcrdlist = p.map(getvoxcrd, fnamelist)
81 | # listidx = 0
82 | # for fname in os.listdir(mhdpath):
83 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
84 | # lunantdictlidc[fname[:-4]] = []
85 | # for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]):
86 | # # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
87 | # # voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
88 | # lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]])
89 | # listidx += 1
90 | # p.close()
91 | # np.save('lunaantdictlidc.npy', lunantdictlidc)
92 | # read LIDC dataset
93 | lunantdictlidc = np.load('lunaantdictlidc.npy').item()
94 | import xlrd
95 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls'
96 | antdict = {}
97 | wb = xlrd.open_workbook(os.path.join(lidccsvfname))
98 | for s in wb.sheets():
99 | if s.name == 'list3.2':
100 | for row in range(1, s.nrows):
101 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \
102 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)]
103 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8
104 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8
105 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8
106 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8
107 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8
108 | for col in range(9, 16):
109 | if s.cell(row, col).value != '':
110 | if isinstance(s.cell(row, col).value, float):
111 | valuelist.append(str(int(s.cell(row, col).value)))
112 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8
113 | else:
114 | valuelist.append(s.cell(row, col).value)
115 | if s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value)) not in antdict:
116 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))] = [valuelist]
117 | else:
118 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))].append(valuelist)
119 | # update LIDC annotation with series number, rather than scan id
120 | import dicom
121 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/'
122 | antdictscan = {}
123 | for k, v in antdict.iteritems():
124 | pid, scan = k.split('_')
125 | hasscan = False
126 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-'+pid)):
127 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu])):
128 | if srs.endswith('.npy'):
129 | print('npy', pid, scan, srs)
130 | continue
131 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu, srs, '000006.dcm']))
132 | # print scan, str(RefDs[0x20, 0x11].value)
133 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0':
134 | if hasscan: print('rep', pid, sdu, srs)
135 | hasscan = True
136 | antdictscan[pid+'_'+srs] = v
137 | break
138 | if not hasscan: print('not found', pid, scan, sdu, srs)
139 | # find the match from LIDC-IDRI annotation
140 | import math
141 | lunaantdictnodid = {}
142 | maxdist = 0
143 | for srcid, lunaantlidc in lunantdictlidc.iteritems():
144 | lunaantdictnodid[srcid] = []
145 | pid, stdid = sidmap[srcid]
146 | # print pid
147 | pid = pid[len('LIDC-IDRI-'):]
148 | for lunantdictlidcsub in lunaantlidc:
149 | lunaant = lunantdictlidcsub[0]
150 | voxcrd = lunantdictlidcsub[1] # z y x
151 | mindist, minidx = 1e8, -1
152 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479', '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']:
153 | continue
154 | for idx, lidcant in enumerate(antdictscan[pid+'_'+srcid]):
155 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z
156 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y
157 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x
158 | if dist < mindist:
159 | mindist = dist
160 | minidx = idx
161 | if mindist > 71:#15.1:
162 | print(srcid, pid, voxcrd, antdictscan[pid+'_'+srcid], mindist)
163 | maxdist = max(maxdist, mindist)
164 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid+'_'+srcid][minidx][6:]])
165 | # np.save('lunaantdictnodid.npy', lunaantdictnodid)
166 | print('maxdist', maxdist)
167 | # save it into a csv
168 | # import csv
169 | # savename = 'annotationnodid.csv'
170 | # fid = open(savename, 'w')
171 | # writer = csv.writer(fid)
172 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'])
173 | # for srcid, ant in lunaantdictnodid.iteritems():
174 | # for antsub in ant:
175 | # writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1])
176 | # fid.close()
177 | # fd 1
178 | fd1lst = []
179 | for fname in os.listdir('/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'):
180 | if fname.endswith('.mhd'): fd1lst.append(fname[:-4])
181 | # find the malignancy, shape information from xml file
182 | import xml.dom.minidom
183 | ndoc = 0
184 | lunadctclssgmdict = {}
185 | mallstall, callstall, sphlstall, marlstall, loblstall, spilstall, texlstall = [], [], [], [], [], [], []
186 | for srsid, extant in lunaantdictnodid.iteritems():
187 | if srsid not in fd1lst: continue
188 | lunadctclssgmdict[srsid] = []
189 | pid, stdid = sidmap[srsid]
190 | for extantvlu in extant:
191 | getnodid = []
192 | nant = 0
193 | mallst = []
194 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])):
195 | if fname.endswith('.xml'):
196 | nant += 1
197 | dom = xml.dom.minidom.parse(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname]))
198 | root = dom.documentElement
199 | rsessions = root.getElementsByTagName('readingSession')
200 | for rsess in rsessions:
201 | unblinds = rsess.getElementsByTagName('unblindedReadNodule')
202 | for unb in unblinds:
203 | nod = unb.getElementsByTagName('noduleID')
204 | if len(nod) != 1:
205 | print('more nod', nod)
206 | continue
207 | if nod[0].firstChild.data in extantvlu[1]:
208 | getnodid.append(nod[0].firstChild.data)
209 | mal = unb.getElementsByTagName('malignancy')
210 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1):
211 | mallst.append(float(mal[0].firstChild.data))
212 | # print(getnodid, extantvlu[1], nant)
213 | if len(getnodid) > len(extantvlu[1]):
214 | print(pid, srsid)
215 | # assert 1 == 0
216 | ndoc = max(ndoc, len(getnodid), len(extantvlu[1]))
217 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]]
218 | if len(mallst) == 0: vlulst.append(0)
219 | else: vlulst.append(sum(mallst)/float(len(mallst)))
220 | lunadctclssgmdict[srsid].append(vlulst+mallst)
221 | import csv
222 | # load predition array
223 | pixdimpred = np.load('../../../CTnoddetector/training/nodcls/besttestpred.npy')#'pixradiustest.npy')
224 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
225 | srslst = pdframe['seriesuid'].tolist()[1:]
226 | crdxlst = pdframe['coordX'].tolist()[1:]
227 | crdylst = pdframe['coordY'].tolist()[1:]
228 | crdzlst = pdframe['coordZ'].tolist()[1:]
229 | dimlst = pdframe['diameter_mm'].tolist()[1:]
230 | mlglst = pdframe['malignant'].tolist()[1:]
231 | newlst = []
232 | import csv
233 | fid = open('annotationdetclsconvfnl_v3.csv', 'w')
234 | writer = csv.writer(fid)
235 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
236 | for i in range(len(srslst)):
237 | writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
238 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
239 | fid.close()
240 | subset1path = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'
241 | testfnamelst = []
242 | for fname in os.listdir(subset1path):
243 | if fname.endswith('.mhd'):
244 | testfnamelst.append(fname[:-4])
245 | ntest = 0
246 | for idx in range(len(newlst)):
247 | fname = newlst[idx][0]
248 | if fname.split('-')[0] in testfnamelst: ntest +=1
249 | print('ntest', ntest, 'ntrain', len(newlst)-ntest)
250 | prednamelst = {}
251 | predacc = 0
252 | predidx = 0
253 | # predlabellst = []
254 | for idx in range(len(newlst)):
255 | fname = newlst[idx][0]
256 | if fname.split('-')[0] in testfnamelst:
257 | # print newlst[idx][-1], pixdimpred[predidx]
258 | if int(pixdimpred[predidx]>0.5) == int(newlst[idx][-1]): predacc += 1
259 | if fname.split('-')[0] not in prednamelst:
260 | prednamelst[fname.split('-')[0]] = [[pixdimpred[predidx], fname.split('-')[1]]]
261 | else:
262 | prednamelst[fname.split('-')[0]].append([pixdimpred[predidx], fname.split('-')[1]])
263 | predidx += 1
264 | print('pred acc', predacc/float(predidx))
265 | pixdimidx = -1
266 | # savename = 'annotationdetclssgm_doctor_fd2.csv'
267 | # fid = open(savename, 'w')
268 | # writer = csv.writer(fid)
269 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
270 | doctornacc, doctornpid = [0]*ndoc, [0]*ndoc
271 | nacc = 0
272 | ntot = 0
273 | for srsid, extant in lunadctclssgmdict.iteritems():
274 | curidx = 0
275 | if srsid not in fd1lst: continue
276 | for subextant in extant:
277 | if subextant[5] in [3, 0]: continue
278 | if abs(subextant[5] - 3) < 1e-2: continue
279 | if subextant[5] > 3: subextant[5] = 1
280 | else: subextant[5] = 0
281 | if subextant[5] == int(prednamelst[srsid][curidx][0]>0.5): nacc += 1
282 | ntot += 1
283 | # writer.writerow(subextant)
284 | for did in range(6, len(subextant), 1):
285 | # if 0.499 <= prednamelst[srsid][curidx] <= 0.501: continue
286 | if subextant[did] == 3: continue
287 | if subextant[5] != (prednamelst[srsid][curidx][0]>0.5):
288 | print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], subextant[5])
289 | # if subextant[5] == 1 and prednamelst[srsid][curidx][0]>0.5:# subextant[did] > 3: # we treat 3 as the positive label
290 | # if subextant[did] < 3: print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], did-6)
291 | # doctornacc[did-6] += 1
292 | # elif subextant[5] == 0 and prednamelst[srsid][curidx][0]<=0.5:# subextant[did] < 3:
293 | # if subextant[did] > 3: print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], did-6)
294 | # doctornacc[did-6] += 1
295 | # if subextant[did] != 3:
296 | # doctornpid[did-6] += 1
297 | curidx += 1
298 | fid.close()
299 | print(nacc / float(ntot))
300 | for i in range(ndoc):
301 | print(i, doctornacc[i], doctornpid[i], doctornacc[i]/float(doctornpid[i]))
--------------------------------------------------------------------------------
/data/pthumanperformance.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import SimpleITK as sitk
3 | import os
4 | import os.path
5 | import numpy as np
6 | fold = 1
7 | def load_itk_image(filename):
8 | with open(filename) as f:
9 | contents = f.readlines()
10 | line = [k for k in contents if k.startswith('TransformMatrix')][0]
11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
12 | transformM = np.round(transformM)
13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])):
14 | isflip = True
15 | else:
16 | isflip = False
17 | itkimage = sitk.ReadImage(filename)
18 | numpyImage = sitk.GetArrayFromImage(itkimage)
19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
21 | return numpyImage, numpyOrigin, numpySpacing,isflip
22 | def worldToVoxelCoord(worldCoord, origin, spacing):
23 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
24 | voxelCoord = stretchedVoxelCoord / spacing
25 | return voxelCoord
26 | # read map file
27 | mapfname = 'LIDC-IDRI-mappingLUNA16'
28 | sidmap = {}
29 | fid = open(mapfname, 'r')
30 | line = fid.readline()
31 | line = fid.readline()
32 | while line:
33 | pidlist = line.split(' ')
34 | # print pidlist
35 | pid = pidlist[0]
36 | stdid = pidlist[1]
37 | srsid = pidlist[2]
38 | if srsid not in sidmap:
39 | sidmap[srsid] = [pid, stdid]
40 | else:
41 | assert sidmap[srsid][0] == pid
42 | assert sidmap[srsid][1] == stdid
43 | line = fid.readline()
44 | fid.close()
45 | # read luna16 annotation
46 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']
47 | lunaantframe = pd.read_csv('annotations.csv', names=colname)
48 | srslist = lunaantframe.seriesuid.tolist()[1:]
49 | cdxlist = lunaantframe.coordX.tolist()[1:]
50 | cdylist = lunaantframe.coordY.tolist()[1:]
51 | cdzlist = lunaantframe.coordZ.tolist()[1:]
52 | dimlist = lunaantframe.diameter_mm.tolist()[1:]
53 | lunaantdict = {}
54 | for idx in range(len(srslist)):
55 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])]
56 | if srslist[idx] in lunaantdict:
57 | lunaantdict[srslist[idx]].append(vlu)
58 | else:
59 | lunaantdict[srslist[idx]] = [vlu]
60 | # # convert luna16 annotation to LIDC-IDRI annotation space
61 | # from multiprocessing import Pool
62 | # lunantdictlidc = {}
63 | # for fold in xrange(10):
64 | # mhdpath = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)
65 | # print 'fold', fold
66 | # def getvoxcrd(fname):
67 | # sliceim,origin,spacing,isflip = load_itk_image(os.path.join(mhdpath, fname))
68 | # lunantdictlidc[fname[:-4]] = []
69 | # voxcrdlist = []
70 | # for lunaant in lunaantdict[fname[:-4]]:
71 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
72 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
73 | # voxcrdlist.append(voxcrd)
74 | # return voxcrdlist
75 | # p = Pool(30)
76 | # fnamelist = []
77 | # for fname in os.listdir(mhdpath):
78 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
79 | # fnamelist.append(fname)
80 | # voxcrdlist = p.map(getvoxcrd, fnamelist)
81 | # listidx = 0
82 | # for fname in os.listdir(mhdpath):
83 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
84 | # lunantdictlidc[fname[:-4]] = []
85 | # for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]):
86 | # # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
87 | # # voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
88 | # lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]])
89 | # listidx += 1
90 | # p.close()
91 | # np.save('lunaantdictlidc.npy', lunantdictlidc)
92 | # read LIDC dataset
93 | lunantdictlidc = np.load('lunaantdictlidc.npy').item()
94 | import xlrd
95 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls'
96 | antdict = {}
97 | wb = xlrd.open_workbook(os.path.join(lidccsvfname))
98 | for s in wb.sheets():
99 | if s.name == 'list3.2':
100 | for row in range(1, s.nrows):
101 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \
102 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)]
103 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8
104 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8
105 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8
106 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8
107 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8
108 | for col in range(9, 16):
109 | if s.cell(row, col).value != '':
110 | if isinstance(s.cell(row, col).value, float):
111 | valuelist.append(str(int(s.cell(row, col).value)))
112 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8
113 | else:
114 | valuelist.append(s.cell(row, col).value)
115 | if s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value)) not in antdict:
116 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))] = [valuelist]
117 | else:
118 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))].append(valuelist)
119 | # update LIDC annotation with series number, rather than scan id
120 | import dicom
121 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/'
122 | antdictscan = {}
123 | for k, v in antdict.iteritems():
124 | pid, scan = k.split('_')
125 | hasscan = False
126 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-'+pid)):
127 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu])):
128 | if srs.endswith('.npy'):
129 | print('npy', pid, scan, srs)
130 | continue
131 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu, srs, '000006.dcm']))
132 | # print scan, str(RefDs[0x20, 0x11].value)
133 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0':
134 | if hasscan: print('rep', pid, sdu, srs)
135 | hasscan = True
136 | antdictscan[pid+'_'+srs] = v
137 | break
138 | if not hasscan: print('not found', pid, scan, sdu, srs)
139 | # find the match from LIDC-IDRI annotation
140 | import math
141 | lunaantdictnodid = {}
142 | maxdist = 0
143 | for srcid, lunaantlidc in lunantdictlidc.iteritems():
144 | lunaantdictnodid[srcid] = []
145 | pid, stdid = sidmap[srcid]
146 | # print pid
147 | pid = pid[len('LIDC-IDRI-'):]
148 | for lunantdictlidcsub in lunaantlidc:
149 | lunaant = lunantdictlidcsub[0]
150 | voxcrd = lunantdictlidcsub[1] # z y x
151 | mindist, minidx = 1e8, -1
152 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479', '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']:
153 | continue
154 | for idx, lidcant in enumerate(antdictscan[pid+'_'+srcid]):
155 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z
156 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y
157 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x
158 | if dist < mindist:
159 | mindist = dist
160 | minidx = idx
161 | if mindist > 71:#15.1:
162 | print(srcid, pid, voxcrd, antdictscan[pid+'_'+srcid], mindist)
163 | maxdist = max(maxdist, mindist)
164 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid+'_'+srcid][minidx][6:]])
165 | # np.save('lunaantdictnodid.npy', lunaantdictnodid)
166 | print('maxdist', maxdist)
167 | # save it into a csv
168 | # import csv
169 | # savename = 'annotationnodid.csv'
170 | # fid = open(savename, 'w')
171 | # writer = csv.writer(fid)
172 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'])
173 | # for srcid, ant in lunaantdictnodid.iteritems():
174 | # for antsub in ant:
175 | # writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1])
176 | # fid.close()
177 | # fd 1
178 | fd1lst = []
179 | for fname in os.listdir('/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'):
180 | if fname.endswith('.mhd'): fd1lst.append(fname[:-4])
181 | # find the malignancy, shape information from xml file
182 | import xml.dom.minidom
183 | ndoc = 0
184 | lunadctclssgmdict = {}
185 | mallstall, callstall, sphlstall, marlstall, loblstall, spilstall, texlstall = [], [], [], [], [], [], []
186 | for srsid, extant in lunaantdictnodid.iteritems():
187 | if srsid not in fd1lst: continue
188 | lunadctclssgmdict[srsid] = []
189 | pid, stdid = sidmap[srsid]
190 | for extantvlu in extant:
191 | getnodid = []
192 | nant = 0
193 | mallst = []
194 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])):
195 | if fname.endswith('.xml'):
196 | nant += 1
197 | dom = xml.dom.minidom.parse(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname]))
198 | root = dom.documentElement
199 | rsessions = root.getElementsByTagName('readingSession')
200 | for rsess in rsessions:
201 | unblinds = rsess.getElementsByTagName('unblindedReadNodule')
202 | for unb in unblinds:
203 | nod = unb.getElementsByTagName('noduleID')
204 | if len(nod) != 1:
205 | print('more nod', nod)
206 | continue
207 | if nod[0].firstChild.data in extantvlu[1]:
208 | getnodid.append(nod[0].firstChild.data)
209 | mal = unb.getElementsByTagName('malignancy')
210 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1):
211 | mallst.append(float(mal[0].firstChild.data))
212 | # print(getnodid, extantvlu[1], nant)
213 | if len(getnodid) > len(extantvlu[1]):
214 | print(pid, srsid)
215 | # assert 1 == 0
216 | ndoc = max(ndoc, len(getnodid), len(extantvlu[1]))
217 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]]
218 | if len(mallst) == 0: vlulst.append(0)
219 | else: vlulst.append(sum(mallst)/float(len(mallst)))
220 | lunadctclssgmdict[srsid].append(vlulst+mallst)
221 | import csv
222 | # load predition array
223 | # pixdimpred = np.load('../../../CTnoddetector/training/nodcls/checkpoint-'+str(fold)+'/besttestpred.npy')#'pixradiustest.npy')
224 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
225 | srslst = pdframe['seriesuid'].tolist()[1:]
226 | crdxlst = pdframe['coordX'].tolist()[1:]
227 | crdylst = pdframe['coordY'].tolist()[1:]
228 | crdzlst = pdframe['coordZ'].tolist()[1:]
229 | dimlst = pdframe['diameter_mm'].tolist()[1:]
230 | mlglst = pdframe['malignant'].tolist()[1:]
231 | newlst = []
232 | import csv
233 | # fid = open('annotationdetclsconvfnl_v3.csv', 'w')
234 | # writer = csv.writer(fid)
235 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
236 | for i in range(len(srslst)):
237 | # writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
238 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]])
239 | # fid.close()
240 | subset1path = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'
241 | testfnamelst = []
242 | for fname in os.listdir(subset1path):
243 | if fname.endswith('.mhd'):
244 | testfnamelst.append(fname[:-4])
245 | ntest = 0
246 | for idx in range(len(newlst)):
247 | fname = newlst[idx][0]
248 | if fname.split('-')[0] in testfnamelst: ntest +=1
249 | print('ntest', ntest, 'ntrain', len(newlst)-ntest)
250 | prednamelst = {}
251 | predacc = 0
252 | predidx = 0
253 | # predlabellst = []
254 | # for idx in xrange(len(newlst)):
255 | # fname = newlst[idx][0]
256 | # if fname.split('-')[0] in testfnamelst:
257 | # # print newlst[idx][-1], pixdimpred[predidx]
258 | # if int(pixdimpred[predidx]>0.5) == int(newlst[idx][-1]): predacc += 1
259 | # if fname.split('-')[0] not in prednamelst:
260 | # prednamelst[fname.split('-')[0]] = [pixdimpred[predidx]]
261 | # else:
262 | # prednamelst[fname.split('-')[0]].append(pixdimpred[predidx])
263 | # predidx += 1
264 | # print 'pred acc', predacc/float(predidx)
265 | pixdimidx = -1
266 | # savename = 'annotationdetclssgm_doctor_fd2.csv'
267 | # fid = open(savename, 'w')
268 | # writer = csv.writer(fid)
269 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])
270 | # get the patient level cancer
271 | ptlabel = {}
272 | nfold1 = 0
273 | for srsid, extant in lunadctclssgmdict.iteritems():
274 | if srsid not in fd1lst: continue
275 | nfold1 += 1
276 | for subextant in extant:
277 | if subextant[5] in [3, 0]: continue
278 | if abs(subextant[5] - 3) < 1e-2: continue
279 | if subextant[5] > 3:
280 | ptlabel[srsid] = 1
281 | break
282 | else:
283 | ptlabel[srsid] = 0
284 | print(len(lunadctclssgmdict.keys()), nfold1, len(ptlabel.keys()), sum(ptlabel.values()))
285 | # get doctors prediction on patient level cancer
286 | dctptlabel = {}
287 | for srsid, extant in lunadctclssgmdict.iteritems():
288 | curidx = 0
289 | if srsid not in fd1lst: continue
290 | for subextant in extant:
291 | if subextant[5] in [3, 0]: continue
292 | if abs(subextant[5] - 3) < 1e-2: continue
293 | dctptlabel[srsid] = [-1]*max(len(subextant) - 6, 4)
294 |
295 | # writer.writerow(subextant)
296 | for did in range(6, len(subextant),1):#len(subextant), 1):
297 | # if 0.499 <= prednamelst[srsid][curidx] <= 0.501: continue
298 | if subextant[did] == 3: continue
299 | if subextant[did] > 3: dctptlabel[srsid][did-6] = 1
300 | if dctptlabel[srsid][did-6] == -1: dctptlabel[srsid][did-6] = 0
301 | # calculate patient level performance
302 | p = [0]*4
303 | n = [0]*4
304 | tp = [0]*4
305 | np = [0]*4
306 | for srs in ptlabel.keys():
307 | if dctptlabel[srs][0] != -1:
308 | n[0] += 1
309 | if ptlabel[srs] == 1: np[0] += 1
310 | if ptlabel[srs] == dctptlabel[srs][0]:
311 | p[0] += 1
312 | if ptlabel[srs] == 1:
313 | tp[0] += 1
314 | if dctptlabel[srs][1] != -1:
315 | n[1] += 1
316 | if ptlabel[srs] == 1: np[1] += 1
317 | if ptlabel[srs] == dctptlabel[srs][1]:
318 | p[1] += 1
319 | if ptlabel[srs] == 1:
320 | tp[1] += 1
321 | if dctptlabel[srs][2] != -1:
322 | n[2] += 1
323 | if ptlabel[srs] == 1: np[2] += 1
324 | if ptlabel[srs] == dctptlabel[srs][2]:
325 | p[2] += 1
326 | if ptlabel[srs] == 1:
327 | tp[2] += 1
328 | if dctptlabel[srs][3] != -1:
329 | n[3] += 1
330 | if ptlabel[srs] == 1: np[3] += 1
331 | if ptlabel[srs] == dctptlabel[srs][3]:
332 | p[3] += 1
333 | if ptlabel[srs] == 1:
334 | tp[3] += 1
335 | print(p, n, tp, np, len(ptlabel))
336 | print(p[0]/float(n[0]), tp[0]/float(np[0]), (p[0]-tp[0])/float(n[0]-np[0]))
337 | print(p[1]/float(n[1]), tp[1]/float(np[1]), (p[1]-tp[1])/float(n[1]-np[1]))
338 | print(p[2]/float(n[2]), tp[2]/float(np[2]), (p[2]-tp[2])/float(n[2]-np[2]))
339 | print(p[3]/float(n[3]), tp[3]/float(np[3]), (p[3]-tp[3])/float(n[3]-np[3]))
340 | import numpy as np
341 | np.save('ptlabel'+str(fold)+'.npy', ptlabel)
342 | np.save('dctptlabel'+str(fold)+'.npy', dctptlabel)
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import random
5 | from PIL import Image, ImageOps
6 |
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 | import numpy as np
12 | import numbers
13 | import types
14 | import collections
15 | from torch.autograd import Variable
16 |
17 | torch.cuda.set_device(0)
18 |
19 |
20 | def resample3d(inp, inp_space, out_space=(1, 1, 1)):
21 | # Infer new shape
22 | # inp = torch.from_numpy(inp)
23 | # inp=torch.FloatTensor(inp)
24 | # inp=Variable(inp)
25 | inp = inp.cuda()
26 | out = resample1d(inp, inp_space[2], out_space[2]).permute(0, 2, 1)
27 | out = resample1d(out, inp_space[1], out_space[1]).permute(2, 1, 0)
28 | out = resample1d(out, inp_space[0], out_space[0]).permute(2, 0, 1)
29 | return out
30 |
31 |
32 | def resample1d(inp, inp_space, out_space=1):
33 | # Output shape
34 | print(inp.size(), inp_space, out_space)
35 | out_shape = list(np.int64(inp.size()[:-1])) + [
36 | int(np.floor(inp.size()[-1] * inp_space / out_space))] # Optional for if we expect a float_tensor
37 | out_shape = [int(item) for item in out_shape]
38 | # Get output coordinates, deltas, and t (chord distances)
39 | # torch.cuda.set_device(inp.get_device())
40 | # Output coordinates in real space
41 | coords = torch.cuda.HalfTensor(range(out_shape[-1])) * out_space
42 | delta = coords.fmod(inp_space).div(inp_space).repeat(out_shape[0], out_shape[1], 1)
43 | t = torch.cuda.HalfTensor(4, out_shape[0], out_shape[1], out_shape[2]).zero_()
44 | t[0] = 1
45 | t[1] = delta
46 | t[2] = delta ** 2
47 | t[3] = delta ** 3
48 | # Nearest neighbours indices
49 | nn = coords.div(inp_space).floor().long()
50 | # Stack the nearest neighbors into P, the Points Array
51 | P = torch.cuda.HalfTensor(4, out_shape[0], out_shape[1], out_shape[2]).zero_()
52 | for i in range(-1, 3):
53 | P[i + 1] = inp.index_select(2, torch.clamp(nn + i, 0, inp.size()[-1] - 1))
54 | # Take catmull-rom spline interpolation:
55 | return 0.5 * t.mul(torch.cuda.HalfTensor([[0, 2, 0, 0],
56 | [-1, 0, 1, 0],
57 | [2, -5, 4, -1],
58 | [-1, 3, -3, 1]]).mm(P.view(4, -1)) \
59 | .view(4,
60 | out_shape[0],
61 | out_shape[1],
62 | out_shape[2])) \
63 | .sum(0) \
64 | .squeeze()
65 |
66 |
67 | class Compose(object):
68 | """Composes several transforms together.
69 |
70 | Args:
71 | transforms (list of ``Transform`` objects): list of transforms to compose.
72 |
73 | Example:
74 | >>> transforms.Compose([
75 | >>> transforms.CenterCrop(10),
76 | >>> transforms.ToTensor(),
77 | >>> ])
78 | """
79 |
80 | def __init__(self, transforms):
81 | self.transforms = transforms
82 |
83 | def __call__(self, img):
84 | for t in self.transforms:
85 | # print(t)
86 | img = t(img)
87 | return img
88 |
89 |
90 | class Normalize(object):
91 | """Normalize an tensor image with mean and standard deviation.
92 |
93 | Given mean: (R, G, B) and std: (R, G, B),
94 | will normalize each channel of the torch.*Tensor, i.e.
95 | channel = (channel - mean) / std
96 |
97 | Args:
98 | mean (sequence): Sequence of means for R, G, B channels respecitvely.
99 | std (sequence): Sequence of standard deviations for R, G, B channels
100 | respecitvely.
101 | """
102 |
103 | def __init__(self, mean, std):
104 | self.mean = mean
105 | self.std = std
106 |
107 | def __call__(self, tensor):
108 | """
109 | Args:
110 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
111 |
112 | Returns:
113 | Tensor: Normalized image.
114 | """
115 | # TODO: make efficient
116 | # for t, m, s in zip(tensor, self.mean, self.std):
117 |
118 | tensor.sub_(self.mean).div_(self.std)
119 | return tensor
120 |
121 |
122 | from scipy.ndimage.interpolation import zoom
123 |
124 |
125 | class RandomScale(object):
126 | ''' Randomly scale from scale size list '''
127 |
128 | def __init__(self, size, interpolation=Image.BILINEAR):
129 | # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 3)
130 | self.size = size
131 | self.interpolation = interpolation
132 |
133 | def __call__(self, img):
134 | # scale = np.random.permutation(len(self.size))[0] / 32.0
135 | scale = random.randint(self.size[0],
136 | self.size[-1] + 1) # (self.size[np.random.permutation(len(self.size))[0]])#, \
137 | # self.size[np.random.permutation(len(self.size))[0]], \
138 | # self.size[np.random.permutation(len(self.size))[0]])
139 | # print img.shape, scale, img.shape*scale
140 | # print('scale', 32.0/scale)
141 | return zoom(img, (scale, scale, scale),
142 | mode='nearest') # resample3d(img,(32,32,32),out_space=scale)#zoom(img, scale) #img.resize(scale, self.interpolation) resample3d(img,img.shape,out_space=scale)
143 |
144 |
145 | class Scale(object):
146 | """Rescale the input PIL.Image to the given size.
147 |
148 | Args:
149 | size (sequence or int): Desired output size. If size is a sequence like
150 | (w, h), output size will be matched to this. If size is an int,
151 | smaller edge of the image will be matched to this number.
152 | i.e, if height > width, then image will be rescaled to
153 | (size * height / width, size)
154 | interpolation (int, optional): Desired interpolation. Default is
155 | ``PIL.Image.BILINEAR``
156 | """
157 |
158 | def __init__(self, size, interpolation=Image.BILINEAR):
159 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 3)
160 | self.size = size
161 | self.interpolation = interpolation
162 |
163 | def __call__(self, img):
164 | """
165 | Args:
166 | img (PIL.Image): Image to be scaled.
167 |
168 | Returns:
169 | PIL.Image: Rescaled image.
170 | """
171 | if isinstance(self.size, int):
172 | w, h, d = img.size
173 | if (w <= h and w == self.size) or (h <= w and h == self.size):
174 | return img
175 | if w < h:
176 | ow = self.size
177 | oh = int(self.size * h / w)
178 | return img.resize((ow, oh), self.interpolation)
179 | else:
180 | oh = self.size
181 | ow = int(self.size * w / h)
182 | return img.resize((ow, oh), self.interpolation)
183 | else:
184 | return img.resize(self.size, self.interpolation)
185 |
186 |
187 | class ZeroOut(object):
188 | """Crops the given PIL.Image at the center.
189 | Args:
190 | size (sequence or int): Desired output size of the crop. If size is an
191 | int instead of sequence like (w, h), a square crop (size, size) is
192 | made.
193 | """
194 |
195 | def __init__(self, size):
196 | self.size = int(size)
197 |
198 | def __call__(self, img):
199 | w, h, d = img.shape # size
200 | x1 = random.randint(0, w - self.size) # np.random.permutation(w-self.size)[0]
201 | y1 = random.randint(0, h - self.size) # np.random.permutation(h-self.size)[0]
202 | z1 = random.randint(0, d - self.size) # np.random.permutation(d-self.size)[0]
203 | img1 = np.array(img)
204 | # print 'zero out', x1, y1, z1, w, h, d, self.size
205 | img1[x1:x1 + self.size, y1:y1 + self.size, z1:z1 + self.size] = np.array(
206 | np.zeros((self.size, self.size, self.size)))
207 | return np.array(img1)
208 |
209 |
210 | class ToTensor(object):
211 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
212 |
213 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
214 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
215 | """
216 |
217 | def __call__(self, pic):
218 | """
219 | Args:
220 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
221 |
222 | Returns:
223 | Tensor: Converted image.
224 | """
225 | if isinstance(pic, np.ndarray):
226 | # handle numpy array
227 | pic = np.expand_dims(pic, -1)
228 | # print('before tensor', pic.shape)
229 | img = torch.from_numpy(pic.transpose((3, 0, 1, 2)))
230 | # backward compatibility
231 | return img.float() # .div(255)
232 |
233 | if accimage is not None and isinstance(pic, accimage.Image):
234 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
235 | pic.copyto(nppic)
236 | return torch.from_numpy(nppic)
237 |
238 | # handle PIL Image
239 | if pic.mode == 'I':
240 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
241 | elif pic.mode == 'I;16':
242 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
243 | else:
244 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
245 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
246 | if pic.mode == 'YCbCr':
247 | nchannel = 3
248 | elif pic.mode == 'I;16':
249 | nchannel = 1
250 | else:
251 | nchannel = len(pic.mode)
252 | img = img.view(pic.size[1], pic.size[0], nchannel)
253 | # put it from HWC to CHW format
254 | # yikes, this transpose takes 80% of the loading time/CPU
255 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
256 | if isinstance(img, torch.ByteTensor):
257 | return img.float() # .div(255)
258 | else:
259 | return img
260 |
261 |
262 | class CenterCrop(object):
263 | """Crops the given PIL.Image at the center.
264 |
265 | Args:
266 | size (sequence or int): Desired output size of the crop. If size is an
267 | int instead of sequence like (w, h), a square crop (size, size) is
268 | made.
269 | """
270 |
271 | def __init__(self, size):
272 | if isinstance(size, numbers.Number):
273 | self.size = (int(size), int(size))
274 | else:
275 | self.size = size
276 |
277 | def __call__(self, img):
278 | """
279 | Args:
280 | img (PIL.Image): Image to be cropped.
281 |
282 | Returns:
283 | PIL.Image: Cropped image.
284 | """
285 | w, h = img.size
286 | th, tw = self.size
287 | x1 = int(round((w - tw) / 2.))
288 | y1 = int(round((h - th) / 2.))
289 | return img.crop((x1, y1, x1 + tw, y1 + th))
290 |
291 |
292 | class Pad(object):
293 | """Pad the given PIL.Image on all sides with the given "pad" value.
294 |
295 | Args:
296 | padding (int or sequence): Padding on each border. If a sequence of
297 | length 4, it is used to pad left, top, right and bottom borders respectively.
298 | fill: Pixel fill value. Default is 0.
299 | """
300 |
301 | def __init__(self, padding, fill=0):
302 | assert isinstance(padding, numbers.Number)
303 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
304 | self.padding = padding
305 | self.fill = fill
306 |
307 | def __call__(self, img):
308 | """
309 | Args:
310 | img (PIL.Image): Image to be padded.
311 |
312 | Returns:
313 | PIL.Image: Padded image.
314 | """
315 | return ImageOps.expand(img, border=self.padding, fill=self.fill)
316 |
317 |
318 | class Lambda(object):
319 | """Apply a user-defined lambda as a transform.
320 |
321 | Args:
322 | lambd (function): Lambda/function to be used for transform.
323 | """
324 |
325 | def __init__(self, lambd):
326 | assert isinstance(lambd, types.LambdaType)
327 | self.lambd = lambd
328 |
329 | def __call__(self, img):
330 | return self.lambd(img)
331 |
332 |
333 | class RandomCrop(object):
334 | """Crop the given PIL.Image at a random location.
335 |
336 | Args:
337 | size (sequence or int): Desired output size of the crop. If size is an
338 | int instead of sequence like (w, h), a square crop (size, size) is
339 | made.
340 | padding (int or sequence, optional): Optional padding on each border
341 | of the image. Default is 0, i.e no padding. If a sequence of length
342 | 4 is provided, it is used to pad left, top, right, bottom borders
343 | respectively.
344 | """
345 |
346 | def __init__(self, size, padding=0):
347 | if isinstance(size, numbers.Number):
348 | self.size = (int(size), int(size), int(size))
349 | else:
350 | self.size = int(size)
351 | self.padding = int(padding)
352 |
353 | def __call__(self, img):
354 | """
355 | Args:
356 | img (PIL.Image): Image to be cropped.
357 |
358 | Returns:
359 | PIL.Image: Cropped image.
360 | """
361 | if self.padding > 0:
362 | # print 'scale out', img.shape
363 | pad = int(self.padding / 2)
364 | img1 = np.ones((img.shape[0] + pad, img.shape[1] + pad, img.shape[2] + pad)) * 170
365 | bg = int(self.padding / 2)
366 | img1[bg:bg + img.shape[0], bg:bg + img.shape[1], bg:bg + img.shape[2]] = np.array(img)
367 | img = np.array(img1)
368 | # img = ImageOps.expand(img, border=self.padding, fill=170)
369 |
370 | w, h, d = img.shape # size
371 | th, tw, td = self.size
372 | # print 'pad out', w, h, d, th, tw, td
373 | if w == tw and h == th and d == td:
374 | return img
375 | x1 = random.randint(0, w - tw)
376 | y1 = random.randint(0, h - th)
377 | z1 = random.randint(0, d - td)
378 | return np.array(img[x1:x1 + th, y1:y1 + tw, z1:z1 + td])
379 | # return img.crop((x1, y1, x1 + tw, y1 + th, z1 + td))
380 |
381 |
382 | class RandomHorizontalFlip(object):
383 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
384 |
385 | def __call__(self, img):
386 | """
387 | Args:
388 | img (PIL.Image): Image to be flipped.
389 |
390 | Returns:
391 | PIL.Image: Randomly flipped image.
392 | """
393 | if random.random() < 0.5:
394 | return np.array(img[:, :, ::-1]) # .transpose(Image.FLIP_LEFT_RIGHT)
395 | return img
396 |
397 |
398 | class RandomZFlip(object):
399 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
400 |
401 | def __call__(self, img):
402 | """
403 | Args:
404 | img (PIL.Image): Image to be flipped.
405 |
406 | Returns:
407 | PIL.Image: Randomly flipped image.
408 | """
409 | if random.random() < 0.5:
410 | return np.array(img[::-1, :, :])
411 | return img
412 |
413 |
414 | class RandomYFlip(object):
415 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
416 |
417 | def __call__(self, img):
418 | """
419 | Args:
420 | img (PIL.Image): Image to be flipped.
421 |
422 | Returns:
423 | PIL.Image: Randomly flipped image.
424 | """
425 | if random.random() < 0.5:
426 | return np.array(img[:, ::-1, :])
427 | return img
428 |
429 |
430 | class RandomSizedCrop(object):
431 | """Crop the given PIL.Image to random size and aspect ratio.
432 |
433 | A crop of random size of (0.08 to 1.0) of the original size and a random
434 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
435 | is finally resized to given size.
436 | This is popularly used to train the Inception networks.
437 |
438 | Args:
439 | size: size of the smaller edge
440 | interpolation: Default: PIL.Image.BILINEAR
441 | """
442 |
443 | def __init__(self, size, interpolation=Image.BILINEAR):
444 | self.size = size
445 | self.interpolation = interpolation
446 |
447 | def __call__(self, img):
448 | for attempt in range(10):
449 | area = img.size[0] * img.size[1]
450 | target_area = random.uniform(0.08, 1.0) * area
451 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
452 |
453 | w = int(round(math.sqrt(target_area * aspect_ratio)))
454 | h = int(round(math.sqrt(target_area / aspect_ratio)))
455 |
456 | if random.random() < 0.5:
457 | w, h = h, w
458 |
459 | if w <= img.size[0] and h <= img.size[1]:
460 | x1 = random.randint(0, img.size[0] - w)
461 | y1 = random.randint(0, img.size[1] - h)
462 |
463 | img = img.crop((x1, y1, x1 + w, y1 + h))
464 | assert (img.size == (w, h))
465 |
466 | return img.resize((self.size, self.size), self.interpolation)
467 |
468 | # Fallback
469 | scale = Scale(self.size, interpolation=self.interpolation)
470 | crop = CenterCrop(self.size)
471 | return crop(scale(img))
472 |
--------------------------------------------------------------------------------
/data/extclsshpinfo.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import SimpleITK as sitk
3 | import os
4 | import os.path
5 | import numpy as np
6 |
7 |
8 | def load_itk_image(filename):
9 | with open(filename) as f:
10 | contents = f.readlines()
11 | line = [k for k in contents if k.startswith('TransformMatrix')][0]
12 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
13 | transformM = np.round(transformM)
14 | if np.any(transformM != np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])):
15 | isflip = True
16 | else:
17 | isflip = False
18 | itkimage = sitk.ReadImage(filename)
19 | numpyImage = sitk.GetArrayFromImage(itkimage)
20 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
21 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
22 | return numpyImage, numpyOrigin, numpySpacing, isflip
23 |
24 |
25 | def worldToVoxelCoord(worldCoord, origin, spacing):
26 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
27 | voxelCoord = stretchedVoxelCoord / spacing
28 | return voxelCoord
29 |
30 |
31 | # read map file
32 | mapfname = 'LIDC-IDRI-mappingLUNA16'
33 | sidmap = {}
34 | fid = open(mapfname, 'r')
35 | line = fid.readline()
36 | line = fid.readline()
37 | while line:
38 | pidlist = line.split(' ')
39 | # print pidlist
40 | pid = pidlist[0]
41 | stdid = pidlist[1]
42 | srsid = pidlist[2]
43 | if srsid not in sidmap:
44 | sidmap[srsid] = [pid, stdid]
45 | else:
46 | assert sidmap[srsid][0] == pid
47 | assert sidmap[srsid][1] == stdid
48 | line = fid.readline()
49 | fid.close()
50 | # read luna16 annotation
51 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']
52 | lunaantframe = pd.read_csv('annotations.csv', names=colname)
53 | srslist = lunaantframe.seriesuid.tolist()[1:]
54 | cdxlist = lunaantframe.coordX.tolist()[1:]
55 | cdylist = lunaantframe.coordY.tolist()[1:]
56 | cdzlist = lunaantframe.coordZ.tolist()[1:]
57 | dimlist = lunaantframe.diameter_mm.tolist()[1:]
58 | lunaantdict = {}
59 | for idx in range(len(srslist)):
60 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])]
61 | if srslist[idx] in lunaantdict:
62 | lunaantdict[srslist[idx]].append(vlu)
63 | else:
64 | lunaantdict[srslist[idx]] = [vlu]
65 | # convert luna16 annotation to LIDC-IDRI annotation space
66 | from multiprocessing import Pool
67 |
68 | lunantdictlidc = {}
69 | for fold in range(10):
70 | mhdpath = '/media/data1/wentao/tianchi/luna16/subset' + str(fold)
71 | print('fold', fold)
72 |
73 |
74 | def getvoxcrd(fname):
75 | sliceim, origin, spacing, isflip = load_itk_image(os.path.join(mhdpath, fname))
76 | lunantdictlidc[fname[:-4]] = []
77 | voxcrdlist = []
78 | for lunaant in lunaantdict[fname[:-4]]:
79 | voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
80 | voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
81 | voxcrdlist.append(voxcrd)
82 | return voxcrdlist
83 |
84 |
85 | p = Pool(30)
86 | fnamelist = []
87 | for fname in os.listdir(mhdpath):
88 | if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
89 | fnamelist.append(fname)
90 | voxcrdlist = p.map(getvoxcrd, fnamelist)
91 | listidx = 0
92 | for fname in os.listdir(mhdpath):
93 | if fname.endswith('.mhd') and fname[:-4] in lunaantdict:
94 | lunantdictlidc[fname[:-4]] = []
95 | for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]):
96 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing)
97 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0]
98 | lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]])
99 | listidx += 1
100 | p.close()
101 | np.save('lunaantdictlidc.npy', lunantdictlidc)
102 | # read LIDC dataset
103 | lunantdictlidc = np.load('lunaantdictlidc.npy').item()
104 | import xlrd
105 |
106 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls'
107 | antdict = {}
108 | wb = xlrd.open_workbook(os.path.join(lidccsvfname))
109 | for s in wb.sheets():
110 | if s.name == 'list3.2':
111 | for row in range(1, s.nrows):
112 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \
113 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)]
114 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8
115 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8
116 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8
117 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8
118 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8
119 | for col in range(9, 16):
120 | if s.cell(row, col).value != '':
121 | if isinstance(s.cell(row, col).value, float):
122 | valuelist.append(int(s.cell(row, col).value))
123 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8
124 | else:
125 | valuelist.append(s.cell(row, col).value)
126 | if s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value)) not in antdict:
127 | antdict[s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value))] = [valuelist]
128 | else:
129 | antdict[s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value))].append(valuelist)
130 | # update LIDC annotation with series number, rather than scan id
131 | import dicom
132 |
133 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/'
134 | antdictscan = {}
135 | for k, v in antdict.iteritems():
136 | pid, scan = k.split('_')
137 | hasscan = False
138 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-' + pid)):
139 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-' + pid, sdu])):
140 | if srs.endswith('.npy'):
141 | print('npy', pid, scan, srs)
142 | continue
143 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-' + pid, sdu, srs, '000006.dcm']))
144 | # print scan, str(RefDs[0x20, 0x11].value)
145 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0':
146 | if hasscan: print('rep', pid, sdu, srs)
147 | hasscan = True
148 | antdictscan[pid + '_' + srs] = v
149 | break
150 | if not hasscan: print('not found', pid, scan, sdu, srs)
151 | # find the match from LIDC-IDRI annotation
152 | import math
153 |
154 | lunaantdictnodid = {}
155 | maxdist = 0
156 | for srcid, lunaantlidc in lunantdictlidc.iteritems():
157 | lunaantdictnodid[srcid] = []
158 | pid, stdid = sidmap[srcid]
159 | # print pid
160 | pid = pid[len('LIDC-IDRI-'):]
161 | for lunantdictlidcsub in lunaantlidc:
162 | lunaant = lunantdictlidcsub[0]
163 | voxcrd = lunantdictlidcsub[1] # z y x
164 | mindist, minidx = 1e8, -1
165 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479',
166 | '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']:
167 | continue
168 | for idx, lidcant in enumerate(antdictscan[pid + '_' + srcid]):
169 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z
170 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y
171 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x
172 | if dist < mindist:
173 | mindist = dist
174 | minidx = idx
175 | if mindist > 71: # 15.1:
176 | print(srcid, pid, voxcrd, antdictscan[pid + '_' + srcid], mindist)
177 | maxdist = max(maxdist, mindist)
178 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid + '_' + srcid][minidx][6:]])
179 | np.save('lunaantdictnodid.npy', lunaantdictnodid)
180 | print('maxdist', maxdist)
181 | # save it into a csv
182 | import csv
183 |
184 | savename = 'annotationnodid.csv'
185 | fid = open(savename, 'w')
186 | writer = csv.writer(fid)
187 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'])
188 | for srcid, ant in lunaantdictnodid.iteritems():
189 | for antsub in ant:
190 | writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1])
191 | fid.close()
192 | # find the malignancy, shape information from xml file
193 | import xml.dom.minidom
194 |
195 | lunadctclssgmdict = {}
196 | for srsid, extant in lunaantdictnodid.iteritems():
197 | lunadctclssgmdict[srsid] = []
198 | pid, stdid = sidmap[srsid]
199 | for extantvlu in extant:
200 | mallst, callst, sphlst, marlst, loblst, spilst, texlst = [], [], [], [], [], [], []
201 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])):
202 | if fname.endswith('.xml'):
203 | dom = xml.dom.minidom.parse(
204 | os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname]))
205 | root = dom.documentElement
206 | rsessions = root.getElementsByTagName('readingSession')
207 | for rsess in rsessions:
208 | unblinds = rsess.getElementsByTagName('unblindedReadNodule')
209 | for unb in unblinds:
210 | nod = unb.getElementsByTagName('noduleID')
211 | if len(nod) != 1: continue
212 | if nod[0].firstChild.data in extantvlu[1]:
213 | cal = unb.getElementsByTagName('calcification')
214 | # print cal[0].firstChild.data, range(1,7,1), int(cal[0].firstChild.data) in range(1,7,1)
215 | if len(cal) == 1 and int(cal[0].firstChild.data) in range(1, 7, 1):
216 | callst.append(float(cal[0].firstChild.data))
217 | sph = unb.getElementsByTagName('sphericity')
218 | if len(sph) == 1 and int(sph[0].firstChild.data) in range(1, 6, 1):
219 | sphlst.append(float(sph[0].firstChild.data))
220 | mar = unb.getElementsByTagName('margin')
221 | if len(mar) == 1 and int(mar[0].firstChild.data) in range(1, 6, 1):
222 | marlst.append(float(mar[0].firstChild.data))
223 | lob = unb.getElementsByTagName('lobulation')
224 | if len(lob) == 1 and int(lob[0].firstChild.data) in range(1, 6, 1):
225 | loblst.append(float(lob[0].firstChild.data))
226 | spi = unb.getElementsByTagName('spiculation')
227 | if len(spi) == 1 and int(spi[0].firstChild.data) in range(1, 6, 1):
228 | spilst.append(float(spi[0].firstChild.data))
229 | tex = unb.getElementsByTagName('texture')
230 | if len(tex) == 1 and int(tex[0].firstChild.data) in range(1, 6, 1):
231 | texlst.append(float(tex[0].firstChild.data))
232 | mal = unb.getElementsByTagName('malignancy')
233 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1):
234 | mallst.append(float(mal[0].firstChild.data))
235 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]]
236 | if len(mallst) == 0:
237 | vlulst.append(0)
238 | else:
239 | vlulst.append(sum(mallst) / float(len(mallst)))
240 | if len(callst) == 0:
241 | vlulst.append(0)
242 | else:
243 | vlulst.append(sum(callst) / float(len(callst)))
244 | if len(sphlst) == 0:
245 | vlulst.append(0)
246 | else:
247 | vlulst.append(sum(sphlst) / float(len(sphlst)))
248 | if len(marlst) == 0:
249 | vlulst.append(0)
250 | else:
251 | vlulst.append(sum(marlst) / float(len(marlst)))
252 | if len(loblst) == 0:
253 | vlulst.append(0)
254 | else:
255 | vlulst.append(sum(loblst) / float(len(loblst)))
256 | if len(spilst) == 0:
257 | vlulst.append(0)
258 | else:
259 | vlulst.append(sum(spilst) / float(len(spilst)))
260 | if len(texlst) == 0:
261 | vlulst.append(0)
262 | else:
263 | vlulst.append(sum(texlst) / float(len(texlst)))
264 | lunadctclssgmdict[srsid].append(vlulst)
265 | # lunadctclssgmdict[srsid].append([extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]]+\
266 | # [sum(mallst)/float(len(mallst)), sum(callst)/float(len(callst)), sum(sphlst)/float(len(sphlst)), \
267 | # sum(marlst)/float(len(marlst)), sum(loblst)/float(len(loblst)), sum(spilst)/float(len(spilst)), \
268 | # sum(texlst)/float(len(texlst))])
269 | np.save('lunadctclssgmdict.npy', lunadctclssgmdict)
270 | savename = 'annotationdetclssgm.csv'
271 | fid = open(savename, 'w')
272 | writer = csv.writer(fid)
273 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant', 'calcification', \
274 | 'sphericity', 'margin', 'lobulation', 'spiculation', 'texture'])
275 | for srsid, extant in lunadctclssgmdict.iteritems():
276 | for subextant in extant:
277 | writer.writerow(subextant)
278 | fid.close()
279 | # discrete the generated csv
280 | import pandas as pd
281 | import csv
282 |
283 | srcname = 'annotationdetclssgm.csv'
284 | dstname = 'annotationdetclssgmfnl.csv'
285 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant', 'calcification', 'sphericity', \
286 | 'margin', 'lobulation', 'spiculation', 'texture']
287 | srcframe = pd.read_csv(srcname, names=colname)
288 | srslist = srcframe.seriesuid.tolist()[1:]
289 | cdxlist = srcframe.coordX.tolist()[1:]
290 | cdylist = srcframe.coordY.tolist()[1:]
291 | cdzlist = srcframe.coordZ.tolist()[1:]
292 | dimlist = srcframe.diameter_mm.tolist()[1:]
293 | mlglist = srcframe.malignant.tolist()[1:]
294 | callist = srcframe.calcification.tolist()[1:]
295 | sphlist = srcframe.sphericity.tolist()[1:]
296 | mrglist = srcframe.margin.tolist()[1:]
297 | loblist = srcframe.lobulation.tolist()[1:]
298 | spclist = srcframe.spiculation.tolist()[1:]
299 | txtlist = srcframe.texture.tolist()[1:]
300 | fid = open(dstname, 'w')
301 | writer = csv.writer(fid)
302 | writer.writerow(colname)
303 | for idx in range(len(srslist)):
304 | lst = [srslist[idx], cdxlist[idx], cdylist[idx], cdzlist[idx], dimlist[idx]]
305 | if abs(float(mlglist[idx]) - 0) < 1e-2: # 0 1 2
306 | lst.append(0)
307 | elif abs(float(mlglist[idx]) - 3) < 1e-2:
308 | lst.append(0)
309 | elif float(mlglist[idx]) > 3:
310 | lst.append(2)
311 | else:
312 | lst.append(1)
313 | lst.append(int(round(float(callist[idx])))) # 0 - 6
314 | if abs(float(sphlist[idx]) - 0) < 1e-2: # 0 1 2 3
315 | lst.append(0)
316 | elif float(sphlist[idx]) < 2:
317 | lst.append(1)
318 | elif float(sphlist[idx]) < 4:
319 | lst.append(2)
320 | else:
321 | lst.append(3)
322 | if abs(float(mrglist[idx]) - 0) < 1e-2: # 0 1 2
323 | lst.append(0)
324 | elif float(mrglist[idx]) < 3:
325 | lst.append(1)
326 | else:
327 | lst.append(2)
328 | if abs(float(loblist[idx]) - 0) < 1e-2: # 0 1 2
329 | lst.append(0)
330 | elif float(loblist[idx]) < 3:
331 | lst.append(1)
332 | else:
333 | lst.append(2)
334 | if abs(float(spclist[idx]) - 0) < 1e-2: # 0 1 2
335 | lst.append(0)
336 | elif float(spclist[idx]) < 3:
337 | lst.append(1)
338 | else:
339 | lst.append(2)
340 | if abs(float(txtlist[idx]) - 0) < 1e-2: # 0 1 2 3
341 | lst.append(0)
342 | elif float(txtlist[idx]) < 2:
343 | lst.append(1)
344 | elif float(txtlist[idx]) < 4:
345 | lst.append(2)
346 | else:
347 | lst.append(3)
348 | writer.writerow(lst)
349 | fid.close()
350 | # fuse annotations for different nodules, generate patient level annotation
351 | import pandas as pd
352 | import csv
353 |
354 | antpdframe = pd.read_csv('annotationdetclssgmfnl.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', \
355 | 'diameter_mm', 'malignant', 'calcification', 'sphericity',
356 | 'margin', 'lobulation', 'spiculation', 'texture'])
357 | srslst = antpdframe.seriesuid.tolist()[1:]
358 | cdxlst = antpdframe.coordX.tolist()[1:]
359 | cdylst = antpdframe.coordY.tolist()[1:]
360 | cdzlst = antpdframe.coordZ.tolist()[1:]
361 | mlglst = antpdframe.malignant.tolist()[1:]
362 | dimlst = antpdframe.diameter_mm.tolist()[1:]
363 | clclst = antpdframe.calcification.tolist()[1:]
364 | sphlst = antpdframe.sphericity.tolist()[1:]
365 | mrglst = antpdframe.margin.tolist()[1:]
366 | loblst = antpdframe.lobulation.tolist()[1:]
367 | spclst = antpdframe.spiculation.tolist()[1:]
368 | txtlst = antpdframe.texture.tolist()[1:]
369 | dctdat = {}
370 | for idx, srs in enumerate(srslst):
371 | if mlglst[idx] == '0':
372 | continue
373 | vlu = [mlglst[idx], clclst[idx], sphlst[idx], mrglst[idx], loblst[idx], spclst[idx], txtlst[idx]]
374 | if srs not in dctdat:
375 | dctdat[srs] = [vlu]
376 | else:
377 | dctdat[srs].append(vlu)
378 | fid = open('annotationdetclssgmv2.csv', 'w')
379 | writer = csv.writer(fid)
380 | writer.writerow(['seriesuid', 'malignant'])
381 | for srs, vlulst in dctdat.iteritems():
382 | mlg = -1
383 | for vlu in vlulst:
384 | mlg = max(mlg, vlu[0])
385 | writer.writerow([srs, mlg])
386 | fid.close()
387 |
--------------------------------------------------------------------------------