├── models ├── __init__.py ├── layers.py ├── cnn_res.py └── net_sphere.py ├── searchspace ├── __init__.py ├── res_search_space.py └── search_space_utils.py ├── configuration ├── __init__.py └── config.py ├── generalnetwork ├── __init__.py └── dpn3d.py ├── _config.yml ├── data ├── model.npy ├── list3.2.csv ├── dataanalysis.py ├── kappatest.py ├── dimcls.py ├── data_enhancement.py ├── dataconverter.py ├── nodclsgbt.py ├── humanperformance.py ├── pthumanperformance.py └── extclsshpinfo.py ├── imgs └── architecture.png ├── requirements.txt ├── .idea ├── dictionaries │ └── 10485.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── misc.xml ├── NAS-Lung.iml ├── remote-mappings.xml └── deployment.xml ├── run_training.sh ├── LICENSE ├── search_main.py ├── .gitignore ├── dataloader.py ├── README.md ├── test.py ├── random_forest.py ├── data_utils.py ├── main.py ├── mainsp.py └── transforms.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /searchspace/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generalnetwork/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /data/model.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/data/model.npy -------------------------------------------------------------------------------- /data/list3.2.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/data/list3.2.csv -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fei-aiart/NAS-Lung/HEAD/imgs/architecture.png -------------------------------------------------------------------------------- /configuration/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class config: 4 | def __init__(self, channel_range): -------------------------------------------------------------------------------- /data/dataanalysis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | a = pd.read_csv('../data/annotationdetclsconvfnl_v3.csv') 4 | path = 'F:\\医学数据集\\LUNA\\rowfile\\subset5' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.1 2 | pandas>=0.25.0 3 | matplotlib>=3.1.0 4 | scikit-learn>=0.22.1 5 | torch>=0.4.1.post2 6 | dill>=0.3.0 7 | scipy>=1.4.1 8 | seaborn>=0.10.0 9 | Pillow>=7.0.0 -------------------------------------------------------------------------------- /.idea/dictionaries/10485.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | cbam 5 | learnable 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /run_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | 5 | maxeps=9 6 | 7 | 8 | for (( i=0; i<=$maxeps; i+=1)) 9 | do 10 | echo "process $i epoch" 11 | CUDA_VISIBLE_DEVICES=0 python main.py --batch_size 8 --num_epochs 400 --fold $i 12 | done 13 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/NAS-Lung.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /data/kappatest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | pp = [0]*4#[111]*4#[0]*4 4 | nn = [0]*4#[106]*4#[0]*4 5 | pn = [0]*4#[12]*4#[0]*4 6 | npp = [0]*4#[47]*4#[0]*4 7 | 8 | for fd in [1,2,3,5]:#xrange(1,4,1): 9 | dctprd = np.load('/media/data1/wentao/tianchi/luna16/CSVFILES/dctptlabel'+str(fd)+'.npy').item() 10 | for d in range(4): 11 | modprd = np.load('modprd'+str(d+1)+'fd'+str(fd)+'.npy').item() 12 | for k, v in dctprd.iteritems(): 13 | if v[d] != -1: 14 | assert v[d] in [0, 1] 15 | if v[d] == 1 and modprd[k] == 1: pp[d] += 1 16 | if v[d] == 0 and modprd[k] == 0: nn[d] += 1 17 | if v[d] == 1 and modprd[k] == 0: pn[d] += 1 18 | if v[d] == 0 and modprd[k] == 1: npp[d] += 1 19 | print(pp, nn, pn, npp) 20 | for d in range(4): 21 | n = pp[d] + nn[d] + pn[d] + npp[d] 22 | p0 = (pp[d] + nn[d]) / float(n) 23 | pe = (pp[d] + pn[d]) * (pp[d] + npp[d]) 24 | pe += (nn[d] + pn[d]) * (nn[d] + npp[d]) 25 | pe /= float(n * n) 26 | print((p0-pe)/(1-pe)) 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Fei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /generalnetwork/dpn3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | self.last_planes = last_planes 13 | self.in_planes = in_planes 14 | 15 | self.conv1 = nn.Conv3d(last_planes, in_planes, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm3d(in_planes) 17 | self.conv2 = nn.Conv3d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 18 | self.bn2 = nn.BatchNorm3d(in_planes) 19 | self.conv3 = nn.Conv3d(in_planes, out_planes + dense_depth, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm3d(out_planes + dense_depth) 21 | 22 | self.shortcut = nn.Sequential() 23 | if first_layer: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv3d(last_planes, out_planes + dense_depth, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm3d(out_planes + dense_depth) 27 | ) 28 | 29 | def forward(self, x): 30 | # print 'bottleneck_0', x.size(), self.last_planes, self.in_planes, 1 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | # print 'bottleneck_1', out.size(), self.in_planes, self.in_planes, 3 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | # print 'bottleneck_2', out.size(), self.in_planes, self.out_planes+self.dense_depth, 1 35 | out = self.bn3(self.conv3(out)) 36 | # print 'bottleneck_3', out.size() 37 | x = self.shortcut(x) 38 | d = self.out_planes 39 | # print 'bottleneck_4', x.size(), self.last_planes, self.out_planes+self.dense_depth, d 40 | out = torch.cat([x[:, :d, :, :] + out[:, :d, :, :], x[:, d:, :, :], out[:, d:, :, :]], 1) 41 | # print 'bottleneck_5', out.size() 42 | out = F.relu(out) 43 | return out 44 | 45 | -------------------------------------------------------------------------------- /search_main.py: -------------------------------------------------------------------------------- 1 | import searchspace.res_search_space as res_search_space 2 | import torch.nn as nn 3 | import logging 4 | import torch 5 | import argparse 6 | 7 | # set args 8 | parser = argparse.ArgumentParser(description='searching') 9 | # parser.add_argument('--sub', type=int, default=5, help="sub data set") 10 | parser.add_argument('--fold', type=int, default=5, help="fold") 11 | parser.add_argument('--gpu_id', type=str, default='0', help="gpu_id") 12 | parser.add_argument('--lr', type=float, default=0.0002, help="lr") 13 | parser.add_argument('--epoch', type=int, default=20, help="epoch") 14 | parser.add_argument('--num_workers', type=int, default=20, help="num_workers") 15 | parser.add_argument('--train_data_path', type=str, default='/data/xxx/LUNA/cls/crop_v3', help="train_data_path") 16 | parser.add_argument('--test_data_path', type=str, default='/data/xxx/LUNA/rowfile/subset', help="test_data_path") 17 | parser.add_argument('--batch_size', type=int, default=8, help="batch_size") 18 | parser.add_argument('--max_depth', type=int, default=9, help="max_depth") 19 | parser.add_argument('--min_depth', type=int, default=3, help="min_depth") 20 | parser.add_argument('--save_module_path', type=str, default='Module') 21 | parser.add_argument('--log_file', type=str, default='log_search') 22 | 23 | if __name__ == '__main__': 24 | args = parser.parse_args() 25 | fold = args.fold 26 | channel_range = [4, 8, 16, 32, 64, 128] 27 | batch_size = args.batch_size 28 | max_depth = args.max_depth 29 | min_depth = args.min_depth 30 | criterion = nn.CrossEntropyLoss() 31 | gpu_id = args.gpu_id 32 | 33 | input_shape = [1, 1, 32, 32, 32] 34 | logging.basicConfig(filename=args.log_file, level=logging.INFO) 35 | use_gpu = torch.cuda.is_available() 36 | train_data_path = args.train_data_path 37 | test_data_path = args.test_data_path 38 | lr = args.lr 39 | save_module_path = args.save_module_path 40 | num_workers = args.num_workers 41 | epoch = args.epoch 42 | # sub = args.sub 43 | # search model 44 | res_search = res_search_space.ResSearchSpace(channel_range, max_depth, min_depth, train_data_path, test_data_path, fold, 45 | batch_size, logging, input_shape, use_gpu, gpu_id, criterion, lr, 46 | save_module_path, num_works=num_workers, epoch=epoch) 47 | res_search.main_method() 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data/dimcls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | CROPSIZE = 36 5 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 6 | srslst = pdframe['seriesuid'].tolist()[1:] 7 | crdxlst = pdframe['coordX'].tolist()[1:] 8 | crdylst = pdframe['coordY'].tolist()[1:] 9 | crdzlst = pdframe['coordZ'].tolist()[1:] 10 | dimlst = pdframe['diameter_mm'].tolist()[1:] 11 | mlglst = pdframe['malignant'].tolist()[1:] 12 | 13 | newlst = [] 14 | import csv 15 | fid = open('annotationdetclsconvfnl_v3.csv', 'w') 16 | writer = csv.writer(fid) 17 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 18 | for i in range(len(srslst)): 19 | writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 20 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 21 | fid.close() 22 | 23 | # train use gbt 24 | subset1path = '/media/data1/wentao/tianchi/luna16/subset1/' 25 | testfnamelst = [] 26 | for fname in os.listdir(subset1path): 27 | if fname.endswith('.mhd'): 28 | testfnamelst.append(fname[:-4]) 29 | ntest = 0 30 | for idx in range(len(newlst)): 31 | fname = newlst[idx][0] 32 | if fname.split('-')[0] in testfnamelst: ntest +=1 33 | print('ntest', ntest, 'ntrain', len(newlst)-ntest) 34 | 35 | traindata = np.zeros((len(newlst)-ntest,)) 36 | trainlabel = np.zeros((len(newlst)-ntest,)) 37 | testdata = np.zeros((ntest,)) 38 | testlabel = np.zeros((ntest,)) 39 | 40 | trainidx = testidx = 0 41 | for idx in range(len(newlst)): 42 | fname = newlst[idx][0] 43 | if fname.split('-')[0] in testfnamelst: 44 | testdata[testidx] = newlst[idx][-2] 45 | testlabel[testidx] = newlst[idx][-1] 46 | testidx += 1 47 | else: 48 | traindata[trainidx] = newlst[idx][-2] 49 | trainlabel[trainidx] = newlst[idx][-1] 50 | trainidx += 1 51 | 52 | tracclst = [] 53 | teacclst = [] 54 | thlst = np.sort(traindata).tolist() 55 | besttr = bestte = 0 56 | for th in thlst: 57 | tracc = np.mean(trainlabel == (traindata > th)) 58 | teacc = np.mean(testlabel == (testdata > th)) 59 | if tracc > besttr: 60 | besttr = tracc 61 | bestte = teacc 62 | tracclst.append(tracc) 63 | teacclst.append(teacc) 64 | import matplotlib 65 | matplotlib.use('Agg') 66 | import matplotlib.pyplot as plt 67 | plt.plot(thlst, tracclst, label='train acc') 68 | plt.plot(thlst, teacclst, label='test acc') 69 | plt.xlabel('Threshold for diameter (mm)') 70 | plt.ylabel('Diagnosis (malignant vs. benign) accuracy (%)') 71 | plt.title('Diagnosis accuracy using diameter feature on fold 1') 72 | plt.legend() 73 | plt.savefig('accwrtdim.png') 74 | print(max(teacclst)) 75 | 76 | print(besttr, bestte) -------------------------------------------------------------------------------- /data/data_enhancement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | 5 | 6 | def horizontal_flip(img): 7 | x, y, d = img.shape 8 | for i in range(y // 2): 9 | intermediate = img[:, i, :] 10 | img[:, i, :] = img[:, y - 1 - i, :] 11 | img[:, y - 1 - i, :] = intermediate 12 | return img 13 | 14 | 15 | def vertical_flip(img): 16 | x, y, d = img.shape 17 | for i in range(x // 2): 18 | intermediate = img[i, :, :] 19 | img[i, :, :] = img[x - 1 - i, :, :] 20 | img[x - 1 - i, :, :] = intermediate 21 | return img 22 | 23 | 24 | def deep_flip(img): 25 | x, y, d = img.shape 26 | for i in range(d // 2): 27 | intermediate = img[:, :, i] 28 | img[:, :, i] = img[:, :, d - 1 - i] 29 | img[:, :, d - 1 - i] = intermediate 30 | return img 31 | 32 | 33 | path = "F:\\医学数据集\\LUNA\\cls\\crop_v3\\" 34 | dataframe = pd.read_csv("F:/PycharmProjects/DeepLung/data/annotationdetclsconvfnl_v3.csv", encoding='utf-8') 35 | data_list = dataframe.to_numpy()[1:] 36 | enhancement_data = dataframe.to_numpy() 37 | test_data = np.empty(shape=(0, 32, 32, 32)) 38 | train_data = np.empty(shape=(0, 32, 32, 32)) 39 | teidlst = [] 40 | for fname in os.listdir('F:\\医学数据集\\LUNA\\rowfile\\subset5' + '\\'): 41 | if fname.endswith('.mhd'): 42 | teidlst.append(fname[:-4]) 43 | for data in data_list: 44 | if data[0].split('-')[0] not in teidlst: 45 | a = np.load(path + data[0] + '.npy') 46 | horizontal_data = horizontal_flip(a) 47 | horizontal_name = path + data[0] + 'horizontal' 48 | horizontal_information = np.copy(data) 49 | horizontal_information[0] = data[0] + 'horizontal' 50 | enhancement_data = np.append(enhancement_data, [horizontal_information], 0) 51 | # np.save(horizontal_name, horizontal_data) 52 | vertical_data = vertical_flip(a) 53 | vertical_name = path + data[0] + 'vertical' 54 | vertical_information = np.copy(data) 55 | vertical_information[0] = data[0] + 'vertical' 56 | enhancement_data = np.append(enhancement_data, [vertical_information], 0) 57 | # np.save(vertical_name, vertical_data) 58 | deep_data = deep_flip(a) 59 | deep_name = path + data[0] + 'deep' 60 | deep_information = np.copy(data) 61 | deep_information[0] = data[0] + 'deep' 62 | enhancement_data = np.append(enhancement_data, [deep_information], 0) 63 | # np.save(deep_name, deep_data) 64 | frame = pd.DataFrame(enhancement_data, index=None, 65 | columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 66 | frame.to_csv('H:\\a.csv', index=None, columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 67 | -------------------------------------------------------------------------------- /data/dataconverter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import pandas as pd 4 | import SimpleITK as sitk 5 | import os 6 | import os.path 7 | def load_itk_image(filename): 8 | with open(filename) as f: 9 | contents = f.readlines() 10 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 12 | transformM = np.round(transformM) 13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])): 14 | isflip = True 15 | else: 16 | isflip = False 17 | itkimage = sitk.ReadImage(filename) 18 | numpyImage = sitk.GetArrayFromImage(itkimage) 19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 21 | return numpyImage, numpyOrigin, numpySpacing,isflip 22 | def worldToVoxelCoord(worldCoord, origin, spacing): 23 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 24 | voxelCoord = stretchedVoxelCoord / spacing 25 | return voxelCoord 26 | # read groundtruth from original data space 27 | # remove data of 0 value 28 | pdframe = pd.read_csv('annotationdetclsgt.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 29 | srslst = pdframe['seriesuid'].tolist()[1:] 30 | crdxlst = pdframe['coordX'].tolist()[1:] 31 | crdylst = pdframe['coordY'].tolist()[1:] 32 | crdzlst = pdframe['coordZ'].tolist()[1:] 33 | dimlst = pdframe['diameter_mm'].tolist()[1:] 34 | mlglst = pdframe['malignant'].tolist()[1:] 35 | dct = {} 36 | for idx in range(len(srslst)): 37 | # if mlglst[idx] == '0': 38 | # continue 39 | assert mlglst[idx] in ['1', '0'] 40 | vlu = [float(crdxlst[idx]), float(crdylst[idx]), float(crdzlst[idx]), float(dimlst[idx]), int(mlglst[idx])] 41 | if srslst[idx] in dct: 42 | dct[srslst[idx]].append(vlu) 43 | else: 44 | dct[srslst[idx]] = [vlu] 45 | # convert it to the preprocessed space 46 | newlst = [] 47 | rawpath = '/media/data1/wentao/tianchi/luna16/lunaall/' 48 | preprocesspath = '/media/data1/wentao/tianchi/luna16/preprocess/lunaall/' 49 | resolution = np.array([1,1,1]) 50 | def process(pid): 51 | # print pid 52 | Mask,origin,spacing,isflip = load_itk_image(os.path.join(rawpath, pid+'.mhd')) 53 | spacing = np.load(os.path.join(preprocesspath, pid+'_spacing.npy')) 54 | extendbox = np.load(os.path.join(preprocesspath, pid+'_extendbox.npy')) 55 | origin = np.load(os.path.join(preprocesspath, pid+'_origin.npy')) 56 | if isflip: 57 | Mask = np.load(os.path.join(preprocesspath, pid+'_mask.npy')) 58 | retlst = [] 59 | for vlu in dct[pid]: 60 | pos = worldToVoxelCoord(vlu[:3][::-1], origin=origin, spacing=spacing) 61 | if isflip: 62 | pos[1:] = Mask.shape[1:3] - pos[1:] 63 | label = np.concatenate([pos, [vlu[3]/spacing[1]]]) 64 | label2 = np.expand_dims(np.copy(label), 1) 65 | # print label2.shape 66 | label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) 67 | label2[3] = label2[3]*spacing[1]/resolution[1] 68 | label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) 69 | label2 = label2[:4].T 70 | retlst.append([pid, label2[0,0], label2[0,1], label2[0,2], label2[0,3], vlu[-1]]) 71 | return retlst 72 | from multiprocessing import Pool 73 | p = Pool(30) 74 | newlst = p.map(process, dct.keys()) 75 | p.close() 76 | print(len(dct.keys()), len(newlst)) 77 | # for pid in dct.keys(): 78 | # print pid 79 | # Mask,origin,spacing,isflip = load_itk_image(os.path.join(rawpath, pid+'.mhd')) 80 | # spacing = np.load(os.path.join(preprocesspath, pid+'_spacing.npy')) 81 | # extendbox = np.load(os.path.join(preprocesspath, pid+'_extendbox.npy')) 82 | # origin = np.load(os.path.join(preprocesspath, pid+'_origin.npy')) 83 | # if isflip: 84 | # Mask = np.load(os.path.join(preprocesspath, pid+'_mask.npy')) 85 | # for vlu in dct[pid]: 86 | # pos = worldToVoxelCoord(vlu[:3][::-1], origin=origin, spacing=spacing) 87 | # if isflip: 88 | # pos[1:] = Mask.shape[1:3] - pos[1:] 89 | # label = np.concatenate([pos, [vlu[3]/spacing[1]]]) 90 | # label2 = np.expand_dims(np.copy(label), 1) 91 | # # print label2.shape 92 | # label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) 93 | # label2[3] = label2[3]*spacing[1]/resolution[1] 94 | # label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) 95 | # label2 = label2[:4].T 96 | # newlst.append([pid, label2[0,0], label2[0,1], label2[0,2], label2[0,3], vlu[-1]]) 97 | # save it to the csv 98 | savecsv = 'annotationdetclsconv_v3.csv' 99 | fid = open(savecsv, 'w') 100 | writer = csv.writer(fid) 101 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 102 | for idx in xrange(len(newlst)): 103 | for subidx in xrange(len(newlst[idx])): 104 | writer.writerow(newlst[idx][subidx]) 105 | fid.close() -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .net_sphere import * 7 | 8 | 9 | class ResCBAMLayer(nn.Module): 10 | """ 11 | CBAM+Res model 12 | """ 13 | 14 | def __init__(self, in_planes, feature_size): 15 | super(ResCBAMLayer, self).__init__() 16 | self.in_planes = in_planes 17 | self.feature_size = feature_size 18 | self.ch_AvgPool = nn.AvgPool3d(feature_size, feature_size) 19 | self.ch_MaxPool = nn.MaxPool3d(feature_size, feature_size) 20 | self.ch_Linear1 = nn.Linear(in_planes, in_planes // 4, bias=False) 21 | self.ch_Linear2 = nn.Linear(in_planes // 4, in_planes, bias=False) 22 | self.ch_Softmax = nn.Softmax(1) 23 | self.sp_Conv = nn.Conv3d(2, 1, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.sp_Softmax = nn.Softmax(1) 25 | 26 | def forward(self, x): 27 | x_ch_avg_pool = self.ch_AvgPool(x).view(x.size(0), -1) 28 | x_ch_max_pool = self.ch_MaxPool(x).view(x.size(0), -1) 29 | # x_ch_avg_linear = self.ch_Linear2(self.ch_Linear1(x_ch_avg_pool)) 30 | a = self.ch_Linear1(x_ch_avg_pool) 31 | x_ch_avg_linear = self.ch_Linear2(a) 32 | 33 | x_ch_max_linear = self.ch_Linear2(self.ch_Linear1(x_ch_max_pool)) 34 | ch_out = (self.ch_Softmax(x_ch_avg_linear + x_ch_max_linear).view(x.size(0), self.in_planes, 1, 1, 1)) * x 35 | x_sp_max_pool = torch.max(ch_out, 1, keepdim=True)[0] 36 | x_sp_avg_pool = torch.sum(ch_out, 1, keepdim=True) / self.in_planes 37 | sp_conv1 = torch.cat([x_sp_max_pool, x_sp_avg_pool], dim=1) 38 | sp_out = self.sp_Conv(sp_conv1) 39 | sp_out = self.sp_Softmax(sp_out.view(x.size(0), -1)).view(x.size(0), 1, x.size(2), x.size(3), x.size(4)) 40 | out = sp_out * x + x 41 | return out 42 | 43 | 44 | def make_conv3d(in_channels: int, out_channels: int, kernel_size: typing.Union[int, tuple], stride: int, 45 | padding: int, dilation=1, groups=1, 46 | bias=True) -> nn.Module: 47 | """ 48 | produce a Conv3D with Batch Normalization and ReLU 49 | 50 | :param in_channels: num of in in channels 51 | :param out_channels: num of out channels 52 | :param kernel_size: size of kernel int or tuple 53 | :param stride: num of stride 54 | :param padding: num of padding 55 | :param bias: bias 56 | :param groups: groups 57 | :param dilation: dilation 58 | :return: conv3d module 59 | """ 60 | module = nn.Sequential( 61 | 62 | nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 63 | groups=groups, 64 | bias=bias), 65 | nn.BatchNorm3d(out_channels), 66 | nn.ReLU()) 67 | return module 68 | 69 | 70 | def conv3d_same_size(in_channels, out_channels, kernel_size, stride=1, 71 | dilation=1, groups=1, 72 | bias=True): 73 | """ 74 | keep the w,h of inputs same as the outputs 75 | 76 | :param in_channels: num of in in channels 77 | :param out_channels: num of out channels 78 | :param kernel_size: size of kernel int or tuple 79 | :param stride: num of stride 80 | :param dilation: Spacing between kernel elements 81 | :param groups: Number of blocked connections from input channels to output channels. 82 | :param bias: If True, adds a learnable bias to the output 83 | :return: conv3d 84 | """ 85 | padding = kernel_size // 2 86 | return make_conv3d(in_channels, out_channels, kernel_size, stride, 87 | padding, dilation, groups, 88 | bias) 89 | 90 | 91 | def conv3d_pooling(in_channels, kernel_size, stride=1, 92 | dilation=1, groups=1, 93 | bias=False): 94 | """ 95 | pooling with convolution 96 | 97 | :param in_channels: 98 | :param kernel_size: 99 | :param stride: 100 | :param dilation: 101 | :param groups: 102 | :param bias: 103 | :return: pooling-convolution 104 | """ 105 | 106 | padding = kernel_size // 2 107 | return make_conv3d(in_channels, in_channels, kernel_size, stride, 108 | padding, dilation, groups, 109 | bias) 110 | 111 | 112 | class ResidualBlock(nn.Module): 113 | """ 114 | a simple residual block 115 | """ 116 | 117 | def __init__(self, in_channels, out_channels): 118 | super(ResidualBlock, self).__init__() 119 | self.my_conv1 = make_conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 120 | self.my_conv2 = make_conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 121 | self.conv3 = make_conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 122 | 123 | def forward(self, inputs): 124 | out1 = self.conv3(inputs) 125 | out = self.my_conv1(inputs) 126 | out = self.my_conv2(out) 127 | out = out + out1 128 | return out 129 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import errno 6 | import numpy as np 7 | import sys 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | import torch 13 | import torch.utils.data as data 14 | from torch.autograd import Variable 15 | # from .utils import download_url, check_integrity 16 | 17 | # npypath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/' 18 | class lunanod(data.Dataset): 19 | """`CIFAR10 `_ Dataset. 20 | 21 | Args: 22 | root (string): Root directory of dataset where directory 23 | ``cifar-10-batches-py`` exists. 24 | train (bool, optional): If True, creates dataset from training set, otherwise 25 | creates from test set. 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 28 | target_transform (callable, optional): A function/transform that takes in the 29 | target and transforms it. 30 | download (bool, optional): If true, downloads the dataset from the internet and 31 | puts it in root directory. If dataset is already downloaded, it is not 32 | downloaded again. 33 | 34 | """ 35 | def __init__(self, npypath, fnamelst, labellst, featlst, train=True, 36 | transform=None, target_transform=None, 37 | download=False): 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | self.train = train # training set or test set 41 | # now load the picked numpy arrays 42 | if self.train: 43 | self.train_data = [] 44 | self.train_labels = [] 45 | self.train_feat = featlst 46 | for label, fentry in zip(labellst, fnamelst): 47 | file = os.path.join(npypath, fentry) 48 | self.train_data.append(np.load(file)) 49 | self.train_labels.append(label) 50 | self.train_data = np.concatenate(self.train_data) 51 | print(len(fnamelst)) 52 | print(self.train_data.shape) 53 | self.train_data = self.train_data.reshape((len(fnamelst), 32, 32, 32)) 54 | # self.train_labels = np.asarray(self.train_labels) 55 | # self.train_data = self.train_data.transpose((0, 2, 3, 4, 1)) # convert to HWZC 56 | self.train_len = len(fnamelst) 57 | else: 58 | self.test_data = [] 59 | self.test_labels = [] 60 | self.test_feat = featlst 61 | for label, fentry in zip(labellst, fnamelst): 62 | # if fentry.shape[0] != 32 or fentry.shape[1] != 32 or fentry.shape[2] != 32: 63 | # print(fentry.shape, type(fentry), type(fentry)!='str') 64 | if not isinstance(fentry,str): 65 | self.test_data.append(fentry) 66 | self.test_labels.append(label) 67 | # print('1') 68 | else: 69 | file = os.path.join(npypath, fentry) 70 | self.test_data.append(np.load(file)) 71 | self.test_labels.append(label) 72 | self.test_data = np.concatenate(self.test_data) 73 | # print(self.test_data.shape) 74 | self.test_data = self.test_data.reshape((len(fnamelst), 32, 32, 32)) 75 | # self.test_labels = np.asarray(self.test_labels) 76 | # self.test_data = self.test_data.transpose((0, 2, 3, 4, 1)) # convert to HWZC 77 | self.test_len = len(fnamelst) 78 | print(self.test_data.shape, len(self.test_labels), len(self.test_feat)) 79 | 80 | def __getitem__(self, index): 81 | """ 82 | Args: 83 | index (int): Index 84 | 85 | Returns: 86 | tuple: (image, target) where target is index of the target class. 87 | """ 88 | if self.train: 89 | img, target, feat = self.train_data[index], self.train_labels[index], self.train_feat[index] 90 | else: 91 | img, target, feat = self.test_data[index], self.test_labels[index], self.test_feat[index] 92 | # img = torch.from_numpy(img) 93 | # img = img.cuda(async = True) 94 | 95 | # doing this so that it is consistent with all other datasets 96 | # to return a PIL Image 97 | # print('1', img.shape, type(img)) 98 | # img = Image.fromarray(img) 99 | # print('2', img.size) 100 | 101 | if self.transform is not None: 102 | img = self.transform(img) 103 | 104 | if self.target_transform is not None: 105 | target = self.target_transform(target) 106 | # print(img.shape, target.shape, feat.shape) 107 | # print(target) 108 | 109 | return img, target, feat 110 | 111 | def __len__(self): 112 | if self.train: 113 | return self.train_len 114 | else: 115 | return self.test_len -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NAS-Lung 2 | 3 | > **3D Neural Architecture Search (NAS) for Pulmonary Nodules Classification** 4 | > 5 | > Hanliang Jiang, Fuhao Shen, Fei Gao\*, Weidong Han. *Learning Efficient, Explainable and Discriminative Representations for Pulmonary Nodules Classification*. Pattern Recognition, 113: 107825, 2021. 6 | > 7 | > ```latex 8 | > @article{Jiang2021naslung, 9 | > author = {Hanliang Jiang and Fuhao Shen and Fei Gao and Weidong Han}, 10 | > title = {Learning efficient, explainable and discriminative representations for pulmonary nodules classification}, 11 | > journal = {Pattern Recognition}, 12 | > volume = {113}, 13 | > pages = {107825}, 14 | > year = {2021}, 15 | > issn = {0031-3203}, 16 | > doi = {https://doi.org/10.1016/j.patcog.2021.107825}, 17 | > } 18 | > ``` 19 | > 20 | > [[Paper@PR]](https://www.sciencedirect.com/science/article/pii/S0031320321000121) [[Paper@arxiv]](https://arxiv.org/abs/2101.07429) [[Code@Github]](https://github.com/fei-hdu/NAS-Lung/issues/3) 21 | 22 | 23 | 24 | ## Architecture 25 | 26 | ![Architecture](imgs/architecture.png) 27 | 28 | ## Results 29 | 30 | ### NASLung 31 | 32 | | model | Accu. | Sens. | Spec. | F1 Score | para.(M) | 33 | | ------------------- | ----- | ----- | ----- | -------- | -------- | 34 | | Multi-crop CNN | 87.14 | - | - | - | - | 35 | | Nodule-level 2D CNN | 87.30 | 88.50 | 86.00 | 87.23 | - | 36 | | Vanilla 3D CNN | 87.40 | 89.40 | 85.20 | 87.25 | - | 37 | | DeepLung | 90.44 | 81.42 | - | - | 141.57 | 38 | | AE-DPN | 90.24 | 92.04 | 88.94 | 90.45 | 678.69 | 39 | | | | | | | | 40 | | **NASLung (ours)** | 90.77 | 85.37 | 95.04 | 89.29 | 16.84 | 41 | 42 | ### Searched 3D Networks 43 | 44 | | Model | Accu. | Sens. | Spec. | F1 Score | para. | 45 | | -------- | ----- | ----- | ----- | -------- | ----- | 46 | | Model-1 | 88.83 | 87.20 | 90.12 | 87.50 | 0.14 | 47 | | Model-2 | 88.42 | 84.38 | 91.46 | 86.67 | 2.61 | 48 | | Model-3 | 88.17 | 84.44 | 91.60 | 86.50 | 3.90 | 49 | | Model-4 | 88.13 | 83.20 | 92.28 | 86.30 | 2.54 | 50 | | Model-5 | 87.97 | 83.72 | 91.31 | 86.22 | 0.43 | 51 | | Model-6 | 87.77 | 83.67 | 91.00 | 86.03 | 0.22 | 52 | | Model-7 | 87.76 | 84.14 | 89.79 | 85.98 | 0.86 | 53 | | Model-8 | 88.00 | 82.43 | 92.69 | 85.97 | 4.02 | 54 | | Model-9 | 88.04 | 78.01 | 96.09 | 85.36 | 4.06 | 55 | | Model-10 | 87.22 | 82.70 | 90.92 | 85.32 | 0.24 | 56 | 57 | ## Prerequisites 58 | 59 | - Linux or similar environment 60 | - Python 3.7 61 | - Pytorch 0.4.1 62 | - NVIDIA GPU + CUDA CuDNN 63 | 64 | ## Getting Started 65 | 66 | ### Installation 67 | 68 | - Clone this repo: 69 | 70 | ```shell script 71 | git clone https://github.com/fei-hdu/NAS-Lung 72 | cd NAS-Lung 73 | ``` 74 | 75 | - Install PyTorch 0.4+ and torchvision from [Pytorch](http://pytorch.org) and other dependencies (e.g., visdom and dominate). You can install all the dependencies by 76 | 77 | ```shell script 78 | pip install -r requirments.txt 79 | ``` 80 | 81 | - Download Dataset [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) 82 | 83 | ### Neural Architecture Search 84 | 85 | ```shell script 86 | python search_main.py --train_data_path {train_data_path} --test_data_path {test_data_path} --save_module_path {save_module_path} 87 | ``` 88 | 89 | ### Train/Test 90 | 91 | - Train a model 92 | 93 | ```shell script 94 | sh run_training.sh 95 | ``` 96 | 97 | - Test a model 98 | 99 | ```shell script 100 | python test.py --test_data_path {test_data_path} --preprocess_path {preprocess_path} --model_path {model_path} 101 | ``` 102 | 103 | ### DataSet 104 | 105 | - [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) 106 | 107 | ### Model Result 108 | 109 | - our final result can be download:[Google Drive](https://drive.google.com/drive/folders/1vUFi5tEfMcDcKqMbxuN3Tt44QwLcDZnA?usp=sharing) 110 | 111 | ### Training/Test Tips 112 | 113 | - Best practice for training and testing your models. 114 | - Feel free to ask any questions about **_coding_**. **Fuhao Shen, `1048532267sfh@gmail.com`** 115 | 116 | ## Acknowledgement 117 | 118 | - Our work/code is inspired by [Partial Order Pruning: for Best Speed/Accuracy Trade-off in Neural Architecture Search, CVPR 2019](https://github.com/lixincn2015/Partial-Order-Pruning). 119 | 120 | ## Selected References 121 | 122 | - S. Armato III, G. et al., Data from **LIDC-IDRI**, The Cancer Imaging . [LIDC-IDRI](http://doi.org/10.7937/K9/TCIA.2015.LO9QL9SX). 123 | - X. Li, Y. Zhou, Z. Pan, J. Feng, **Partial order pruning**: For best speed/accuracy trade-off in neural architecture search (2019) 9145–9153. 124 | - S. Woo, J. Park, J.-Y. Lee, I. So Kweon, **CBAM**: Convolutional block attention module, in: Proceedings of the European Conference on Computer Vision (ECCV), 2018, pp. 3–19. 125 | - W. Liu, Y. Wen, Z. Yu, M. Li, B. Raj, L. Song, **Sphereface**: Deep hypersphere embedding for face recognition, in: The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017. 126 | - T. Elsken, J. H. Metzen, F. Hutter, **Neural architecture search**: A survey, Journal of Machine Learning Research 20 (55) (2019) 1–21. 127 | - W. Zhu, C. Liu, W. Fan, X. Xie, **Deeplung**: Deep 3d dual path nets for automated pulmonary nodule detection and classification, in: 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), IEEE, 2018, pp. 673–681. 128 | -------------------------------------------------------------------------------- /models/cnn_res.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .net_sphere import * 7 | 8 | 9 | debug = False 10 | 11 | 12 | class ResCBAMLayer(nn.Module): 13 | def __init__(self, in_planes, feature_size): 14 | super(ResCBAMLayer, self).__init__() 15 | self.in_planes = in_planes 16 | self.feature_size = feature_size 17 | self.ch_AvgPool = nn.AvgPool3d(feature_size, feature_size) 18 | self.ch_MaxPool = nn.MaxPool3d(feature_size, feature_size) 19 | self.ch_Linear1 = nn.Linear(in_planes, in_planes // 4, bias=False) 20 | self.ch_Linear2 = nn.Linear(in_planes // 4, in_planes, bias=False) 21 | self.ch_Softmax = nn.Softmax(1) 22 | self.sp_Conv = nn.Conv3d(2, 1, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.sp_Softmax = nn.Softmax(1) 24 | self.sp_sigmoid = nn.Sigmoid() 25 | def forward(self, x): 26 | x_ch_avg_pool = self.ch_AvgPool(x).view(x.size(0), -1) 27 | x_ch_max_pool = self.ch_MaxPool(x).view(x.size(0), -1) 28 | # x_ch_avg_linear = self.ch_Linear2(self.ch_Linear1(x_ch_avg_pool)) 29 | a = self.ch_Linear1(x_ch_avg_pool) 30 | x_ch_avg_linear = self.ch_Linear2(a) 31 | 32 | x_ch_max_linear = self.ch_Linear2(self.ch_Linear1(x_ch_max_pool)) 33 | ch_out = (self.ch_Softmax(x_ch_avg_linear + x_ch_max_linear).view(x.size(0), self.in_planes, 1, 1, 1)) * x 34 | x_sp_max_pool = torch.max(ch_out, 1, keepdim=True)[0] 35 | x_sp_avg_pool = torch.sum(ch_out, 1, keepdim=True) / self.in_planes 36 | sp_conv1 = torch.cat([x_sp_max_pool, x_sp_avg_pool], dim=1) 37 | sp_out = self.sp_Conv(sp_conv1) 38 | sp_out = self.sp_sigmoid(sp_out.view(x.size(0), -1)).view(x.size(0), 1, x.size(2), x.size(3), x.size(4)) 39 | out = sp_out * x + x 40 | return out 41 | 42 | 43 | def make_conv3d(in_channels: int, out_channels: int, kernel_size: typing.Union[int, tuple], stride: int, 44 | padding: int, dilation=1, groups=1, 45 | bias=True) -> nn.Module: 46 | """ 47 | produce a Conv3D with Batch Normalization and ReLU 48 | 49 | :param in_channels: num of in in 50 | :param out_channels: num of out channels 51 | :param kernel_size: size of kernel int or tuple 52 | :param stride: num of stride 53 | :param padding: num of padding 54 | :param bias: bias 55 | :param groups: groups 56 | :param dilation: dilation 57 | :return: my conv3d module 58 | """ 59 | module = nn.Sequential( 60 | 61 | nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 62 | groups=groups, 63 | bias=bias), 64 | nn.BatchNorm3d(out_channels), 65 | nn.ReLU()) 66 | return module 67 | 68 | 69 | def conv3d_same_size(in_channels, out_channels, kernel_size, stride=1, 70 | dilation=1, groups=1, 71 | bias=True): 72 | padding = kernel_size // 2 73 | return make_conv3d(in_channels, out_channels, kernel_size, stride, 74 | padding, dilation, groups, 75 | bias) 76 | 77 | 78 | def conv3d_pooling(in_channels, kernel_size, stride=1, 79 | dilation=1, groups=1, 80 | bias=False): 81 | padding = kernel_size // 2 82 | return make_conv3d(in_channels, in_channels, kernel_size, stride, 83 | padding, dilation, groups, 84 | bias) 85 | 86 | 87 | class ResidualBlock(nn.Module): 88 | """ 89 | a simple residual block 90 | """ 91 | 92 | def __init__(self, in_channels, out_channels): 93 | super(ResidualBlock, self).__init__() 94 | self.my_conv1 = make_conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 95 | self.my_conv2 = make_conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 96 | self.conv3 = make_conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 97 | 98 | def forward(self, inputs): 99 | out1 = self.conv3(inputs) 100 | out = self.my_conv1(inputs) 101 | out = self.my_conv2(out) 102 | out = out + out1 103 | return out 104 | 105 | 106 | class ConvRes(nn.Module): 107 | def __init__(self, config): 108 | super(ConvRes, self).__init__() 109 | self.conv1 = conv3d_same_size(in_channels=1, out_channels=4, kernel_size=3) 110 | self.conv2 = conv3d_same_size(in_channels=4, out_channels=4, kernel_size=3) 111 | self.config = config 112 | self.last_channel = 4 113 | self.first_cbam = ResCBAMLayer(4, 32) 114 | layers = [] 115 | i = 0 116 | for stage in config: 117 | i = i+1 118 | layers.append(conv3d_pooling(self.last_channel, kernel_size=3, stride=2)) 119 | for channel in stage: 120 | layers.append(ResidualBlock(self.last_channel, channel)) 121 | self.last_channel = channel 122 | layers.append(ResCBAMLayer(self.last_channel, 32//(2**i))) 123 | self.layers = nn.Sequential(*layers) 124 | self.avg_pooling = nn.AvgPool3d(kernel_size=4, stride=4) 125 | self.fc = AngleLinear(in_features=self.last_channel, out_features=2) 126 | 127 | def forward(self, inputs): 128 | if debug: 129 | print(inputs.size()) 130 | out = self.conv1(inputs) 131 | if debug: 132 | print(out.size()) 133 | out = self.conv2(out) 134 | if debug: 135 | print(out.size()) 136 | out = self.first_cbam(out) 137 | out = self.layers(out) 138 | if debug: 139 | print(out.size()) 140 | out = self.avg_pooling(out) 141 | out = out.view(out.size(0), -1) 142 | if debug: 143 | print(out.size()) 144 | out = self.fc(out) 145 | return out 146 | 147 | 148 | def test(): 149 | global debug 150 | debug = True 151 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]]) 152 | inputs = torch.randn((1, 1, 32, 32, 32)) 153 | output = net(inputs) 154 | print(net.config) 155 | print(output) 156 | -------------------------------------------------------------------------------- /data/nodclsgbt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import mahotas 4 | from mahotas.features.lbp import lbp 5 | 6 | CROPSIZE = 32 # 24#30#36 7 | print(CROPSIZE) 8 | pdframe = pd.read_csv('annotationdetclsconvfnl_v3.csv', 9 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 10 | srslst = pdframe['seriesuid'].tolist()[1:] 11 | crdxlst = pdframe['coordX'].tolist()[1:] 12 | crdylst = pdframe['coordY'].tolist()[1:] 13 | crdzlst = pdframe['coordZ'].tolist()[1:] 14 | dimlst = pdframe['diameter_mm'].tolist()[1:] 15 | mlglst = pdframe['malignant'].tolist()[1:] 16 | 17 | newlst = [] 18 | import csv 19 | 20 | fid = open('annotationdetclsconvfnl_v3.csv', 'r') 21 | # writer = csv.writer(fid) 22 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 23 | for i in range(len(srslst)): 24 | # writer.writerow([srslst[i] + '-' + str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 25 | newlst.append([srslst[i] + '-' + str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 26 | fid.close() 27 | 28 | preprocesspath = '/media/jehovah/Work/data/LUNA/propocess/all/' 29 | savepath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/' 30 | import os 31 | import os.path 32 | 33 | if not os.path.exists(savepath): os.mkdir(savepath) 34 | for idx in range(len(newlst)): 35 | fname = newlst[idx][0] 36 | # if fname != '1.3.6.1.4.1.14519.5.2.1.6279.6001.119209873306155771318545953948-581': continue 37 | pid = fname.split('-')[0] 38 | crdx = int(float(newlst[idx][1])) 39 | crdy = int(float(newlst[idx][2])) 40 | crdz = int(float(newlst[idx][3])) 41 | dim = int(float(newlst[idx][4])) 42 | data = np.load(os.path.join(preprocesspath, pid + '_clean.npy')) 43 | bgx = max(0, crdx - CROPSIZE / 2) 44 | bgy = max(0, crdy - CROPSIZE / 2) 45 | bgz = max(0, crdz - CROPSIZE / 2) 46 | cropdata = np.ones((CROPSIZE, CROPSIZE, CROPSIZE)) * 170 47 | cropdatatmp = np.array(data[0, bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE]) 48 | cropdata[CROPSIZE / 2 - cropdatatmp.shape[0] / 2:CROPSIZE / 2 - cropdatatmp.shape[0] / 2 + cropdatatmp.shape[0], \ 49 | CROPSIZE / 2 - cropdatatmp.shape[1] / 2:CROPSIZE / 2 - cropdatatmp.shape[1] / 2 + cropdatatmp.shape[1], \ 50 | CROPSIZE / 2 - cropdatatmp.shape[2] / 2:CROPSIZE / 2 - cropdatatmp.shape[2] / 2 + cropdatatmp.shape[2]] = np.array( 51 | 2 - cropdatatmp) 52 | assert cropdata.shape[0] == CROPSIZE and cropdata.shape[1] == CROPSIZE and cropdata.shape[2] == CROPSIZE 53 | np.save(os.path.join(savepath, fname + '.npy'), cropdata) 54 | 55 | # train use gbt 56 | subset1path = '/media/jehovah/Work/data/LUNA/rowfile/subset1/' 57 | testfnamelst = [] 58 | for fname in os.listdir(subset1path): 59 | if fname.endswith('.mhd'): 60 | testfnamelst.append(fname[:-4]) 61 | ntest = 0 62 | for idx in range(len(newlst)): 63 | fname = newlst[idx][0] 64 | if fname.split('-')[0] in testfnamelst: ntest += 1 65 | print('ntest', ntest, 'ntrain', len(newlst) - ntest) 66 | 67 | traindata = np.zeros((len(newlst) - ntest, CROPSIZE * CROPSIZE * CROPSIZE)) 68 | trainlabel = np.zeros((len(newlst) - ntest,)) 69 | testdata = np.zeros((ntest, CROPSIZE * CROPSIZE * CROPSIZE)) 70 | testlabel = np.zeros((ntest,)) 71 | 72 | trainidx = testidx = 0 73 | for idx in range(len(newlst)): 74 | fname = newlst[idx][0] 75 | # print fname 76 | data = np.load(os.path.join(savepath, fname + '.npy')) 77 | # print data.shape 78 | bgx = data.shape[0] / 2 - CROPSIZE / 2 79 | bgy = data.shape[1] / 2 - CROPSIZE / 2 80 | bgz = data.shape[2] / 2 - CROPSIZE / 2 81 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE]) 82 | if fname.split('-')[0] in testfnamelst: 83 | testdata[testidx, :] = np.reshape(data, (-1,)) / 255 84 | # testdata[testidx, -4] = newlst[idx][1] 85 | # testdata[testidx, -3] = newlst[idx][2] 86 | # testdata[testidx, -2] = newlst[idx][3] 87 | # testdata[testidx, -1] = newlst[idx][4] 88 | testlabel[testidx] = newlst[idx][-1] 89 | testidx += 1 90 | else: 91 | traindata[trainidx, :] = np.reshape(data, (-1,)) / 255 92 | # traindata[trainidx, -4] = newlst[idx][1] 93 | # traindata[trainidx, -3] = newlst[idx][2] 94 | # traindata[trainidx, -2] = newlst[idx][3] 95 | # traindata[trainidx, -1] = newlst[idx][4] 96 | trainlabel[trainidx] = newlst[idx][-1] 97 | trainidx += 1 98 | maxtraindata1 = max(traindata[:, -1]) 99 | # traindata[:, -1] = np.array(traindata[:, -1] / maxtraindata1) 100 | # maxtraindata2 = max(traindata[:, -2]) 101 | # traindata[:, -2] = np.array(traindata[:, -2] / maxtraindata2) 102 | # maxtraindata3 = max(traindata[:, -3]) 103 | # traindata[:, -3] = np.array(traindata[:, -3] / maxtraindata3) 104 | # maxtraindata4 = max(traindata[:, -4]) 105 | # traindata[:, -4] = np.array(traindata[:, -4] / maxtraindata4) 106 | # testdata[:, -1] = np.array(testdata[:, -1] / maxtraindata1) 107 | # testdata[:, -2] = np.array(testdata[:, -2] / maxtraindata2) 108 | # testdata[:, -3] = np.array(testdata[:, -3] / maxtraindata3) 109 | # testdata[:, -4] = np.array(testdata[:, -4] / maxtraindata4) 110 | from sklearn.ensemble import GradientBoostingClassifier as gbt 111 | 112 | 113 | def gbtfunc(dep): 114 | m = gbt(max_depth=dep, random_state=0) 115 | m.fit(traindata, trainlabel) 116 | predtrain = m.predict(traindata) 117 | predtest = m.predict_proba(testdata) 118 | # print predtest.shape, predtest[1,:] 119 | return np.sum(predtrain == trainlabel) / float(traindata.shape[0]), \ 120 | np.mean((predtest[:, 1] > 0.5).astype(int) == testlabel), predtest # / float(testdata.shape[0]), 121 | 122 | 123 | # trainacc, testacc, predtest = gbtfunc(3) 124 | # print trainacc, testacc 125 | # np.save('pixradiustest.npy', predtest[:,1]) 126 | from multiprocessing import Pool 127 | 128 | p = Pool(30) 129 | acclst = p.map(gbtfunc, range(1, 9, 1)) # 3,4,1))#5,1))#1,9,1)) 130 | for acc in acclst: 131 | print("{0:.4f}".format(acc[0]), "{0:.4f}".format(acc[1])) 132 | p.close() 133 | # for dep in xrange(1,9,1): 134 | # m = gbt(max_depth=dep) 135 | # m.fit(traindata, trainlabel) 136 | # print dep, 'trainacc', np.sum(m.predict(traindata) == trainlabel) / float(traindata.shape[0]) 137 | # print dep, 'testacc', np.sum(m.predict(testdata) == testlabel) / float(testdata.shape[0]) 138 | -------------------------------------------------------------------------------- /searchspace/res_search_space.py: -------------------------------------------------------------------------------- 1 | """the width of a block to be no narrower than its preceding block""" 2 | 3 | import numpy as np 4 | from .search_space_utils import * 5 | import random 6 | from data_utils import * 7 | from models.cnn_res import ConvRes 8 | 9 | 10 | class ResSearchSpace: 11 | """ 12 | search class 13 | """ 14 | def __init__(self, channel_range, max_depth, min_depth, trained_data_path, test_data_path, fold, batch_size, 15 | logging, input_shape, use_gpu, gpu_id, criterion, lr, save_module_path, num_works, epoch): 16 | self.lr = lr 17 | self.epoch = epoch 18 | self.num_works = num_works 19 | self.fold = fold 20 | # self.sub = sub 21 | self.save_module_path = save_module_path 22 | self.max_depth = max_depth 23 | self.min_depth = min_depth 24 | self.criterion = criterion 25 | self.logging = logging 26 | self.input_shape = input_shape 27 | self.use_gpu = use_gpu 28 | self.gpu_id = gpu_id 29 | self.max_depth = max_depth 30 | self.channel_range = channel_range 31 | self.trained_module_acc_lat = np.empty((0, 3)) 32 | # self.trained_module_acc_lat = [module_config,acc,lat] 33 | self.pruned_module = [] 34 | # initialize all architecture set 35 | self.untrained_module = get_all_search_space(min_len=min_depth, max_len=max_depth, channel_range=channel_range) 36 | # load data set 37 | self.train_loader, self.test_loader = load_data(trained_data_path, test_data_path, self.fold, batch_size, 38 | self.num_works) 39 | self.trained_yw = [] 40 | self.save_module_path = save_module_path 41 | self.trained_yw_and_module = [] 42 | if not os.path.isdir(self.save_module_path): 43 | os.mkdir(self.save_module_path) 44 | 45 | def random_generate(self): 46 | """ 47 | get untrained model config 48 | 49 | :return: model config 50 | """ 51 | count = self.untrained_module.__len__() 52 | index = random.randint(0, count) 53 | return self.untrained_module[index] 54 | 55 | def main_method(self): 56 | """ 57 | search main method 58 | """ 59 | B = [] 60 | stable_time = 0 61 | repeat_time = 0 62 | while True: 63 | # get untrained model 64 | config = self.random_generate() 65 | # config = [[512, 512, 512,512,512], [512, 512, 512,512,512], [512, 512, 512, 512,512]] 66 | self.untrained_module.remove(config) 67 | # train model 68 | net = ConvRes(config) 69 | net_lat = get_module_lat(net, input_shape=self.input_shape) 70 | net = net_to_cuda(net, use_gpu=self.use_gpu, gpu_ids=self.gpu_id) 71 | optimizer = optim.Adam(net.parameters(), lr=self.lr, betas=(0.5, 0.999)) 72 | acc = get_acc(net, self.use_gpu, self.train_loader, self.test_loader, optimizer, self.criterion, 73 | self.logging, self.lr, config, self.epoch) 74 | print(f'module:{config}\nacc:{acc} lat:{net_lat}') 75 | self.logging.info(f'config:{config}\nacc:{acc} lat:{net_lat}') 76 | del net 77 | self.trained_module_acc_lat = np.append(self.trained_module_acc_lat, [[config, acc, net_lat]], axis=0) 78 | # prune model 79 | for module in self.trained_module_acc_lat: 80 | yw = get_yw(self.trained_module_acc_lat, module) 81 | if len(yw) != 0: 82 | module_config = module[0] 83 | yw_config = yw[0] 84 | yw_lat = yw[2] 85 | if [yw_config, module_config] not in self.trained_yw_and_module: 86 | print(f'yw:{yw_config}\nmodule:{module_config}') 87 | self.logging.info(f'yw:{yw_config}\nmodule:{module_config}') 88 | narrower_module = get_narrower_module(self.channel_range, module_config) 89 | print('found narrower_module:' + str(narrower_module.__len__())) 90 | self.logging.info( 91 | 'found narrower_module:' + str(narrower_module.__len__()) 92 | ) 93 | shallower_module = get_shallower_module(self.min_depth, [module_config], shallower_module=[]) 94 | print('found shallower_module:' + str(shallower_module.__len__())) 95 | self.logging.info( 96 | 'found shallower_module:' + str(shallower_module.__len__()) 97 | ) 98 | pruned_narrower_module = 0 99 | for i in narrower_module: 100 | if i in self.untrained_module: 101 | lat = get_latency(ConvRes(i), input_size=self.input_shape) 102 | if lat > yw_lat: 103 | self.pruned_module.append(i) 104 | self.untrained_module.remove(i) 105 | pruned_narrower_module = pruned_narrower_module + 1 106 | print('pruned_narrower_module:' + str(pruned_narrower_module)) 107 | self.logging.info( 108 | 'pruned_narrower_module:' + str(pruned_narrower_module) 109 | ) 110 | pruned_shallower_module = 0 111 | for i in shallower_module: 112 | if i in self.untrained_module: 113 | lat = get_latency(ConvRes(i), input_size=self.input_shape) 114 | if lat > yw_lat: 115 | pruned_shallower_module = pruned_shallower_module + 1 116 | self.pruned_module.append(i) 117 | self.untrained_module.remove(i) 118 | print('pruned_shallower_module:' + str(pruned_shallower_module)) 119 | self.logging.info( 120 | 'pruned_shallower_module:' + str(pruned_shallower_module) 121 | ) 122 | self.trained_yw_and_module.append([yw_config, module_config]) 123 | else: 124 | print(f'{module[0]}yw not found') 125 | 126 | B1 = get_excellent_module(self.trained_module_acc_lat) 127 | if repeat_time % 20 == 0: 128 | np.save(self.save_module_path + '/' + str(repeat_time) + 'trained', self.trained_module_acc_lat) 129 | np.save(self.save_module_path + '/' + str(repeat_time) + 'excellent', B1) 130 | if np.array_equal(B1, B): 131 | stable_time = stable_time + 1 132 | B = copy.deepcopy(B1) 133 | else: 134 | stable_time = 0 135 | B = copy.deepcopy(B1) 136 | if stable_time > 30: 137 | break 138 | repeat_time = repeat_time + 1 139 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import transforms as transforms 5 | import pandas as pd 6 | from dataloader import lunanod 7 | from torch.autograd import Variable 8 | from itertools import combinations, permutations 9 | import logging 10 | import pandas as pd 11 | import argparse 12 | 13 | 14 | def load_data(test_data_path, preprocess_path, fold, batch_size, num_workers): 15 | test_data_path = '/data/xxx/LUNA/rowfile/subset' 16 | crop_size = 32 17 | black_list = [] 18 | 19 | preprocess_path = '/data/xxx/LUNA/cls/crop_v3' 20 | pix_value, npix = 0, 0 21 | for file_name in os.listdir(preprocess_path): 22 | if file_name.endswith('.npy'): 23 | if file_name[:-4] in black_list: 24 | continue 25 | data = np.load(os.path.join(preprocess_path, file_name)) 26 | pix_value += np.sum(data) 27 | npix += np.prod(data.shape) 28 | pix_mean = pix_value / float(npix) 29 | pix_value = 0 30 | for file_name in os.listdir(preprocess_path): 31 | if file_name.endswith('.npy'): 32 | if file_name[:-4] in black_list: continue 33 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean 34 | pix_value += np.sum(data * data) 35 | pix_std = np.sqrt(pix_value / float(npix)) 36 | print(pix_mean, pix_std) 37 | transform_train = transforms.Compose([ 38 | # transforms.RandomScale(range(28, 38)), 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.RandomYFlip(), 42 | transforms.RandomZFlip(), 43 | transforms.ZeroOut(4), 44 | transforms.ToTensor(), 45 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func 46 | ]) 47 | 48 | transform_test = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((pix_mean), (pix_std)), 51 | ]) 52 | 53 | # load data list 54 | test_file_name_list = [] 55 | test_label_list = [] 56 | test_feat_list = [] 57 | 58 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv', 59 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 60 | 61 | all_list = data_frame['seriesuid'].tolist()[1:] 62 | label_list = data_frame['malignant'].tolist()[1:] 63 | crdx_list = data_frame['coordX'].tolist()[1:] 64 | crdy_list = data_frame['coordY'].tolist()[1:] 65 | crdz_list = data_frame['coordZ'].tolist()[1:] 66 | dim_list = data_frame['diameter_mm'].tolist()[1:] 67 | # test id 68 | test_id_list = [] 69 | for file_name in os.listdir(test_data_path + str(fold) + '/'): 70 | 71 | if file_name.endswith('.mhd'): 72 | test_id_list.append(file_name[:-4]) 73 | mxx = mxy = mxz = mxd = 0 74 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list): 75 | mxx = max(abs(float(x)), mxx) 76 | mxy = max(abs(float(y)), mxy) 77 | mxz = max(abs(float(z)), mxz) 78 | mxd = max(abs(float(d)), mxd) 79 | if srsid in black_list: 80 | continue 81 | # crop raw pixel as feature 82 | data = np.load(os.path.join(preprocess_path, srsid + '.npy')) 83 | bgx = int(data.shape[0] / 2 - crop_size / 2) 84 | bgy = int(data.shape[1] / 2 - crop_size / 2) 85 | bgz = int(data.shape[2] / 2 - crop_size / 2) 86 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size]) 87 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2] 88 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3 89 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float) 90 | feat[mask] = 1 91 | if srsid.split('-')[0] in test_id_list: 92 | test_file_name_list.append(srsid + '.npy') 93 | test_label_list.append(int(label)) 94 | test_feat_list.append(feat) 95 | for idx in range(len(test_feat_list)): 96 | test_feat_list[idx][-1] /= mxd 97 | 98 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False, 99 | download=True, 100 | transform=transform_test) 101 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) 102 | return test_loader 103 | 104 | 105 | def load_module(path): 106 | checkpoint = torch.load(path) 107 | net = checkpoint['net'] 108 | net.cuda() 109 | return net 110 | 111 | 112 | def get_targets(test_loader): 113 | target_list = np.empty(shape=0) 114 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader): 115 | target_list = np.append(target_list, targets) 116 | target_list = target_list.astype(int) 117 | return target_list 118 | 119 | 120 | def get_permutations(model_list, count, top_count): 121 | result = [] 122 | for i in permutations(model_list, count): 123 | result.append(list(i)) 124 | if result.__len__() >= top_count: 125 | return result 126 | return result 127 | 128 | 129 | def test_module(module, test_loader): 130 | module.eval() 131 | total = 0 132 | correct = 0 133 | TP = 0 134 | TN = 0 135 | FN = 0 136 | FP = 0 137 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader): 138 | inputs, targets = inputs.cuda(), targets.cuda() 139 | 140 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets) 141 | outputs = module(inputs) 142 | prediction = 0 143 | if not isinstance(outputs, tuple): 144 | _, prediction = torch.max(outputs.data, 1) 145 | # print(prediction.shape) 146 | # print('1') 147 | else: 148 | _, prediction = torch.max(outputs[0].data, 1) 149 | # print('2') 150 | TP += ((prediction == 1) & (targets.data == 1)).cpu().sum() 151 | TN += ((prediction == 0) & (targets.data == 0)).cpu().sum() 152 | FN += ((prediction == 0) & (targets.data == 1)).cpu().sum() 153 | FP += ((prediction == 1) & (targets.data == 0)).cpu().sum() 154 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item()) 155 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item()) 156 | acc = (TP.data.item()+TN.data.item()) / (TP.data.item()+TN.data.item()+FN.data.item()+FP.data.item()) 157 | print(f'acc:{acc}') 158 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr)) 159 | 160 | 161 | def get_predicted(result_array): 162 | positive_array = result_array == 1 163 | negative_array = result_array == 0 164 | positive_count = np.sum(positive_array, axis=0) 165 | negative_count = np.sum(negative_array, axis=0) 166 | predicted = positive_count > negative_count 167 | return predicted.astype(int) 168 | 169 | 170 | parser = argparse.ArgumentParser(description='test') 171 | parser.add_argument('--model_path', type=str, 172 | default='/data/fuhao/PartialOrderPrunning/[4,4,[4, 8, 16, 16, 16], [32, 128], [128]]/checkpoint-5/ckpt.t7', 173 | help='ckpt.t7') 174 | parser.add_argument('--fold', type=int, default=5, help='1-5') 175 | parser.add_argument('--batch_size', type=int, default=8) 176 | parser.add_argument('--num_workers', type=int, default=24) 177 | parser.add_argument('--test_data_path', type=str, default='/data/xxx/LUNA/rowfile/subset') 178 | parser.add_argument('--preprocess_path', type=str, default='/data/xxx/LUNA/cls/crop_v3') 179 | args = parser.parse_args() 180 | 181 | if __name__ == '__main__': 182 | fold = args.fold 183 | batch_size = args.batch_size 184 | num_workers = args.num_workers 185 | test_data_path = args.test_data_path 186 | preprocess_path = args.preprocess_path 187 | net = load_module(args.model_path) 188 | test_data_loader = load_data(test_data_path, preprocess_path, fold, batch_size, num_workers) 189 | test_module(net, test_data_loader) 190 | -------------------------------------------------------------------------------- /searchspace/search_space_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import time 4 | import torch 5 | import copy 6 | import itertools 7 | 8 | 9 | def get_all_search_space(min_len, max_len, channel_range): 10 | """ 11 | get all configs of model 12 | 13 | :param min_len: min of the depth of model 14 | :param max_len: max of the depth of model 15 | :param channel_range: list, the range of channel 16 | :return: all search space 17 | """ 18 | all_search_space = [] 19 | # get all model config with max length 20 | max_array = get_search_space(max_len, channel_range) 21 | max_array = np.array(max_array) 22 | for i in range(min_len, max_len+1): 23 | new_array = max_array[:, :i] 24 | repeat_list = new_array.tolist() 25 | # remove repeated list from lists 26 | new_list = remove_repeated_element(repeat_list) 27 | for list in new_list: 28 | for first_split in range(1, i -1): 29 | for second_split in range(first_split + 1, i): 30 | # split list 31 | all_search_space.append( 32 | [list[:first_split], list[first_split:second_split], list[second_split:]]) 33 | return all_search_space 34 | 35 | 36 | def get_limited_search_space(min_len, max_len, channel_range): 37 | """ 38 | get all limited configs of model, 39 | the depth of a stage between [model_depth//4,model_depth//2] 40 | 41 | :param min_len: min of the depth of model 42 | :param max_len: max of the depth of model 43 | :param channel_range: list, the range of channel 44 | :return: all search space 45 | """ 46 | max_len = max_len+1 47 | all_search_space = [] 48 | # get all model config with max length 49 | max_array = get_search_space(max_len, channel_range) 50 | max_array = np.array(max_array) 51 | for i in range(min_len, max_len): 52 | new_array = max_array[:, :i] 53 | repeat_list = new_array.tolist() 54 | # remove repeated list from lists 55 | new_list = remove_repeated_element(repeat_list) 56 | for list in new_list: 57 | # limit [model_depth//4,model_depth//2] 58 | for first_split in range(i // 4, i - i // 2 + 1): 59 | for second_split in range(first_split + i // 4, i - i // 4 + 1): 60 | all_search_space.append( 61 | [list[:first_split], list[first_split:second_split], list[second_split:]]) 62 | return all_search_space 63 | 64 | 65 | def get_search_space(max_len, channel_range, search_space=[], now=0): 66 | """ 67 | Recursive. 68 | Get all configuration combinations 69 | 70 | :param max_len: max of the depth of model 71 | :param channel_range: list, the range of channel 72 | :param search_space: search space 73 | :param now: depth of model 74 | :return: 75 | """ 76 | result = [] 77 | if now == 0: 78 | for i in channel_range: 79 | result.append([i]) 80 | else: 81 | for i in search_space: 82 | larger_channel = get_larger_channel(channel_range, i[-1]) 83 | for m in larger_channel: 84 | tmp = i.copy() 85 | tmp.append(m) 86 | result.append(tmp) 87 | now = now + 1 88 | if now < max_len: 89 | return get_search_space(max_len, channel_range, search_space=result, now=now) 90 | else: 91 | return result 92 | 93 | 94 | def get_larger_channel(channel_range, channel_num): 95 | """ 96 | get channels which is larger than inputs 97 | 98 | :param channel_range: list,channel range 99 | :param channel_num: input channel 100 | :return: list,channels which is larger than inputs 101 | """ 102 | result = filter(lambda x: x >= channel_num, channel_range) 103 | return list(result) 104 | 105 | 106 | def get_smaller_channel(channel, channel_range): 107 | """ 108 | get channels which is smaller than inputs 109 | 110 | :param channel:input channel 111 | :param channel_range:list,channel range 112 | :return:list,channels which is larger than inputs 113 | """ 114 | 115 | return list(filter(lambda x: x < channel, channel_range)) 116 | 117 | 118 | def get_shallower_module(min_len, module_config, shallower_module=[]): 119 | """ 120 | get module config which is shallower than module_config 121 | 122 | :param min_len: min depth of model 123 | :param module_config: input module config 124 | :param shallower_module: 125 | :return: list,module config which is shallower than module_config 126 | """ 127 | new_module_config = [] 128 | for config in module_config: 129 | for m in range(len(config)): 130 | if type(config[m]) is not int: 131 | if len(config[m]) > 1: 132 | for n in range(len(config[m])): 133 | tmp = copy.deepcopy(config) 134 | del tmp[m][n] 135 | new_module_config.append(tmp) 136 | new_module_config = remove_repeated_element(new_module_config) 137 | shallower_module.extend(new_module_config) 138 | # sum(len(x) for x in a):get the count of all element 139 | if shallower_module != []: 140 | count = sum(len(x) for x in shallower_module[-1]) 141 | if count > min_len: 142 | return get_shallower_module(min_len, new_module_config, shallower_module) 143 | else: 144 | return shallower_module 145 | else: 146 | return [] 147 | 148 | 149 | def remove_repeated_element(repeated_list): 150 | """ 151 | Remove duplicate elements 152 | 153 | :param repeated_list: input list 154 | :return: List without duplicate elements 155 | """ 156 | repeated_list.sort() 157 | new_list = [repeated_list[k] for k in range(len(repeated_list)) if 158 | k == 0 or repeated_list[k] != repeated_list[k - 1]] 159 | return new_list 160 | 161 | 162 | def get_element_count(the_list): 163 | """ 164 | get depth of model 165 | 166 | :param the_list: input model config 167 | :return: depth of model 168 | """ 169 | count = sum(len(x) for x in the_list) 170 | return count 171 | 172 | 173 | def flat_list(the_list): 174 | """ 175 | flatten list 176 | 177 | :param the_list: 178 | :return: flatten list 179 | """ 180 | return [item for sublist in the_list for item in sublist] 181 | 182 | 183 | def get_narrower_module(channel_range, module_config): 184 | """ 185 | get module config which is narrower than module_config 186 | 187 | :param channel_range: channel range 188 | :param module_config: input model config 189 | :return: list,module config which is narrower than module_config 190 | """ 191 | len_list = [] 192 | for i in module_config: 193 | len_list.append(len(i)) 194 | count = get_element_count(module_config) 195 | config_list = get_search_space(count, channel_range) 196 | config_array = np.array(config_list) 197 | module_config_array = np.array(flat_list(module_config)) 198 | equal_module_config_array = config_array <= module_config_array 199 | equal_module_config_array = np.prod(equal_module_config_array, 1) 200 | index = np.where(equal_module_config_array == 1) 201 | narrower_config = config_array[index[0]] 202 | narrower_config = narrower_config.tolist() 203 | result = [] 204 | for i in narrower_config: 205 | result.append([i[:len_list[0]], i[len_list[0]:len_list[1] + len_list[0]], i[len_list[1] + len_list[0]:]]) 206 | return remove_repeated_element(result) 207 | 208 | 209 | def get_latency(module, input_size): 210 | """ 211 | get the latency of module 212 | 213 | :param module: 214 | :param input_size: 215 | :return: latency 216 | """ 217 | module_input = torch.randn(input_size) 218 | start = time.time() 219 | output = module(module_input) 220 | end = time.time() 221 | return end - start 222 | 223 | 224 | def get_excellent_module(trained_module): 225 | """ 226 | get model with less latency and higher acc 227 | 228 | :param trained_module: trained module list 229 | :return: excellent module 230 | """ 231 | excellent_module = np.empty(shape=(0, 3)) 232 | acc_and_lat = trained_module[:, 1:] 233 | for module in trained_module: 234 | tmp = copy.deepcopy(acc_and_lat) 235 | tmp[:, 0] = acc_and_lat[:, 0] <= module[1] 236 | tmp[:, 1] = acc_and_lat[:, 1] >= module[2] 237 | tmp = np.sum(tmp, axis=1) 238 | if 0 not in tmp: 239 | excellent_module = np.append(excellent_module, [module], axis=0) 240 | return excellent_module 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /random_forest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import transforms as transforms 5 | import pandas as pd 6 | from dataloader import lunanod 7 | from torch.autograd import Variable 8 | from itertools import combinations, permutations 9 | import logging 10 | import pandas as pd 11 | 12 | 13 | def load_data(fold, batch_size, num_workers): 14 | test_data_path = '/data/xxx/LUNA/rowfile/subset' 15 | crop_size = 32 16 | black_list = [] 17 | 18 | preprocess_path = '/data/xxx/LUNA/cls/crop_v3' 19 | pix_value, npix = 0, 0 20 | for file_name in os.listdir(preprocess_path): 21 | if file_name.endswith('.npy'): 22 | if file_name[:-4] in black_list: 23 | continue 24 | data = np.load(os.path.join(preprocess_path, file_name)) 25 | pix_value += np.sum(data) 26 | npix += np.prod(data.shape) 27 | pix_mean = pix_value / float(npix) 28 | pix_value = 0 29 | for file_name in os.listdir(preprocess_path): 30 | if file_name.endswith('.npy'): 31 | if file_name[:-4] in black_list: continue 32 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean 33 | pix_value += np.sum(data * data) 34 | pix_std = np.sqrt(pix_value / float(npix)) 35 | print(pix_mean, pix_std) 36 | transform_train = transforms.Compose([ 37 | # transforms.RandomScale(range(28, 38)), 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.RandomYFlip(), 41 | transforms.RandomZFlip(), 42 | transforms.ZeroOut(4), 43 | transforms.ToTensor(), 44 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func 45 | ]) 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((pix_mean), (pix_std)), 50 | ]) 51 | 52 | # load data list 53 | test_file_name_list = [] 54 | test_label_list = [] 55 | test_feat_list = [] 56 | 57 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv', 58 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 59 | 60 | all_list = data_frame['seriesuid'].tolist()[1:] 61 | label_list = data_frame['malignant'].tolist()[1:] 62 | crdx_list = data_frame['coordX'].tolist()[1:] 63 | crdy_list = data_frame['coordY'].tolist()[1:] 64 | crdz_list = data_frame['coordZ'].tolist()[1:] 65 | dim_list = data_frame['diameter_mm'].tolist()[1:] 66 | # test id 67 | test_id_list = [] 68 | for file_name in os.listdir(test_data_path + str(fold) + '/'): 69 | 70 | if file_name.endswith('.mhd'): 71 | test_id_list.append(file_name[:-4]) 72 | mxx = mxy = mxz = mxd = 0 73 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list): 74 | mxx = max(abs(float(x)), mxx) 75 | mxy = max(abs(float(y)), mxy) 76 | mxz = max(abs(float(z)), mxz) 77 | mxd = max(abs(float(d)), mxd) 78 | if srsid in black_list: 79 | continue 80 | # crop raw pixel as feature 81 | data = np.load(os.path.join(preprocess_path, srsid + '.npy')) 82 | bgx = int(data.shape[0] / 2 - crop_size / 2) 83 | bgy = int(data.shape[1] / 2 - crop_size / 2) 84 | bgz = int(data.shape[2] / 2 - crop_size / 2) 85 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size]) 86 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2] 87 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3 88 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float) 89 | feat[mask] = 1 90 | if srsid.split('-')[0] in test_id_list: 91 | test_file_name_list.append(srsid + '.npy') 92 | test_label_list.append(int(label)) 93 | test_feat_list.append(feat) 94 | for idx in range(len(test_feat_list)): 95 | test_feat_list[idx][-1] /= mxd 96 | 97 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False, 98 | download=True, 99 | transform=transform_test) 100 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) 101 | return test_loader 102 | 103 | 104 | def load_module(module_config, set_num): 105 | path = f'/data/fuhao/PartialOrderPrunning/{module_config}/checkpoint-{set_num}/ckpt.t7' 106 | checkpoint = torch.load(path) 107 | net = checkpoint['net'] 108 | net.cuda() 109 | return net 110 | 111 | 112 | def get_targets(test_loader): 113 | target_list = np.empty(shape=0) 114 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader): 115 | target_list = np.append(target_list, targets) 116 | target_list = target_list.astype(int) 117 | return target_list 118 | 119 | 120 | def get_permutations(model_list, count, top_count): 121 | result = [] 122 | for i in permutations(model_list, count): 123 | result.append(list(i)) 124 | if result.__len__() >= top_count: 125 | return result 126 | return result 127 | 128 | 129 | def test_module(module_config, set_num, test_loader): 130 | module = load_module(module_config, set_num) 131 | module.eval() 132 | result = np.empty(shape=0) 133 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader): 134 | inputs, targets = inputs.cuda(), targets.cuda() 135 | 136 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets) 137 | outputs = module(inputs) 138 | if not isinstance(outputs, tuple): 139 | _, predicted = torch.max(outputs.data, 1) 140 | else: 141 | _, predicted = torch.max(outputs[0].data, 1) 142 | result = np.append(result, predicted) 143 | return result 144 | 145 | 146 | def get_predicted(result_array): 147 | positive_array = result_array == 1 148 | negative_array = result_array == 0 149 | positive_count = np.sum(positive_array, axis=0) 150 | negative_count = np.sum(negative_array, axis=0) 151 | predicted = positive_count > negative_count 152 | return predicted.astype(int) 153 | 154 | 155 | if __name__ == '__main__': 156 | run_result = np.empty(shape=(0, 20)) 157 | top_count = 20 158 | module_list = np.load('data/model.npy') 159 | module_list = list(filter(lambda x: '[32,64,[' in x, module_list)) 160 | logging.basicConfig(filename='modelfusion_huge_log', level=logging.INFO) 161 | save_excel = 'modelfusion_huge' 162 | for i in range(3, 20): 163 | if i % 2 == 1: 164 | permutations_result = get_permutations(module_list[:i + 4], i, top_count) 165 | num = 0 166 | for modules in permutations_result: 167 | num += 1 168 | logging.info(f'model_count={i}') 169 | print(f'model_count={i}') 170 | logging.info(f'num:{num}') 171 | print(f'num:{num}') 172 | logging.info(modules) 173 | print(modules) 174 | line = [] 175 | for fold in range(6): 176 | test_loader = load_data(fold, 8, 20) 177 | targets = get_targets(test_loader) 178 | length = targets.shape[0] 179 | all_result = np.empty(shape=(0, length)) 180 | for module_config in modules: 181 | result = test_module(module_config, fold, test_loader) 182 | all_result = np.append(all_result, [result], axis=0) 183 | predicted = get_predicted(all_result) 184 | TP = np.sum((predicted == 1) & (targets == 1)) 185 | TN = np.sum((predicted == 0) & (targets == 0)) 186 | FN = np.sum((predicted == 0) & (targets == 1)) 187 | FP = np.sum((predicted == 1) & (targets == 0)) 188 | tpr = 100. * TP / (TP + FN) 189 | fpr = 100. * FP / (FP + TN) 190 | acc = 100. * np.sum(predicted == targets) / length 191 | line.append(acc) 192 | line.append(tpr) 193 | line.append(fpr) 194 | logging.info(f'set={fold}') 195 | print(f'set={fold}') 196 | logging.info(f'acc={acc}') 197 | print(f'acc={acc}') 198 | logging.info(f'tpr={tpr} fpr={fpr}') 199 | print(f'tpr={tpr} fpr={fpr}') 200 | run_result = np.append(run_result, np.array(line)) 201 | np.save('run_result_huge', run_result) 202 | df = pd.DataFrame(data=run_result, 203 | columns=['module_count', 'module_config', 204 | 'fold-0-acc', 'fold-0-tpr', 'fold-0-fpr', 205 | 'fold-1-acc', 'fold-1-tpr', 'fold-1-fpr', 206 | 'fold-2-acc', 'fold-2-tpr', 'fold-2-fpr', 207 | 'fold-3-acc', 'fold-3-tpr', 'fold-3-fpr', 208 | 'fold-4-acc', 'fold-4-tpr', 'fold-4-fpr', 209 | 'fold-5-acc', 'fold-5-tpr', 'fold-5-fpr'], 210 | index=None) 211 | df.to_excel(save_excel) 212 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import pandas as pd 9 | import transforms as transforms 10 | from dataloader import lunanod 11 | import os 12 | import argparse 13 | import time 14 | from models.cnn_res import * 15 | # from utils import progress_bar 16 | from torch.autograd import Variable 17 | import logging 18 | import numpy as np 19 | import copy 20 | 21 | 22 | def load_data(trained_data_path, test_data_path, fold, batch_size, num_workers): 23 | crop_size = 32 24 | black_list = [] 25 | 26 | preprocess_path = trained_data_path 27 | pix_value, npix = 0, 0 28 | for file_name in os.listdir(preprocess_path): 29 | if file_name.endswith('.npy'): 30 | if file_name[:-4] in black_list: 31 | continue 32 | data = np.load(os.path.join(preprocess_path, file_name)) 33 | pix_value += np.sum(data) 34 | npix += np.prod(data.shape) 35 | pix_mean = pix_value / float(npix) 36 | pix_value = 0 37 | for file_name in os.listdir(preprocess_path): 38 | if file_name.endswith('.npy'): 39 | if file_name[:-4] in black_list: continue 40 | data = np.load(os.path.join(preprocess_path, file_name)) - pix_mean 41 | pix_value += np.sum(data * data) 42 | pix_std = np.sqrt(pix_value / float(npix)) 43 | print(pix_mean, pix_std) 44 | transform_train = transforms.Compose([ 45 | # transforms.RandomScale(range(28, 38)), 46 | transforms.RandomCrop(32, padding=4), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.RandomYFlip(), 49 | transforms.RandomZFlip(), 50 | transforms.ZeroOut(4), 51 | transforms.ToTensor(), 52 | transforms.Normalize((pix_mean), (pix_std)), # need to cal mean and std, revise norm func 53 | ]) 54 | 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize((pix_mean), (pix_std)), 58 | ]) 59 | 60 | # load data list 61 | train_file_name_list = [] 62 | train_label_list = [] 63 | train_feat_list = [] 64 | test_file_name_list = [] 65 | test_label_list = [] 66 | test_feat_list = [] 67 | 68 | data_frame = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv', 69 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 70 | 71 | all_list = data_frame['seriesuid'].tolist()[1:] 72 | label_list = data_frame['malignant'].tolist()[1:] 73 | crdx_list = data_frame['coordX'].tolist()[1:] 74 | crdy_list = data_frame['coordY'].tolist()[1:] 75 | crdz_list = data_frame['coordZ'].tolist()[1:] 76 | dim_list = data_frame['diameter_mm'].tolist()[1:] 77 | # test id 78 | test_id_list = [] 79 | for file_name in os.listdir(test_data_path + str(fold) + '/'): 80 | 81 | if file_name.endswith('.mhd'): 82 | test_id_list.append(file_name[:-4]) 83 | mxx = mxy = mxz = mxd = 0 84 | for srsid, label, x, y, z, d in zip(all_list, label_list, crdx_list, crdy_list, crdz_list, dim_list): 85 | mxx = max(abs(float(x)), mxx) 86 | mxy = max(abs(float(y)), mxy) 87 | mxz = max(abs(float(z)), mxz) 88 | mxd = max(abs(float(d)), mxd) 89 | if srsid in black_list: 90 | continue 91 | # crop raw pixel as feature 92 | data = np.load(os.path.join(preprocess_path, srsid + '.npy')) 93 | bgx = int(data.shape[0] / 2 - crop_size / 2) 94 | bgy = int(data.shape[1] / 2 - crop_size / 2) 95 | bgz = int(data.shape[2] / 2 - crop_size / 2) 96 | data = np.array(data[bgx:bgx + crop_size, bgy:bgy + crop_size, bgz:bgz + crop_size]) 97 | y, x, z = np.ogrid[-crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2, -crop_size / 2:crop_size / 2] 98 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3 99 | feat = np.zeros((crop_size, crop_size, crop_size), dtype=float) 100 | feat[mask] = 1 101 | if srsid.split('-')[0] in test_id_list: 102 | test_file_name_list.append(srsid + '.npy') 103 | test_label_list.append(int(label)) 104 | test_feat_list.append(feat) 105 | else: 106 | train_file_name_list.append(srsid + '.npy') 107 | train_label_list.append(int(label)) 108 | train_feat_list.append(feat) 109 | for idx in range(len(train_feat_list)): 110 | train_feat_list[idx][-1] /= mxd 111 | for idx in range(len(test_feat_list)): 112 | test_feat_list[idx][-1] /= mxd 113 | train_set = lunanod(preprocess_path, train_file_name_list, train_label_list, train_feat_list, train=True, 114 | download=True, 115 | transform=transform_train) 116 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) 117 | 118 | test_set = lunanod(preprocess_path, test_file_name_list, test_label_list, test_feat_list, train=False, 119 | download=True, 120 | transform=transform_test) 121 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) 122 | return train_loader, test_loader 123 | 124 | 125 | def train_module(net, use_cuda, train_loader, optimizer, criterion, log, lr, config, epoch): 126 | net.train() 127 | 128 | for i in range(epoch): 129 | correct = 0 130 | total = 0 131 | for batch_idx, (inputs, targets, feat) in enumerate(train_loader): 132 | if use_cuda: 133 | inputs, targets = inputs.cuda(), targets.cuda() 134 | 135 | optimizer.zero_grad() 136 | inputs, targets = Variable(inputs), Variable(targets) 137 | outputs = net(inputs) 138 | loss = criterion(outputs, targets) 139 | 140 | loss.backward() 141 | optimizer.step() 142 | _, predicted = torch.max(outputs.data, 1) 143 | total += targets.size(0) 144 | correct += predicted.eq(targets.data).cpu().sum() 145 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 146 | print( 147 | 'ep ' + str(i) + str(config) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 148 | 149 | log.info( 150 | 'ep ' + str(i) + str(config) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 151 | return net 152 | 153 | 154 | def my_test_module(net, use_cuda, test_loader, criterion, log): 155 | epoch_start_time = time.time() 156 | # global best_acc 157 | # global best_acc_gbt 158 | net.eval() 159 | test_loss = 0 160 | correct = 0 161 | total = 0 162 | TP = FP = FN = TN = 0 163 | for batch_idx, (inputs, targets, feat) in enumerate(test_loader): 164 | if use_cuda: 165 | inputs, targets = inputs.cuda(), targets.cuda() 166 | 167 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets) 168 | outputs = net(inputs) 169 | 170 | loss = criterion(outputs, targets) 171 | test_loss += loss.data.item() 172 | _, predicted = torch.max(outputs.data, 1) 173 | total += targets.size(0) 174 | correct += predicted.eq(targets.data).cpu().sum() 175 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum() 176 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum() 177 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum() 178 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum() 179 | 180 | acc = 100. * correct.data.item() / total 181 | 182 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item()) 183 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item()) 184 | 185 | print('teacc ' + str(acc)) 186 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr)) 187 | print('Time Taken: %d sec' % (time.time() - epoch_start_time)) 188 | log.info( 189 | 'teacc ' + str(acc)) 190 | log.info( 191 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr)) 192 | return acc 193 | 194 | 195 | def get_acc(net, use_cuda, train_loader, test_loader, optimizer, criterion, log, lr, config, epoch): 196 | net = train_module(net, use_cuda, train_loader, optimizer, criterion, log, lr, config, epoch) 197 | acc = my_test_module(net, use_cuda, test_loader, criterion, log) 198 | return acc 199 | 200 | 201 | def net_to_cuda(net, use_gpu, gpu_ids): 202 | if use_gpu: 203 | net.cuda() 204 | if gpu_ids == 'all': 205 | device_ids = range(torch.cuda.device_count()) 206 | else: 207 | device_ids = list(map(int, list(filter(str.isdigit, gpu_ids)))) 208 | 209 | print('gpu use' + str(device_ids)) 210 | net = torch.nn.DataParallel(net, device_ids=device_ids) 211 | return net 212 | 213 | 214 | def get_module_lat(module, input_shape): 215 | x = torch.randn(input_shape) 216 | start = time.time() 217 | y = module(x) 218 | print(y) 219 | end = time.time() 220 | return end - start 221 | 222 | 223 | def get_yw(modules, module): 224 | module_acc = module[1] 225 | tmp_modules = copy.deepcopy(modules) 226 | # original_module_index = np.where(tmp_modules[:, 0] == [module[0]])[0] 227 | # tmp_modules = np.delete(tmp_modules, original_module_index, 0) 228 | tmp_modules = tmp_modules.tolist() 229 | tmp_modules.remove(module.tolist()) 230 | tmp_modules = np.array(tmp_modules) 231 | if tmp_modules.size > 0: 232 | modules_acc = tmp_modules[:, 1] 233 | tmp = np.where(modules_acc >= module_acc)[0] 234 | better_modules = tmp_modules[tmp] 235 | if better_modules.size > 0: 236 | min_lat_index = np.argmin(better_modules[:, 2]) 237 | return better_modules[min_lat_index].tolist() 238 | return [] 239 | 240 | 241 | # a = np.array([[[[1, 2, 3], [12, 3], [1, 2]], 0.2, 0.3], [[[1, 2, 3], [12, 3], [1, 2]], 0.3, 0.3]]) 242 | # # a = np.array([[[[1, 2, 3], [12, 3], [1, 2]], 0.2, 0.3]]) 243 | # get_yw(a, a[1]) 244 | # # for i in a: 245 | # # print(i) 246 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import transforms as transforms 11 | from dataloader import lunanod 12 | import os 13 | import argparse 14 | import time 15 | from models.cnn_res import * 16 | # from utils import progress_bar 17 | from torch.autograd import Variable 18 | import logging 19 | import numpy as np 20 | import ast 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 23 | parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') 24 | parser.add_argument('--batch_size', default=1, type=int, help='batch size ') 25 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 26 | parser.add_argument('--savemodel', type=str, default='', help='resume from checkpoint model') 27 | parser.add_argument("--gpuids", type=str, default='all', help='use which gpu') 28 | 29 | parser.add_argument('--num_epochs', type=int, default=700) 30 | parser.add_argument('--num_epochs_decay', type=int, default=70) 31 | 32 | parser.add_argument('--num_workers', type=int, default=24) 33 | 34 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 35 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 36 | parser.add_argument('--lamb', type=float, default=1, help="lambda for loss2") 37 | parser.add_argument('--fold', type=int, default=5, help="fold") 38 | 39 | args = parser.parse_args() 40 | 41 | CROPSIZE = 32 42 | gbtdepth = 1 43 | fold = args.fold 44 | blklst = [] # ['1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-388', \ 45 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-389', \ 46 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.132817748896065918417924920957-660'] 47 | logging.basicConfig(filename='log-' + str(fold), level=logging.INFO) 48 | 49 | use_cuda = torch.cuda.is_available() 50 | best_acc = 0 # best test accuracy 51 | best_acc_gbt = 0 52 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 53 | # Cal mean std 54 | # preprocesspath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/' 55 | preprocesspath = '/data/xxx/LUNA/cls/crop_v3/' 56 | # preprocesspath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/' 57 | pixvlu, npix = 0, 0 58 | for fname in os.listdir(preprocesspath): 59 | # print(fname) 60 | if fname.endswith('.npy'): 61 | if fname[:-4] in blklst: continue 62 | data = np.load(os.path.join(preprocesspath, fname)) 63 | pixvlu += np.sum(data) 64 | # print("data.shape = " + str(data.shape)) 65 | npix += np.prod(data.shape) 66 | pixmean = pixvlu / float(npix) 67 | pixvlu = 0 68 | for fname in os.listdir(preprocesspath): 69 | if fname.endswith('.npy'): 70 | if fname[:-4] in blklst: continue 71 | data = np.load(os.path.join(preprocesspath, fname)) - pixmean 72 | pixvlu += np.sum(data * data) 73 | pixstd = np.sqrt(pixvlu / float(npix)) 74 | # pixstd /= 255 75 | print(pixmean, pixstd) 76 | logging.info('mean ' + str(pixmean) + ' std ' + str(pixstd)) 77 | # Datatransforms 78 | logging.info('==> Preparing data..') # Random Crop, Zero out, x z flip, scale, 79 | transform_train = transforms.Compose([ 80 | # transforms.RandomScale(range(28, 38)), 81 | transforms.RandomCrop(32, padding=4), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.RandomYFlip(), 84 | transforms.RandomZFlip(), 85 | transforms.ZeroOut(4), 86 | transforms.ToTensor(), 87 | transforms.Normalize((pixmean), (pixstd)), # need to cal mean and std, revise norm func 88 | ]) 89 | 90 | transform_test = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((pixmean), (pixstd)), 93 | ]) 94 | 95 | # load data list 96 | trfnamelst = [] 97 | trlabellst = [] 98 | trfeatlst = [] 99 | tefnamelst = [] 100 | telabellst = [] 101 | tefeatlst = [] 102 | import pandas as pd 103 | 104 | dataframe = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv', 105 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 106 | 107 | alllst = dataframe['seriesuid'].tolist()[1:] 108 | labellst = dataframe['malignant'].tolist()[1:] 109 | crdxlst = dataframe['coordX'].tolist()[1:] 110 | crdylst = dataframe['coordY'].tolist()[1:] 111 | crdzlst = dataframe['coordZ'].tolist()[1:] 112 | dimlst = dataframe['diameter_mm'].tolist()[1:] 113 | # test id 114 | teidlst = [] 115 | for fname in os.listdir('/data/xxx/LUNA/rowfile/subset' + str(fold) + '/'): 116 | # for fname in os.listdir('/media/jehovah/Work/data/LUNA/rowfile/subset' + str(fold) + '/'): 117 | 118 | if fname.endswith('.mhd'): 119 | teidlst.append(fname[:-4]) 120 | mxx = mxy = mxz = mxd = 0 121 | for srsid, label, x, y, z, d in zip(alllst, labellst, crdxlst, crdylst, crdzlst, dimlst): 122 | mxx = max(abs(float(x)), mxx) 123 | mxy = max(abs(float(y)), mxy) 124 | mxz = max(abs(float(z)), mxz) 125 | mxd = max(abs(float(d)), mxd) 126 | if srsid in blklst: continue 127 | # crop raw pixel as feature 128 | data = np.load(os.path.join(preprocesspath, srsid + '.npy')) 129 | bgx = int(data.shape[0] / 2 - CROPSIZE / 2) 130 | bgy = int(data.shape[1] / 2 - CROPSIZE / 2) 131 | bgz = int(data.shape[2] / 2 - CROPSIZE / 2) 132 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE]) 133 | # feat = np.hstack((np.reshape(data, (-1,)) / 255, float(d))) 134 | y, x, z = np.ogrid[-CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2] 135 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3 136 | feat = np.zeros((CROPSIZE, CROPSIZE, CROPSIZE), dtype=float) 137 | feat[mask] = 1 138 | # print(feat.shape) 139 | if srsid.split('-')[0] in teidlst: 140 | tefnamelst.append(srsid + '.npy') 141 | telabellst.append(int(label)) 142 | tefeatlst.append(feat) 143 | else: 144 | trfnamelst.append(srsid + '.npy') 145 | trlabellst.append(int(label)) 146 | trfeatlst.append(feat) 147 | for idx in range(len(trfeatlst)): 148 | # trfeatlst[idx][0] /= mxx 149 | # trfeatlst[idx][1] /= mxy 150 | # trfeatlst[idx][2] /= mxz 151 | trfeatlst[idx][-1] /= mxd 152 | for idx in range(len(tefeatlst)): 153 | # tefeatlst[idx][0] /= mxx 154 | # tefeatlst[idx][1] /= mxy 155 | # tefeatlst[idx][2] /= mxz 156 | tefeatlst[idx][-1] /= mxd 157 | trainset = lunanod(preprocesspath, trfnamelst, trlabellst, trfeatlst, train=True, download=True, 158 | transform=transform_train) 159 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=20) 160 | 161 | testset = lunanod(preprocesspath, tefnamelst, telabellst, tefeatlst, train=False, download=True, 162 | transform=transform_test) 163 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=20) 164 | savemodelpath = './checkpoint-' + str(fold) + '/' 165 | # Model 166 | print(args.resume) 167 | if args.resume: 168 | print('==> Resuming from checkpoint..') 169 | print(args.savemodel) 170 | if args.savemodel == '': 171 | logging.info('==> Resuming from checkpoint..') 172 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!' 173 | checkpoint = torch.load(savemodelpath + 'ckpt.t7') 174 | 175 | else: 176 | logging.info('==> Resuming from checkpoint..') 177 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!' 178 | checkpoint = torch.load(args.savemodel) 179 | net = checkpoint['net'] 180 | best_acc = checkpoint['acc'] 181 | start_epoch = checkpoint['epoch'] 182 | print(savemodelpath + " load success") 183 | print(start_epoch) 184 | else: 185 | logging.info('==> Building model..') 186 | logging.info('args.savemodel : ' + args.savemodel) 187 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]]) 188 | if args.savemodel != "": 189 | # args.savemodel = '/home/xxx/DeepLung-master/nodcls/checkpoint-5/ckpt.t7' 190 | checkpoint = torch.load(args.savemodel) 191 | finenet = checkpoint 192 | Low_rankmodel_dic = net.state_dict() 193 | finenet = {k: v for k, v in finenet.items() if k in Low_rankmodel_dic} 194 | Low_rankmodel_dic.update(finenet) 195 | net.load_state_dict(Low_rankmodel_dic) 196 | print("net_loaded") 197 | 198 | lr = args.lr 199 | 200 | 201 | def get_lr(epoch): 202 | global lr 203 | if (epoch + 1) > (args.num_epochs - args.num_epochs_decay): 204 | lr -= (lr / float(args.num_epochs_decay)) 205 | for param_group in optimizer.param_groups: 206 | param_group['lr'] = lr 207 | print('Decay learning rate to lr: {}.'.format(lr)) 208 | 209 | 210 | if use_cuda: 211 | net.cuda() 212 | if args.gpuids == 'all': 213 | device_ids = range(torch.cuda.device_count()) 214 | else: 215 | device_ids = map(int, list(filter(str.isdigit, args.gpuids))) 216 | 217 | print('gpu use' + str(device_ids)) 218 | net = torch.nn.DataParallel(net, device_ids=device_ids) 219 | cudnn.benchmark = False # True 220 | 221 | criterion = nn.CrossEntropyLoss() 222 | optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 223 | 224 | 225 | # L2Loss = torch.nn.MSELoss() 226 | 227 | # Training 228 | def train(epoch): 229 | logging.info('\nEpoch: ' + str(epoch)) 230 | net.train() 231 | get_lr(epoch) 232 | train_loss = 0 233 | correct = 0 234 | total = 0 235 | 236 | for batch_idx, (inputs, targets, feat) in enumerate(trainloader): 237 | if use_cuda: 238 | inputs, targets = inputs.cuda(), targets.cuda() 239 | 240 | optimizer.zero_grad() 241 | inputs, targets = Variable(inputs), Variable(targets) 242 | outputs = net(inputs) 243 | loss = criterion(outputs, targets) 244 | 245 | loss.backward() 246 | optimizer.step() 247 | train_loss += loss.data.item() 248 | _, predicted = torch.max(outputs.data, 1) 249 | total += targets.size(0) 250 | correct += predicted.eq(targets.data).cpu().sum() 251 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 252 | 253 | print('ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 254 | 255 | logging.info( 256 | 'ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 257 | 258 | 259 | def test(epoch): 260 | epoch_start_time = time.time() 261 | global best_acc 262 | global best_acc_gbt 263 | net.eval() 264 | test_loss = 0 265 | correct = 0 266 | total = 0 267 | TP = FP = FN = TN = 0 268 | for batch_idx, (inputs, targets, feat) in enumerate(testloader): 269 | if use_cuda: 270 | inputs, targets = inputs.cuda(), targets.cuda() 271 | 272 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets) 273 | outputs = net(inputs) 274 | 275 | loss = criterion(outputs, targets) 276 | test_loss += loss.data.item() 277 | _, predicted = torch.max(outputs.data, 1) 278 | total += targets.size(0) 279 | correct += predicted.eq(targets.data).cpu().sum() 280 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum() 281 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum() 282 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum() 283 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum() 284 | 285 | # Save checkpoint. 286 | acc = 100. * correct.data.item() / total 287 | if acc > best_acc: 288 | logging.info('Saving..') 289 | state = { 290 | 'net': net.module if use_cuda else net, 291 | 'acc': acc, 292 | 'epoch': epoch, 293 | } 294 | if not os.path.isdir(savemodelpath): 295 | os.mkdir(savemodelpath) 296 | torch.save(state, savemodelpath + 'ckpt.t7') 297 | best_acc = acc 298 | logging.info('Saving..') 299 | state = { 300 | 'net': net.module if use_cuda else net, 301 | 'acc': acc, 302 | 'epoch': epoch, 303 | } 304 | if not os.path.isdir(savemodelpath): 305 | os.mkdir(savemodelpath) 306 | if epoch % 50 == 0: 307 | torch.save(state, savemodelpath + 'ckpt' + str(epoch) + '.t7') 308 | # best_acc = acc 309 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item()) 310 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item()) 311 | 312 | print('teacc ' + str(acc) + ' bestacc ' + str(best_acc)) 313 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr)) 314 | print('Time Taken: %d sec' % (time.time() - epoch_start_time)) 315 | logging.info( 316 | 'teacc ' + str(acc) + ' bestacc ' + str(best_acc)) 317 | logging.info( 318 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr)) 319 | 320 | 321 | if __name__ == '__main__': 322 | for epoch in range(start_epoch + 1, start_epoch + args.num_epochs + 1): # 200): 323 | train(epoch) 324 | test(epoch) 325 | -------------------------------------------------------------------------------- /models/net_sphere.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/clcarwin/sphereface_pytorch/blob/master/net_sphere.py 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | from torch.nn import Parameter 10 | import dill 11 | import math 12 | 13 | 14 | def myphi(x, m): 15 | x = x * m 16 | return 1 - x ** 2 / math.factorial(2) + x ** 4 / math.factorial(4) - x ** 6 / math.factorial(6) + \ 17 | x ** 8 / math.factorial(8) - x ** 9 / math.factorial(9) 18 | 19 | 20 | import math 21 | import torch 22 | from torch import nn 23 | # from scipy.special import binom 24 | import scipy.special as special 25 | 26 | class LSoftmaxLinear(nn.Module): 27 | 28 | def __init__(self, input_features, output_features, margin=4, device='cuda'): 29 | super().__init__() 30 | self.input_dim = input_features # number of input feature i.e. output of the last fc layer 31 | self.output_dim = output_features # number of output = class numbers 32 | self.margin = margin # m 33 | self.beta = 100 34 | self.beta_min = 0 35 | self.scale = 0.99 36 | 37 | # self.device = device # gpu or cpu 38 | use_cuda = not False and torch.cuda.is_available() 39 | self.device = torch.device("cuda" if use_cuda else "cpu") 40 | 41 | # Initialize L-Softmax parameters 42 | self.weight = nn.Parameter(torch.FloatTensor(input_features, output_features)) 43 | self.divisor = math.pi / self.margin # pi/m 44 | self.C_m_2n = torch.Tensor(special.binom(margin, range(0, margin + 1, 2))).to(device) # C_m{2n} 45 | self.cos_powers = torch.Tensor(range(self.margin, -1, -2)).to(device) # m - 2n 46 | self.sin2_powers = torch.Tensor(range(len(self.cos_powers))).to(device) # n 47 | self.signs = torch.ones(margin // 2 + 1).to(device) # 1, -1, 1, -1, ... 48 | self.signs[1::2] = -1 49 | 50 | def calculate_cos_m_theta(self, cos_theta): 51 | sin2_theta = 1 - cos_theta ** 2 52 | cos_terms = cos_theta.unsqueeze(1) ** self.cos_powers.unsqueeze(0) # cos^{m - 2n} 53 | sin2_terms = (sin2_theta.unsqueeze(1) # sin2^{n} 54 | ** self.sin2_powers.unsqueeze(0)) 55 | 56 | cos_m_theta = (self.signs.unsqueeze(0) * # -1^{n} * C_m{2n} * cos^{m - 2n} * sin2^{n} 57 | self.C_m_2n.unsqueeze(0) * 58 | cos_terms * 59 | sin2_terms).sum(1) # summation of all terms 60 | 61 | return cos_m_theta 62 | 63 | def reset_parameters(self): 64 | nn.init.kaiming_normal_(self.weight.data.t()) 65 | 66 | def find_k(self, cos): 67 | # to account for acos numerical errors 68 | eps = 1e-7 69 | cos = torch.clamp(cos, -1 + eps, 1 - eps) 70 | acos = cos.acos() 71 | k = (acos / self.divisor).floor().detach() 72 | return k 73 | 74 | def forward(self, input, target=None): 75 | a = 0 76 | if self.training: 77 | assert target is not None 78 | x, w = input, self.weight 79 | beta = max(self.beta, self.beta_min) 80 | logit = x.mm(w) 81 | indexes = range(logit.size(0)) 82 | logit_target = logit[indexes, target] 83 | 84 | # cos(theta) = w * x / ||w||*||x|| 85 | w_target_norm = w[:, target].norm(p=2, dim=0) 86 | x_norm = x.norm(p=2, dim=1) 87 | cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10) 88 | 89 | # equation 7 90 | cos_m_theta_target = self.calculate_cos_m_theta(cos_theta_target) 91 | 92 | # find k in equation 6 93 | k = self.find_k(cos_theta_target) 94 | 95 | # f_y_i 96 | logit_target_updated = (w_target_norm * 97 | x_norm * 98 | (((-1) ** k * cos_m_theta_target) - 2 * k)) 99 | logit_target_updated_beta = (logit_target_updated + beta * logit[indexes, target]) / (1 + beta) 100 | 101 | logit[indexes, target] = logit_target_updated_beta 102 | self.beta *= self.scale 103 | return logit 104 | else: 105 | assert target is None 106 | return input.mm(self.weight) 107 | 108 | 109 | class AngleLinear(nn.Module): 110 | def __init__(self, in_features, out_features, m=4, phiflag=True): 111 | super(AngleLinear, self).__init__() 112 | self.in_features = in_features 113 | self.out_features = out_features 114 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 115 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 116 | self.phiflag = phiflag 117 | self.m = m 118 | # self.mlambda = [ 119 | # lambda x: x ** 0, 120 | # lambda x: x ** 1, 121 | # lambda x: 2 * x ** 2 - 1, 122 | # lambda x: 4 * x ** 3 - 3 * x, 123 | # lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 124 | # lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 125 | # ] 126 | 127 | def forward(self, input): 128 | x = input # size=(B,F) F is feature len (128*512) 129 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features 130 | # w = 512*227 131 | ww = w.renorm(2, 1, 1e-5).mul(1e5) 132 | xlen = x.pow(2).sum(1).pow(0.5) # size=B 133 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum 134 | 135 | cos_theta = x.mm(ww) # size=(B,Classnum) 136 | cos_theta = cos_theta / xlen.view(-1, 1) / wlen.view(1, -1) 137 | cos_theta = cos_theta.clamp(-1, 1) 138 | 139 | if self.phiflag: 140 | cos_m_theta = 8 * cos_theta ** 4 - 8 * cos_theta ** 2 + 1 141 | theta = Variable(cos_theta.data.acos()) 142 | k = (self.m * theta / 3.14159265).floor() 143 | n_one = k * 0.0 - 1 144 | phi_theta = (n_one ** k) * cos_m_theta - 2 * k 145 | else: 146 | theta = cos_theta.acos() 147 | phi_theta = myphi(theta, self.m) 148 | phi_theta = phi_theta.clamp(-1 * self.m, 1) 149 | 150 | cos_theta = cos_theta * xlen.view(-1, 1) 151 | phi_theta = phi_theta * xlen.view(-1, 1) 152 | output = (cos_theta, phi_theta) 153 | return output # size=(B,Classnum,2) 154 | 155 | 156 | class AngleLoss(nn.Module): 157 | def __init__(self, gamma=0): 158 | super(AngleLoss, self).__init__() 159 | self.gamma = gamma 160 | self.it = 1 161 | self.LambdaMin = 5.0 162 | self.LambdaMax = 1500.0 163 | self.lamb = 1500.0 164 | 165 | def forward(self, input, target): 166 | cos_theta, phi_theta = input 167 | target = target.view(-1, 1) # size=(B,1) 168 | index = cos_theta.data * 0.0 # size=(B, Classnum) 169 | # index = index.scatter(1, target.data.view(-1, 1).long(), 1) 170 | index = index.byte() 171 | index = Variable(index) 172 | # index = Variable(torch.randn(1,2)).byte() 173 | 174 | self.lamb = max(self.LambdaMin, self.LambdaMax / (1 + 0.1 * self.it)) 175 | output = cos_theta * 1.0 # size=(B,Classnum) 176 | output1 = output.clone() 177 | # output1[index1] = output[index] - cos_theta[index] * (1.0 + 0) / (1 + self.lamb) 178 | # output1[index1] = output[index] + phi_theta[index] * (1.0 + 0) / (1 + self.lamb) 179 | output[index] = output1[index]- cos_theta[index] * (1.0 + 0) / (1 + self.lamb)+ phi_theta[index] * (1.0 + 0) / (1 + self.lamb) 180 | logpt = F.log_softmax(output) 181 | logpt = logpt.gather(1, target.long()) 182 | logpt = logpt.view(-1) 183 | pt = Variable(logpt.data.exp()) 184 | 185 | loss = -1 * (1 - pt) ** self.gamma * logpt 186 | loss = loss.mean() 187 | # loss = torch.sum(cos_theta)+ torch.sum(phi_theta) 188 | return loss 189 | 190 | 191 | # class STN(nn.Module): 192 | # def __init__(self ): 193 | # super(STN, self).__init__() 194 | # self.localization = nn.Sequential( 195 | # nn.Conv2d(3, 8, kernel_size=7), 196 | # nn.MaxPool2d(2, stride=2), 197 | # nn.ReLU(True), 198 | # nn.Conv2d(8, 10, kernel_size=5), 199 | # nn.MaxPool2d(2, stride=2), 200 | # nn.ReLU(True) 201 | # ) 202 | # self.fc_loc = nn.Sequential( 203 | # nn.Linear(10*24*20, 32), 204 | # nn.ReLU(True), 205 | # nn.Linear(32, 3 * 2) 206 | # ) 207 | # 208 | # # Initialize the weights/bias with identity transformation 209 | # # self.fc_loc[2].weight.data.zero_() 210 | # # self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) 211 | # 212 | # def forward(self, x): 213 | # xs = self.localization(x) 214 | # xs = xs.view(-1, 10*24*20) 215 | # theta = self.fc_loc(xs) 216 | # theta = theta.view(-1, 2, 3) 217 | # 218 | # grid = F.affine_grid(theta, x.size()) 219 | # x = F.grid_sample(x, grid) 220 | # 221 | # return x 222 | 223 | 224 | class sphere20a(nn.Module): 225 | def __init__(self, classnum=10574, feature=False): 226 | # classnum = dataloader.dataset.class_num = 227 227 | super(sphere20a, self).__init__() 228 | self.classnum = classnum 229 | self.feature = feature 230 | # input = B*3*112*96 231 | self.conv1_1 = nn.Conv2d(3, 64, 3, 2, 1) # =>B*64*56*48 232 | self.relu1_1 = nn.PReLU(64) 233 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) 234 | self.relu1_2 = nn.PReLU(64) 235 | self.conv1_3 = nn.Conv2d(64, 64, 3, 1, 1) 236 | self.relu1_3 = nn.PReLU(64) 237 | 238 | self.conv2_1 = nn.Conv2d(64, 128, 3, 2, 1) # =>B*128*28*24 239 | self.relu2_1 = nn.PReLU(128) 240 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) 241 | self.relu2_2 = nn.PReLU(128) 242 | self.conv2_3 = nn.Conv2d(128, 128, 3, 1, 1) 243 | self.relu2_3 = nn.PReLU(128) 244 | 245 | self.conv2_4 = nn.Conv2d(128, 128, 3, 1, 1) # =>B*128*28*24 246 | self.relu2_4 = nn.PReLU(128) 247 | self.conv2_5 = nn.Conv2d(128, 128, 3, 1, 1) 248 | self.relu2_5 = nn.PReLU(128) 249 | 250 | self.conv3_1 = nn.Conv2d(128, 256, 3, 2, 1) # =>B*256*14*12 251 | self.relu3_1 = nn.PReLU(256) 252 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 253 | self.relu3_2 = nn.PReLU(256) 254 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) 255 | self.relu3_3 = nn.PReLU(256) 256 | 257 | self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12 258 | self.relu3_4 = nn.PReLU(256) 259 | self.conv3_5 = nn.Conv2d(256, 256, 3, 1, 1) 260 | self.relu3_5 = nn.PReLU(256) 261 | 262 | self.conv3_6 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12 263 | self.relu3_6 = nn.PReLU(256) 264 | self.conv3_7 = nn.Conv2d(256, 256, 3, 1, 1) 265 | self.relu3_7 = nn.PReLU(256) 266 | 267 | self.conv3_8 = nn.Conv2d(256, 256, 3, 1, 1) # =>B*256*14*12 268 | self.relu3_8 = nn.PReLU(256) 269 | self.conv3_9 = nn.Conv2d(256, 256, 3, 1, 1) 270 | self.relu3_9 = nn.PReLU(256) 271 | 272 | self.conv4_1 = nn.Conv2d(256, 512, 3, 2, 1) # =>B*512*7*6 273 | self.relu4_1 = nn.PReLU(512) 274 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 275 | self.relu4_2 = nn.PReLU(512) 276 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) 277 | self.relu4_3 = nn.PReLU(512) 278 | 279 | self.fc5 = nn.Linear(512 * 7 * 6, 512) 280 | self.fc6 = AngleLinear(512, self.classnum) 281 | # self.stn = STN() 282 | 283 | def forward(self, x, target=None): 284 | # x = self.stn(x) 285 | x = self.relu1_1(self.conv1_1(x)) 286 | x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x)))) 287 | 288 | x = self.relu2_1(self.conv2_1(x)) 289 | x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x)))) 290 | x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x)))) 291 | 292 | x = self.relu3_1(self.conv3_1(x)) 293 | x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x)))) 294 | x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x)))) 295 | x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x)))) 296 | x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x)))) 297 | 298 | x = self.relu4_1(self.conv4_1(x)) 299 | x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x)))) 300 | 301 | x = x.view(x.size(0), -1) 302 | x = self.fc5(x) # 128*512 303 | if self.feature: # self.feature=False 304 | return x 305 | 306 | x = self.fc6(x) 307 | 308 | return x 309 | 310 | 311 | class testsp(nn.Module): 312 | def __init__(self): 313 | super(testsp, self).__init__() 314 | self.linear = AngleLinear(100, 2) 315 | 316 | def forward(self, x): 317 | out = self.linear(x) 318 | return out 319 | 320 | def test(): 321 | net = testsp() 322 | x = Variable(torch.randn(1,100)) 323 | tar = Variable(torch.Tensor(1)) 324 | out = net(x) 325 | cre = AngleLoss() 326 | loss = cre(out, tar) 327 | loss.backward() 328 | -------------------------------------------------------------------------------- /mainsp.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | import models.net_sphere as sp_net 10 | import transforms as transforms 11 | from dataloader import lunanod 12 | import os 13 | import argparse 14 | import time 15 | from models.cnn_res import * 16 | # from utils import progress_bar 17 | from torch.autograd import Variable 18 | import logging 19 | import numpy as np 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 22 | parser.add_argument('--lr', default=0.0002, type=float, help='learning rate') 23 | parser.add_argument('--batch_size', default=1, type=int, help='batch size ') 24 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 25 | parser.add_argument('--savemodel', type=str, default='', help='resume from checkpoint model') 26 | parser.add_argument("--gpuids", type=str, default='all', help='use which gpu') 27 | 28 | parser.add_argument('--num_epochs', type=int, default=700) 29 | parser.add_argument('--num_epochs_decay', type=int, default=70) 30 | 31 | parser.add_argument('--num_workers', type=int, default=24) 32 | 33 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 34 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 35 | parser.add_argument('--lamb', type=float, default=1, help="lambda for loss2") 36 | parser.add_argument('--fold', type=int, default=5, help="fold") 37 | 38 | args = parser.parse_args() 39 | 40 | CROPSIZE = 32 41 | gbtdepth = 1 42 | fold = args.fold 43 | blklst = [] # ['1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-388', \ 44 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-389', \ 45 | # '1.3.6.1.4.1.14519.5.2.1.6279.6001.132817748896065918417924920957-660'] 46 | logging.basicConfig(filename='log-' + str(fold), level=logging.INFO) 47 | 48 | use_cuda = torch.cuda.is_available() 49 | best_acc = 0 # best test accuracy 50 | best_acc_gbt = 0 51 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 52 | # Cal mean std 53 | # preprocesspath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/' 54 | preprocesspath = '/data/xxx/LUNA/cls/crop_v3/' 55 | # preprocesspath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/' 56 | pixvlu, npix = 0, 0 57 | for fname in os.listdir(preprocesspath): 58 | # print(fname) 59 | if fname.endswith('.npy'): 60 | if fname[:-4] in blklst: continue 61 | data = np.load(os.path.join(preprocesspath, fname)) 62 | pixvlu += np.sum(data) 63 | # print("data.shape = " + str(data.shape)) 64 | npix += np.prod(data.shape) 65 | pixmean = pixvlu / float(npix) 66 | pixvlu = 0 67 | for fname in os.listdir(preprocesspath): 68 | if fname.endswith('.npy'): 69 | if fname[:-4] in blklst: continue 70 | data = np.load(os.path.join(preprocesspath, fname)) - pixmean 71 | pixvlu += np.sum(data * data) 72 | pixstd = np.sqrt(pixvlu / float(npix)) 73 | # pixstd /= 255 74 | print(pixmean, pixstd) 75 | logging.info('mean ' + str(pixmean) + ' std ' + str(pixstd)) 76 | # Datatransforms 77 | logging.info('==> Preparing data..') # Random Crop, Zero out, x z flip, scale, 78 | transform_train = transforms.Compose([ 79 | # transforms.RandomScale(range(28, 38)), 80 | transforms.RandomCrop(32, padding=4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.RandomYFlip(), 83 | transforms.RandomZFlip(), 84 | transforms.ZeroOut(4), 85 | transforms.ToTensor(), 86 | transforms.Normalize((pixmean), (pixstd)), # need to cal mean and std, revise norm func 87 | ]) 88 | 89 | transform_test = transforms.Compose([ 90 | transforms.ToTensor(), 91 | transforms.Normalize((pixmean), (pixstd)), 92 | ]) 93 | 94 | # load data list 95 | trfnamelst = [] 96 | trlabellst = [] 97 | trfeatlst = [] 98 | tefnamelst = [] 99 | telabellst = [] 100 | tefeatlst = [] 101 | import pandas as pd 102 | 103 | dataframe = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv', 104 | names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 105 | 106 | alllst = dataframe['seriesuid'].tolist()[1:] 107 | labellst = dataframe['malignant'].tolist()[1:] 108 | crdxlst = dataframe['coordX'].tolist()[1:] 109 | crdylst = dataframe['coordY'].tolist()[1:] 110 | crdzlst = dataframe['coordZ'].tolist()[1:] 111 | dimlst = dataframe['diameter_mm'].tolist()[1:] 112 | # test id 113 | teidlst = [] 114 | for fname in os.listdir('/data/xxx/LUNA/rowfile/subset' + str(fold) + '/'): 115 | # for fname in os.listdir('/media/jehovah/Work/data/LUNA/rowfile/subset' + str(fold) + '/'): 116 | 117 | if fname.endswith('.mhd'): 118 | teidlst.append(fname[:-4]) 119 | mxx = mxy = mxz = mxd = 0 120 | for srsid, label, x, y, z, d in zip(alllst, labellst, crdxlst, crdylst, crdzlst, dimlst): 121 | mxx = max(abs(float(x)), mxx) 122 | mxy = max(abs(float(y)), mxy) 123 | mxz = max(abs(float(z)), mxz) 124 | mxd = max(abs(float(d)), mxd) 125 | if srsid in blklst: continue 126 | # crop raw pixel as feature 127 | data = np.load(os.path.join(preprocesspath, srsid + '.npy')) 128 | bgx = int(data.shape[0] / 2 - CROPSIZE / 2) 129 | bgy = int(data.shape[1] / 2 - CROPSIZE / 2) 130 | bgz = int(data.shape[2] / 2 - CROPSIZE / 2) 131 | data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE]) 132 | # feat = np.hstack((np.reshape(data, (-1,)) / 255, float(d))) 133 | y, x, z = np.ogrid[-CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2] 134 | mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3 135 | feat = np.zeros((CROPSIZE, CROPSIZE, CROPSIZE), dtype=float) 136 | feat[mask] = 1 137 | # print(feat.shape) 138 | if srsid.split('-')[0] in teidlst: 139 | tefnamelst.append(srsid + '.npy') 140 | telabellst.append(int(label)) 141 | tefeatlst.append(feat) 142 | else: 143 | trfnamelst.append(srsid + '.npy') 144 | trlabellst.append(int(label)) 145 | trfeatlst.append(feat) 146 | for idx in range(len(trfeatlst)): 147 | # trfeatlst[idx][0] /= mxx 148 | # trfeatlst[idx][1] /= mxy 149 | # trfeatlst[idx][2] /= mxz 150 | trfeatlst[idx][-1] /= mxd 151 | for idx in range(len(tefeatlst)): 152 | # tefeatlst[idx][0] /= mxx 153 | # tefeatlst[idx][1] /= mxy 154 | # tefeatlst[idx][2] /= mxz 155 | tefeatlst[idx][-1] /= mxd 156 | trainset = lunanod(preprocesspath, trfnamelst, trlabellst, trfeatlst, train=True, download=True, 157 | transform=transform_train) 158 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=20) 159 | 160 | testset = lunanod(preprocesspath, tefnamelst, telabellst, tefeatlst, train=False, download=True, 161 | transform=transform_test) 162 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=20) 163 | savemodelpath = './checkpoint-' + str(fold) + '/' 164 | train_val = np.empty(shape=0) 165 | test_val = np.empty(shape=(0, 3)) 166 | # Model 167 | print(args.resume) 168 | if args.resume: 169 | 170 | print('==> Resuming from checkpoint..') 171 | print(args.savemodel) 172 | if args.savemodel == '': 173 | logging.info('==> Resuming from checkpoint..') 174 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!' 175 | checkpoint = torch.load(savemodelpath + 'ckpt.t7') 176 | 177 | else: 178 | logging.info('==> Resuming from checkpoint..') 179 | assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!' 180 | checkpoint = torch.load(args.savemodel) 181 | net = checkpoint['net'] 182 | best_acc = checkpoint['acc'] 183 | start_epoch = checkpoint['epoch'] 184 | print(savemodelpath + " load success") 185 | print(start_epoch) 186 | else: 187 | logging.info('==> Building model..') 188 | logging.info('args.savemodel : ' + args.savemodel) 189 | net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]]) 190 | if args.savemodel != "": 191 | # args.savemodel = '/home/xxx/DeepLung-master/nodcls/checkpoint-5/ckpt.t7' 192 | checkpoint = torch.load(args.savemodel) 193 | finenet = checkpoint 194 | Low_rankmodel_dic = net.state_dict() 195 | finenet = {k: v for k, v in finenet.items() if k in Low_rankmodel_dic} 196 | Low_rankmodel_dic.update(finenet) 197 | net.load_state_dict(Low_rankmodel_dic) 198 | print("net_loaded") 199 | 200 | lr = args.lr 201 | 202 | 203 | def get_lr(epoch): 204 | global lr 205 | if (epoch + 1) > (args.num_epochs - args.num_epochs_decay): 206 | lr -= (lr / float(args.num_epochs_decay)) 207 | for param_group in optimizer.param_groups: 208 | param_group['lr'] = lr 209 | print('Decay learning rate to lr: {}.'.format(lr)) 210 | 211 | 212 | if use_cuda: 213 | net.cuda() 214 | if args.gpuids == 'all': 215 | device_ids = range(torch.cuda.device_count()) 216 | else: 217 | device_ids = map(int, list(filter(str.isdigit, args.gpuids))) 218 | 219 | print('gpu use' + str(device_ids)) 220 | net = torch.nn.DataParallel(net, device_ids=device_ids) 221 | cudnn.benchmark = False # True 222 | 223 | criterion = sp_net.AngleLoss() 224 | optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 225 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) 226 | 227 | # L2Loss = torch.nn.MSELoss() 228 | 229 | # Training 230 | def train(epoch): 231 | logging.info('\nEpoch: ' + str(epoch)) 232 | net.train() 233 | get_lr(epoch) 234 | train_loss = 0 235 | correct = 0 236 | total = 0 237 | 238 | for batch_idx, (inputs, targets, feat) in enumerate(trainloader): 239 | if use_cuda: 240 | inputs, targets = inputs.cuda(), targets.cuda() 241 | 242 | optimizer.zero_grad() 243 | inputs, targets = Variable(inputs), Variable(targets) 244 | outputs = net(inputs) 245 | loss = criterion(outputs, targets) 246 | 247 | loss.backward() 248 | optimizer.step() 249 | train_loss += loss.data.item() 250 | _, predicted = torch.max(outputs[0].data, 1) 251 | total += targets.size(0) 252 | correct += predicted.eq(targets.data).cpu().sum() 253 | # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 254 | 255 | print('ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 256 | logging.info( 257 | 'ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr)) 258 | np.append(train_val, correct.data.item() / float(total)) 259 | 260 | 261 | def test(epoch): 262 | epoch_start_time = time.time() 263 | global best_acc 264 | global best_acc_gbt 265 | net.eval() 266 | test_loss = 0 267 | correct = 0 268 | total = 0 269 | TP = FP = FN = TN = 0 270 | for batch_idx, (inputs, targets, feat) in enumerate(testloader): 271 | if use_cuda: 272 | inputs, targets = inputs.cuda(), targets.cuda() 273 | 274 | inputs, targets = Variable(inputs, requires_grad=False), Variable(targets) 275 | outputs = net(inputs) 276 | 277 | loss = criterion(outputs, targets) 278 | test_loss += loss.data.item() 279 | _, predicted = torch.max(outputs[0].data, 1) 280 | total += targets.size(0) 281 | correct += predicted.eq(targets.data).cpu().sum() 282 | TP += ((predicted == 1) & (targets.data == 1)).cpu().sum() 283 | TN += ((predicted == 0) & (targets.data == 0)).cpu().sum() 284 | FN += ((predicted == 0) & (targets.data == 1)).cpu().sum() 285 | FP += ((predicted == 1) & (targets.data == 0)).cpu().sum() 286 | 287 | # Save checkpoint. 288 | acc = 100. * correct.data.item() / total 289 | if acc > best_acc: 290 | logging.info('Saving..') 291 | state = { 292 | 'net': net.module if use_cuda else net, 293 | 'acc': acc, 294 | 'epoch': epoch, 295 | } 296 | if not os.path.isdir(savemodelpath): 297 | os.mkdir(savemodelpath) 298 | torch.save(state, savemodelpath + 'ckpt.t7') 299 | best_acc = acc 300 | logging.info('Saving..') 301 | state = { 302 | 'net': net.module if use_cuda else net, 303 | 'acc': acc, 304 | 'epoch': epoch, 305 | } 306 | if not os.path.isdir(savemodelpath): 307 | os.mkdir(savemodelpath) 308 | if epoch % 50 == 0: 309 | torch.save(state, savemodelpath + 'ckpt' + str(epoch) + '.t7') 310 | # best_acc = acc 311 | tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item()) 312 | fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item()) 313 | 314 | print('teacc ' + str(acc) + ' bestacc ' + str(best_acc)) 315 | print('tpr ' + str(tpr) + ' fpr ' + str(fpr)) 316 | print('Time Taken: %d sec' % (time.time() - epoch_start_time)) 317 | logging.info( 318 | 'teacc ' + str(acc) + ' bestacc ' + str(best_acc)) 319 | logging.info( 320 | 'tpr ' + str(tpr) + ' fpr ' + str(fpr)) 321 | np.append(test_val, [[acc, tpr, fpr]], axis=0) 322 | 323 | 324 | if __name__ == '__main__': 325 | for epoch in range(start_epoch + 1, start_epoch + args.num_epochs + 1): # 200): 326 | train(epoch) 327 | test(epoch) 328 | np.save(savemodelpath + "train_acc", train_val) 329 | np.save(savemodelpath + "test_acc", test_val) 330 | -------------------------------------------------------------------------------- /data/humanperformance.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import SimpleITK as sitk 3 | import os 4 | import os.path 5 | import numpy as np 6 | fold = 1 7 | def load_itk_image(filename): 8 | with open(filename) as f: 9 | contents = f.readlines() 10 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 12 | transformM = np.round(transformM) 13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])): 14 | isflip = True 15 | else: 16 | isflip = False 17 | itkimage = sitk.ReadImage(filename) 18 | numpyImage = sitk.GetArrayFromImage(itkimage) 19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 21 | return numpyImage, numpyOrigin, numpySpacing,isflip 22 | def worldToVoxelCoord(worldCoord, origin, spacing): 23 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 24 | voxelCoord = stretchedVoxelCoord / spacing 25 | return voxelCoord 26 | # read map file 27 | mapfname = 'LIDC-IDRI-mappingLUNA16' 28 | sidmap = {} 29 | fid = open(mapfname, 'r') 30 | line = fid.readline() 31 | line = fid.readline() 32 | while line: 33 | pidlist = line.split(' ') 34 | # print pidlist 35 | pid = pidlist[0] 36 | stdid = pidlist[1] 37 | srsid = pidlist[2] 38 | if srsid not in sidmap: 39 | sidmap[srsid] = [pid, stdid] 40 | else: 41 | assert sidmap[srsid][0] == pid 42 | assert sidmap[srsid][1] == stdid 43 | line = fid.readline() 44 | fid.close() 45 | # read luna16 annotation 46 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'] 47 | lunaantframe = pd.read_csv('annotations.csv', names=colname) 48 | srslist = lunaantframe.seriesuid.tolist()[1:] 49 | cdxlist = lunaantframe.coordX.tolist()[1:] 50 | cdylist = lunaantframe.coordY.tolist()[1:] 51 | cdzlist = lunaantframe.coordZ.tolist()[1:] 52 | dimlist = lunaantframe.diameter_mm.tolist()[1:] 53 | lunaantdict = {} 54 | for idx in xrange(len(srslist)): 55 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])] 56 | if srslist[idx] in lunaantdict: 57 | lunaantdict[srslist[idx]].append(vlu) 58 | else: 59 | lunaantdict[srslist[idx]] = [vlu] 60 | # # convert luna16 annotation to LIDC-IDRI annotation space 61 | # from multiprocessing import Pool 62 | # lunantdictlidc = {} 63 | # for fold in xrange(10): 64 | # mhdpath = '/media/data1/wentao/tianchi/luna16/subset'+str(fold) 65 | # print 'fold', fold 66 | # def getvoxcrd(fname): 67 | # sliceim,origin,spacing,isflip = load_itk_image(os.path.join(mhdpath, fname)) 68 | # lunantdictlidc[fname[:-4]] = [] 69 | # voxcrdlist = [] 70 | # for lunaant in lunaantdict[fname[:-4]]: 71 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 72 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 73 | # voxcrdlist.append(voxcrd) 74 | # return voxcrdlist 75 | # p = Pool(30) 76 | # fnamelist = [] 77 | # for fname in os.listdir(mhdpath): 78 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 79 | # fnamelist.append(fname) 80 | # voxcrdlist = p.map(getvoxcrd, fnamelist) 81 | # listidx = 0 82 | # for fname in os.listdir(mhdpath): 83 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 84 | # lunantdictlidc[fname[:-4]] = [] 85 | # for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]): 86 | # # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 87 | # # voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 88 | # lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]]) 89 | # listidx += 1 90 | # p.close() 91 | # np.save('lunaantdictlidc.npy', lunantdictlidc) 92 | # read LIDC dataset 93 | lunantdictlidc = np.load('lunaantdictlidc.npy').item() 94 | import xlrd 95 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls' 96 | antdict = {} 97 | wb = xlrd.open_workbook(os.path.join(lidccsvfname)) 98 | for s in wb.sheets(): 99 | if s.name == 'list3.2': 100 | for row in range(1, s.nrows): 101 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \ 102 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)] 103 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8 104 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8 105 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8 106 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8 107 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8 108 | for col in range(9, 16): 109 | if s.cell(row, col).value != '': 110 | if isinstance(s.cell(row, col).value, float): 111 | valuelist.append(str(int(s.cell(row, col).value))) 112 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8 113 | else: 114 | valuelist.append(s.cell(row, col).value) 115 | if s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value)) not in antdict: 116 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))] = [valuelist] 117 | else: 118 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))].append(valuelist) 119 | # update LIDC annotation with series number, rather than scan id 120 | import dicom 121 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/' 122 | antdictscan = {} 123 | for k, v in antdict.iteritems(): 124 | pid, scan = k.split('_') 125 | hasscan = False 126 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-'+pid)): 127 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu])): 128 | if srs.endswith('.npy'): 129 | print('npy', pid, scan, srs) 130 | continue 131 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu, srs, '000006.dcm'])) 132 | # print scan, str(RefDs[0x20, 0x11].value) 133 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0': 134 | if hasscan: print('rep', pid, sdu, srs) 135 | hasscan = True 136 | antdictscan[pid+'_'+srs] = v 137 | break 138 | if not hasscan: print('not found', pid, scan, sdu, srs) 139 | # find the match from LIDC-IDRI annotation 140 | import math 141 | lunaantdictnodid = {} 142 | maxdist = 0 143 | for srcid, lunaantlidc in lunantdictlidc.iteritems(): 144 | lunaantdictnodid[srcid] = [] 145 | pid, stdid = sidmap[srcid] 146 | # print pid 147 | pid = pid[len('LIDC-IDRI-'):] 148 | for lunantdictlidcsub in lunaantlidc: 149 | lunaant = lunantdictlidcsub[0] 150 | voxcrd = lunantdictlidcsub[1] # z y x 151 | mindist, minidx = 1e8, -1 152 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479', '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']: 153 | continue 154 | for idx, lidcant in enumerate(antdictscan[pid+'_'+srcid]): 155 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z 156 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y 157 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x 158 | if dist < mindist: 159 | mindist = dist 160 | minidx = idx 161 | if mindist > 71:#15.1: 162 | print(srcid, pid, voxcrd, antdictscan[pid+'_'+srcid], mindist) 163 | maxdist = max(maxdist, mindist) 164 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid+'_'+srcid][minidx][6:]]) 165 | # np.save('lunaantdictnodid.npy', lunaantdictnodid) 166 | print('maxdist', maxdist) 167 | # save it into a csv 168 | # import csv 169 | # savename = 'annotationnodid.csv' 170 | # fid = open(savename, 'w') 171 | # writer = csv.writer(fid) 172 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']) 173 | # for srcid, ant in lunaantdictnodid.iteritems(): 174 | # for antsub in ant: 175 | # writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1]) 176 | # fid.close() 177 | # fd 1 178 | fd1lst = [] 179 | for fname in os.listdir('/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'): 180 | if fname.endswith('.mhd'): fd1lst.append(fname[:-4]) 181 | # find the malignancy, shape information from xml file 182 | import xml.dom.minidom 183 | ndoc = 0 184 | lunadctclssgmdict = {} 185 | mallstall, callstall, sphlstall, marlstall, loblstall, spilstall, texlstall = [], [], [], [], [], [], [] 186 | for srsid, extant in lunaantdictnodid.iteritems(): 187 | if srsid not in fd1lst: continue 188 | lunadctclssgmdict[srsid] = [] 189 | pid, stdid = sidmap[srsid] 190 | for extantvlu in extant: 191 | getnodid = [] 192 | nant = 0 193 | mallst = [] 194 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])): 195 | if fname.endswith('.xml'): 196 | nant += 1 197 | dom = xml.dom.minidom.parse(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname])) 198 | root = dom.documentElement 199 | rsessions = root.getElementsByTagName('readingSession') 200 | for rsess in rsessions: 201 | unblinds = rsess.getElementsByTagName('unblindedReadNodule') 202 | for unb in unblinds: 203 | nod = unb.getElementsByTagName('noduleID') 204 | if len(nod) != 1: 205 | print('more nod', nod) 206 | continue 207 | if nod[0].firstChild.data in extantvlu[1]: 208 | getnodid.append(nod[0].firstChild.data) 209 | mal = unb.getElementsByTagName('malignancy') 210 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1): 211 | mallst.append(float(mal[0].firstChild.data)) 212 | # print(getnodid, extantvlu[1], nant) 213 | if len(getnodid) > len(extantvlu[1]): 214 | print(pid, srsid) 215 | # assert 1 == 0 216 | ndoc = max(ndoc, len(getnodid), len(extantvlu[1])) 217 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]] 218 | if len(mallst) == 0: vlulst.append(0) 219 | else: vlulst.append(sum(mallst)/float(len(mallst))) 220 | lunadctclssgmdict[srsid].append(vlulst+mallst) 221 | import csv 222 | # load predition array 223 | pixdimpred = np.load('../../../CTnoddetector/training/nodcls/besttestpred.npy')#'pixradiustest.npy') 224 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 225 | srslst = pdframe['seriesuid'].tolist()[1:] 226 | crdxlst = pdframe['coordX'].tolist()[1:] 227 | crdylst = pdframe['coordY'].tolist()[1:] 228 | crdzlst = pdframe['coordZ'].tolist()[1:] 229 | dimlst = pdframe['diameter_mm'].tolist()[1:] 230 | mlglst = pdframe['malignant'].tolist()[1:] 231 | newlst = [] 232 | import csv 233 | fid = open('annotationdetclsconvfnl_v3.csv', 'w') 234 | writer = csv.writer(fid) 235 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 236 | for i in range(len(srslst)): 237 | writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 238 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 239 | fid.close() 240 | subset1path = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/' 241 | testfnamelst = [] 242 | for fname in os.listdir(subset1path): 243 | if fname.endswith('.mhd'): 244 | testfnamelst.append(fname[:-4]) 245 | ntest = 0 246 | for idx in range(len(newlst)): 247 | fname = newlst[idx][0] 248 | if fname.split('-')[0] in testfnamelst: ntest +=1 249 | print('ntest', ntest, 'ntrain', len(newlst)-ntest) 250 | prednamelst = {} 251 | predacc = 0 252 | predidx = 0 253 | # predlabellst = [] 254 | for idx in range(len(newlst)): 255 | fname = newlst[idx][0] 256 | if fname.split('-')[0] in testfnamelst: 257 | # print newlst[idx][-1], pixdimpred[predidx] 258 | if int(pixdimpred[predidx]>0.5) == int(newlst[idx][-1]): predacc += 1 259 | if fname.split('-')[0] not in prednamelst: 260 | prednamelst[fname.split('-')[0]] = [[pixdimpred[predidx], fname.split('-')[1]]] 261 | else: 262 | prednamelst[fname.split('-')[0]].append([pixdimpred[predidx], fname.split('-')[1]]) 263 | predidx += 1 264 | print('pred acc', predacc/float(predidx)) 265 | pixdimidx = -1 266 | # savename = 'annotationdetclssgm_doctor_fd2.csv' 267 | # fid = open(savename, 'w') 268 | # writer = csv.writer(fid) 269 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 270 | doctornacc, doctornpid = [0]*ndoc, [0]*ndoc 271 | nacc = 0 272 | ntot = 0 273 | for srsid, extant in lunadctclssgmdict.iteritems(): 274 | curidx = 0 275 | if srsid not in fd1lst: continue 276 | for subextant in extant: 277 | if subextant[5] in [3, 0]: continue 278 | if abs(subextant[5] - 3) < 1e-2: continue 279 | if subextant[5] > 3: subextant[5] = 1 280 | else: subextant[5] = 0 281 | if subextant[5] == int(prednamelst[srsid][curidx][0]>0.5): nacc += 1 282 | ntot += 1 283 | # writer.writerow(subextant) 284 | for did in range(6, len(subextant), 1): 285 | # if 0.499 <= prednamelst[srsid][curidx] <= 0.501: continue 286 | if subextant[did] == 3: continue 287 | if subextant[5] != (prednamelst[srsid][curidx][0]>0.5): 288 | print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], subextant[5]) 289 | # if subextant[5] == 1 and prednamelst[srsid][curidx][0]>0.5:# subextant[did] > 3: # we treat 3 as the positive label 290 | # if subextant[did] < 3: print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], did-6) 291 | # doctornacc[did-6] += 1 292 | # elif subextant[5] == 0 and prednamelst[srsid][curidx][0]<=0.5:# subextant[did] < 3: 293 | # if subextant[did] > 3: print(srsid+'-'+prednamelst[srsid][curidx][1], prednamelst[srsid][curidx][0], did-6) 294 | # doctornacc[did-6] += 1 295 | # if subextant[did] != 3: 296 | # doctornpid[did-6] += 1 297 | curidx += 1 298 | fid.close() 299 | print(nacc / float(ntot)) 300 | for i in range(ndoc): 301 | print(i, doctornacc[i], doctornpid[i], doctornacc[i]/float(doctornpid[i])) -------------------------------------------------------------------------------- /data/pthumanperformance.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import SimpleITK as sitk 3 | import os 4 | import os.path 5 | import numpy as np 6 | fold = 1 7 | def load_itk_image(filename): 8 | with open(filename) as f: 9 | contents = f.readlines() 10 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 11 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 12 | transformM = np.round(transformM) 13 | if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])): 14 | isflip = True 15 | else: 16 | isflip = False 17 | itkimage = sitk.ReadImage(filename) 18 | numpyImage = sitk.GetArrayFromImage(itkimage) 19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 21 | return numpyImage, numpyOrigin, numpySpacing,isflip 22 | def worldToVoxelCoord(worldCoord, origin, spacing): 23 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 24 | voxelCoord = stretchedVoxelCoord / spacing 25 | return voxelCoord 26 | # read map file 27 | mapfname = 'LIDC-IDRI-mappingLUNA16' 28 | sidmap = {} 29 | fid = open(mapfname, 'r') 30 | line = fid.readline() 31 | line = fid.readline() 32 | while line: 33 | pidlist = line.split(' ') 34 | # print pidlist 35 | pid = pidlist[0] 36 | stdid = pidlist[1] 37 | srsid = pidlist[2] 38 | if srsid not in sidmap: 39 | sidmap[srsid] = [pid, stdid] 40 | else: 41 | assert sidmap[srsid][0] == pid 42 | assert sidmap[srsid][1] == stdid 43 | line = fid.readline() 44 | fid.close() 45 | # read luna16 annotation 46 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'] 47 | lunaantframe = pd.read_csv('annotations.csv', names=colname) 48 | srslist = lunaantframe.seriesuid.tolist()[1:] 49 | cdxlist = lunaantframe.coordX.tolist()[1:] 50 | cdylist = lunaantframe.coordY.tolist()[1:] 51 | cdzlist = lunaantframe.coordZ.tolist()[1:] 52 | dimlist = lunaantframe.diameter_mm.tolist()[1:] 53 | lunaantdict = {} 54 | for idx in range(len(srslist)): 55 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])] 56 | if srslist[idx] in lunaantdict: 57 | lunaantdict[srslist[idx]].append(vlu) 58 | else: 59 | lunaantdict[srslist[idx]] = [vlu] 60 | # # convert luna16 annotation to LIDC-IDRI annotation space 61 | # from multiprocessing import Pool 62 | # lunantdictlidc = {} 63 | # for fold in xrange(10): 64 | # mhdpath = '/media/data1/wentao/tianchi/luna16/subset'+str(fold) 65 | # print 'fold', fold 66 | # def getvoxcrd(fname): 67 | # sliceim,origin,spacing,isflip = load_itk_image(os.path.join(mhdpath, fname)) 68 | # lunantdictlidc[fname[:-4]] = [] 69 | # voxcrdlist = [] 70 | # for lunaant in lunaantdict[fname[:-4]]: 71 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 72 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 73 | # voxcrdlist.append(voxcrd) 74 | # return voxcrdlist 75 | # p = Pool(30) 76 | # fnamelist = [] 77 | # for fname in os.listdir(mhdpath): 78 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 79 | # fnamelist.append(fname) 80 | # voxcrdlist = p.map(getvoxcrd, fnamelist) 81 | # listidx = 0 82 | # for fname in os.listdir(mhdpath): 83 | # if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 84 | # lunantdictlidc[fname[:-4]] = [] 85 | # for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]): 86 | # # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 87 | # # voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 88 | # lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]]) 89 | # listidx += 1 90 | # p.close() 91 | # np.save('lunaantdictlidc.npy', lunantdictlidc) 92 | # read LIDC dataset 93 | lunantdictlidc = np.load('lunaantdictlidc.npy').item() 94 | import xlrd 95 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls' 96 | antdict = {} 97 | wb = xlrd.open_workbook(os.path.join(lidccsvfname)) 98 | for s in wb.sheets(): 99 | if s.name == 'list3.2': 100 | for row in range(1, s.nrows): 101 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \ 102 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)] 103 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8 104 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8 105 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8 106 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8 107 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8 108 | for col in range(9, 16): 109 | if s.cell(row, col).value != '': 110 | if isinstance(s.cell(row, col).value, float): 111 | valuelist.append(str(int(s.cell(row, col).value))) 112 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8 113 | else: 114 | valuelist.append(s.cell(row, col).value) 115 | if s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value)) not in antdict: 116 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))] = [valuelist] 117 | else: 118 | antdict[s.cell(row, 0).value+'_'+str(int(s.cell(row, 1).value))].append(valuelist) 119 | # update LIDC annotation with series number, rather than scan id 120 | import dicom 121 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/' 122 | antdictscan = {} 123 | for k, v in antdict.iteritems(): 124 | pid, scan = k.split('_') 125 | hasscan = False 126 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-'+pid)): 127 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu])): 128 | if srs.endswith('.npy'): 129 | print('npy', pid, scan, srs) 130 | continue 131 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-'+pid, sdu, srs, '000006.dcm'])) 132 | # print scan, str(RefDs[0x20, 0x11].value) 133 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0': 134 | if hasscan: print('rep', pid, sdu, srs) 135 | hasscan = True 136 | antdictscan[pid+'_'+srs] = v 137 | break 138 | if not hasscan: print('not found', pid, scan, sdu, srs) 139 | # find the match from LIDC-IDRI annotation 140 | import math 141 | lunaantdictnodid = {} 142 | maxdist = 0 143 | for srcid, lunaantlidc in lunantdictlidc.iteritems(): 144 | lunaantdictnodid[srcid] = [] 145 | pid, stdid = sidmap[srcid] 146 | # print pid 147 | pid = pid[len('LIDC-IDRI-'):] 148 | for lunantdictlidcsub in lunaantlidc: 149 | lunaant = lunantdictlidcsub[0] 150 | voxcrd = lunantdictlidcsub[1] # z y x 151 | mindist, minidx = 1e8, -1 152 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479', '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']: 153 | continue 154 | for idx, lidcant in enumerate(antdictscan[pid+'_'+srcid]): 155 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z 156 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y 157 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x 158 | if dist < mindist: 159 | mindist = dist 160 | minidx = idx 161 | if mindist > 71:#15.1: 162 | print(srcid, pid, voxcrd, antdictscan[pid+'_'+srcid], mindist) 163 | maxdist = max(maxdist, mindist) 164 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid+'_'+srcid][minidx][6:]]) 165 | # np.save('lunaantdictnodid.npy', lunaantdictnodid) 166 | print('maxdist', maxdist) 167 | # save it into a csv 168 | # import csv 169 | # savename = 'annotationnodid.csv' 170 | # fid = open(savename, 'w') 171 | # writer = csv.writer(fid) 172 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']) 173 | # for srcid, ant in lunaantdictnodid.iteritems(): 174 | # for antsub in ant: 175 | # writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1]) 176 | # fid.close() 177 | # fd 1 178 | fd1lst = [] 179 | for fname in os.listdir('/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/'): 180 | if fname.endswith('.mhd'): fd1lst.append(fname[:-4]) 181 | # find the malignancy, shape information from xml file 182 | import xml.dom.minidom 183 | ndoc = 0 184 | lunadctclssgmdict = {} 185 | mallstall, callstall, sphlstall, marlstall, loblstall, spilstall, texlstall = [], [], [], [], [], [], [] 186 | for srsid, extant in lunaantdictnodid.iteritems(): 187 | if srsid not in fd1lst: continue 188 | lunadctclssgmdict[srsid] = [] 189 | pid, stdid = sidmap[srsid] 190 | for extantvlu in extant: 191 | getnodid = [] 192 | nant = 0 193 | mallst = [] 194 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])): 195 | if fname.endswith('.xml'): 196 | nant += 1 197 | dom = xml.dom.minidom.parse(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname])) 198 | root = dom.documentElement 199 | rsessions = root.getElementsByTagName('readingSession') 200 | for rsess in rsessions: 201 | unblinds = rsess.getElementsByTagName('unblindedReadNodule') 202 | for unb in unblinds: 203 | nod = unb.getElementsByTagName('noduleID') 204 | if len(nod) != 1: 205 | print('more nod', nod) 206 | continue 207 | if nod[0].firstChild.data in extantvlu[1]: 208 | getnodid.append(nod[0].firstChild.data) 209 | mal = unb.getElementsByTagName('malignancy') 210 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1): 211 | mallst.append(float(mal[0].firstChild.data)) 212 | # print(getnodid, extantvlu[1], nant) 213 | if len(getnodid) > len(extantvlu[1]): 214 | print(pid, srsid) 215 | # assert 1 == 0 216 | ndoc = max(ndoc, len(getnodid), len(extantvlu[1])) 217 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]] 218 | if len(mallst) == 0: vlulst.append(0) 219 | else: vlulst.append(sum(mallst)/float(len(mallst))) 220 | lunadctclssgmdict[srsid].append(vlulst+mallst) 221 | import csv 222 | # load predition array 223 | # pixdimpred = np.load('../../../CTnoddetector/training/nodcls/checkpoint-'+str(fold)+'/besttestpred.npy')#'pixradiustest.npy') 224 | pdframe = pd.read_csv('annotationdetclsconv_v3.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 225 | srslst = pdframe['seriesuid'].tolist()[1:] 226 | crdxlst = pdframe['coordX'].tolist()[1:] 227 | crdylst = pdframe['coordY'].tolist()[1:] 228 | crdzlst = pdframe['coordZ'].tolist()[1:] 229 | dimlst = pdframe['diameter_mm'].tolist()[1:] 230 | mlglst = pdframe['malignant'].tolist()[1:] 231 | newlst = [] 232 | import csv 233 | # fid = open('annotationdetclsconvfnl_v3.csv', 'w') 234 | # writer = csv.writer(fid) 235 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 236 | for i in range(len(srslst)): 237 | # writer.writerow([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 238 | newlst.append([srslst[i]+'-'+str(i), crdxlst[i], crdylst[i], crdzlst[i], dimlst[i], mlglst[i]]) 239 | # fid.close() 240 | subset1path = '/media/data1/wentao/tianchi/luna16/subset'+str(fold)+'/' 241 | testfnamelst = [] 242 | for fname in os.listdir(subset1path): 243 | if fname.endswith('.mhd'): 244 | testfnamelst.append(fname[:-4]) 245 | ntest = 0 246 | for idx in range(len(newlst)): 247 | fname = newlst[idx][0] 248 | if fname.split('-')[0] in testfnamelst: ntest +=1 249 | print('ntest', ntest, 'ntrain', len(newlst)-ntest) 250 | prednamelst = {} 251 | predacc = 0 252 | predidx = 0 253 | # predlabellst = [] 254 | # for idx in xrange(len(newlst)): 255 | # fname = newlst[idx][0] 256 | # if fname.split('-')[0] in testfnamelst: 257 | # # print newlst[idx][-1], pixdimpred[predidx] 258 | # if int(pixdimpred[predidx]>0.5) == int(newlst[idx][-1]): predacc += 1 259 | # if fname.split('-')[0] not in prednamelst: 260 | # prednamelst[fname.split('-')[0]] = [pixdimpred[predidx]] 261 | # else: 262 | # prednamelst[fname.split('-')[0]].append(pixdimpred[predidx]) 263 | # predidx += 1 264 | # print 'pred acc', predacc/float(predidx) 265 | pixdimidx = -1 266 | # savename = 'annotationdetclssgm_doctor_fd2.csv' 267 | # fid = open(savename, 'w') 268 | # writer = csv.writer(fid) 269 | # writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant']) 270 | # get the patient level cancer 271 | ptlabel = {} 272 | nfold1 = 0 273 | for srsid, extant in lunadctclssgmdict.iteritems(): 274 | if srsid not in fd1lst: continue 275 | nfold1 += 1 276 | for subextant in extant: 277 | if subextant[5] in [3, 0]: continue 278 | if abs(subextant[5] - 3) < 1e-2: continue 279 | if subextant[5] > 3: 280 | ptlabel[srsid] = 1 281 | break 282 | else: 283 | ptlabel[srsid] = 0 284 | print(len(lunadctclssgmdict.keys()), nfold1, len(ptlabel.keys()), sum(ptlabel.values())) 285 | # get doctors prediction on patient level cancer 286 | dctptlabel = {} 287 | for srsid, extant in lunadctclssgmdict.iteritems(): 288 | curidx = 0 289 | if srsid not in fd1lst: continue 290 | for subextant in extant: 291 | if subextant[5] in [3, 0]: continue 292 | if abs(subextant[5] - 3) < 1e-2: continue 293 | dctptlabel[srsid] = [-1]*max(len(subextant) - 6, 4) 294 | 295 | # writer.writerow(subextant) 296 | for did in range(6, len(subextant),1):#len(subextant), 1): 297 | # if 0.499 <= prednamelst[srsid][curidx] <= 0.501: continue 298 | if subextant[did] == 3: continue 299 | if subextant[did] > 3: dctptlabel[srsid][did-6] = 1 300 | if dctptlabel[srsid][did-6] == -1: dctptlabel[srsid][did-6] = 0 301 | # calculate patient level performance 302 | p = [0]*4 303 | n = [0]*4 304 | tp = [0]*4 305 | np = [0]*4 306 | for srs in ptlabel.keys(): 307 | if dctptlabel[srs][0] != -1: 308 | n[0] += 1 309 | if ptlabel[srs] == 1: np[0] += 1 310 | if ptlabel[srs] == dctptlabel[srs][0]: 311 | p[0] += 1 312 | if ptlabel[srs] == 1: 313 | tp[0] += 1 314 | if dctptlabel[srs][1] != -1: 315 | n[1] += 1 316 | if ptlabel[srs] == 1: np[1] += 1 317 | if ptlabel[srs] == dctptlabel[srs][1]: 318 | p[1] += 1 319 | if ptlabel[srs] == 1: 320 | tp[1] += 1 321 | if dctptlabel[srs][2] != -1: 322 | n[2] += 1 323 | if ptlabel[srs] == 1: np[2] += 1 324 | if ptlabel[srs] == dctptlabel[srs][2]: 325 | p[2] += 1 326 | if ptlabel[srs] == 1: 327 | tp[2] += 1 328 | if dctptlabel[srs][3] != -1: 329 | n[3] += 1 330 | if ptlabel[srs] == 1: np[3] += 1 331 | if ptlabel[srs] == dctptlabel[srs][3]: 332 | p[3] += 1 333 | if ptlabel[srs] == 1: 334 | tp[3] += 1 335 | print(p, n, tp, np, len(ptlabel)) 336 | print(p[0]/float(n[0]), tp[0]/float(np[0]), (p[0]-tp[0])/float(n[0]-np[0])) 337 | print(p[1]/float(n[1]), tp[1]/float(np[1]), (p[1]-tp[1])/float(n[1]-np[1])) 338 | print(p[2]/float(n[2]), tp[2]/float(np[2]), (p[2]-tp[2])/float(n[2]-np[2])) 339 | print(p[3]/float(n[3]), tp[3]/float(np[3]), (p[3]-tp[3])/float(n[3]-np[3])) 340 | import numpy as np 341 | np.save('ptlabel'+str(fold)+'.npy', ptlabel) 342 | np.save('dctptlabel'+str(fold)+'.npy', dctptlabel) -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import numpy as np 12 | import numbers 13 | import types 14 | import collections 15 | from torch.autograd import Variable 16 | 17 | torch.cuda.set_device(0) 18 | 19 | 20 | def resample3d(inp, inp_space, out_space=(1, 1, 1)): 21 | # Infer new shape 22 | # inp = torch.from_numpy(inp) 23 | # inp=torch.FloatTensor(inp) 24 | # inp=Variable(inp) 25 | inp = inp.cuda() 26 | out = resample1d(inp, inp_space[2], out_space[2]).permute(0, 2, 1) 27 | out = resample1d(out, inp_space[1], out_space[1]).permute(2, 1, 0) 28 | out = resample1d(out, inp_space[0], out_space[0]).permute(2, 0, 1) 29 | return out 30 | 31 | 32 | def resample1d(inp, inp_space, out_space=1): 33 | # Output shape 34 | print(inp.size(), inp_space, out_space) 35 | out_shape = list(np.int64(inp.size()[:-1])) + [ 36 | int(np.floor(inp.size()[-1] * inp_space / out_space))] # Optional for if we expect a float_tensor 37 | out_shape = [int(item) for item in out_shape] 38 | # Get output coordinates, deltas, and t (chord distances) 39 | # torch.cuda.set_device(inp.get_device()) 40 | # Output coordinates in real space 41 | coords = torch.cuda.HalfTensor(range(out_shape[-1])) * out_space 42 | delta = coords.fmod(inp_space).div(inp_space).repeat(out_shape[0], out_shape[1], 1) 43 | t = torch.cuda.HalfTensor(4, out_shape[0], out_shape[1], out_shape[2]).zero_() 44 | t[0] = 1 45 | t[1] = delta 46 | t[2] = delta ** 2 47 | t[3] = delta ** 3 48 | # Nearest neighbours indices 49 | nn = coords.div(inp_space).floor().long() 50 | # Stack the nearest neighbors into P, the Points Array 51 | P = torch.cuda.HalfTensor(4, out_shape[0], out_shape[1], out_shape[2]).zero_() 52 | for i in range(-1, 3): 53 | P[i + 1] = inp.index_select(2, torch.clamp(nn + i, 0, inp.size()[-1] - 1)) 54 | # Take catmull-rom spline interpolation: 55 | return 0.5 * t.mul(torch.cuda.HalfTensor([[0, 2, 0, 0], 56 | [-1, 0, 1, 0], 57 | [2, -5, 4, -1], 58 | [-1, 3, -3, 1]]).mm(P.view(4, -1)) \ 59 | .view(4, 60 | out_shape[0], 61 | out_shape[1], 62 | out_shape[2])) \ 63 | .sum(0) \ 64 | .squeeze() 65 | 66 | 67 | class Compose(object): 68 | """Composes several transforms together. 69 | 70 | Args: 71 | transforms (list of ``Transform`` objects): list of transforms to compose. 72 | 73 | Example: 74 | >>> transforms.Compose([ 75 | >>> transforms.CenterCrop(10), 76 | >>> transforms.ToTensor(), 77 | >>> ]) 78 | """ 79 | 80 | def __init__(self, transforms): 81 | self.transforms = transforms 82 | 83 | def __call__(self, img): 84 | for t in self.transforms: 85 | # print(t) 86 | img = t(img) 87 | return img 88 | 89 | 90 | class Normalize(object): 91 | """Normalize an tensor image with mean and standard deviation. 92 | 93 | Given mean: (R, G, B) and std: (R, G, B), 94 | will normalize each channel of the torch.*Tensor, i.e. 95 | channel = (channel - mean) / std 96 | 97 | Args: 98 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 99 | std (sequence): Sequence of standard deviations for R, G, B channels 100 | respecitvely. 101 | """ 102 | 103 | def __init__(self, mean, std): 104 | self.mean = mean 105 | self.std = std 106 | 107 | def __call__(self, tensor): 108 | """ 109 | Args: 110 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 111 | 112 | Returns: 113 | Tensor: Normalized image. 114 | """ 115 | # TODO: make efficient 116 | # for t, m, s in zip(tensor, self.mean, self.std): 117 | 118 | tensor.sub_(self.mean).div_(self.std) 119 | return tensor 120 | 121 | 122 | from scipy.ndimage.interpolation import zoom 123 | 124 | 125 | class RandomScale(object): 126 | ''' Randomly scale from scale size list ''' 127 | 128 | def __init__(self, size, interpolation=Image.BILINEAR): 129 | # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 3) 130 | self.size = size 131 | self.interpolation = interpolation 132 | 133 | def __call__(self, img): 134 | # scale = np.random.permutation(len(self.size))[0] / 32.0 135 | scale = random.randint(self.size[0], 136 | self.size[-1] + 1) # (self.size[np.random.permutation(len(self.size))[0]])#, \ 137 | # self.size[np.random.permutation(len(self.size))[0]], \ 138 | # self.size[np.random.permutation(len(self.size))[0]]) 139 | # print img.shape, scale, img.shape*scale 140 | # print('scale', 32.0/scale) 141 | return zoom(img, (scale, scale, scale), 142 | mode='nearest') # resample3d(img,(32,32,32),out_space=scale)#zoom(img, scale) #img.resize(scale, self.interpolation) resample3d(img,img.shape,out_space=scale) 143 | 144 | 145 | class Scale(object): 146 | """Rescale the input PIL.Image to the given size. 147 | 148 | Args: 149 | size (sequence or int): Desired output size. If size is a sequence like 150 | (w, h), output size will be matched to this. If size is an int, 151 | smaller edge of the image will be matched to this number. 152 | i.e, if height > width, then image will be rescaled to 153 | (size * height / width, size) 154 | interpolation (int, optional): Desired interpolation. Default is 155 | ``PIL.Image.BILINEAR`` 156 | """ 157 | 158 | def __init__(self, size, interpolation=Image.BILINEAR): 159 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 3) 160 | self.size = size 161 | self.interpolation = interpolation 162 | 163 | def __call__(self, img): 164 | """ 165 | Args: 166 | img (PIL.Image): Image to be scaled. 167 | 168 | Returns: 169 | PIL.Image: Rescaled image. 170 | """ 171 | if isinstance(self.size, int): 172 | w, h, d = img.size 173 | if (w <= h and w == self.size) or (h <= w and h == self.size): 174 | return img 175 | if w < h: 176 | ow = self.size 177 | oh = int(self.size * h / w) 178 | return img.resize((ow, oh), self.interpolation) 179 | else: 180 | oh = self.size 181 | ow = int(self.size * w / h) 182 | return img.resize((ow, oh), self.interpolation) 183 | else: 184 | return img.resize(self.size, self.interpolation) 185 | 186 | 187 | class ZeroOut(object): 188 | """Crops the given PIL.Image at the center. 189 | Args: 190 | size (sequence or int): Desired output size of the crop. If size is an 191 | int instead of sequence like (w, h), a square crop (size, size) is 192 | made. 193 | """ 194 | 195 | def __init__(self, size): 196 | self.size = int(size) 197 | 198 | def __call__(self, img): 199 | w, h, d = img.shape # size 200 | x1 = random.randint(0, w - self.size) # np.random.permutation(w-self.size)[0] 201 | y1 = random.randint(0, h - self.size) # np.random.permutation(h-self.size)[0] 202 | z1 = random.randint(0, d - self.size) # np.random.permutation(d-self.size)[0] 203 | img1 = np.array(img) 204 | # print 'zero out', x1, y1, z1, w, h, d, self.size 205 | img1[x1:x1 + self.size, y1:y1 + self.size, z1:z1 + self.size] = np.array( 206 | np.zeros((self.size, self.size, self.size))) 207 | return np.array(img1) 208 | 209 | 210 | class ToTensor(object): 211 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 212 | 213 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 214 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 215 | """ 216 | 217 | def __call__(self, pic): 218 | """ 219 | Args: 220 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 221 | 222 | Returns: 223 | Tensor: Converted image. 224 | """ 225 | if isinstance(pic, np.ndarray): 226 | # handle numpy array 227 | pic = np.expand_dims(pic, -1) 228 | # print('before tensor', pic.shape) 229 | img = torch.from_numpy(pic.transpose((3, 0, 1, 2))) 230 | # backward compatibility 231 | return img.float() # .div(255) 232 | 233 | if accimage is not None and isinstance(pic, accimage.Image): 234 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 235 | pic.copyto(nppic) 236 | return torch.from_numpy(nppic) 237 | 238 | # handle PIL Image 239 | if pic.mode == 'I': 240 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 241 | elif pic.mode == 'I;16': 242 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 243 | else: 244 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 245 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 246 | if pic.mode == 'YCbCr': 247 | nchannel = 3 248 | elif pic.mode == 'I;16': 249 | nchannel = 1 250 | else: 251 | nchannel = len(pic.mode) 252 | img = img.view(pic.size[1], pic.size[0], nchannel) 253 | # put it from HWC to CHW format 254 | # yikes, this transpose takes 80% of the loading time/CPU 255 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 256 | if isinstance(img, torch.ByteTensor): 257 | return img.float() # .div(255) 258 | else: 259 | return img 260 | 261 | 262 | class CenterCrop(object): 263 | """Crops the given PIL.Image at the center. 264 | 265 | Args: 266 | size (sequence or int): Desired output size of the crop. If size is an 267 | int instead of sequence like (w, h), a square crop (size, size) is 268 | made. 269 | """ 270 | 271 | def __init__(self, size): 272 | if isinstance(size, numbers.Number): 273 | self.size = (int(size), int(size)) 274 | else: 275 | self.size = size 276 | 277 | def __call__(self, img): 278 | """ 279 | Args: 280 | img (PIL.Image): Image to be cropped. 281 | 282 | Returns: 283 | PIL.Image: Cropped image. 284 | """ 285 | w, h = img.size 286 | th, tw = self.size 287 | x1 = int(round((w - tw) / 2.)) 288 | y1 = int(round((h - th) / 2.)) 289 | return img.crop((x1, y1, x1 + tw, y1 + th)) 290 | 291 | 292 | class Pad(object): 293 | """Pad the given PIL.Image on all sides with the given "pad" value. 294 | 295 | Args: 296 | padding (int or sequence): Padding on each border. If a sequence of 297 | length 4, it is used to pad left, top, right and bottom borders respectively. 298 | fill: Pixel fill value. Default is 0. 299 | """ 300 | 301 | def __init__(self, padding, fill=0): 302 | assert isinstance(padding, numbers.Number) 303 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 304 | self.padding = padding 305 | self.fill = fill 306 | 307 | def __call__(self, img): 308 | """ 309 | Args: 310 | img (PIL.Image): Image to be padded. 311 | 312 | Returns: 313 | PIL.Image: Padded image. 314 | """ 315 | return ImageOps.expand(img, border=self.padding, fill=self.fill) 316 | 317 | 318 | class Lambda(object): 319 | """Apply a user-defined lambda as a transform. 320 | 321 | Args: 322 | lambd (function): Lambda/function to be used for transform. 323 | """ 324 | 325 | def __init__(self, lambd): 326 | assert isinstance(lambd, types.LambdaType) 327 | self.lambd = lambd 328 | 329 | def __call__(self, img): 330 | return self.lambd(img) 331 | 332 | 333 | class RandomCrop(object): 334 | """Crop the given PIL.Image at a random location. 335 | 336 | Args: 337 | size (sequence or int): Desired output size of the crop. If size is an 338 | int instead of sequence like (w, h), a square crop (size, size) is 339 | made. 340 | padding (int or sequence, optional): Optional padding on each border 341 | of the image. Default is 0, i.e no padding. If a sequence of length 342 | 4 is provided, it is used to pad left, top, right, bottom borders 343 | respectively. 344 | """ 345 | 346 | def __init__(self, size, padding=0): 347 | if isinstance(size, numbers.Number): 348 | self.size = (int(size), int(size), int(size)) 349 | else: 350 | self.size = int(size) 351 | self.padding = int(padding) 352 | 353 | def __call__(self, img): 354 | """ 355 | Args: 356 | img (PIL.Image): Image to be cropped. 357 | 358 | Returns: 359 | PIL.Image: Cropped image. 360 | """ 361 | if self.padding > 0: 362 | # print 'scale out', img.shape 363 | pad = int(self.padding / 2) 364 | img1 = np.ones((img.shape[0] + pad, img.shape[1] + pad, img.shape[2] + pad)) * 170 365 | bg = int(self.padding / 2) 366 | img1[bg:bg + img.shape[0], bg:bg + img.shape[1], bg:bg + img.shape[2]] = np.array(img) 367 | img = np.array(img1) 368 | # img = ImageOps.expand(img, border=self.padding, fill=170) 369 | 370 | w, h, d = img.shape # size 371 | th, tw, td = self.size 372 | # print 'pad out', w, h, d, th, tw, td 373 | if w == tw and h == th and d == td: 374 | return img 375 | x1 = random.randint(0, w - tw) 376 | y1 = random.randint(0, h - th) 377 | z1 = random.randint(0, d - td) 378 | return np.array(img[x1:x1 + th, y1:y1 + tw, z1:z1 + td]) 379 | # return img.crop((x1, y1, x1 + tw, y1 + th, z1 + td)) 380 | 381 | 382 | class RandomHorizontalFlip(object): 383 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 384 | 385 | def __call__(self, img): 386 | """ 387 | Args: 388 | img (PIL.Image): Image to be flipped. 389 | 390 | Returns: 391 | PIL.Image: Randomly flipped image. 392 | """ 393 | if random.random() < 0.5: 394 | return np.array(img[:, :, ::-1]) # .transpose(Image.FLIP_LEFT_RIGHT) 395 | return img 396 | 397 | 398 | class RandomZFlip(object): 399 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 400 | 401 | def __call__(self, img): 402 | """ 403 | Args: 404 | img (PIL.Image): Image to be flipped. 405 | 406 | Returns: 407 | PIL.Image: Randomly flipped image. 408 | """ 409 | if random.random() < 0.5: 410 | return np.array(img[::-1, :, :]) 411 | return img 412 | 413 | 414 | class RandomYFlip(object): 415 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 416 | 417 | def __call__(self, img): 418 | """ 419 | Args: 420 | img (PIL.Image): Image to be flipped. 421 | 422 | Returns: 423 | PIL.Image: Randomly flipped image. 424 | """ 425 | if random.random() < 0.5: 426 | return np.array(img[:, ::-1, :]) 427 | return img 428 | 429 | 430 | class RandomSizedCrop(object): 431 | """Crop the given PIL.Image to random size and aspect ratio. 432 | 433 | A crop of random size of (0.08 to 1.0) of the original size and a random 434 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 435 | is finally resized to given size. 436 | This is popularly used to train the Inception networks. 437 | 438 | Args: 439 | size: size of the smaller edge 440 | interpolation: Default: PIL.Image.BILINEAR 441 | """ 442 | 443 | def __init__(self, size, interpolation=Image.BILINEAR): 444 | self.size = size 445 | self.interpolation = interpolation 446 | 447 | def __call__(self, img): 448 | for attempt in range(10): 449 | area = img.size[0] * img.size[1] 450 | target_area = random.uniform(0.08, 1.0) * area 451 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 452 | 453 | w = int(round(math.sqrt(target_area * aspect_ratio))) 454 | h = int(round(math.sqrt(target_area / aspect_ratio))) 455 | 456 | if random.random() < 0.5: 457 | w, h = h, w 458 | 459 | if w <= img.size[0] and h <= img.size[1]: 460 | x1 = random.randint(0, img.size[0] - w) 461 | y1 = random.randint(0, img.size[1] - h) 462 | 463 | img = img.crop((x1, y1, x1 + w, y1 + h)) 464 | assert (img.size == (w, h)) 465 | 466 | return img.resize((self.size, self.size), self.interpolation) 467 | 468 | # Fallback 469 | scale = Scale(self.size, interpolation=self.interpolation) 470 | crop = CenterCrop(self.size) 471 | return crop(scale(img)) 472 | -------------------------------------------------------------------------------- /data/extclsshpinfo.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import SimpleITK as sitk 3 | import os 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | def load_itk_image(filename): 9 | with open(filename) as f: 10 | contents = f.readlines() 11 | line = [k for k in contents if k.startswith('TransformMatrix')][0] 12 | transformM = np.array(line.split(' = ')[1].split(' ')).astype('float') 13 | transformM = np.round(transformM) 14 | if np.any(transformM != np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])): 15 | isflip = True 16 | else: 17 | isflip = False 18 | itkimage = sitk.ReadImage(filename) 19 | numpyImage = sitk.GetArrayFromImage(itkimage) 20 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 21 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 22 | return numpyImage, numpyOrigin, numpySpacing, isflip 23 | 24 | 25 | def worldToVoxelCoord(worldCoord, origin, spacing): 26 | stretchedVoxelCoord = np.absolute(worldCoord - origin) 27 | voxelCoord = stretchedVoxelCoord / spacing 28 | return voxelCoord 29 | 30 | 31 | # read map file 32 | mapfname = 'LIDC-IDRI-mappingLUNA16' 33 | sidmap = {} 34 | fid = open(mapfname, 'r') 35 | line = fid.readline() 36 | line = fid.readline() 37 | while line: 38 | pidlist = line.split(' ') 39 | # print pidlist 40 | pid = pidlist[0] 41 | stdid = pidlist[1] 42 | srsid = pidlist[2] 43 | if srsid not in sidmap: 44 | sidmap[srsid] = [pid, stdid] 45 | else: 46 | assert sidmap[srsid][0] == pid 47 | assert sidmap[srsid][1] == stdid 48 | line = fid.readline() 49 | fid.close() 50 | # read luna16 annotation 51 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm'] 52 | lunaantframe = pd.read_csv('annotations.csv', names=colname) 53 | srslist = lunaantframe.seriesuid.tolist()[1:] 54 | cdxlist = lunaantframe.coordX.tolist()[1:] 55 | cdylist = lunaantframe.coordY.tolist()[1:] 56 | cdzlist = lunaantframe.coordZ.tolist()[1:] 57 | dimlist = lunaantframe.diameter_mm.tolist()[1:] 58 | lunaantdict = {} 59 | for idx in range(len(srslist)): 60 | vlu = [float(cdxlist[idx]), float(cdylist[idx]), float(cdzlist[idx]), float(dimlist[idx])] 61 | if srslist[idx] in lunaantdict: 62 | lunaantdict[srslist[idx]].append(vlu) 63 | else: 64 | lunaantdict[srslist[idx]] = [vlu] 65 | # convert luna16 annotation to LIDC-IDRI annotation space 66 | from multiprocessing import Pool 67 | 68 | lunantdictlidc = {} 69 | for fold in range(10): 70 | mhdpath = '/media/data1/wentao/tianchi/luna16/subset' + str(fold) 71 | print('fold', fold) 72 | 73 | 74 | def getvoxcrd(fname): 75 | sliceim, origin, spacing, isflip = load_itk_image(os.path.join(mhdpath, fname)) 76 | lunantdictlidc[fname[:-4]] = [] 77 | voxcrdlist = [] 78 | for lunaant in lunaantdict[fname[:-4]]: 79 | voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 80 | voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 81 | voxcrdlist.append(voxcrd) 82 | return voxcrdlist 83 | 84 | 85 | p = Pool(30) 86 | fnamelist = [] 87 | for fname in os.listdir(mhdpath): 88 | if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 89 | fnamelist.append(fname) 90 | voxcrdlist = p.map(getvoxcrd, fnamelist) 91 | listidx = 0 92 | for fname in os.listdir(mhdpath): 93 | if fname.endswith('.mhd') and fname[:-4] in lunaantdict: 94 | lunantdictlidc[fname[:-4]] = [] 95 | for subidx, lunaant in enumerate(lunaantdict[fname[:-4]]): 96 | # voxcrd = worldToVoxelCoord(lunaant[:3][::-1], origin, spacing) 97 | # voxcrd[-1] = sliceim.shape[0] - voxcrd[0] 98 | lunantdictlidc[fname[:-4]].append([lunaant, voxcrdlist[listidx][subidx]]) 99 | listidx += 1 100 | p.close() 101 | np.save('lunaantdictlidc.npy', lunantdictlidc) 102 | # read LIDC dataset 103 | lunantdictlidc = np.load('lunaantdictlidc.npy').item() 104 | import xlrd 105 | 106 | lidccsvfname = '/media/data1/wentao/LIDC-IDRI/list3.2.xls' 107 | antdict = {} 108 | wb = xlrd.open_workbook(os.path.join(lidccsvfname)) 109 | for s in wb.sheets(): 110 | if s.name == 'list3.2': 111 | for row in range(1, s.nrows): 112 | valuelist = [int(s.cell(row, 2).value), s.cell(row, 3).value, s.cell(row, 4).value, \ 113 | int(s.cell(row, 5).value), int(s.cell(row, 6).value), int(s.cell(row, 7).value)] 114 | assert abs(s.cell(row, 1).value - int(s.cell(row, 1).value)) < 1e-8 115 | assert abs(s.cell(row, 2).value - int(s.cell(row, 2).value)) < 1e-8 116 | assert abs(s.cell(row, 5).value - int(s.cell(row, 5).value)) < 1e-8 117 | assert abs(s.cell(row, 6).value - int(s.cell(row, 6).value)) < 1e-8 118 | assert abs(s.cell(row, 7).value - int(s.cell(row, 7).value)) < 1e-8 119 | for col in range(9, 16): 120 | if s.cell(row, col).value != '': 121 | if isinstance(s.cell(row, col).value, float): 122 | valuelist.append(int(s.cell(row, col).value)) 123 | assert abs(s.cell(row, col).value - int(s.cell(row, col).value)) < 1e-8 124 | else: 125 | valuelist.append(s.cell(row, col).value) 126 | if s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value)) not in antdict: 127 | antdict[s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value))] = [valuelist] 128 | else: 129 | antdict[s.cell(row, 0).value + '_' + str(int(s.cell(row, 1).value))].append(valuelist) 130 | # update LIDC annotation with series number, rather than scan id 131 | import dicom 132 | 133 | LIDCpath = '/media/data1/wentao/LIDC-IDRI/DOI/' 134 | antdictscan = {} 135 | for k, v in antdict.iteritems(): 136 | pid, scan = k.split('_') 137 | hasscan = False 138 | for sdu in os.listdir(os.path.join(LIDCpath, 'LIDC-IDRI-' + pid)): 139 | for srs in os.listdir(os.path.join(*[LIDCpath, 'LIDC-IDRI-' + pid, sdu])): 140 | if srs.endswith('.npy'): 141 | print('npy', pid, scan, srs) 142 | continue 143 | RefDs = dicom.read_file(os.path.join(*[LIDCpath, 'LIDC-IDRI-' + pid, sdu, srs, '000006.dcm'])) 144 | # print scan, str(RefDs[0x20, 0x11].value) 145 | if str(RefDs[0x20, 0x11].value) == scan or scan == '0': 146 | if hasscan: print('rep', pid, sdu, srs) 147 | hasscan = True 148 | antdictscan[pid + '_' + srs] = v 149 | break 150 | if not hasscan: print('not found', pid, scan, sdu, srs) 151 | # find the match from LIDC-IDRI annotation 152 | import math 153 | 154 | lunaantdictnodid = {} 155 | maxdist = 0 156 | for srcid, lunaantlidc in lunantdictlidc.iteritems(): 157 | lunaantdictnodid[srcid] = [] 158 | pid, stdid = sidmap[srcid] 159 | # print pid 160 | pid = pid[len('LIDC-IDRI-'):] 161 | for lunantdictlidcsub in lunaantlidc: 162 | lunaant = lunantdictlidcsub[0] 163 | voxcrd = lunantdictlidcsub[1] # z y x 164 | mindist, minidx = 1e8, -1 165 | if srcid in ['1.3.6.1.4.1.14519.5.2.1.6279.6001.174692377730646477496286081479', 166 | '1.3.6.1.4.1.14519.5.2.1.6279.6001.300246184547502297539521283806']: 167 | continue 168 | for idx, lidcant in enumerate(antdictscan[pid + '_' + srcid]): 169 | dist = math.pow(voxcrd[0] - lidcant[3], 2) # z 170 | dist = math.pow(voxcrd[1] - lidcant[4], 2) # y 171 | dist += math.pow(voxcrd[2] - lidcant[5], 2) # x 172 | if dist < mindist: 173 | mindist = dist 174 | minidx = idx 175 | if mindist > 71: # 15.1: 176 | print(srcid, pid, voxcrd, antdictscan[pid + '_' + srcid], mindist) 177 | maxdist = max(maxdist, mindist) 178 | lunaantdictnodid[srcid].append([lunaant, antdictscan[pid + '_' + srcid][minidx][6:]]) 179 | np.save('lunaantdictnodid.npy', lunaantdictnodid) 180 | print('maxdist', maxdist) 181 | # save it into a csv 182 | import csv 183 | 184 | savename = 'annotationnodid.csv' 185 | fid = open(savename, 'w') 186 | writer = csv.writer(fid) 187 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm']) 188 | for srcid, ant in lunaantdictnodid.iteritems(): 189 | for antsub in ant: 190 | writer.writerow([srcid] + [antsub[0][0], antsub[0][1], antsub[0][2], antsub[0][3]] + antsub[1]) 191 | fid.close() 192 | # find the malignancy, shape information from xml file 193 | import xml.dom.minidom 194 | 195 | lunadctclssgmdict = {} 196 | for srsid, extant in lunaantdictnodid.iteritems(): 197 | lunadctclssgmdict[srsid] = [] 198 | pid, stdid = sidmap[srsid] 199 | for extantvlu in extant: 200 | mallst, callst, sphlst, marlst, loblst, spilst, texlst = [], [], [], [], [], [], [] 201 | for fname in os.listdir(os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid])): 202 | if fname.endswith('.xml'): 203 | dom = xml.dom.minidom.parse( 204 | os.path.join(*['/media/data1/wentao/LIDC-IDRI/DOI/', pid, stdid, srsid, fname])) 205 | root = dom.documentElement 206 | rsessions = root.getElementsByTagName('readingSession') 207 | for rsess in rsessions: 208 | unblinds = rsess.getElementsByTagName('unblindedReadNodule') 209 | for unb in unblinds: 210 | nod = unb.getElementsByTagName('noduleID') 211 | if len(nod) != 1: continue 212 | if nod[0].firstChild.data in extantvlu[1]: 213 | cal = unb.getElementsByTagName('calcification') 214 | # print cal[0].firstChild.data, range(1,7,1), int(cal[0].firstChild.data) in range(1,7,1) 215 | if len(cal) == 1 and int(cal[0].firstChild.data) in range(1, 7, 1): 216 | callst.append(float(cal[0].firstChild.data)) 217 | sph = unb.getElementsByTagName('sphericity') 218 | if len(sph) == 1 and int(sph[0].firstChild.data) in range(1, 6, 1): 219 | sphlst.append(float(sph[0].firstChild.data)) 220 | mar = unb.getElementsByTagName('margin') 221 | if len(mar) == 1 and int(mar[0].firstChild.data) in range(1, 6, 1): 222 | marlst.append(float(mar[0].firstChild.data)) 223 | lob = unb.getElementsByTagName('lobulation') 224 | if len(lob) == 1 and int(lob[0].firstChild.data) in range(1, 6, 1): 225 | loblst.append(float(lob[0].firstChild.data)) 226 | spi = unb.getElementsByTagName('spiculation') 227 | if len(spi) == 1 and int(spi[0].firstChild.data) in range(1, 6, 1): 228 | spilst.append(float(spi[0].firstChild.data)) 229 | tex = unb.getElementsByTagName('texture') 230 | if len(tex) == 1 and int(tex[0].firstChild.data) in range(1, 6, 1): 231 | texlst.append(float(tex[0].firstChild.data)) 232 | mal = unb.getElementsByTagName('malignancy') 233 | if len(mal) == 1 and int(mal[0].firstChild.data) in range(1, 6, 1): 234 | mallst.append(float(mal[0].firstChild.data)) 235 | vlulst = [srsid, extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]] 236 | if len(mallst) == 0: 237 | vlulst.append(0) 238 | else: 239 | vlulst.append(sum(mallst) / float(len(mallst))) 240 | if len(callst) == 0: 241 | vlulst.append(0) 242 | else: 243 | vlulst.append(sum(callst) / float(len(callst))) 244 | if len(sphlst) == 0: 245 | vlulst.append(0) 246 | else: 247 | vlulst.append(sum(sphlst) / float(len(sphlst))) 248 | if len(marlst) == 0: 249 | vlulst.append(0) 250 | else: 251 | vlulst.append(sum(marlst) / float(len(marlst))) 252 | if len(loblst) == 0: 253 | vlulst.append(0) 254 | else: 255 | vlulst.append(sum(loblst) / float(len(loblst))) 256 | if len(spilst) == 0: 257 | vlulst.append(0) 258 | else: 259 | vlulst.append(sum(spilst) / float(len(spilst))) 260 | if len(texlst) == 0: 261 | vlulst.append(0) 262 | else: 263 | vlulst.append(sum(texlst) / float(len(texlst))) 264 | lunadctclssgmdict[srsid].append(vlulst) 265 | # lunadctclssgmdict[srsid].append([extantvlu[0][0], extantvlu[0][1], extantvlu[0][2], extantvlu[0][3]]+\ 266 | # [sum(mallst)/float(len(mallst)), sum(callst)/float(len(callst)), sum(sphlst)/float(len(sphlst)), \ 267 | # sum(marlst)/float(len(marlst)), sum(loblst)/float(len(loblst)), sum(spilst)/float(len(spilst)), \ 268 | # sum(texlst)/float(len(texlst))]) 269 | np.save('lunadctclssgmdict.npy', lunadctclssgmdict) 270 | savename = 'annotationdetclssgm.csv' 271 | fid = open(savename, 'w') 272 | writer = csv.writer(fid) 273 | writer.writerow(['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant', 'calcification', \ 274 | 'sphericity', 'margin', 'lobulation', 'spiculation', 'texture']) 275 | for srsid, extant in lunadctclssgmdict.iteritems(): 276 | for subextant in extant: 277 | writer.writerow(subextant) 278 | fid.close() 279 | # discrete the generated csv 280 | import pandas as pd 281 | import csv 282 | 283 | srcname = 'annotationdetclssgm.csv' 284 | dstname = 'annotationdetclssgmfnl.csv' 285 | colname = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant', 'calcification', 'sphericity', \ 286 | 'margin', 'lobulation', 'spiculation', 'texture'] 287 | srcframe = pd.read_csv(srcname, names=colname) 288 | srslist = srcframe.seriesuid.tolist()[1:] 289 | cdxlist = srcframe.coordX.tolist()[1:] 290 | cdylist = srcframe.coordY.tolist()[1:] 291 | cdzlist = srcframe.coordZ.tolist()[1:] 292 | dimlist = srcframe.diameter_mm.tolist()[1:] 293 | mlglist = srcframe.malignant.tolist()[1:] 294 | callist = srcframe.calcification.tolist()[1:] 295 | sphlist = srcframe.sphericity.tolist()[1:] 296 | mrglist = srcframe.margin.tolist()[1:] 297 | loblist = srcframe.lobulation.tolist()[1:] 298 | spclist = srcframe.spiculation.tolist()[1:] 299 | txtlist = srcframe.texture.tolist()[1:] 300 | fid = open(dstname, 'w') 301 | writer = csv.writer(fid) 302 | writer.writerow(colname) 303 | for idx in range(len(srslist)): 304 | lst = [srslist[idx], cdxlist[idx], cdylist[idx], cdzlist[idx], dimlist[idx]] 305 | if abs(float(mlglist[idx]) - 0) < 1e-2: # 0 1 2 306 | lst.append(0) 307 | elif abs(float(mlglist[idx]) - 3) < 1e-2: 308 | lst.append(0) 309 | elif float(mlglist[idx]) > 3: 310 | lst.append(2) 311 | else: 312 | lst.append(1) 313 | lst.append(int(round(float(callist[idx])))) # 0 - 6 314 | if abs(float(sphlist[idx]) - 0) < 1e-2: # 0 1 2 3 315 | lst.append(0) 316 | elif float(sphlist[idx]) < 2: 317 | lst.append(1) 318 | elif float(sphlist[idx]) < 4: 319 | lst.append(2) 320 | else: 321 | lst.append(3) 322 | if abs(float(mrglist[idx]) - 0) < 1e-2: # 0 1 2 323 | lst.append(0) 324 | elif float(mrglist[idx]) < 3: 325 | lst.append(1) 326 | else: 327 | lst.append(2) 328 | if abs(float(loblist[idx]) - 0) < 1e-2: # 0 1 2 329 | lst.append(0) 330 | elif float(loblist[idx]) < 3: 331 | lst.append(1) 332 | else: 333 | lst.append(2) 334 | if abs(float(spclist[idx]) - 0) < 1e-2: # 0 1 2 335 | lst.append(0) 336 | elif float(spclist[idx]) < 3: 337 | lst.append(1) 338 | else: 339 | lst.append(2) 340 | if abs(float(txtlist[idx]) - 0) < 1e-2: # 0 1 2 3 341 | lst.append(0) 342 | elif float(txtlist[idx]) < 2: 343 | lst.append(1) 344 | elif float(txtlist[idx]) < 4: 345 | lst.append(2) 346 | else: 347 | lst.append(3) 348 | writer.writerow(lst) 349 | fid.close() 350 | # fuse annotations for different nodules, generate patient level annotation 351 | import pandas as pd 352 | import csv 353 | 354 | antpdframe = pd.read_csv('annotationdetclssgmfnl.csv', names=['seriesuid', 'coordX', 'coordY', 'coordZ', \ 355 | 'diameter_mm', 'malignant', 'calcification', 'sphericity', 356 | 'margin', 'lobulation', 'spiculation', 'texture']) 357 | srslst = antpdframe.seriesuid.tolist()[1:] 358 | cdxlst = antpdframe.coordX.tolist()[1:] 359 | cdylst = antpdframe.coordY.tolist()[1:] 360 | cdzlst = antpdframe.coordZ.tolist()[1:] 361 | mlglst = antpdframe.malignant.tolist()[1:] 362 | dimlst = antpdframe.diameter_mm.tolist()[1:] 363 | clclst = antpdframe.calcification.tolist()[1:] 364 | sphlst = antpdframe.sphericity.tolist()[1:] 365 | mrglst = antpdframe.margin.tolist()[1:] 366 | loblst = antpdframe.lobulation.tolist()[1:] 367 | spclst = antpdframe.spiculation.tolist()[1:] 368 | txtlst = antpdframe.texture.tolist()[1:] 369 | dctdat = {} 370 | for idx, srs in enumerate(srslst): 371 | if mlglst[idx] == '0': 372 | continue 373 | vlu = [mlglst[idx], clclst[idx], sphlst[idx], mrglst[idx], loblst[idx], spclst[idx], txtlst[idx]] 374 | if srs not in dctdat: 375 | dctdat[srs] = [vlu] 376 | else: 377 | dctdat[srs].append(vlu) 378 | fid = open('annotationdetclssgmv2.csv', 'w') 379 | writer = csv.writer(fid) 380 | writer.writerow(['seriesuid', 'malignant']) 381 | for srs, vlulst in dctdat.iteritems(): 382 | mlg = -1 383 | for vlu in vlulst: 384 | mlg = max(mlg, vlu[0]) 385 | writer.writerow([srs, mlg]) 386 | fid.close() 387 | --------------------------------------------------------------------------------