├── medseg ├── __init__.py ├── models │ ├── __init__.py │ ├── model.py │ ├── unet.py │ ├── unet_plain.py │ ├── res_unet.py │ ├── hrnet.py │ └── deeplabv3p.py ├── utils │ ├── __init__.py │ ├── config.py │ └── util.py ├── prep_2d.py ├── loss.py ├── eval.py ├── infer.py ├── vis.py ├── aug.py ├── prep_3d.py ├── prep_png.py └── train.py ├── tool ├── infer │ ├── util.py │ ├── connected.py │ ├── merge.py │ ├── vote.py │ ├── aorta.py │ ├── 3d_diameter.py │ ├── slice2nii.py │ ├── 2d_diameter.py │ └── util.bk ├── train │ ├── util.py │ ├── vis.py │ ├── mhd2nii.py │ ├── gen_list.py │ ├── folder_split.py │ ├── resize.py │ ├── to_slice.py │ ├── slice_mp.py │ └── dataset_scan.py ├── SimHei.ttf ├── to_pinyin.py ├── README.md ├── flood_fill.py └── zip_dataset.py ├── .gitattributes ├── requirements.txt ├── config ├── lits-zprep.yaml ├── hrnet_optic.yaml ├── unet_optic.yaml ├── deeplab_optic.yaml ├── resunet_optic.yaml ├── eval.yaml ├── resunet_aorta.yaml ├── test.yaml └── resunet_lits.yaml ├── .gitignore ├── README.md └── README_en.md /medseg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /medseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /medseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tool/infer/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /tool/train/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ttf filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nibabel 2 | scikit-image 3 | tqdm 4 | numpy 5 | visualdl 6 | medpy 7 | pypinyin 8 | trimesh 9 | SimpleITK 10 | paddleseg 11 | -------------------------------------------------------------------------------- /config/lits-zprep.yaml: -------------------------------------------------------------------------------- 1 | PREP: 2 | PLANE: 'xz' 3 | FRONT: 1 4 | SIZE: [-1, 512,512] 5 | THRESH: 512 6 | INTERP: True 7 | INTERP_PIXDIM: [-1, -1, 1] 8 | -------------------------------------------------------------------------------- /tool/SimHei.ttf: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fd952619133560ec32c44d9b5f95a795d273db6aafa507edbcb2686c26f2c203 3 | size 10061591 4 | -------------------------------------------------------------------------------- /config/hrnet_optic.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "optic" 3 | INPUTS_PATH: "/home/aistudio/optic/image" 4 | LABELS_PATH: "/home/aistudio/optic/label" 5 | PREP_PATH: "/home/aistudio/optic/prep" 6 | TRAIN: 7 | ARCHITECTURE: "hrnet" 8 | BATCH_SIZE: 8 9 | SNAPSHOT_BATCH: 30 10 | DISP_BATCH: 1 11 | -------------------------------------------------------------------------------- /config/unet_optic.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "optic" 3 | INPUTS_PATH: "/home/aistudio/optic/image" 4 | LABELS_PATH: "/home/aistudio/optic/label" 5 | PREP_PATH: "/home/aistudio/optic/prep" 6 | TRAIN: 7 | ARCHITECTURE: "res_unet" 8 | BATCH_SIZE: 8 9 | SNAPSHOT_BATCH: 30 10 | DISP_BATCH: 1 11 | -------------------------------------------------------------------------------- /config/deeplab_optic.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "optic" 3 | INPUTS_PATH: "/home/aistudio/optic/image" 4 | LABELS_PATH: "/home/aistudio/optic/label" 5 | PREP_PATH: "/home/aistudio/optic/prep" 6 | TRAIN: 7 | ARCHITECTURE: "deeplabv3p" 8 | BATCH_SIZE: 8 9 | SNAPSHOT_BATCH: 30 10 | DISP_BATCH: 1 11 | -------------------------------------------------------------------------------- /config/resunet_optic.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "optic" 3 | INPUTS_PATH: "/home/aistudio/optic/image" 4 | LABELS_PATH: "/home/aistudio/optic/label" 5 | PREP_PATH: "/home/aistudio/optic/prep" 6 | TRAIN: 7 | ARCHITECTURE: "res_unet" 8 | BATCH_SIZE: 8 9 | SNAPSHOT_BATCH: 30 10 | DISP_BATCH: 1 11 | -------------------------------------------------------------------------------- /config/eval.yaml: -------------------------------------------------------------------------------- 1 | EVAL: 2 | PATH: 3 | SEG: "/data/aorta/300/label/nii/aunet-best" 4 | GT: "/data/aorta/300/test/label" 5 | NAME: "aunet-best" 6 | METRICS: [ 7 | "IOU", 8 | "Dice", 9 | "TP", 10 | "TN", 11 | "Precision", 12 | "Recall", 13 | "Sensitivity", 14 | "Specificity", 15 | "Accuracy", 16 | "Dice", 17 | "IOU", 18 | "Assd", 19 | "Ravd", 20 | ] 21 | -------------------------------------------------------------------------------- /config/resunet_aorta.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "lits" 3 | INPUTS_PATH: "/home/aistudio/data/volume" 4 | LABELS_PATH: "/home/aistudio/data/label" 5 | PREP_PATH: "/home/aistudio/data/preprocess" 6 | PREP: 7 | FRONT: 1 8 | WINDOW: False 9 | CROP: False 10 | THRESH: 128 11 | SIZE: (512, 512, -1) 12 | BATCH_SIZE: 16 13 | TRAIN: 14 | ARCHITECTURE: "res_unet" 15 | BATCH_SIZE: 30 16 | SNAPSHOT_BATCH: 300 17 | DISP_BATCH: 10 18 | 19 | AUG: 20 | ROTATE: 21 | RATIO: (0, 0.3, 0) 22 | RANGE: (0,(-10,10),0) 23 | WINDOWLIZE: True 24 | WWWC: (400, 0) 25 | -------------------------------------------------------------------------------- /tool/infer/connected.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from tqdm import tqdm 5 | import nibabel as nib 6 | import numpy as np 7 | 8 | from util import filter_largest_volume 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("-i", "--in_dir", type=str, required=True) 12 | parser.add_argument("-o", "--out_dir", type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | names = os.listdir(args.in_dir) 16 | 17 | for name in tqdm(names): 18 | segf = nib.load(os.path.join(args.in_dir, name)) 19 | header = segf.header 20 | data = segf.get_fdata() 21 | data = filter_largest_volume(data, mode="hard") 22 | newf = nib.Nifti1Image(data, segf.affine, header) 23 | nib.save(newf, os.path.join(args.out_dir, name)) 24 | -------------------------------------------------------------------------------- /tool/train/vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | for k, v in os.environ.items(): 9 | if k.startswith("QT_") and "cv2" in v: 10 | del os.environ[k] 11 | 12 | 13 | base_dir = "/home/lin/Desktop/data/lits/img" 14 | img_dir = osp.join(base_dir, "JPEGImages") 15 | lab_dir = osp.join(base_dir, "Annotations") 16 | 17 | for f in os.listdir(lab_dir)[:30]: 18 | img = osp.join(img_dir, f) 19 | lab = osp.join(lab_dir, f) 20 | 21 | img = cv2.imread(img) 22 | lab = cv2.imread(lab, cv2.IMREAD_UNCHANGED) 23 | print(lab.sum()) 24 | lab = lab * 255 25 | # lab = lab.reshape([512, 512, 1]) 26 | print(img.shape, lab.shape) 27 | 28 | fig = plt.figure(figsize=(10, 10)) 29 | fig.add_subplot(1, 2, 1) 30 | plt.imshow(img) 31 | fig.add_subplot(1, 2, 2) 32 | plt.imshow(lab) 33 | plt.show() 34 | -------------------------------------------------------------------------------- /tool/to_pinyin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | import util 8 | 9 | 10 | parser = argparse.ArgumentParser("zip_dataset") 11 | parser.add_argument("--in_dir", type=str) 12 | parser.add_argument("--out_dir", type=str) 13 | args = parser.parse_args() 14 | 15 | 16 | def to_pinyin(name, nonum=False): 17 | new_name = "" 18 | for ch in name: 19 | if u"\u4e00" <= ch <= u"\u9fff": 20 | new_name += pinyin(ch, style=Style.NORMAL)[0][0] 21 | else: 22 | # if nonum and ("0" <= ch <= "9" or ch == "_"): 23 | # continue 24 | new_name += ch 25 | return new_name 26 | 27 | 28 | def main(): 29 | files = os.listdir(args.in_dir) 30 | for f in tqdm(files): 31 | shutil.copy( 32 | os.path.join(args.in_dir, f), 33 | os.path.join(args.out_dir, to_pinyin(f)), 34 | ) 35 | # input("here") 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /tool/train/mhd2nii.py: -------------------------------------------------------------------------------- 1 | # 将mhd的扫描转换为nii格式 2 | import os 3 | import argparse 4 | 5 | import SimpleITK as sitk 6 | import nibabel as nib 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--mhd_dir", type=str, required=True) 12 | parser.add_argument("--nii_dir", type=str, required=True) 13 | parser.add_argument("--rot", type=int, default=0) 14 | args = parser.parse_args() 15 | 16 | if not os.path.exists(args.nii_dir): 17 | os.makedirs(args.nii_dir) 18 | 19 | # TODO: 添加多进程 20 | for fname in tqdm(os.listdir(args.mhd_dir)): 21 | if not fname.endswith(".mhd"): 22 | continue 23 | scan = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(args.mhd_dir, fname))) 24 | scan = scan.swapaxes(0, 1).swapaxes(1, 2) 25 | for _ in range(args.rot): 26 | scan = np.rot90(scan) 27 | 28 | # TODO: 研究mhd/raw格式是否带有更多头文件信息 29 | new_scan = nib.Nifti1Image(scan, np.eye(4)) 30 | nib.save(new_scan, os.path.join(args.nii_dir, fname.replace("mhd", "nii.gz"))) 31 | -------------------------------------------------------------------------------- /tool/infer/merge.py: -------------------------------------------------------------------------------- 1 | # 将肝脏和肿瘤的分割标签合并 2 | import nibabel as nib 3 | import os 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | 8 | liver_path = "/home/aistudio/data/liver" 9 | tumor_path = "/home/aistudio/data/tumor" 10 | merge_path = "/home/aistudio/data/merge" 11 | 12 | 13 | assert len(os.listdir(liver_path)) == len(os.listdir(tumor_path)), "肝脏和肿瘤的分割标签数量不想等" 14 | 15 | for liver_fname in tqdm(os.listdir(liver_path)): 16 | liverf = nib.load(os.path.join(liver_path, liver_fname)) 17 | tumorf = nib.load(os.path.join(tumor_path, liver_fname)) 18 | 19 | liver = liverf.get_fdata() 20 | tumor = tumorf.get_fdata() 21 | 22 | assert len(np.where(tumor == 1)[0]) == 0, "肿瘤中包含标签为1的前景,是不是肝脏和肿瘤数据弄反了" 23 | assert len(np.where(liver == 2)[0]) == 0, "肝脏中包含标签为2的前景,是不是肝脏和肿瘤数据弄反了" 24 | 25 | print("肝体{}, 肿瘤{}".format(liver.sum(), tumor.sum())) 26 | 27 | tumor = tumor / 2 28 | liver += tumor 29 | 30 | merge_file = nib.Nifti1Image(liver, liverf.affine) 31 | nib.save(merge_file, os.path.join(merge_path, liver_fname)) 32 | -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "lits" 3 | INPUTS_PATH: "/home/aistudio/data/scan_temp" 4 | LABELS_PATH: "/home/aistudio/data/label_temp" 5 | PREP_PATH: "/home/aistudio/data/preprocess" 6 | PREP: 7 | FRONT: 1 8 | WINDOW: False 9 | CROP: False 10 | THRESH: 1024 11 | SIZE: (512, 512, -1) 12 | BATCH_SIZE: 16 13 | TRAIN: 14 | ARCHITECTURE: "unet_plain" 15 | BATCH_SIZE: 1 16 | SNAPSHOT_BATCH: 600 17 | DISP_BATCH: 1 18 | REG_TYPE: L1 19 | REG_COEFF: 1e-6 20 | AUG: 21 | ROTATE: 22 | RATIO: (0, 0.2, 0) 23 | RANGE: (0,(-15, -16),0) 24 | FLIP: 25 | RATIO: (0, 0, 0) 26 | WINDOWLIZE: True 27 | WWWC: (400, 0) 28 | ZOOM: 29 | RATIO: (0, 0.3, 0.3) 30 | RANGE: (0, (0.8, 1.0), (0.8, 1.0)) 31 | INFER: 32 | PATH: 33 | INPUT: "/home/aistudio/data/inference" 34 | OUTPUT: "/home/aistudio/data/infer_lab" 35 | PARAM: "/home/aistudio/param/liver_inf" 36 | BATCH_SIZE: 196 37 | EVAL: 38 | PATH: 39 | SEG: "/home/aistudio/data/infer_lab" 40 | GT: "/home/aistudio/data/label" 41 | METRICS: [ 42 | "IOU", 43 | "Dice", 44 | "TP", 45 | "TN", 46 | "Precision", 47 | "Recall", 48 | "Sensitivity", 49 | "Specificity", 50 | "Accuracy", 51 | ] 52 | -------------------------------------------------------------------------------- /medseg/models/model.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | from models.res_unet import res_unet 3 | from models.unet_plain import unet_plain 4 | from models.hrnet import hrnet 5 | from models.deeplabv3p import deeplabv3p 6 | 7 | # models = { 8 | # "unet_simple": unet_simple(volume, 2, [512, 512]), 9 | # "res_unet": unet_base(volume, 2, [512, 512]), 10 | # "deeplabv3": deeplabv3p(volume, 2), 11 | # "hrnet": hrnet(volume, 2), 12 | # } 13 | 14 | 15 | def create_model(input, num_class=2): 16 | """构建训练模型. 17 | 18 | Parameters 19 | ---------- 20 | input : paddle.data 21 | 输入的 placeholder. 22 | num_class : int 23 | 输出分类有几类. 24 | 25 | Returns 26 | ------- 27 | type 28 | 构建好的模型. 29 | 30 | """ 31 | if cfg.TRAIN.ARCHITECTURE == "unet_plain": 32 | return unet_plain(input, num_class, cfg.TRAIN.INPUT_SIZE) 33 | if cfg.TRAIN.ARCHITECTURE == "res_unet": 34 | return res_unet(input, num_class, cfg.TRAIN.INPUT_SIZE) 35 | if cfg.TRAIN.ARCHITECTURE == "hrnet": 36 | return hrnet(input, num_class) 37 | if cfg.TRAIN.ARCHITECTURE == "deeplabv3p": 38 | return deeplabv3p(input, num_class) 39 | raise Exception("错误的网络类型: {}".format(cfg.TRAIN.ARCHITECTURE)) 40 | -------------------------------------------------------------------------------- /config/resunet_lits.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | NAME: "lits" 3 | INPUTS_PATH: "/home/aistudio/data/scan" 4 | LABELS_PATH: "/home/aistudio/data/label" 5 | PREP_PATH: "/home/aistudio/data/preprocess" 6 | PREP: 7 | FRONT: 1 8 | WINDOW: False 9 | CROP: False 10 | THRESH: 0 11 | SIZE: (512, 512, -1) 12 | BATCH_SIZE: 20 13 | THICKNESS: 5 14 | 15 | TRAIN: 16 | ARCHITECTURE: "res_unet" 17 | BATCH_SIZE: 32 18 | EPOCHS: 30 19 | SNAPSHOT_BATCH: 600 20 | DISP_BATCH: 20 21 | OPTIMIZER: "adam" 22 | REG_TYPE: L1 23 | REG_COEFF: 3e-6 24 | BOUNDARIES: [80000, 20000] 25 | LR: [0.002, 0.001, 0.0005] 26 | 27 | AUG: 28 | ROTATE: 29 | RATIO: (0, 0, 0) 30 | RANGE: (0,(-15, -16),0) 31 | FLIP: 32 | RATIO: (0, 0, 0) 33 | WINDOWLIZE: True 34 | WWWC: (400, 0) 35 | ZOOM: 36 | RATIO: (0, 0.3, 0.3) 37 | RANGE: (0, (0.8, 1.1), (0.8, 1.1)) 38 | 39 | INFER: 40 | PATH: 41 | INPUT: "/home/aistudio/data/inference" 42 | OUTPUT: "/home/aistudio/data/infer_lab" 43 | PARAM: "/home/aistudio/liverSeg/model/lits/inf" 44 | BATCH_SIZE: 64 45 | THRESH: 0.5 46 | FILTER_LARGES: True 47 | WWWC: (400, 0) 48 | 49 | EVAL: 50 | PATH: 51 | SEG: "/home/aistudio/data/infer_lab" 52 | GT: "/home/aistudio/data/label" 53 | METRICS: [ 54 | "IOU", 55 | "Dice", 56 | "TP", 57 | "TN", 58 | "Precision", 59 | "Recall", 60 | "Sensitivity", 61 | "Specificity", 62 | "Accuracy", 63 | ] 64 | -------------------------------------------------------------------------------- /tool/README.md: -------------------------------------------------------------------------------- 1 | 对数据进行预处理和后处理的一些有用的脚本,功能介绍在[../README.md]中有写 2 | 3 | # 数据格式转换 4 | 写代码角度讲nii格式一般用起来比较方便。dcm包含的信息多,但是一个文件夹不一定就是一个序列;一层一个文件的保存方式下,进行大量文件读写I/O效率也不高。这个项目基本都用nii格式,记录一些格式转换方法。 5 | 6 | [dcm2niix](https://github.com/rordenlab/dcm2niix)可以将dcm转换nii,下面是一个命令行转换的例子。 7 | ```shell 8 | dcm2niix -f 输出文件名,支持填入多种扫描里的信息 -d 9 -c 在输出nii文件中写注释 dcm文件夹 9 | ``` 10 | train目录下的 [mhd2nii.py](./train/mhd2nii.py) 可以将一个目录下的mhd格式扫描转成nii。 11 | 12 | # 数据标注 13 | 标注数据的时候ITK-snap用起来很方便,用好命令行参数可以自动进行文件打开,节省时间。 14 | ```shell 15 | count=0 ; 16 | tot=`ls -l | wc -l` 17 | for f in `ls`; 18 | do count=`expr $count + 1`; 19 | echo $count / $tot; 20 | echo $f; 21 | echo -e "\n"; 22 | itksnap -s /path/to/label/${f} -g /path/to/scan/${f} --geometry 1920x1080+0+0; 23 | done 24 | ``` 25 | 26 | beep函数可以用扬声器测试程序发出一声beep,结合定时可以更好掌握标注时间 27 | ```shell 28 | beep1() 29 | { 30 | ( \speaker-test --frequency $1 --test sine )& 31 | pid=$! 32 | \sleep ${2}s 33 | \kill -9 $pid 34 | } 35 | 36 | beep() 37 | { 38 | beep1 350 0.2 39 | beep1 350 0.2 40 | beep1 350 0.2 41 | beep1 350 0.4 42 | sleep 1 43 | } 44 | 45 | count=0 ; 46 | for f in `ls`; 47 | do 48 | # 四个 beep 49 | (sleep 2m; beep;) & 50 | pid2=$! 51 | (sleep 4m; beep; beep;) & 52 | pid4=$! 53 | (sleep 6m; beep; beep; beep;) & 54 | pid6=$! 55 | (sleep 8m; beep; beep; beep; beep; ) & 56 | pid8=$! 57 | 58 | # 计数 59 | count=`expr $count + 1`; 60 | echo $count / `ls -l | wc -l`; 61 | echo ${f}; 62 | echo -e "\n"; 63 | 64 | # 打开扫描和标签 65 | itksnap -s ./${f} -g ../nii/${f} --geometry 1920x1080+0+0; 66 | 67 | # 文件归档 68 | # cp ./${f} ../manual-label/ 69 | mv ./${f} ../manual-fin../ished/ 70 | 71 | # 关闭没响的beep 72 | kill -9 $pid2 73 | kill -9 $pid4 74 | kill -9 $pid6 75 | kill -9 $pid8 76 | done 77 | ``` 78 | -------------------------------------------------------------------------------- /tool/infer/vote.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import numpy as np 3 | from multiprocessing import Pool 4 | from multiprocessing import cpu_count 5 | import os 6 | 7 | 8 | def listdir(path): 9 | dirs = os.listdir(path) 10 | if ".DS_Store" in dirs: 11 | dirs.remove(".DS_Store") 12 | if "checkpoint" in dirs: 13 | dirs.remove("checkpoint") 14 | 15 | dirs.sort() # 通过一样的sort保持vol和seg的对应 16 | return dirs 17 | 18 | 19 | voters_base = "/home/aistudio/data/voting" 20 | voter_paths = ["voter1", "voter2", "voter3"] 21 | voter_paths = [os.path.join(voters_base, path) for path in voter_paths] 22 | 23 | 24 | def voting(fname): 25 | voter0 = nib.load(os.path.join(voter_paths[0], fname)) 26 | merged = voter0.get_fdata() 27 | print("voting {}".format(fname)) 28 | # print(merged.sum()) 29 | for ind in range(1, len(voter_paths)): 30 | voterf = nib.load(os.path.join(voter_paths[ind], fname)) 31 | merged += voterf.get_fdata() 32 | merged[merged > len(voter_paths) / 2] = 1 33 | merged[merged != 1] = 0 34 | 35 | 36 | def main(): 37 | voter_names = [listdir(path) for path in voter_paths] 38 | print(voter_names) 39 | for data_ind in range(len(voter_names[0])): 40 | for voter_ind in range(1, len(voter_names)): 41 | # print(data_ind, voter_ind) 42 | assert voter_names[voter_ind][data_ind] == voter_names[0][data_ind], "第 {} 组数据,{} 和 {} 名称不相同".format( 43 | data_ind, voter_names[voter_ind][data_ind], voter_names[0][data_ind] 44 | ) 45 | # for fname in voter_names[0]: 46 | # voting(fname) 47 | with Pool(cpu_count()) as p: 48 | p.map(voting, voter_names[0]) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /tool/train/gen_list.py: -------------------------------------------------------------------------------- 1 | # 按照文件名排序,生成文件列表 2 | import os 3 | import os.path as osp 4 | import argparse 5 | import logging 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import util 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--base_dir", type=str, help="数据集路径", default=None) 14 | parser.add_argument( 15 | "--img_fdr", 16 | type=str, 17 | help="扫描文件夹", 18 | default="JPEGImages", 19 | ) 20 | parser.add_argument( 21 | "--lab_fdr", 22 | type=str, 23 | help="标签文件路径", 24 | default="Annotations", 25 | ) 26 | parser.add_argument("-d", "--delimiter", type=str, help="分隔符", default=" ") 27 | parser.add_argument( 28 | "-s", 29 | "--split", 30 | nargs=3, 31 | help="训练/验证/测试划分比例,比如 7 2 1", 32 | default=["7", "2", "1"], 33 | ) 34 | args = parser.parse_args() 35 | 36 | imgs = util.listdir(osp.join(args.base_dir, args.img_fdr)) 37 | labs = util.listdir(osp.join(args.base_dir, args.lab_fdr)) 38 | assert len(imgs) == len( 39 | labs 40 | ), f"Scan slice number ({len(imgs)}) isn't equal to label slice number({len(labs)})" 41 | 42 | names = [[i, l] for i, l in zip(imgs, labs)] 43 | file_names = ["train_list.txt", "eval_list.txt", "test_list.txt"] 44 | split = util.toint(args.split) 45 | tot = np.sum(split) 46 | split = [int(s / tot * len(names)) for s in split] 47 | print(f"Train/Eval/Test split is {split}") 48 | split[1] += split[0] 49 | split[2] = tot 50 | part = 0 51 | f = open(osp.join(args.base_dir, file_names[part]), "w") 52 | for idx, (img, lab) in enumerate(names): 53 | if idx == split[part] and idx != len(names) - 1: 54 | f.close() 55 | part += 1 56 | f = open(osp.join(args.base_dir, file_names[part]), "w") 57 | 58 | print( 59 | "{:s}{:s}{:s}".format( 60 | osp.join(args.img_fdr, img), 61 | args.delimiter, 62 | osp.join(args.lab_fdr, lab), 63 | ), 64 | file=f, 65 | ) 66 | -------------------------------------------------------------------------------- /tool/infer/aorta.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import math 5 | from multiprocessing import Pool 6 | import time 7 | import random 8 | 9 | import numpy as np 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in_dir", type=str, default="./train-diameter") 14 | args = parser.parse_args() 15 | 16 | 17 | def score(data, split=0.2): 18 | if len(data) < 10: 19 | return -1, -1, -1, -1, -1 20 | 21 | mins = [] 22 | # print("\n\n\n\n\n", data) 23 | for idx in range(len(data)): 24 | if len(data[idx]) == 1: 25 | continue 26 | mins.append(np.min(data[idx][1:])) 27 | # print(mins) 28 | abdomin = np.max(mins[: int(len(mins) * split)]) 29 | chest = np.max(mins[int(len(mins) * split) :]) 30 | if abdomin > 30: 31 | abdomin += 20 32 | res = max(abdomin, chest) 33 | if 40 < res < 50: 34 | cat1 = 1 35 | elif res > 50: 36 | cat1 = 2 37 | else: 38 | cat1 = 0 39 | if res > 39.5: 40 | cat2 = 1 41 | else: 42 | cat2 = 0 43 | if res > 50: 44 | cat3 = 1 45 | else: 46 | cat3 = 0 47 | # print(abdomin, chest, cat1, cat2, cat3) 48 | return abdomin, chest, cat1, cat2, cat3 49 | 50 | 51 | if __name__ == "__main__": 52 | names = os.listdir(args.in_dir) 53 | for name in names: 54 | path = os.path.join(args.in_dir, name) 55 | with open(path, "r") as f: 56 | data = f.readlines() 57 | data = [d.split(",") for d in data[1:]] 58 | for i in range(len(data)): 59 | for j in range(len(data[i])): 60 | try: 61 | data[i][j] = float(data[i][j]) 62 | except: 63 | del data[i][j] 64 | print(name.split("_")[0], end="\t") 65 | for d in score(data): 66 | print(d, end="\t") 67 | print() 68 | -------------------------------------------------------------------------------- /medseg/prep_2d.py: -------------------------------------------------------------------------------- 1 | # 将图片组batch,存成和3D预处理一样的npz格式 2 | import os 3 | 4 | import cv2 5 | import argparse 6 | import numpy as np 7 | 8 | import utils.util as util 9 | from utils.config import cfg 10 | import aug 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="数据预处理") 15 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 16 | parser.add_argument("opts", nargs=argparse.REMAINDER) 17 | args = parser.parse_args() 18 | 19 | if args.cfg_file is not None: 20 | cfg.update_from_file(args.cfg_file) 21 | if args.opts: 22 | cfg.update_from_list(args.opts) 23 | 24 | 25 | def main(): 26 | images = util.listdir(cfg.DATA.INPUTS_PATH) 27 | labels = util.listdir(cfg.DATA.LABELS_PATH) 28 | print(images) 29 | 30 | if not os.path.exists(cfg.DATA.PREP_PATH): 31 | os.makedirs(cfg.DATA.PREP_PATH) 32 | npz_count = 0 33 | img_npz = [] 34 | lab_npz = [] 35 | for ind in range(len(images)): 36 | img = cv2.imread(os.path.join(cfg.DATA.INPUTS_PATH, images[ind])) 37 | lab = cv2.imread(os.path.join(cfg.DATA.LABELS_PATH, labels[ind])) 38 | img = img.swapaxes(0, 2) 39 | lab = lab.swapaxes(0, 2) 40 | 41 | lab = lab[0] / 255 42 | lab = lab[np.newaxis, :, :] 43 | 44 | img, lab = aug.crop(img, lab, size=[3, 512, 512]) 45 | 46 | img_npz.append(img) 47 | lab_npz.append(lab) 48 | 49 | print(img.shape, lab.shape) 50 | 51 | if len(img_npz) == cfg.PREP.BATCH_SIZE or ind == len(images) - 1: 52 | imgs = np.array(img_npz) 53 | labs = np.array(lab_npz) 54 | file_name = "{}-{}".format(cfg.DATA.NAME, npz_count) 55 | file_path = os.path.join(cfg.DATA.PREP_PATH, file_name) 56 | np.savez(file_path, imgs=imgs, labs=labs) 57 | img_npz = [] 58 | lab_npz = [] 59 | npz_count += 1 60 | 61 | 62 | if __name__ == "__main__": 63 | parse_args() 64 | main() 65 | -------------------------------------------------------------------------------- /tool/flood_fill.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对分割标签进行漫水填充 3 | """ 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from skimage.morphology import flood_fill 7 | import numpy as np 8 | import nibabel as nib 9 | import queue 10 | import os 11 | 12 | # TODO: -60以下前景的要去掉 13 | 14 | lab_path = "/home/lin/Desktop/data/ann/flood/" 15 | fill_path = "/home/lin/Desktop/data/ann/flooded/" 16 | processed = 0 17 | lab_names = os.listdir(lab_path) 18 | for lab_name in lab_names: 19 | print("--------") 20 | print(lab_name) 21 | 22 | if os.path.exists(os.path.join(fill_path, lab_name)): 23 | print("[skp]跳过已经手工refine过的标签") 24 | continue 25 | 26 | labf = nib.load(os.path.join(lab_path, lab_name)) 27 | lab = labf.get_fdata() 28 | 29 | # print(np.where(lab == 2)) 30 | if len(np.where(lab == 2)[0]) == 0: 31 | print("[skp]跳过没有lab2的标签") 32 | continue 33 | 34 | print(lab.shape) 35 | processed += 1 36 | 37 | # 第 2 个维度是层数,从0开始,和itk里面下表是 i 对应 i + 1 38 | # 片内是 itk 显示的片子顺时针转90度 39 | # 第 0 个维度在数组中是上到下,在itk中是左到右 40 | # 第 1 个维度在数组中是左到右,在itk中是上到下 41 | 42 | for sli_ind in range(lab.shape[2]): 43 | 44 | seeds = np.where(lab[:, :, sli_ind] == 2) 45 | # print(seeds) 46 | 47 | # plt.imshow(lab[:, :, sli_ind]) 48 | # plt.show() 49 | 50 | for i in range(len(seeds[0])): 51 | lab[seeds[0][i], seeds[1][i], sli_ind] = 0 52 | slice = lab[:, :, sli_ind] 53 | flood_fill( 54 | slice, 55 | (seeds[1][i], seeds[0][i]), 56 | 1, 57 | selem=[[0, 1, 0], [1, 0, 1], [0, 1, 0]], 58 | inplace=True, 59 | ) 60 | 61 | lab[seeds[0][i], seeds[1][i], sli_ind] = 1 62 | 63 | # plt.imshow(lab[:, :, sli_ind]) 64 | # plt.show() 65 | 66 | filled = nib.Nifti1Image(lab, np.eye(4)) 67 | nib.save(filled, os.path.join(fill_path, lab_name)) 68 | print("\t处理完成") 69 | 70 | print("共: ", len(lab_names), "处理了: ", processed) 71 | -------------------------------------------------------------------------------- /tool/train/folder_split.py: -------------------------------------------------------------------------------- 1 | # 原来的数据和标签分别在一个目录里,进行随机split之后按照pdseg的目录结构放 2 | import os 3 | import argparse 4 | import random 5 | 6 | import util 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dst_dir", type=str, required=True) 11 | parser.add_argument("--img_folder", type=str, required=True) 12 | parser.add_argument("--lab_folder", type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | 16 | def mv(curr, dest): 17 | print(os.path.join(dest.rstrip(dest.split("/")[-1]))) 18 | if not os.path.exists(os.path.join(dest.rstrip(dest.split("/")[-1]))): 19 | os.makedirs(os.path.join(dest.rstrip(dest.split("/")[-1]))) 20 | os.rename(curr, dest) 21 | 22 | 23 | def split_dataset( 24 | img_folder, 25 | lab_folder, 26 | dst_dir, 27 | split=[8, 2, 0], 28 | folders=["imgs", "annotations"], 29 | sub_folders=["train", "val", "test"], 30 | ): 31 | # 1. 创建目标目录 32 | for fd1 in folders: 33 | for fd2 in sub_folders: 34 | dir = os.path.join(dst_dir, fd1, fd2) 35 | if not os.path.exists(dir): 36 | os.makedirs(dir) 37 | 38 | # 2. 获取图像和标签文件名,打乱 39 | img_names = util.listdir(img_folder) 40 | lab_names = util.listdir(lab_folder) 41 | names = [[i, l] for i, l in zip(img_names, lab_names)] 42 | random.shuffle(names) 43 | 44 | for idx in range(10): 45 | print(names[idx]) 46 | 47 | # 3. 计算划分点 48 | split.insert(0, 0) 49 | for ind in range(1, len(split)): 50 | split[ind] += split[ind - 1] 51 | 52 | split = [x / split[-1] for x in split] 53 | split = [int(len(img_names) * split[ind]) for ind in range(4)] 54 | print(split) 55 | 56 | # 4. 进行移动 57 | for part in range(3): 58 | print(f"正在处理{sub_folders[part]}") 59 | for idx in range(split[part], split[part + 1]): 60 | img, lab = names[idx] 61 | mv( 62 | os.path.join(img_folder, img), 63 | os.path.join(dst_dir, folders[0], sub_folders[part], img), 64 | ) 65 | mv( 66 | os.path.join(lab_folder, lab), 67 | os.path.join(dst_dir, folders[1], sub_folders[part], lab), 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | split_dataset( 73 | img_folder=args.img_folder, 74 | lab_folder=args.lab_folder, 75 | dst_dir=args.dst_dir, 76 | ) 77 | -------------------------------------------------------------------------------- /medseg/loss.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.fluid as fluid 3 | from paddle.fluid.layers import log 4 | 5 | 6 | def mean_iou(pred, label, num_classes=2): 7 | """ 8 | 计算miou 9 | """ 10 | pred = fluid.layers.argmax(pred, axis=1) 11 | pred = fluid.layers.cast(pred, "int32") 12 | label = fluid.layers.cast(label, "int32") 13 | miou, wrong, correct = fluid.layers.mean_iou(pred, label, num_classes) 14 | return miou 15 | 16 | 17 | def weighed_binary_cross_entropy(y, y_predict, beta=2, epsilon=1e-6): 18 | """ 19 | 返回 wce loss 20 | beta标记的是希望positive类给到多少的权重,如果positive少,beta给大于1相当与比0的类更重视 21 | """ 22 | y = fluid.layers.clip(y, epsilon, 1 - epsilon) 23 | y_predict = fluid.layers.clip(y_predict, epsilon, 1 - epsilon) 24 | 25 | ylogp = fluid.layers.elementwise_mul(y, log(y_predict)) 26 | betas = fluid.layers.fill_constant(ylogp.shape, "float32", beta) 27 | ylogp = fluid.layers.elementwise_mul(betas, ylogp) 28 | 29 | ones = fluid.layers.fill_constant(y_predict.shape, "float32", 1) 30 | ylogp = fluid.layers.elementwise_add( 31 | ylogp, elementwise_mul(elementwise_sub(ones, y), log(elementwise_sub(ones, y_predict))) 32 | ) 33 | 34 | zeros = fluid.layers.fill_constant(y_predict.shape, "float32", 0) 35 | return fluid.layers.elementwise_sub(zeros, ylogp) 36 | 37 | 38 | def focal_loss(y_predict, y, alpha=0.85, gamma=2, epsilon=1e-6): 39 | """ 40 | alpha 变大,对前景类惩罚变大,更加重视 41 | gamma 变大,对信心大的例子更加忽略,学习难的例子 42 | """ 43 | y = fluid.layers.clip(y, epsilon, 1 - epsilon) 44 | y_predict = fluid.layers.clip(y_predict, epsilon, 1 - epsilon) 45 | 46 | return -1 * ( 47 | alpha * fluid.layers.pow((1 - y_predict), gamma) * y * log(y_predict) 48 | + (1 - alpha) * fluid.layers.pow(y_predict, gamma) * (1 - y) * log(1 - y_predict) 49 | ) 50 | 51 | 52 | def create_loss(predict, label, num_classes=2): 53 | predict = fluid.layers.transpose(predict, perm=[0, 2, 3, 1]) 54 | predict = fluid.layers.reshape(predict, shape=[-1, num_classes]) 55 | predict = fluid.layers.softmax(predict) 56 | label = fluid.layers.reshape(label, shape=[-1, 1]) 57 | label = fluid.layers.cast(label, "int64") 58 | dice_loss = fluid.layers.dice_loss(predict, label) 59 | 60 | # label = fluid.layers.cast(label, "int64") 61 | 62 | ce_loss = fluid.layers.cross_entropy(predict, label) 63 | # focal = focal_loss(predict, label) 64 | 65 | return fluid.layers.reduce_mean(ce_loss + dice_loss) 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # project cache 2 | lits.csv 3 | test.py 4 | *.csv 5 | summary.txt 6 | 7 | # BOS 8 | bos_conf.py 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /tool/train/resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | from multiprocessing import Pool 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | import scipy.ndimage 9 | 10 | import util 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--scan_dir", type=str, required=True, help="扫描或标签路径") 14 | parser.add_argument("--out_dir", type=str, required=True, help="resize后输出路径") 15 | parser.add_argument("--size", nargs=2, default=[512, 512], help="resize目标大小") 16 | parser.add_argument("-t", "--thickness", type=float, default=None, help="统一的层间距") 17 | parser.add_argument( 18 | "-l", 19 | "--is_label", 20 | default=False, 21 | action="store_true", 22 | help="是否是标签,如果是标签会使用零阶插值", 23 | ) 24 | args = parser.parse_args() 25 | 26 | # TODO: 区分标签和扫描 27 | def resize(path): 28 | name = osp.basename(path) 29 | scanf = nib.load(path) 30 | header = scanf.header.copy() 31 | scan_data = scanf.get_fdata() 32 | old_pixdim = scanf.header.copy()["pixdim"] 33 | old_shape = scan_data.shape 34 | if scan_data.shape[:2] != args.size: 35 | scale = [t / c for c, t in zip(scan_data.shape[:2], args.size)] 36 | s = scale 37 | header["pixdim"][1] /= s[0] 38 | header["pixdim"][2] /= s[1] 39 | else: 40 | scale = [1, 1] 41 | 42 | if args.thickness and header["pixdim"][3] != args.thickness: 43 | scale.append(header["pixdim"][3] / args.thickness) 44 | header["pixdim"][3] = args.thickness 45 | else: 46 | scale.append(1) 47 | 48 | if scale != [1, 1, 1]: 49 | s = scale 50 | scan_data = scipy.ndimage.interpolation.zoom( 51 | scan_data, (s[0], s[1], s[2]), order=0 if args.is_label else 3 52 | ) 53 | # if args.is_label: 54 | # scan_data = scan_data.astype("uint8") 55 | 56 | newf = nib.Nifti1Image(scan_data.astype(np.float32), scanf.affine, header) 57 | nib.save(newf, osp.join(args.out_dir, name)) 58 | print( 59 | name, 60 | ":", 61 | old_pixdim[1:4], 62 | old_shape, 63 | scale, 64 | header["pixdim"][1:4], 65 | scan_data.shape, 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | args.size = util.toint(args.size) 71 | names = os.listdir(args.scan_dir) 72 | names = [ 73 | osp.join(args.scan_dir, n) 74 | for n in names 75 | if n.endswith("nii") or n.endswith("nii.gz") 76 | ] 77 | if not osp.exists(args.out_dir): 78 | os.makedirs(args.out_dir) 79 | 80 | p = Pool(8) 81 | p.map(resize, names) 82 | 83 | # for name in names: 84 | # to_512(name) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # medSeg 2 | 中文 | [English](./README_en.md) 3 | 4 | 仿照百度[PaddleSeg](https://github.com/paddlepaddle/paddleseg)结构实现的一个医学影像方向分割任务开发套件。主要目标是实现多种2D和3D网络(3D网络还在开发中),多种loss和数据增强策略。目前项目还在开发中,但是已经能在肝脏分割场景下做到 .94 的IOU。2.5D P-Unet项目基于这个套件实现。开发计划见[Project](https://github.com/davidlinhl/medSeg/projects/1) 5 | 6 | ## 项目结构 7 | #### medseg 项目主体 8 | 9 | - prep_3d.py prep_2d.py 对3D扫描或2D切片进行推理 10 | - loss.py 定义loss 11 | - models 定义模型 12 | - aug.py 定义数据增强方法 13 | - train.py 训练网络 14 | - infer.py 用训练完的模型进行推理 15 | - vis.py 对数据进行可视化 16 | - eval.py 对分割结果进行评估,支持基本所有医学影像常用2d/3d metric 17 | 18 | #### tool 工具脚本 19 | tool中提供了一些实用的工具脚本,[train](./train)目录下主要用于训练前的预处理,[infer](./infer)目录下的主要用于推理和后处理。 20 | 21 | - train 22 | - mhd2nii.py : 将mhd格式文件转成nii 23 | - resize.py : 将nii格式的扫描或标签转成512大小 24 | - dataset_scan.py : 生成数据集总览,包括强度分布和归一化需要的平均中位数 25 | - to_slice.py : 将3D扫描和标签转成2D的切片,实测多线程提速1倍左右 26 | - vis.py : 随机抽取切片结果或3D序列中的片进行可视化 27 | - gen_list.py : 生成数据文件列表,按照比例划分训练/验证/测试集 28 | - folder_split.py : 将整个数据集随机划分成训练,验证和测试集 29 | - infer 30 | - 2d_diameter.py : 在切片内测量分割标签中的血管直径 31 | - zip_dataset.py : 将一个路径下的文件打包,每个压缩包不超过指定大小。和分包zip不同的是每个压缩包都是单独的包,都可以解压实例测 32 | - flood_fill.py : 对分割标签进行漫水填充 33 | - to_pinyin.py : 将中文文件名转拼音 34 | 35 | #### config 配置文件 36 | 所有配置参考[config.py](https://github.com/davidlinhl/medSeg/blob/master/medseg/utils/config.py) 37 | 38 | ## 使用方法 39 | ### 配置环境 40 | 安装环境依赖: 41 | ```shell 42 | pip install -r requirements.txt 43 | ``` 44 | paddle框架的安装参考[paddle官网](https://www.paddlepaddle.org.cn/) 45 | 如果进行训练需要有数据,目前项目主要面向lits调试,在aistudio上可以找到。[训练集](https://aistudio.baidu.com/aistudio/datasetDetail/10273) [测试集](https://aistudio.baidu.com/aistudio/datasetDetail/10292) 46 | 47 | 数据集下载,解压之后将所有的训练集volume放到一个文件夹,所有的训练集label放到一个文件夹,测试集volume放到一个文件夹。修改 lits.yaml 中对应的路径。 48 | 49 | ### 预处理 50 | 配置完毕需要首先进行数据预处理,这里主要是将数据统一成npz格式,方便后续训练。也可以在这一步结合一些预处理步骤对3D CT数据可以做窗口化,3D旋转。 51 | ```shell 52 | python medseg/prep_3d.py -c config/lits.yaml 53 | ``` 54 | ### 训练 55 | 网络用预处理后的数据进行训练,训练提供一些参数,-h 可以显示。如果用的是cpu版本的paddle,不要添加 --use_gpu 参数。 56 | ```shell 57 | python medseg/train.py -c config/lits.yaml --use_gpu --do_eval 58 | ``` 59 | ### 预测 60 | 最后一步是用训练好的网络进行预测,要配置好模型权重的路径,按照上一步实际输出的路径进行修改。代码会读取inference路径下所有的nii逐个进行预测。目前支持的数据格式有 .nii, .nii.gz。 61 | ```shell 62 | python infer.py -c config/lits.yaml --use_gpu 63 | ``` 64 |
65 | 66 | 这个项目在aistudio中有完整的环境,fork项目可以直接运行,[项目地址](https://aistudio.baidu.com/aistudio/projectdetail/250994)中运行 67 | 68 | # 更新日志 69 | * 2020.6.21 70 | **v0.2.0** 71 | * 项目整体修改为使用配置文件,丰富功能,增加可视化 72 | 73 | * 2020.5.1 74 | **v0.1.0** 75 | * 在lits数据集上跑通预处理,训练,预测脚本 76 | 77 | 如项目使用中有任何问题,欢迎加入 Aistudio医学兴趣组,在群中提问,也可以和更多大佬一起学习和进步。 78 | 79 |
80 | 2132453929.jpg 81 |
82 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # MedSeg 2 | English | [简体中文](./README_cn.md) 3 | 4 | Medical image segmentation toolkit based on PaddlePaddle framework. Our target is to implement various 2D and 3D model architectures, various loss function and data augmentation methods. This is still a work in progress but has achieved promising results on liver segmentation and aorta segmentation. The development plans can be seen in the [Project](https://github.com/davidlinhl/medSeg/projects/1) 5 | 6 | ## Project Structure 7 | Currently this project contains only 2D segmentation models. The structure is as follows. 8 | 9 | - medseg: Promary code 10 | - train.py: Training pipeline 11 | - aug.py: Data augmentation 12 | - loss.py: Various model loss 13 | - eval.py: Various metrics to evaluate segmentation result 14 | - vis.py: Visualize results 15 | - models: Currently only 2D models 16 | - tool: Useful scripts 17 | - train: Tools for converting scan format, generating 2D slices from scan etc. 18 | - infer: Tools used after inference, merging segmentation results from slices, voting model fusion etc. 19 | - config: Training configurations 20 | All configurations can be found in [utils/config.py](https://github.com/davidlinhl/medSeg/blob/master/medseg/utils/config.py) 21 | 22 | ## Usage 23 | ### Environment Set Up 24 | Install project dependencies with 25 | ```shell 26 | pip install -r requirements.txt 27 | ``` 28 | Instructions for installing PaddlePaddle-GPU can be found on PaddlePaddle's [official home page](https://www.paddlepaddle.org.cn/) 29 | 30 | ### Preprocess 31 | Preprocess 3D scans either into 2D slices or 3D patches. Applying WWWC or other slice-wise augmentation can also be done here. 32 | ```shell 33 | python medseg/prep_3d.py -c config/lits.yaml 34 | ``` 35 | 36 | ### Training 37 | The training script contains several choices. Run with -h command to see details about them. If u r training with CPU only, don't include the --use_gpu command. 38 | ```shell 39 | python medseg/train.py -c config/lits.yaml --use_gpu --do_eval 40 | ``` 41 | 42 | ### Inference 43 | The last step is doing inference with previously trained model. The script would perform inference on all data under specified path and perform inference. Currently supports nii format only. 44 | ```shell 45 | python medseg/infer.py -c config/lits.yaml --use_gpu 46 | ``` 47 | 48 | ### Evaluation and Else 49 | After getting inference results, you may want to know how well the model performs. We have an evaluation script with multiple metrics implemented with medpy. 50 | ```shell 51 | python medseg/eval.py -c config/eval.yaml 52 | ``` 53 | For aorta, specifically, we also have scripts for measuring blood vessel diameter and reporting aorta aneurysm. 54 | ```shell 55 | python tool/infer/2d_diameter.py 56 | python tool/infer/aorta.py 57 | ``` 58 | These two scripts combined can calculate aorta diameter and report aorta aneurysm based on it. 59 | 60 | We have an [Aistudio](https://aistudio.baidu.com/aistudio/projectdetail/250994) project that have all the data and environment ready. 61 | 62 | Should u have any question in using this toolkit, u can contact the developer at linhandev@qq.com 63 | -------------------------------------------------------------------------------- /tool/train/to_slice.py: -------------------------------------------------------------------------------- 1 | """ 2 | 将nii和标签批量转成2D切片(png/npy) 3 | 要求扫描和标签按照字典序排序相同(文件名相同,拓展名不同就可以满足这个) 4 | """ 5 | 6 | import os 7 | import os.path as osp 8 | import argparse 9 | import logging 10 | from tqdm import tqdm 11 | 12 | import util 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--scan_dir", type=str, help="扫描文件路径", required=True) 16 | parser.add_argument("--label_dir", type=str, help="标签文件路径", default=None) 17 | parser.add_argument("--out_dir", type=str, help="数据集输出路径", required=True) 18 | parser.add_argument( 19 | "--thick", 20 | type=int, 21 | help="切片厚度,默认3。如果是保存成png格式必须为3", 22 | default=3, 23 | ) 24 | parser.add_argument( 25 | "-t", 26 | "--thresh", 27 | type=int, 28 | help="前景像素数量大于这个数才包含到数据集里,否则这个slice跳过", 29 | default=None, 30 | ) 31 | parser.add_argument( 32 | "-s", 33 | "--size", 34 | nargs=2, 35 | help="输出片的大小,不声明这个参数不进行任何插值,否则扫描3阶插值,标签0阶缩放到这个大小", 36 | default=None, 37 | ) 38 | parser.add_argument("--wwwc", nargs=2, help="窗宽窗位", default=["1000", "0"]) 39 | parser.add_argument( 40 | "-r", 41 | "--rot", 42 | type=int, 43 | help="逆时针90度转多少次,可以为负", 44 | default=0, 45 | ) # TODO: 用库做体位校正 46 | parser.add_argument("-f", "--front", type=int, help="如果标签有多种前景,要保留的前景值", default=None) 47 | parser.add_argument( 48 | "-fm", 49 | "--front_mode", 50 | type=str, 51 | help="多个前景保留一个的策略。stack:把大于front的标签都设成front,小于front的标签设成背景。single:只保留front,其他的都设成背景", 52 | default=None, 53 | ) 54 | parser.add_argument( 55 | "-itv", 56 | "--interval", 57 | type=int, 58 | help="每隔这个数量取1片,重建层间距很小,片层很多的时候可以用这个跳过一些片", 59 | default=1, 60 | ) 61 | parser.add_argument("-c", "--check", default=False, action="store_true", help="是否检查数据集") 62 | parser.add_argument("--ext", type=str, help="文件保存的拓展名,不带点", default="png") 63 | parser.add_argument("--transpose", type=bool, default=False, help="是否调整数据维度顺序") 64 | parser.add_argument("--prefix", type=str, help="文件保存的前缀", default=None) 65 | args = parser.parse_args() 66 | 67 | logging.basicConfig( 68 | level=logging.DEBUG, 69 | format="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s", 70 | ) 71 | 72 | if args.thick % 2 != 1: 73 | logging.error( 74 | f"The thickkess argument {args.thick} is not odd, plz use an odd number." 75 | ) 76 | exit() 77 | 78 | if args.ext == "png" and args.thick not in [1, 3]: 79 | logging.error( 80 | f"Can't save {args.thick} channel image with png format. Png format only supports 1 or 3 channels. Switch to save in npy format instead." 81 | ) 82 | exit() 83 | 84 | 85 | # TODO: 完善对扫描和标签的检查 86 | if args.check: 87 | util.check_nii_match(args.scan_dir, args.label_dir) 88 | 89 | scans = util.listdir(args.scan_dir) 90 | if args.label_dir is not None: 91 | labels = util.listdir(args.label_dir) 92 | assert len(labels) == len( 93 | scans 94 | ), f"The number of labels {len(labels)} is not equal to number of scans {len(scans)}" 95 | logging.info("Discovered scan/label pairs:") 96 | for s, l in zip(scans, labels): 97 | logging.info(f" {s} \t {l}") 98 | cmd = input( 99 | f"""Totally {len(scans)} pairs, plz check for any mismatch. 100 | Input Y/y to continue, input anything else to stop: """ 101 | ) 102 | else: 103 | cmd = input( 104 | f"""Totally {len(scans)} scans. 105 | Input Y/y to continue, input anything else to stop: """ 106 | ) 107 | labels = [None for _ in range(len(scans))] 108 | if cmd.lower() != "y": 109 | exit("Exit on user command") 110 | 111 | progress = tqdm(range(len(scans))) 112 | 113 | for scan, label in zip(scans, labels): 114 | progress.set_description(f"Processing {osp.basename(scan)}") 115 | util.slice_med( 116 | osp.join(args.scan_dir, scan), 117 | osp.join(args.out_dir, "JPEGImages"), 118 | osp.join(args.label_dir, label) if args.label_dir else None, 119 | osp.join(args.out_dir, "Annotations") if args.label_dir else None, 120 | args.thick, 121 | rot=args.rot, 122 | wwwc=util.toint(args.wwwc), 123 | thresh=args.thresh, 124 | front=args.front, 125 | front_mode=args.front_mode, 126 | itv=args.interval, 127 | ext=args.ext, 128 | transpose=args.transpose, 129 | prefix=args.prefix, 130 | ) 131 | progress.update(n=1) 132 | -------------------------------------------------------------------------------- /tool/train/slice_mp.py: -------------------------------------------------------------------------------- 1 | """ 2 | 将nii和标签批量转成2D切片(png/npy) 3 | 要求扫描和标签按照字典序排序相同(文件名相同,拓展名不同就可以满足这个) 4 | """ 5 | import multiprocessing 6 | import os 7 | import os.path as osp 8 | import argparse 9 | import logging 10 | from tqdm import tqdm 11 | import functools 12 | 13 | import util 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--scan_dir", type=str, help="扫描文件路径", required=True) 17 | parser.add_argument("--label_dir", type=str, help="标签文件路径", default=None) 18 | parser.add_argument("--out_dir", type=str, help="数据集输出路径", required=True) 19 | parser.add_argument( 20 | "--thick", 21 | type=int, 22 | help="切片厚度,默认3。如果是保存成png格式必须为3", 23 | default=3, 24 | ) 25 | parser.add_argument( 26 | "-t", 27 | "--thresh", 28 | type=int, 29 | help="前景像素数量大于这个数才包含到数据集里,否则这个slice跳过", 30 | default=None, 31 | ) 32 | parser.add_argument( 33 | "-s", 34 | "--size", 35 | nargs=2, 36 | help="输出片的大小,不声明这个参数不进行任何插值,否则扫描3阶插值,标签0阶缩放到这个大小", 37 | default=None, 38 | ) 39 | parser.add_argument("--wwwc", nargs=2, help="窗宽窗位", default=["1000", "0"]) 40 | parser.add_argument( 41 | "-r", 42 | "--rot", 43 | type=int, 44 | help="逆时针90度转多少次,可以为负", 45 | default=0, 46 | ) # TODO: 用库做体位校正 47 | parser.add_argument("-f", "--front", type=int, help="如果标签有多种前景,要保留的前景值", default=None) 48 | parser.add_argument( 49 | "-fm", 50 | "--front_mode", 51 | type=str, 52 | help="多个前景保留一个的策略。stack:把大于front的标签都设成front,小于front的标签设成背景。single:只保留front,其他的都设成背景", 53 | default=None, 54 | ) 55 | parser.add_argument( 56 | "-itv", 57 | "--interval", 58 | type=int, 59 | help="每隔这个数量取1片,重建层间距很小,片层很多的时候可以用这个跳过一些片", 60 | default=1, 61 | ) 62 | parser.add_argument("-c", "--check", default=False, action="store_true", help="是否检查数据集") 63 | parser.add_argument("--ext", type=str, help="文件保存的拓展名,不带点", default="png") 64 | parser.add_argument( 65 | "--transpose", default=False, action="store_true", help="是否调整数据维度顺序" 66 | ) 67 | parser.add_argument("--prefix", type=str, help="文件保存的前缀", default="") 68 | parser.add_argument("-p", "--process", type=int, help="进程数", default=2) 69 | 70 | args = parser.parse_args() 71 | 72 | logging.basicConfig( 73 | level=logging.DEBUG, 74 | format="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s", 75 | ) 76 | 77 | if args.thick % 2 != 1: 78 | logging.error( 79 | f"The thickkess argument {args.thick} is not odd, plz use an odd number." 80 | ) 81 | exit() 82 | 83 | if args.ext == "png" and args.thick not in [1, 3]: 84 | logging.error( 85 | f"Can't save {args.thick} channel image with png format. Png format only supports 1 or 3 channels. Switch to save in npy format instead." 86 | ) 87 | exit() 88 | 89 | 90 | # TODO: 完善对扫描和标签的检查 91 | if args.check: 92 | util.check_nii_match(args.scan_dir, args.label_dir) 93 | 94 | scans = util.listdir(args.scan_dir) 95 | if args.label_dir is not None: 96 | labels = util.listdir(args.label_dir) 97 | assert len(labels) == len( 98 | scans 99 | ), f"The number of labels {len(labels)} is not equal to number of scans {len(scans)}" 100 | logging.info("Discovered scan/label pairs:") 101 | for s, l in zip(scans, labels): 102 | logging.info(f" {s} \t {l}") 103 | cmd = input( 104 | f"""Totally {len(scans)} pairs, plz check for any mismatch. 105 | Input Y/y to continue, input anything else to stop: """ 106 | ) 107 | else: 108 | cmd = input( 109 | f"""Totally {len(scans)} scans. 110 | Input Y/y to continue, input anything else to stop: """ 111 | ) 112 | labels = [None for _ in range(len(scans))] 113 | if cmd.lower() != "y": 114 | exit("Exit on user command") 115 | 116 | pool = multiprocessing.Pool(processes=args.process) 117 | tasks = [] 118 | for scan, label in zip(scans, labels): 119 | tasks.append( 120 | ( 121 | osp.join(args.scan_dir, scan), 122 | osp.join(args.out_dir, "JPEGImages"), 123 | osp.join(args.label_dir, label) if args.label_dir else None, 124 | osp.join(args.out_dir, "Annotations") if args.label_dir else None, 125 | args.thick, 126 | args.rot, 127 | util.toint(args.wwwc), 128 | args.thresh, 129 | args.front, 130 | args.front_mode, 131 | args.interval, 132 | None, 133 | args.ext, 134 | args.transpose, 135 | args.prefix, 136 | ) 137 | ) 138 | pool.starmap(util.slice_med, tasks) 139 | 140 | pool.close() 141 | pool.join() 142 | -------------------------------------------------------------------------------- /medseg/models/unet.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import contextlib 20 | import paddle 21 | import paddle.fluid as fluid 22 | from utils.config import cfg 23 | from models.libs.model_libs import scope, name_scope 24 | from models.libs.model_libs import bn, bn_relu, relu 25 | from models.libs.model_libs import conv, max_pool, deconv 26 | 27 | 28 | def double_conv(data, out_ch): 29 | param_attr = fluid.ParamAttr( 30 | name='weights', 31 | regularizer=fluid.regularizer.L2DecayRegularizer( 32 | regularization_coeff=0.0), 33 | initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.33)) 34 | with scope("conv0"): 35 | data = bn_relu( 36 | conv(data, out_ch, 3, stride=1, padding=1, param_attr=param_attr)) 37 | with scope("conv1"): 38 | data = bn_relu( 39 | conv(data, out_ch, 3, stride=1, padding=1, param_attr=param_attr)) 40 | return data 41 | 42 | 43 | def down(data, out_ch): 44 | # 下采样:max_pool + 2个卷积 45 | with scope("down"): 46 | data = max_pool(data, 2, 2, 0) 47 | data = double_conv(data, out_ch) 48 | return data 49 | 50 | 51 | def up(data, short_cut, out_ch): 52 | # 上采样:data上采样(resize或deconv), 并与short_cut concat 53 | param_attr = fluid.ParamAttr( 54 | name='weights', 55 | regularizer=fluid.regularizer.L2DecayRegularizer( 56 | regularization_coeff=0.0), 57 | initializer=fluid.initializer.XavierInitializer(), 58 | ) 59 | with scope("up"): 60 | if cfg.MODEL.UNET.UPSAMPLE_MODE == 'bilinear': 61 | data = fluid.layers.resize_bilinear(data, short_cut.shape[2:]) 62 | else: 63 | data = deconv( 64 | data, 65 | out_ch // 2, 66 | filter_size=2, 67 | stride=2, 68 | padding=0, 69 | param_attr=param_attr) 70 | data = fluid.layers.concat([data, short_cut], axis=1) 71 | data = double_conv(data, out_ch) 72 | return data 73 | 74 | 75 | def encode(data): 76 | # 编码器设置 77 | short_cuts = [] 78 | with scope("encode"): 79 | with scope("block1"): 80 | data = double_conv(data, 64) 81 | short_cuts.append(data) 82 | with scope("block2"): 83 | data = down(data, 128) 84 | short_cuts.append(data) 85 | with scope("block3"): 86 | data = down(data, 256) 87 | short_cuts.append(data) 88 | with scope("block4"): 89 | data = down(data, 512) 90 | short_cuts.append(data) 91 | with scope("block5"): 92 | data = down(data, 512) 93 | return data, short_cuts 94 | 95 | 96 | def decode(data, short_cuts): 97 | # 解码器设置,与编码器对称 98 | with scope("decode"): 99 | with scope("decode1"): 100 | data = up(data, short_cuts[3], 256) 101 | with scope("decode2"): 102 | data = up(data, short_cuts[2], 128) 103 | with scope("decode3"): 104 | data = up(data, short_cuts[1], 64) 105 | with scope("decode4"): 106 | data = up(data, short_cuts[0], 64) 107 | return data 108 | 109 | 110 | def get_logit(data, num_classes): 111 | # 根据类别数设置最后一个卷积层输出 112 | param_attr = fluid.ParamAttr( 113 | name='weights', 114 | regularizer=fluid.regularizer.L2DecayRegularizer( 115 | regularization_coeff=0.0), 116 | initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01)) 117 | with scope("logit"): 118 | data = conv( 119 | data, num_classes, 3, stride=1, padding=1, param_attr=param_attr) 120 | return data 121 | 122 | 123 | def unet(input, num_classes): 124 | # UNET网络配置,对称的编码器解码器 125 | encode_data, short_cuts = encode(input) 126 | decode_data = decode(encode_data, short_cuts) 127 | logit = get_logit(decode_data, num_classes) 128 | return logit 129 | 130 | 131 | if __name__ == '__main__': 132 | image_shape = [3, 320, 320] 133 | image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') 134 | logit = unet(image, 4) 135 | print("logit:", logit.shape) 136 | -------------------------------------------------------------------------------- /tool/infer/3d_diameter.py: -------------------------------------------------------------------------------- 1 | # 用mesh进行3d重建,之后计算管径 2 | import argparse 3 | import os 4 | 5 | import skimage.measure 6 | import scipy.ndimage 7 | import numpy as np 8 | import nibabel as nib 9 | import trimesh 10 | from util import filter_polygon, sort_line, Polygon 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--in_dir", type=str, default="/home/lin/Desktop/aorta/private/label/test" 15 | ) 16 | parser.add_argument("--out_dir", type=str, default="./img") 17 | args = parser.parse_args() 18 | 19 | 20 | vol_names = os.listdir(args.in_dir) 21 | for vol_name in vol_names: 22 | # 1. 获取需要测量的标签数据,插值成1024分辨率提升精度 23 | volf = nib.load(os.path.join(args.in_dir, vol_name)) 24 | vol = volf.get_fdata() 25 | print(vol.shape) 26 | # if vol.shape[0] < 1024: 27 | # vol = scipy.ndimage.interpolation.zoom(vol, (2, 2, 1), order=3) 28 | # 2. 进行3D重建 29 | verts, faces, normals, values = skimage.measure.marching_cubes(vol) 30 | # print(verts) 31 | verts = [[v[0], v[1], v[2] * 10] for v in verts] 32 | # print(verts) 33 | # 将bb左下角放到原点 34 | min = np.min(verts, axis=0) 35 | print(min) 36 | verts = verts - min 37 | # print(verts) 38 | mesh = trimesh.Trimesh(vertices=verts, faces=faces) 39 | 40 | # print(mesh.is_watertight) 41 | # print(mesh.bounds) 42 | 43 | # mesh.export("aorta.stl") 44 | # mesh = trimesh.load("./aorta.stl") 45 | 46 | scene = trimesh.Scene() 47 | scene.add_geometry(mesh) 48 | axis = trimesh.creation.axis(origin_size=10, origin_color=[1.0, 0, 0]) 49 | scene.add_geometry(axis) 50 | scene.show() 51 | 52 | # 3. 获取测量路径,在血管壁上取一条线的点 53 | # 3.1 查所有不同的高度,作为一个片曾 54 | heights = [] 55 | for v in verts: 56 | if v[2] not in heights: 57 | heights.append(v[2]) 58 | print("heights", heights) 59 | 60 | slices = [[] for _ in range(len(heights))] 61 | for ind, h in enumerate(heights): 62 | for v in verts: 63 | if v[2] == h: 64 | slices[ind].append(v) 65 | 66 | # 3.2 算所有片曾的圆心 67 | centers = [] 68 | polygons = [] 69 | for ind in range(len(heights)): 70 | res = filter_polygon(slices[ind], "all", 15) 71 | for poly in res: 72 | polygons.append(Polygon(poly)) 73 | 74 | polygons = sort_line(polygons) 75 | 76 | # center_cloud = trimesh.points.PointCloud([a.center for a in polygons], [0, 255, 0, 100]) 77 | # scene.add_geometry(center_cloud) 78 | # 79 | # base_cloud = trimesh.points.PointCloud([a.base for a in polygons], [255, 0, 0, 100]) 80 | # scene.add_geometry(base_cloud) 81 | # scene.show() 82 | 83 | # 4. 找每个一base和血管相交,最小的圆 84 | for polygon in polygons[10:]: 85 | base = polygon.base 86 | last_prin = [0, 0, 1] 87 | last_size = 65535 88 | diameters = [] 89 | while True: 90 | stride = 0.01 91 | tweak = [ 92 | [0, 0, stride], 93 | [0, 0, -stride], 94 | [0, stride, 0], 95 | [0, -stride, 0], 96 | [stride, 0, 0], 97 | [-stride, 0, 0], 98 | ] 99 | new_prins = np.array(last_prin) + np.array(tweak) 100 | print(new_prins) 101 | sizes = [] 102 | 103 | for prin in new_prins: 104 | print(prin) 105 | lines = trimesh.intersections.mesh_plane(mesh, prin, base) 106 | size = Polygon(lines).cal_size() 107 | if size < 5: 108 | size = 65535 109 | sizes.append(size) 110 | print(sizes) 111 | 112 | min_ind = np.array(sizes).argmin() 113 | if last_size > sizes[min_ind]: 114 | last_prin = new_prins[min_ind] 115 | last_size = sizes[min_ind] 116 | print("+_+_+_+_", last_size) 117 | else: 118 | break 119 | 120 | lines = trimesh.intersections.mesh_plane(mesh, last_prin, base) 121 | min_polygon = Polygon(lines) 122 | diameters.append(min_polygon.cal_diameter()) 123 | center_cloud = trimesh.points.PointCloud(polygon.points, [0, 255, 0, 100]) 124 | scene.add_geometry(center_cloud) 125 | center_cloud = trimesh.points.PointCloud(min_polygon.points, [255, 0, 0, 100]) 126 | scene.add_geometry(center_cloud) 127 | scene.show() 128 | print("get_min") 129 | input("here") 130 | 131 | # points = [] 132 | # for p in lines: 133 | # points.append([p[0][0], p[0][1], p[0][2]]) 134 | # points = ang_sort(points) 135 | # print(points) 136 | # points = filter_polygon(points, [0, 0, 0], 10) 137 | 138 | # points = trimesh.points.PointCloud(points, [200, 200, 250, 100]) 139 | 140 | # faces = [] 141 | # for ind in range(len(points)): 142 | # faces.append([ind, int((ind + len(points) / 2) % len(points)), int((ind + len(points) / 2 + 1) % len(points))]) 143 | # plane = trimesh.creation.extrude_triangulation(points, [[ind, ind + 1, ind + 2] for ind in range(20)], height=-1) 144 | -------------------------------------------------------------------------------- /medseg/eval.py: -------------------------------------------------------------------------------- 1 | # 在验证集上对模型的多种指标进行评估 2 | import os 3 | from multiprocessing import Pool, cpu_count 4 | from datetime import datetime 5 | import argparse 6 | 7 | from medpy import metric 8 | import nibabel as nib 9 | from tqdm import tqdm 10 | import scipy.ndimage 11 | 12 | import utils.util as util 13 | from utils.config import cfg 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="预测") 18 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 19 | parser.add_argument("opts", nargs=argparse.REMAINDER) 20 | args = parser.parse_args() 21 | 22 | if args.cfg_file is not None: 23 | cfg.update_from_file(args.cfg_file) 24 | if args.opts: 25 | cfg.update_from_list(args.opts) 26 | 27 | cfg.set_immutable(True) 28 | 29 | 30 | def main(): 31 | headers = [] 32 | # TODO: 研究一个更简洁的保持第一行和下面的数顺序一致的方法 33 | metrics = [ 34 | "FP", 35 | "FN", 36 | "TP", 37 | "TN", 38 | "Precision", 39 | "Recall", 40 | "Sensitivity", 41 | "Specificity", 42 | "Accuracy", 43 | "Kappa", 44 | "Dice", 45 | "IOU", 46 | # 3D 47 | "Assd", 48 | "Ravd", 49 | ] 50 | for m in metrics: 51 | if m in cfg.EVAL.METRICS: 52 | headers.append(m) 53 | 54 | preds = util.listdir(cfg.EVAL.PATH.SEG) 55 | labels = util.listdir(cfg.EVAL.PATH.GT) 56 | 57 | f = open(cfg.EVAL.PATH.NAME + "-" + str(datetime.now()) + ".csv", "w") 58 | print("文件", end=",", file=f) 59 | for ind, header in enumerate(headers): 60 | print(header, end="," if ind != len(headers) - 1 else "\n", file=f) 61 | with Pool(cpu_count()) as p: 62 | res = p.map(calculate, [(preds[idx], labels[idx]) for idx in range(len(preds))]) 63 | for pred, r in res: 64 | print(pred, end=",", file=f) 65 | for ind, x in enumerate(r): 66 | print(x, end="," if ind != len(headers) - 1 else "\n", file=f) 67 | f.close() 68 | 69 | 70 | # def write_res(pred, res): 71 | # f = open(cfg.EVAL.PATH.RESULT + "-" + str(datetime.now()) + ".csv", "w+") 72 | # print(pred, end=",", file=f) 73 | # for ind, x in enumerate(res[ind]): 74 | # print(x, end="," if ieval.csv-2020-12-16 18:59:40.731539.csvnd != len(headers) - 1 else "\n", file=f) 75 | # print("\n", file=f) 76 | # f.close() 77 | 78 | 79 | def calculate(input): 80 | pred_name = input[0] 81 | lab_name = input[1] 82 | 83 | predf = nib.load(os.path.join(cfg.EVAL.PATH.SEG, pred_name)) 84 | labf = nib.load(os.path.join(cfg.EVAL.PATH.GT, lab_name)) 85 | 86 | pred = predf.get_fdata() 87 | lab = labf.get_fdata() 88 | print(pred_name, lab_name, pred.shape, lab.shape) 89 | if lab.shape[0] != pred.shape[0]: 90 | ratio = [a / b for a, b in zip(pred.shape, lab.shape)] 91 | lab = scipy.ndimage.interpolation.zoom(lab, ratio, order=1) 92 | print("插值后大小: ", pred.shape, lab.shape) 93 | assert pred.shape == lab.shape, "分割结果和GT大小不同: {},{}, {}".format( 94 | pred.shape, lab.shape, preds[ind] 95 | ) 96 | 97 | temp = [] 98 | if "FP" in cfg.EVAL.METRICS: 99 | # fp = metric.binary.obj_fpr(pred, lab) 100 | # temp.append(fp) 101 | pass 102 | 103 | if "FN" in cfg.EVAL.METRICS: 104 | pass 105 | 106 | if "TP" in cfg.EVAL.METRICS: 107 | tp = metric.binary.true_positive_rate(pred, lab) 108 | temp.append(tp) 109 | 110 | if "TN" in cfg.EVAL.METRICS: 111 | tn = metric.binary.true_negative_rate(pred, lab) 112 | temp.append(tn) 113 | 114 | if "Precision" in cfg.EVAL.METRICS: 115 | prec = metric.binary.precision(pred, lab) 116 | temp.append(prec) 117 | 118 | if "Recall" in cfg.EVAL.METRICS: 119 | rec = metric.binary.recall(pred, lab) 120 | temp.append(rec) 121 | 122 | if "Sensitivity" in cfg.EVAL.METRICS: 123 | rec = metric.binary.sensitivity(pred, lab) # same as recall 124 | temp.append(rec) 125 | 126 | if "Specificity" in cfg.EVAL.METRICS: 127 | spec = metric.binary.specificity(pred, lab) 128 | temp.append(spec) 129 | 130 | if "Accuracy" in cfg.EVAL.METRICS: 131 | tp = metric.binary.true_positive_rate(pred, lab) 132 | tn = metric.binary.true_negative_rate(pred, lab) 133 | acc = (tp + tn) / 2 134 | temp.append(acc) 135 | 136 | if "Kappa" in cfg.EVAL.METRICS: 137 | pass 138 | 139 | if "Dice" in cfg.EVAL.METRICS: 140 | dice = metric.dc(pred, lab) 141 | temp.append(dice) 142 | 143 | if "IOU" in cfg.EVAL.METRICS: 144 | iou = metric.binary.jc(pred, lab) 145 | temp.append(iou) 146 | 147 | if "Assd" in cfg.EVAL.METRICS: 148 | assd = metric.binary.assd(pred, lab) 149 | temp.append(assd) 150 | 151 | if "Ravd" in cfg.EVAL.METRICS: 152 | ravd = metric.binary.ravd(pred, lab) 153 | temp.append(ravd) 154 | 155 | return pred_name, temp 156 | 157 | 158 | # TODO: 绘制箱须图 159 | # https://matplotlib.org/gallery/pyplots/boxplot_demo_pyplot.html#sphx-glr-gallery-pyplots-boxplot-demo-pyplot-py 160 | 161 | if __name__ == "__main__": 162 | parse_args() 163 | main() 164 | -------------------------------------------------------------------------------- /tool/infer/slice2nii.py: -------------------------------------------------------------------------------- 1 | """ 2 | 将切片的推理结果合起来 3 | 默认推理结果是 name-idx.png 格式,格式不同的话修改 get_name 函数 4 | """ 5 | import os 6 | import os.path as osp 7 | import concurrent 8 | from time import sleep 9 | from queue import Queue 10 | import argparse 11 | import multiprocessing 12 | from multiprocessing import Pool 13 | 14 | import nibabel as nib 15 | import cv2 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | from scipy import ndimage 19 | from tqdm import tqdm 20 | 21 | from util import to_pinyin 22 | import util 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--scan_dir", 28 | type=str, 29 | required=True, 30 | help="扫描路径,会去找头文件信息", 31 | ) 32 | parser.add_argument( 33 | "--seg_dir", 34 | type=str, 35 | required=True, 36 | help="nii分割标签输出路径", 37 | ) 38 | parser.add_argument( 39 | "--png_dir", 40 | type=str, 41 | required=True, 42 | help="png格式分割推理结果路径", 43 | ) 44 | parser.add_argument( 45 | "--rot", 46 | type=int, 47 | default=0, 48 | help="对结果进行几次旋转", 49 | ) 50 | parser.add_argument( 51 | "--filter", 52 | default=False, 53 | action="store_true", 54 | help="是否过滤最大连通块", 55 | ) 56 | parser.add_argument( 57 | "--percent", 58 | type=str, 59 | default=None, 60 | help="最大连通块占所有前景标签比例,可以估计分割结果质量,不写不进行统计", 61 | ) 62 | args = parser.parse_args() 63 | 64 | 65 | def get_name(name): 66 | """从文件名中解析扫描序列的名字和片层下标 67 | 68 | Parameters 69 | ---------- 70 | name : str 71 | 片层标签文件名 72 | 73 | Returns 74 | ------- 75 | str, int 76 | 序列名,用于group一个序列的所有推理结果 77 | 序列下标,表示这个片层在序列中的下标 78 | 79 | """ 80 | # TODO: rfind 81 | pos = -name[::-1].find("-") - 1 # 找到最后一个 - 82 | # print(name, name[:pos], int(name[pos + 1 :].split(".")[0])) 83 | return name[:pos], int(name[pos + 1 :].split(".")[0]) 84 | 85 | 86 | # TODO: 多进程不是多线程 87 | class ThreadPool(concurrent.futures.ThreadPoolExecutor): 88 | def __init__(self, maxsize=6, *args, **kwargs): 89 | super(ThreadPool, self).__init__(*args, **kwargs) 90 | self._work_queue = Queue(maxsize=maxsize) 91 | 92 | 93 | # 检查文件匹配情况 94 | img_names = util.listdir(args.png_dir, sort=False) 95 | patient_names = [] 96 | for n in img_names: 97 | n, _ = get_name(n) 98 | if n not in patient_names: 99 | patient_names.append(n) 100 | patient_names = [n + ".nii.gz" for n in patient_names] 101 | 102 | nii_names_set = set(os.listdir(args.scan_dir)) 103 | patient_names_set = set(patient_names) 104 | for n in patient_names_set - nii_names_set: 105 | print(n, "dont have nii") 106 | for n in nii_names_set - patient_names_set: 107 | print(n, "dont have segmentation result") 108 | patient_names.sort() 109 | print(patient_names) 110 | 111 | input("Press any key to start!") 112 | if args.percent: 113 | percent_file = open(args.percent, "a+") 114 | if not os.path.exists(args.seg_dir): 115 | os.makedirs(args.seg_dir) 116 | 117 | 118 | def run(patient): 119 | if osp.exists(osp.join(args.seg_dir, patient)): 120 | print(patient, "already finished, skipping") 121 | return 122 | 123 | patient_imgs = [n for n in img_names if get_name(n)[0] == patient.split(".")[0]] 124 | patient_imgs.sort(key=lambda n: int(get_name(n)[1])) 125 | # print(patient, patient_imgs, len(patient_imgs)) 126 | label = cv2.imread( 127 | os.path.join(args.png_dir, patient_imgs[0]), cv2.IMREAD_UNCHANGED 128 | ) 129 | s = label.shape 130 | label_data = np.zeros([s[0], s[1], len(patient_imgs)], dtype="uint8") 131 | 132 | try: 133 | # print(os.path.join(args.scan_dir, patient)) 134 | scanf = nib.load(os.path.join(args.scan_dir, patient)) 135 | scan_header = scanf.header 136 | except: 137 | print(f"[ERROR] {patient}'s scan is not found! Skipping {patient}") 138 | return 139 | # scanf = nib.load(os.path.join(args.scan_dir, "张金华_20201024213424575a.nii")) 140 | # scan_header = scanf.header 141 | 142 | for img_name in patient_imgs: 143 | img = cv2.imread(os.path.join(args.png_dir, img_name), cv2.IMREAD_UNCHANGED) 144 | ind = int(get_name(img_name)[1]) 145 | label_data[:, :, ind] = img 146 | 147 | save_nii( 148 | label_data, 149 | scanf.affine, 150 | scan_header, 151 | os.path.join(args.seg_dir, patient), 152 | ) # BUG: 貌似会出现最后一两个进程卡住,无法保存的情况 153 | 154 | 155 | if args.percent: 156 | percent_file.close() 157 | 158 | 159 | def save_nii(label_data, affine, header, dir): 160 | print("++++", dir) 161 | print(label_data.shape) 162 | label_data = np.rot90(label_data, args.rot, axes=(0, 1)) 163 | label_data = np.transpose(label_data, [1, 0, 2]) 164 | if args.filter: 165 | tot = label_data.sum() 166 | label_data = util.filter_largest_volume(label_data, mode="hard") 167 | largest = label_data.sum() 168 | if args.percent: 169 | print(osp.basename(dir), largest / tot, file=percent_file) 170 | percent_file.flush() 171 | newf = nib.Nifti1Image(label_data.astype(np.float64), affine, header) 172 | nib.save(newf, dir) 173 | print("--------", "finish", dir) 174 | 175 | 176 | print(patient_names) 177 | with Pool(multiprocessing.cpu_count()) as p: 178 | p.map(run, patient_names) 179 | -------------------------------------------------------------------------------- /medseg/models/unet_plain.py: -------------------------------------------------------------------------------- 1 | # Author: Jingxiao Gu 2 | # Baidu Account: Seigato 3 | # Description: Unet Base Network for Lane Segmentation Competition 4 | # 80 unet 5 | 6 | 7 | import paddle.fluid as fluid 8 | from paddle.fluid.initializer import MSRA 9 | from paddle.fluid.param_attr import ParamAttr 10 | 11 | 12 | def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, act=None, bn=True, bias_attr=False): 13 | conv = fluid.layers.conv2d( 14 | input=input, 15 | num_filters=num_filters, 16 | filter_size=filter_size, 17 | stride=stride, 18 | padding=(filter_size - 1) // 2, 19 | groups=groups, 20 | act=None, 21 | bias_attr=bias_attr, 22 | param_attr=ParamAttr(initializer=MSRA()), 23 | ) 24 | if bn == True: 25 | conv = fluid.layers.batch_norm(input=conv, act=act) 26 | return conv 27 | 28 | 29 | def conv_layer(input, num_filters, filter_size, stride=1, groups=1, act=None): 30 | conv = fluid.layers.conv2d( 31 | input=input, 32 | num_filters=num_filters, 33 | filter_size=filter_size, 34 | stride=stride, 35 | padding=(filter_size - 1) // 2, 36 | groups=groups, 37 | act=act, 38 | bias_attr=ParamAttr(initializer=MSRA()), 39 | param_attr=ParamAttr(initializer=MSRA()), 40 | ) 41 | return conv 42 | 43 | 44 | def shortcut(input, ch_out, stride): 45 | ch_in = input.shape[1] 46 | if ch_in != ch_out or stride != 1: 47 | return conv_bn_layer(input, ch_out, 1, stride) 48 | else: 49 | return input 50 | 51 | 52 | def bottleneck_block(input, num_filters, stride): 53 | conv_bn = conv_bn_layer(input=input, num_filters=num_filters, filter_size=1, act="relu") 54 | conv_bn = conv_bn_layer(input=conv_bn, num_filters=num_filters, filter_size=3, stride=stride, act=None) 55 | short_bn = shortcut(input, num_filters, stride) 56 | return fluid.layers.elementwise_add(x=short_bn, y=conv_bn, act="relu") 57 | 58 | 59 | def encoder_block(input, encoder_depths, encoder_filters, block): 60 | conv_bn = input 61 | for i in range(encoder_depths[block]): 62 | conv_bn = bottleneck_block( 63 | input=conv_bn, num_filters=encoder_filters[block], stride=2 if i == 0 and block != 0 else 1, 64 | ) 65 | print("| Encoder Block", block, conv_bn.shape) 66 | return conv_bn 67 | 68 | 69 | def decoder_block(input, concat_input, decoder_depths, decoder_filters, block): 70 | deconv_bn = input 71 | deconv_bn = fluid.layers.resize_bilinear( 72 | input=deconv_bn, out_shape=(deconv_bn.shape[2] * 2, deconv_bn.shape[3] * 2) 73 | ) 74 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1) 75 | 76 | concat_input = conv_bn_layer( 77 | input=concat_input, num_filters=concat_input.shape[1] // 2, filter_size=1, act="relu" 78 | ) 79 | 80 | deconv_bn = fluid.layers.concat([deconv_bn, concat_input], axis=1) 81 | for i in range(decoder_depths[block]): 82 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1) 83 | print("| Decoder Block", block, deconv_bn.shape) 84 | return deconv_bn 85 | 86 | 87 | def unet_plain(img, label_number, img_size): 88 | print("| Build Custom-Designed Resnet-Unet:") 89 | encoder_depth = [3, 4, 6, 4] 90 | encoder_filters = [64, 128, 256, 512] 91 | decoder_depth = [4, 3, 3, 2] 92 | decoder_filters = [256, 128, 64, 32] 93 | print("| Input Image Data", img.shape) 94 | """ 95 | Encoder 96 | """ 97 | # Start Conv 98 | start_conv = conv_bn_layer(input=img, num_filters=32, filter_size=3, stride=2, act="relu") 99 | start_conv = conv_bn_layer(input=start_conv, num_filters=32, filter_size=3, stride=1, act="relu") 100 | start_pool = fluid.layers.pool2d( 101 | input=start_conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type="max" 102 | ) 103 | print("| Start Convolution", start_conv.shape) 104 | 105 | conv0 = encoder_block(start_pool, encoder_depth, encoder_filters, block=0) 106 | conv1 = encoder_block(conv0, encoder_depth, encoder_filters, block=1) 107 | conv2 = encoder_block(conv1, encoder_depth, encoder_filters, block=2) 108 | conv3 = encoder_block(conv2, encoder_depth, encoder_filters, block=3) 109 | 110 | """ 111 | Decoder 112 | """ 113 | decode_conv1 = decoder_block(conv3, conv2, decoder_depth, decoder_filters, block=0) 114 | decode_conv2 = decoder_block(decode_conv1, conv1, decoder_depth, decoder_filters, block=1) 115 | decode_conv3 = decoder_block(decode_conv2, conv0, decoder_depth, decoder_filters, block=2) 116 | decode_conv4 = decoder_block(decode_conv3, start_conv, decoder_depth, decoder_filters, block=3) 117 | 118 | """ 119 | Output Coder 120 | """ 121 | decode_conv5 = fluid.layers.resize_bilinear(input=decode_conv4, out_shape=img_size) 122 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=32, stride=1) 123 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=16, stride=1) 124 | logit = conv_layer(input=decode_conv5, num_filters=label_number, filter_size=1, act=None) 125 | print("| Output Predictions:", logit.shape) 126 | # logit = fluid.layers.resize_bilinear(input=logit, out_shape=(3384, 1020)) 127 | print("| Final Predictions:", logit.shape) 128 | 129 | return logit 130 | -------------------------------------------------------------------------------- /medseg/models/res_unet.py: -------------------------------------------------------------------------------- 1 | # Author: Jingxiao Gu 2 | # Baidu Account: Seigato 3 | # Description: Unet Simple Network for Lane Segmentation Competition 4 | 5 | import paddle.fluid as fluid 6 | from paddle.fluid.initializer import MSRA 7 | from paddle.fluid.param_attr import ParamAttr 8 | 9 | # 测试:所有的relu都换成leaky_relu试一下 10 | 11 | 12 | def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, act=None, bn=True, bias_attr=False): 13 | conv = fluid.layers.conv2d( 14 | input=input, 15 | num_filters=num_filters, 16 | filter_size=filter_size, 17 | stride=stride, 18 | padding=(filter_size - 1) // 2, # same 19 | groups=groups, 20 | act=act, 21 | bias_attr=bias_attr, 22 | param_attr=ParamAttr(initializer=MSRA()), 23 | ) 24 | if bn == True: 25 | conv = fluid.layers.batch_norm(input=conv, act=act) 26 | return conv 27 | 28 | 29 | def conv_layer(input, num_filters, filter_size, stride=1, groups=1, act=None): 30 | conv = fluid.layers.conv2d( 31 | input=input, 32 | num_filters=num_filters, 33 | filter_size=filter_size, 34 | stride=stride, 35 | padding=(filter_size - 1) // 2, 36 | groups=groups, 37 | act=act, 38 | bias_attr=ParamAttr(initializer=MSRA()), 39 | param_attr=ParamAttr(initializer=MSRA()), 40 | ) 41 | return conv 42 | 43 | 44 | def shortcut(input, ch_out, stride): 45 | ch_in = input.shape[1] 46 | if ch_in != ch_out or stride != 1: 47 | return conv_bn_layer(input, ch_out, 1, stride) 48 | else: 49 | return input 50 | 51 | 52 | def bottleneck_block(input, num_filters, stride): 53 | conv_bn = conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act="leaky_relu") 54 | conv_bn = conv_bn_layer(input=conv_bn, num_filters=num_filters, filter_size=3, stride=stride, act=None) 55 | short_bn = shortcut(input, num_filters, stride) 56 | return fluid.layers.elementwise_add(x=short_bn, y=conv_bn, act="leaky_relu") 57 | 58 | 59 | def encoder_block(input, encoder_depths, encoder_filters, block): 60 | conv_bn = input 61 | for i in range(encoder_depths[block]): 62 | conv_bn = bottleneck_block( 63 | input=conv_bn, num_filters=encoder_filters[block], stride=2 if i == 0 and block != 0 else 1, 64 | ) 65 | print("| Encoder Block", block, conv_bn.shape) 66 | return conv_bn 67 | 68 | 69 | def decoder_block(input, concat_input, decoder_depths, decoder_filters, block): 70 | deconv_bn = input 71 | deconv_bn = fluid.layers.resize_bilinear( 72 | input=deconv_bn, out_shape=(deconv_bn.shape[2] * 2, deconv_bn.shape[3] * 2) 73 | ) 74 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1) 75 | 76 | concat_input = conv_bn_layer( 77 | input=concat_input, num_filters=concat_input.shape[1] // 2, filter_size=1, act="leaky_relu" 78 | ) 79 | 80 | deconv_bn = fluid.layers.concat([deconv_bn, concat_input], axis=1) 81 | for i in range(decoder_depths[block]): 82 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1) 83 | print("| Decoder Block", block, deconv_bn.shape) 84 | return deconv_bn 85 | 86 | 87 | def res_unet(img, label_number, img_size): 88 | print("| Build Custom-Designed Resnet-Unet:") 89 | encoder_depth = [3, 4, 5, 3] 90 | encoder_filters = [64, 128, 256, 512] 91 | decoder_depth = [2, 3, 3, 2] 92 | decoder_filters = [256, 128, 64, 32] 93 | print("| Input Image Data", img.shape) 94 | """ 95 | Encoder 96 | """ 97 | # Start Conv 98 | start_conv = conv_bn_layer(input=img, num_filters=32, filter_size=3, stride=2, act="leaky_relu") 99 | start_conv = conv_bn_layer(input=start_conv, num_filters=32, filter_size=3, stride=1, act="leaky_relu") 100 | start_pool = fluid.layers.pool2d( 101 | input=start_conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type="max" 102 | ) 103 | print("| Start Convolution", start_conv.shape) 104 | 105 | conv0 = encoder_block(start_pool, encoder_depth, encoder_filters, block=0) 106 | conv1 = encoder_block(conv0, encoder_depth, encoder_filters, block=1) 107 | conv2 = encoder_block(conv1, encoder_depth, encoder_filters, block=2) 108 | conv3 = encoder_block(conv2, encoder_depth, encoder_filters, block=3) 109 | 110 | """ 111 | Decoder 112 | """ 113 | decode_conv1 = decoder_block(conv3, conv2, decoder_depth, decoder_filters, block=0) 114 | decode_conv2 = decoder_block(decode_conv1, conv1, decoder_depth, decoder_filters, block=1) 115 | decode_conv3 = decoder_block(decode_conv2, conv0, decoder_depth, decoder_filters, block=2) 116 | decode_conv4 = decoder_block(decode_conv3, start_conv, decoder_depth, decoder_filters, block=3) 117 | 118 | """ 119 | Output Coder 120 | """ 121 | print("+_+_+") 122 | print(decode_conv4.shape) 123 | print(img_size) 124 | print("+_+_+") 125 | decode_conv5 = fluid.layers.resize_bilinear(input=decode_conv4, out_shape=img_size) 126 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=32, stride=1) 127 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=16, stride=1) 128 | logit = conv_layer(input=decode_conv5, num_filters=label_number, filter_size=1, act=None) 129 | print("| Output Predictions:", logit.shape) 130 | # logit = fluid.layers.resize_bilinear(input=logit, out_shape=(3384, 1020)) 131 | print("| Final Predictions:", logit.shape) 132 | 133 | return logit 134 | -------------------------------------------------------------------------------- /tool/train/dataset_scan.py: -------------------------------------------------------------------------------- 1 | """ 2 | 扫描数据集,生成summary 3 | """ 4 | # TODO: 添加窗宽窗位 5 | # TODO: check大小是否一样 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | 11 | import nibabel as nib 12 | import numpy as np 13 | from matplotlib import pyplot as plt 14 | from matplotlib.font_manager import FontProperties 15 | from tqdm import tqdm 16 | 17 | import util 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "-s", 22 | "--scan_dir", 23 | type=str, 24 | required=True, 25 | help="扫描路径", 26 | ) 27 | parser.add_argument("-l", "--label_dir", type=str, help="标签路径") 28 | parser.add_argument("-p", "--plt_dir", type=str, help="强度分布输出路径,不写不进行绘制", default=None) 29 | parser.add_argument( 30 | "--wwwc", 31 | nargs=2, 32 | default=None, 33 | help="窗宽窗位,不写不进行窗宽窗位处理", 34 | ) 35 | parser.add_argument( 36 | "--skip", 37 | nargs=2, 38 | default=["", ""], 39 | help="在检查文件配对的时候,扫描和标签文件开头略过两个字符串,用于匹配 scan-1.nii 和 label-1.nii 这种情况", 40 | ) 41 | args = parser.parse_args() 42 | 43 | font = FontProperties(fname="../SimHei.ttf", size=16) 44 | 45 | util.check_nii_match(args.scan_dir, args.label_dir, args.skip) 46 | 47 | scans = util.listdir(args.scan_dir) 48 | labels = util.listdir(args.label_dir) 49 | assert len(scans) == len(labels), "扫描和标签数量不相等" 50 | 51 | print(f"数据集中共{len(scans)}组扫描和标签,对应情况:") 52 | for idx in range(len(scans)): 53 | print(f"{scans[idx]} \t {labels[idx]}") 54 | 55 | if not osp.exists(osp.join(args.plt_dir, "scan")): 56 | os.makedirs(osp.join(args.plt_dir, "scan")) 57 | if not osp.exists(osp.join(args.plt_dir, "label")): 58 | os.makedirs(osp.join(args.plt_dir, "label")) 59 | 60 | pixdims = [] 61 | shapes = [] 62 | norms = [] 63 | 64 | pbar = tqdm(range(len(scans)), desc="正在统计") 65 | for idx in range(len(scans)): 66 | pbar.set_postfix(filename=scans[idx].split(".")[0]) 67 | pbar.update(1) 68 | 69 | scanf = nib.load(os.path.join(args.scan_dir, scans[idx])) 70 | labelf = nib.load(os.path.join(args.label_dir, labels[idx])) 71 | 72 | header = scanf.header.structarr 73 | shape = scanf.header.get_data_shape() 74 | shapes.append([shape[0], shape[1], shape[2]]) 75 | pixdims.append(header["pixdim"][1:4]) 76 | 77 | scan = scanf.get_fdata() 78 | norms.append([scan.min(), np.median(scan), scan.max()]) 79 | 80 | if args.plt_dir: 81 | scan_plt = osp.join(args.plt_dir, "scan") 82 | label_plt = osp.join(args.plt_dir, "label") 83 | 84 | scan = scan.reshape([scan.size]) 85 | 86 | plt.title(scans[idx].split(".")[0], fontproperties=font) 87 | plt.xlabel( 88 | "size:[{},{},{}] pixdims:[{},{},{}] ".format( 89 | shape[0], 90 | shape[1], 91 | shape[2], 92 | header["pixdim"][1], 93 | header["pixdim"][2], 94 | header["pixdim"][3], 95 | ) 96 | ) 97 | nums, bins, patchs = plt.hist(scan, bins=1000) 98 | plt.savefig(osp.join(scan_plt, scans[idx].split(".")[0] + ".png")) 99 | plt.close() 100 | 101 | file = open(osp.join(scan_plt, f"{scans[idx].split('.')[0]}.txt"), "w") 102 | print("--------- {} --------".format(scans[idx]), file=file) 103 | 104 | sum = 0 105 | for num in nums: 106 | sum += num 107 | nowsum = 0 108 | for i in range(0, len(nums)): 109 | nowsum += nums[i] 110 | print( 111 | "[{:<10f},{:<10f}] : {:>10} percentage : {}".format( 112 | bins[i], bins[i + 1], nums[i], nowsum / sum 113 | ), 114 | file=file, 115 | ) 116 | file.close() 117 | 118 | label = labelf.get_fdata() 119 | label = np.reshape(label, [label.size]) 120 | plt.title( 121 | f"{scans[idx].split('.')[0]} [{np.min(label)},{np.max(label)}]", 122 | fontproperties=font, 123 | ) 124 | plt.xlabel( 125 | "size:[{},{},{}] pixdims:[{},{},{}] ".format( 126 | shape[0], 127 | shape[1], 128 | shape[2], 129 | header["pixdim"][1], 130 | header["pixdim"][2], 131 | header["pixdim"][3], 132 | ) 133 | ) 134 | nums, bins, patchs = plt.hist(label, bins=5) 135 | plt.savefig(osp.join(label_plt, scans[idx].rstrip(".nii") + ".png")) 136 | plt.close() 137 | 138 | file = open( 139 | os.path.join(label_plt, "{}.txt".format(scans[idx].split(".")[0])), 140 | "w", 141 | ) 142 | print("--------- {} --------".format(scans[idx]), file=file) 143 | 144 | sum = 0 145 | for num in nums: 146 | sum += num 147 | nowsum = 0 148 | for i in range(0, len(nums)): 149 | nowsum += nums[i] 150 | print( 151 | "[{:<10f},{:<10f}] : {:>10} percentage : {}".format( 152 | bins[i], bins[i + 1], nums[i], nowsum / sum 153 | ), 154 | file=file, 155 | ) 156 | 157 | pbar.close() 158 | 159 | spacing = np.median(pixdims, axis=0) 160 | size = np.median(shapes, axis=0) 161 | 162 | print(norms) 163 | norm = [] 164 | norms = np.array(norms) 165 | norm.append(np.min(norms[:, 0])) 166 | norm.append(np.median(norms[:, 1])) 167 | norm.append(np.max(norms[:, 2])) 168 | 169 | print(spacing, size, norm) 170 | 171 | file = open("./summary.txt", "w") 172 | 173 | print("spacing", file=file) 174 | for dat in spacing: 175 | print(dat, file=file) 176 | 177 | print("\nsize", file=file) 178 | for dat in size: 179 | print(dat, file=file) 180 | 181 | print("\nnorm", file=file) 182 | for dat in norm: 183 | print(dat, file=file) 184 | file.close() 185 | -------------------------------------------------------------------------------- /medseg/infer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import time 4 | import os 5 | from multiprocessing import Process, Queue 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | import cv2 10 | import nibabel as nib 11 | import paddle 12 | from paddle import fluid 13 | 14 | import utils.util as util 15 | from utils.config import cfg 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="预测") 20 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 21 | parser.add_argument("--use_gpu", action="store_true", default=False, help="使用GPU推理") 22 | parser.add_argument("opts", nargs=argparse.REMAINDER) 23 | args = parser.parse_args() 24 | 25 | if args.cfg_file is not None: 26 | cfg.update_from_file(args.cfg_file) 27 | if args.opts: 28 | cfg.update_from_list(args.opts) 29 | if args.use_gpu: # 命令行参数只能从false改成true,不能声明false 30 | cfg.TRAIN.USE_GPU = True 31 | 32 | cfg.set_immutable(True) 33 | 34 | 35 | def read_data(file_path, q): 36 | """读取数据并进行预处理. 37 | 38 | Parameters 39 | ---------- 40 | file_path : str 41 | 需要读取的数据文件名. 42 | q : queue 43 | 读入的数据放入这个队列. 44 | 45 | """ 46 | print("Start Reading: ", file_path) 47 | volf = nib.load(file_path) 48 | volume = np.array(volf.get_fdata()) 49 | 50 | if cfg.INFER.WINDOWLIZE: 51 | volume = util.windowlize_image(volume, cfg.INFER.WWWC) 52 | if cfg.INFER.DO_INTERP: 53 | header = volf.header.structarr 54 | # pixdim 是这套 ct 三个维度的间距 55 | pixdim = [header["pixdim"][ind] for ind in range(1, 4)] 56 | spacing = list(cfg.INFER.SPACING) 57 | for ind in range(3): 58 | if spacing[ind] == -1: 59 | spacing[ind] = pixdim[ind] 60 | ratio = [pixdim[0] / spacing[0], pixdim[1] / spacing[1], pixdim[2] / spacing[2]] 61 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3) 62 | q.put([volume, volf.affine]) 63 | print("Finish Reading: ", file_path) 64 | 65 | 66 | def post_process(fpath, inference, affine): 67 | """这个函数对数据进行后处理和存盘,防止GPU空等. 68 | 69 | Parameters 70 | ---------- 71 | fpath : str 72 | 要保存的文件名称. 73 | inference : ndarray 74 | 输出的分割标签数组. 75 | """ 76 | print("Start Postprocess: ", fpath) 77 | inference[inference >= cfg.INFER.THRESH] = 1 78 | inference[inference < cfg.INFER.THRESH] = 0 79 | if cfg.INFER.FILTER_LARGES: 80 | inference = util.filter_largest_volume(inference, 1, "soft") 81 | if cfg.INFER.DO_INTERP: 82 | ratio = [1 / x for x in ratio] 83 | inference = scipy.ndimage.interpolation.zoom(inference, ratio, order=3) 84 | 85 | inference = inference.astype("int8") 86 | inference_file = nib.Nifti1Image(inference, affine) 87 | inferece_path = fpath 88 | nib.save(inference_file, inferece_path) 89 | print("Finish POstprocess: ", fpath) 90 | 91 | 92 | def main(): 93 | places = fluid.CUDAPlace(0) if cfg.TRAIN.USE_GPU else fluid.CPUPlace() 94 | exe = fluid.Executor(places) 95 | 96 | infer_exe = fluid.Executor(places) 97 | inference_scope = fluid.core.Scope() 98 | vol_queue = Queue() 99 | 100 | if not os.path.exists(cfg.INFER.PATH.OUTPUT): 101 | os.makedirs(cfg.INFER.PATH.OUTPUT) 102 | 103 | with fluid.scope_guard(inference_scope): 104 | [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model( 105 | cfg.INFER.PATH.PARAM, infer_exe 106 | ) 107 | 108 | inf_volumes = os.listdir(cfg.INFER.PATH.INPUT) 109 | for vol_ind in tqdm(range(len(inf_volumes)), position=0): 110 | inf_volume = inf_volumes[vol_ind] 111 | if vol_queue.empty(): 112 | read_data(os.path.join(cfg.INFER.PATH.INPUT, inf_volumes[vol_ind]), vol_queue) 113 | 114 | volume, affine = vol_queue.get() 115 | print(volume.shape) 116 | 117 | # 这里异步调一个读取数据,必须在get后面,队列是有锁的 118 | if vol_ind != len(inf_volumes): 119 | p = Process( 120 | target=read_data, 121 | args=(os.path.join(cfg.INFER.PATH.INPUT, inf_volumes[vol_ind + 1]), vol_queue), 122 | ) 123 | p.start() 124 | # read_data(inf_path, vol_queue) 125 | 126 | inference = np.zeros(volume.shape) 127 | 128 | batch_size = cfg.INFER.BATCH_SIZE 129 | ind = 0 130 | flag = True 131 | pbar = tqdm(total=volume.shape[2] - 2, position=1, leave=False) 132 | pbar.set_postfix(file=inf_volume) 133 | while flag: 134 | batch_data = [] 135 | for j in range(0, batch_size): 136 | ind = ind + 1 137 | pbar.update(1) 138 | data = volume[:, :, ind - 1 : ind + 2] 139 | data = data.swapaxes(0, 2).reshape([3, data.shape[1], data.shape[0]]).astype("float32") 140 | batch_data.append(data) 141 | 142 | if ind == volume.shape[2] - 2: 143 | flag = False 144 | pbar.refresh() 145 | break 146 | batch_data = np.array(batch_data) 147 | 148 | result = infer_exe.run( 149 | inference_program, feed={feed_target_names[0]: batch_data}, fetch_list=fetch_targets, 150 | ) 151 | 152 | result = np.array(result) 153 | result = result.reshape([-1, 2, 512, 512]) 154 | 155 | ii = ind 156 | for j in range(result.shape[0] - 1, -1, -1): 157 | resp = result[j, 1, :, :].reshape([512, 512]) 158 | resp = resp.swapaxes(0, 1) 159 | inference[:, :, ii] = resp 160 | ii = ii - 1 161 | 162 | # 这里调用多进程后处理和存储 163 | p = Process( 164 | target=post_process, 165 | args=( 166 | os.path.join(cfg.INFER.PATH.OUTPUT, inf_volume).replace("volume", "segmentation"), 167 | inference, 168 | affine, 169 | ), 170 | ) 171 | p.start() 172 | # post_process( 173 | # os.path.join(cfg.INFER.PATH.OUTPUT, inf_volume).replace("volume", "segmentation"), 174 | # inference, 175 | # affine, 176 | # ) 177 | pbar.close() 178 | 179 | 180 | if __name__ == "__main__": 181 | parse_args() 182 | main() 183 | -------------------------------------------------------------------------------- /medseg/vis.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对内存中的ndarray,npz,nii进行可视化 3 | """ 4 | import sys 5 | import os 6 | import argparse 7 | import time 8 | 9 | import SimpleITK as sitk 10 | import nibabel as nib 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | from utils.config import cfg 15 | import utils.util as util 16 | import train 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="数据预处理") 21 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 22 | parser.add_argument("opts", nargs=argparse.REMAINDER) 23 | args = parser.parse_args() 24 | 25 | if args.cfg_file is not None: 26 | cfg.update_from_file(args.cfg_file) 27 | if args.opts: 28 | cfg.update_from_list(args.opts) 29 | 30 | 31 | def show_slice(vol, lab): 32 | """展示一个2.5D的数据对. 33 | 34 | Parameters 35 | ---------- 36 | vol : ndarray 37 | 2.5D的扫描slice. 38 | lab : ndarray 39 | 1片分割标签. 40 | """ 41 | if vol.shape[0] <= 3: # CWH 需要转换WHC 42 | vol = vol.swapaxes(0, 2) 43 | lab = lab.swapaxes(0, 2) 44 | if lab.ndim == 2: 45 | lab = lab[:, :, np.newaxis] 46 | if len(vol.shape) == 3 and vol.shape[2] == 3: # 如果输入是3 channel的取中间一片 47 | vol = vol[:, :, 1] 48 | if len(vol.shape) == 2: 49 | vol = vol[:, :, np.newaxis] 50 | 51 | vol = np.tile(vol, (1, 1, 3)) 52 | lab = np.tile(lab, (1, 1, 3)) 53 | print("vis shape", vol.shape, lab.shape) 54 | vmax = vol.max() 55 | vmin = vol.min() 56 | vol = (vol - vmin) / (vmax - vmin) * 255 57 | lab = lab * 255 58 | 59 | vol = vol.astype("uint8") 60 | lab = lab.astype("uint8") 61 | 62 | plt.figure(figsize=(15, 15)) 63 | plt.subplot(121) 64 | plt.imshow(vol) 65 | plt.subplot(122) 66 | plt.imshow(lab) 67 | plt.show() 68 | plt.close() 69 | 70 | 71 | def show_nii(): 72 | scans = util.listdir(cfg.DATA.INPUTS_PATH) 73 | labels = util.listdir(cfg.DATA.LABELS_PATH) 74 | records = [] 75 | for ind in range(len(scans)): 76 | # for ind in range(3): 77 | print(scans[ind], labels[ind]) 78 | 79 | scanf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, scans[ind])) 80 | labelf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[ind])) 81 | scan = scanf.get_fdata() 82 | scan = util.windowlize_image(scan, cfg.PREP.WWWC) 83 | label = labelf.get_fdata() 84 | scan, label = util.cal_direction(scans[ind], scan, label) 85 | print(scan.shape) 86 | print(label.shape) 87 | sli_ind = int(scan.shape[2] / 6) 88 | # for sli_ind in range(vol.shape[2]): 89 | show_slice(scan[:, :, sli_ind * 2], label[:, :, sli_ind * 2]) 90 | show_slice(scan[:, :, sli_ind * 3], label[:, :, sli_ind * 3]) 91 | show_slice(scan[:, :, sli_ind * 4], label[:, :, sli_ind * 4]) 92 | t = input("是否左右翻转: ") 93 | records.append([scans[ind], t]) 94 | time.sleep(1) 95 | print(records) 96 | 97 | f = open("./flip.csv", "w") 98 | for record in records: 99 | print(record[0] + "," + record[1], file=f) 100 | f.close() 101 | # 1 rot 1 次,2 rot 3 次 102 | # 0 左右不 flip, 1 左右 flip 103 | 104 | 105 | def show_npz(): 106 | """对训练数据npz进行可视化. 107 | 108 | """ 109 | for npz in os.listdir(cfg.TRAIN.DATA_PATH): 110 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz)) 111 | vol = data["imgs"] 112 | lab = data["labs"] 113 | # for ind in range(vol.shape[0]): 114 | # show_slice(vol[ind], lab[ind]) 115 | show_slice(vol[0], lab[0]) 116 | show_slice(vol[vol.shape[0] - 1], lab[vol.shape[0] - 1]) 117 | 118 | 119 | def show_aug(): 120 | """在读取npz基础上做aug之后展示. 121 | 122 | """ 123 | for npz in os.listdir(cfg.TRAIN.DATA_PATH): 124 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz)) 125 | vol = data["imgs"] 126 | lab = data["labs"] 127 | vol = vol.astype("float32") 128 | lab = lab.astype("int32") 129 | if cfg.AUG.WINDOWLIZE: 130 | vol = util.windowlize_image(vol, cfg.AUG.WWWC) # 肝脏常用 131 | # for ind in range(vol.shape[0]): 132 | vol_slice, lab_slice = train.aug_mapper([vol[0], lab[0]]) 133 | show_slice(vol_slice, lab_slice) 134 | 135 | vol_slice, lab_slice = train.aug_mapper([vol[vol.shape[0] - 1], lab[vol.shape[0] - 1]]) 136 | show_slice(vol_slice, lab_slice) 137 | 138 | 139 | if __name__ == "__main__": 140 | parse_args() 141 | show_nii() 142 | # show_npz() 143 | # show_aug() 144 | 145 | 146 | # import os 147 | # import matplotlib.pyplot as plt 148 | # from nibabel.orientations import aff2axcodes 149 | 150 | # vols = "/home/aistudio/data/volume" 151 | # for voln in os.listdir(vols): 152 | # print("--------") 153 | # print(voln) 154 | # 155 | # volf = sitk.ReadImage(os.path.join(vols, voln)) 156 | # 157 | # vold = sitk.GetArrayFromImage(volf) 158 | # print(vold.shape) 159 | # vold[500:512, 250:260, 0] = 2048 160 | # 161 | # plt.imshow(vold[0, :, :]) 162 | # plt.show() 163 | 164 | # vols = "/home/aistudio/data/volume" 165 | # directions = [] 166 | # for voln in os.listdir(vols): 167 | # print("--------") 168 | # print(voln) 169 | # 170 | # volf = nib.load(os.path.join(vols, voln)) 171 | # print(volf.affine) 172 | # print("codes", aff2axcodes(volf.affine)) 173 | # 174 | # vold = volf.get_fdata() 175 | # vold[500:512, 250:260, 0] = 2048 176 | # 177 | # plt.imshow(vold[:, :, 0]) 178 | # plt.show() 179 | # 180 | # cmd = input("direction: ") 181 | # if cmd == "a": 182 | # dir = [voln, 1] # 床在左边 183 | # else: 184 | # dir = [voln, 2] # 床在右边 185 | # directions.append(dir) 186 | # 187 | # 188 | # f = open("./directions.csv", "w") 189 | # for dir in directions: 190 | # print(dir[0], ",", dir[1], file=f) 191 | # f.close() 192 | 193 | 194 | # vols = "/home/aistudio/data/volume" 195 | # 196 | # 197 | # # 获取体位信息 198 | # f = open("./directions.csv") 199 | # dirs = f.readlines() 200 | # print("dirs: ", dirs) 201 | # dirs = [x.rstrip("\n") for x in dirs] 202 | # dirs = [x.split(",") for x in dirs] 203 | # dic = {} 204 | # for dir in dirs: 205 | # dic[dir[0].strip()] = dir[1].strip() 206 | # f.close() 207 | # 208 | # print(dic) 209 | # 210 | # 211 | # for voln in os.listdir(vols): 212 | # print("--------") 213 | # print(voln) 214 | # 215 | # volf = nib.load(os.path.join(vols, voln)) 216 | # vold = volf.get_fdata() 217 | # print(dic[voln]) 218 | # if dic[voln] == "2": 219 | # vold = np.rot90(vold, 3) 220 | # else: 221 | # vold = np.rot90(vold, 1) 222 | # 223 | # plt.imshow(vold[:, :, 50]) 224 | # plt.show() 225 | -------------------------------------------------------------------------------- /tool/zip_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | 将一个目录下的所有文件和文件夹按照原来的文件结构打包,每个包不超过指定大小 4 | """ 5 | 6 | # TODO 测试压缩后删除原文件功能 7 | # TODO: 修改成按照zip文件的大小显示进度条 8 | 9 | import zipfile 10 | import os 11 | from tqdm import tqdm 12 | import platform 13 | import argparse 14 | import logging 15 | 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser("zip_dataset") 19 | parser.add_argument( 20 | "-i", "--dataset_dir", type=str, required=True, help="[必填] 需要压缩的数据集路径,所有文件所在的文件夹" 21 | ) 22 | parser.add_argument( 23 | "-o", 24 | "--zip_dir", 25 | type=str, 26 | required=True, 27 | help="[必填] 压缩后的压缩包保存路径,如果有条件可以和待压缩数据放到不同硬件上,能加快一点速度", 28 | ) 29 | parser.add_argument("--size", type=float, default=10.0, help="[可选] 压缩文件过程中每个包不超过这个大小,G为单位") 30 | parser.add_argument( 31 | "-m", "--method", type=str, default="zip", help="[可选] 压缩方法,可选的有:store(只打包不压缩),zip,bz2,lzma", 32 | ) 33 | parser.add_argument( 34 | "-v", "--verbos", action="store_true", default=False, help="[可选] 执行过程中显示详细信息" 35 | ) 36 | parser.add_argument("--debug", action="store_true", default=False, help="[可选] 执行过程中显示详细信息") 37 | parser.add_argument( 38 | "-d", 39 | "--delete", 40 | action="store_true", 41 | default=False, 42 | help="[慎用] 在压缩完文件后删除对应的原文件,如果盘上空间不够压完就删掉原文件不会炸盘。!!慎用此功能!!", 43 | ) 44 | 45 | return parser.parse_args() 46 | 47 | 48 | def do_zip(args): 49 | # 1. 参数校验和设置 50 | if args.verbos: 51 | level = logging.INFO 52 | elif args.debug: 53 | level = logging.DEBUG 54 | else: 55 | level = logging.CRITICAL 56 | logging.basicConfig(format="%(asctime)s : %(message)s", level=level) 57 | 58 | # TODO: 不同压缩方法不用后缀 59 | methods = { 60 | "zip": zipfile.ZIP_DEFLATED, 61 | "store": zipfile.ZIP_STORED, 62 | "bz2": zipfile.ZIP_BZIP2, 63 | "lzma": zipfile.ZIP_LZMA, 64 | } 65 | try: 66 | mode = methods[args.method] 67 | except KeyError: 68 | raise RuntimeError("mode 参数 {} 不合法".format(mode)) 69 | 70 | # if args.method == "zip": 71 | # mode = zipfile.ZIP_DEFLATED 72 | # elif args.method == "store": 73 | # mode = zipfile.ZIP_STORED 74 | # elif args.method == "bz2": 75 | # mode = zipfile.ZIP_BZIP2 76 | # elif argms.mode == "lzma": 77 | # mode = zipfile.ZIP_LZMA 78 | # else: 79 | # raise RuntimeError("mode参数 {} 不合法".format(mode)) 80 | 81 | # 2. 确认路径和数据集名 82 | print("\n", os.listdir(args.dataset_dir)[:10], "\n") 83 | print("以上是您指定的待压缩路径 {} 下的前10个文件(夹),请确定该路径是否正确".format(args.dataset_dir)) 84 | cmd = input("如果 是 请输入 y/Y ,按其他任意键退出执行: ") 85 | if cmd != "y" and cmd != "Y": 86 | logging.error("用户退出执行") 87 | exit(0) 88 | 89 | if not os.path.exists(args.zip_dir): # 如果zip输出路径不存在创建它 90 | logging.info("创建zip文件夹: {}".format(args.zip_dir)) 91 | os.makedirs(args.zip_dir) 92 | 93 | dataset_name = os.path.basename(args.dataset_dir.rstrip("\\").rstrip("/")) # 文件夹名做数据集名 94 | logging.info("默认数据集名称为: {}".format(dataset_name)) 95 | print("默认数据集名称为: {}".format(dataset_name)) 96 | cmd = input("确认使用该名称请输入 y/Y,如想使用其他名称请输入: ") 97 | if cmd != "y" and cmd != "Y": 98 | dataset_name = cmd 99 | logging.info("最终使用数据集名称为: {}".format(dataset_name)) 100 | 101 | # 3. 制作当前压缩包名,创建压缩包文件 102 | zip_num = 1 103 | curr_name = "{}-{}.zip".format(dataset_name, zip_num) 104 | curr_zip_path = os.path.join(args.zip_dir, curr_name) 105 | f = zipfile.ZipFile(curr_zip_path, "a", mode) 106 | 107 | files_list = [] # 用来存储待压缩的文件路径和在压缩包中的路径,存到一定数量之后一起往包里压,避免频繁查看当前压缩包大小 108 | list_size = 0 # 当前 files_list 中文件总大小, 单位B 109 | zip_tot_size = args.size * 1000 ** 3 # 每个压缩包不超过这个大小,单位B。因为有的设备上是按照1k换算的,所以保险用1B*1000^3做1G 110 | zip_left_size = zip_tot_size # 当前压缩包离最大大小还有多少 111 | 112 | """ 113 | 压缩的整体策略是 114 | 1. 将文件路径添加进 files_list,直到列表中再添加一个文件就会超过压缩包离最大限制的空间 zip_left_size。 115 | 因为files_list文件的文件计算的是没压缩的大小,所以这些文件都加进压缩包大小一定不会超过限制。 116 | 2. 将 files_list 中的文件都加入压缩包,检查当前压缩包的大小,决定是否开新压缩包。 之后继续制作 files_list 117 | 3. 经过多次 2 的添加压缩包大小会接近最大限制,开新压缩包的条件是当前压缩包的大小加上当前准备压缩的文件大小超过了最大限制。 118 | 这里可能会有一点浪费,但是这个文件没实际压进包没法知道是不是会超限制,所以就直接保守开新包了。这里在压缩率比较高包还很小的时候可能会不停的开新包。 119 | """ 120 | for dirpath, dirnames, filenames in os.walk(os.path.join(args.dataset_dir)): 121 | for filename in filenames: 122 | # 获取当前文件大小,判断列表中加入这个文件是否超过限制 123 | curr_file_size = os.path.getsize(os.path.join(dirpath, filename)) 124 | logging.debug("Name: {}, Size: {}".format(filename, curr_file_size / 1024 ** 2)) 125 | list_size += curr_file_size 126 | 127 | if list_size >= zip_left_size: # 如果当前列表中未压缩文件的大小大于zip包能装的大小,那么开始压包 128 | logging.info("当前列表中文件大小是: {} M ".format(list_size / 1024 ** 2)) 129 | logging.info("当前压缩包剩余大小: {} M".format(zip_left_size / 1024 ** 2)) 130 | 131 | logging.critical("正在将 {} 个文件写入压缩包".format(len(files_list))) 132 | logging.debug("前三个文件是: {}".format(str(files_list[:3]))) 133 | logging.debug("最后三个文件是: {}".format(str(files_list[-3:]))) 134 | # 将列表里所有的文件写入zip 135 | for pair in tqdm(files_list, ascii=True): 136 | f.write(pair[0], pair[1]) 137 | if args.delete: 138 | os.remove(pair[0]) 139 | 140 | files_list = [] # 写入完成,清空列表 141 | # 循环中这个pass的文件是在if之后加入列表的,所以列表文件的大小直接就是这个pass文件的大小,下面就添加了 142 | list_size = curr_file_size 143 | 144 | curr_zip_size = os.path.getsize(curr_zip_path) 145 | logging.info("当前压缩包的大小是: {} M\n".format(curr_zip_size / 1024 ** 2)) 146 | 147 | if curr_zip_size + curr_file_size > zip_tot_size: # 如果加入这个文件压缩包就超大小限制了就开新的压缩包 148 | f.close() 149 | zip_num += 1 150 | curr_name = "{}-{}.zip".format(dataset_name, zip_num) 151 | curr_zip_path = os.path.join(args.zip_dir, curr_name) 152 | f = zipfile.ZipFile(curr_zip_path, "a", mode) 153 | logging.critical("\n\n\n正创建新的压缩包: {} ".format(curr_name)) 154 | zip_left_size = zip_tot_size 155 | else: 156 | zip_left_size = zip_tot_size - curr_zip_size 157 | 158 | # 第一个是文件路径,第二个是压缩包中的路径,压缩包中保存原来文件夹的结构 159 | files_list.append( 160 | [ 161 | os.path.join(dirpath, filename), 162 | os.path.join(dataset_name, dirpath[len(args.dataset_dir) + 1 :], filename), 163 | ] 164 | ) 165 | 166 | # 最后一个压缩包一般都不会到限制的大小触发写入,剩下的所有文件写入最后一个压缩包 167 | if len(files_list) != 0: 168 | logging.critical("正在将 {} 个文件写入最后一个压缩包".format(len(files_list))) 169 | for pair in tqdm(files_list, ascii=True): 170 | f.write(pair[0], pair[1]) 171 | files_list = [] 172 | list_size = 0 173 | f.close() 174 | 175 | logging.critical("压缩结束,共 {} 个压缩包".format(zip_num)) 176 | 177 | 178 | if __name__ == "__main__": 179 | args = get_args() 180 | do_zip(args) 181 | -------------------------------------------------------------------------------- /medseg/aug.py: -------------------------------------------------------------------------------- 1 | # 对数据进行增强 2 | # 1. 需要能处理numpy数组,不只是图片 3 | # 2. 要能处理2d和3d的情况 4 | # 3. 所有的操作执行不执行要概率控制 5 | # 4. 首先实现一个能用的版本,逐步实现都用基础矩阵操作不调包 6 | # 5. 所有的函数的默认参数都是调用不做任何变化 7 | 8 | import random 9 | import math 10 | import time 11 | 12 | import cv2 13 | import numpy as np 14 | from scipy.ndimage.interpolation import map_coordinates 15 | from scipy.ndimage.filters import gaussian_filter 16 | import scipy.ndimage 17 | import matplotlib.pyplot as plt 18 | import skimage.io 19 | 20 | from utils.util import pad_volume 21 | 22 | random.seed(time.time()) 23 | 24 | 25 | def flip(volume, label=None, chance=(0, 0, 0)): 26 | """. 27 | 28 | Parameters 29 | ---------- 30 | volume : type 31 | Description of parameter `volume`. 32 | label : type 33 | Description of parameter `label`. 34 | chance : type 35 | Description of parameter `chance`. 36 | 37 | Returns 38 | ------- 39 | flip(volume, label=None, 40 | Description of returned object. 41 | 42 | """ 43 | if random.random() < chance[0]: 44 | volume = volume[::-1, :, :] 45 | if label is not None: 46 | label = label[::-1, :, :] 47 | if random.random() < chance[1]: 48 | volume = volume[:, ::-1, :] 49 | if label is not None: 50 | label = label[:, ::-1, :] 51 | if random.random() < chance[2]: 52 | volume = volume[:, :, ::-1] 53 | if label is not None: 54 | label = label[:, :, ::-1] 55 | 56 | if label is not None: 57 | return volume, label 58 | return volume 59 | 60 | 61 | # x,y,z 任意角度旋转,背景填充,mirror,0,extend 62 | # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab 需要研究 63 | def rotate(volume, label=None, angel=([0, 0], [0, 0], [0, 0]), chance=(0, 0, 0), cval=0): 64 | """ 按照指定象限旋转 65 | angel:是角度不是弧度 66 | """ 67 | 68 | for axes in range(3): 69 | if random.random() < chance[axes]: 70 | rand_ang = angel[axes][0] + random.random() * (angel[axes][1] - angel[axes][0]) 71 | volume = scipy.ndimage.rotate( 72 | volume, 73 | rand_ang, 74 | axes=(axes, (axes + 1) % 3), 75 | reshape=False, 76 | mode="constant", 77 | cval=cval, 78 | ) 79 | if label is not None: 80 | label = scipy.ndimage.rotate( 81 | label, rand_ang, axes=(axes, (axes + 1) % 3), reshape=False, 82 | ) 83 | if label is not None: 84 | return volume, label 85 | return volume 86 | 87 | 88 | # 缩放大小, vol 是三阶, lab 是插值, 给的是目标大小 89 | def zoom(volume, label=None, ratio=[(1, 1), (1, 1), (1, 1)], chance=(0, 0, 0)): 90 | ratio = list(ratio) 91 | chance = list(chance) 92 | for axes in range(3): 93 | if random.random() < chance[axes]: # 如果随机超过做zoom的概率,那就是不做缩放 94 | ratio[axes] = ratio[axes][0] + random.random() * (ratio[axes][1] - ratio[axes][0]) 95 | else: 96 | ratio[axes] = 1 97 | volume = scipy.ndimage.zoom(volume, ratio, order=3, mode="constant") 98 | if label is not None: 99 | label = scipy.ndimage.zoom(label, ratio, order=3, mode="constant") 100 | return volume, label 101 | return volume 102 | 103 | 104 | def crop(volume, label=None, size=[3, 512, 512], pad_value=0): 105 | """在随机位置裁剪出一个指定大小的体积 106 | 每个维度都有输入图片更大或者size更大两种情况: 107 | - 如果输入图片更大,保证不会裁剪出图片,位置随机; 108 | - 如果size更大,只进行pad操作,体积在正中间. 109 | 对于标签,标签中是1的维度不会进行pad;不是1的和volume都一样 110 | Parameters 111 | ---------- 112 | volume : np.ndarray 113 | Description of parameter `volume`. 114 | label : np.ndarray 115 | Description of parameter `label`. 116 | size : 需要裁出的体积,list 117 | Description of parameter `size`. 118 | pad_value : int 119 | volume用pad_value填充,标签默认用0填充. 120 | 121 | Returns 122 | ------- 123 | type 124 | size大小的ndarray. 125 | 126 | """ 127 | # 1. 先pad一手,让数据至少size大 128 | volume = pad_volume(volume, size, pad_value, False) 129 | if label is not None: 130 | lab_size = list(size) 131 | for ind, s in enumerate(label.shape): 132 | if s == 1: # 是1的维度都不动 133 | lab_size[ind] = -1 134 | label = pad_volume(label, lab_size, 0, False) 135 | # 2.随机一个裁剪范围起点,之后进行crop裁剪 136 | crop_low = [int(random.random() * (x - y)) for x, y in zip(volume.shape, size)] 137 | r = [[l, l + s] for l, s in zip(crop_low, size)] 138 | volume = volume[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]] 139 | if label is not None: 140 | for ind in range(3): 141 | if label.shape[ind] == 1: 142 | r[ind][0] = 0 143 | r[ind][1] = 1 144 | label = label[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]] 145 | return volume, label 146 | return volume 147 | 148 | 149 | def elastic_transform(image, alpha, sigma, random_state=None): 150 | if random_state is None: 151 | random_state = np.random.RandomState(None) 152 | shape = image.shape 153 | print("randomstate", random_state.rand(*shape) * 2 - 1, "end") 154 | dx = ( 155 | gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 156 | ) 157 | print(dx.shape) 158 | dy = ( 159 | gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 160 | ) 161 | dz = np.zeros_like(dx) 162 | 163 | x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])) 164 | indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)) 165 | 166 | distored_image = map_coordinates(image, indices, order=1, mode="reflect") 167 | return distored_image.reshape(image.shape) 168 | 169 | 170 | # TODO: 增加随机shift的增强,这个针对显影剂,随机给整个输入加上一个值,肝脏是-80到0的正态 171 | # TODO: 增加随机噪音的增强 172 | # TODO: 增加图片随机组合的增强 173 | # TODO: 全部流程放进一个函数 174 | 175 | # cat = skimage.io.imread("~/Desktop/cat.jpg") 176 | # checker = skimage.io.imread("~/Desktop/checker.png") 177 | # img = np.ones([10, 10]) 178 | # 179 | # img[0:5, 0:5] = 0 180 | # 181 | # 182 | # plt.imshow(cat) 183 | # plt.show() 184 | # if img.ndim == 2: 185 | # img = img[:, :, np.newaxis] 186 | 187 | # print(img.shape) 188 | # img, lab = flip(cat, cat, (1, 1, 0)) 189 | # plt.imshow(img) 190 | # plt.show() 191 | # 192 | # img, lab = rotate(cat, cat, ([-45, 45], 0, [0, 0]), (1, 0, 0)) 193 | # plt.imshow(img) 194 | # plt.show() 195 | # 196 | # img, lab = zoom(cat, cat, [(0.2, 0.3), (0.7, 0.8), (0.9, 1)], (0.5, 1, 0)) 197 | # plt.imshow(img) 198 | # plt.show() 199 | 200 | # print(cat.shape) 201 | # img, label = crop(cat, cat, [400, 500, 3]) 202 | # plt.imshow(img) 203 | # plt.show() 204 | # 205 | # plt.imshow(checker) 206 | # plt.show() 207 | # checker = checker[:, :, np.newaxis] 208 | # img = crop(cat, None, [512, 512, 3]) 209 | # img = elastic_transform(checker, 900, 8) 210 | # img = img.reshape(img.shape[0], img.shape[1]) 211 | # plt.imshow(img) 212 | # plt.show() 213 | 214 | # 215 | # print(img) 216 | # if img.shape[2] == 1: 217 | # img = img.reshape(img.shape[0], img.shape[1]) 218 | # plt.imshow(img) 219 | # plt.show() 220 | 221 | # plt.imshow(lab) 222 | # plt.show() 223 | 224 | 225 | """ 226 | paddle clas 增广策略 227 | 228 | 图像变换类: 229 | 旋转 230 | 色调 231 | 背景模糊 232 | 透明度变换 233 | 饱和度变换 234 | 235 | 图像裁剪类: 236 | 遮挡 237 | 238 | 图像混叠: 239 | 两幅图一定权重直接叠 240 | 一张图切一部分放到另一张图 241 | """ 242 | -------------------------------------------------------------------------------- /medseg/prep_3d.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | 4 | import os 5 | 6 | 7 | import numpy as np 8 | import nibabel as nib 9 | from tqdm import tqdm 10 | import scipy 11 | import matplotlib.pyplot as plt 12 | 13 | import utils.util as util 14 | from utils.config import cfg 15 | import utils.util as util 16 | 17 | import argparse 18 | import aug 19 | 20 | # import vis 21 | 22 | np.set_printoptions(threshold=np.inf) 23 | 24 | 25 | """ 26 | 对 3D 体数据进行一些预处理,并保存成npz文件 27 | 每个npz文件包含volume和label两个数组,volume和label各包含n条扫描记录,文件进行压缩 28 | """ 29 | # TODO: 支持更多的影像格式 30 | # TODO: 提供预处理npz gzip选项 31 | # https://stackoverflow.com/questions/54238670/what-is-the-advantage-of-saving-npz-files-instead-of-npy-in-python-regard 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description="数据预处理") 36 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 37 | parser.add_argument("opts", nargs=argparse.REMAINDER) 38 | args = parser.parse_args() 39 | 40 | if args.cfg_file is not None: 41 | cfg.update_from_file(args.cfg_file) 42 | if args.opts: 43 | cfg.update_from_list(args.opts) 44 | 45 | 46 | def main(): 47 | # 1. 创建输出路径,删除非空的summary表格 48 | if cfg.PREP.PLANE == "xy" and not os.path.exists(cfg.DATA.PREP_PATH): 49 | os.makedirs(cfg.DATA.PREP_PATH) 50 | if cfg.PREP.PLANE == "xz" and not os.path.exists(cfg.DATA.Z_PREP_PATH): 51 | os.makedirs(cfg.DATA.Z_PREP_PATH) 52 | 53 | if os.path.exists(cfg.DATA.SUMMARY_FILE) and os.path.getsize(cfg.DATA.SUMMARY_FILE) != 0: 54 | os.remove(cfg.DATA.SUMMARY_FILE) 55 | 56 | volumes = util.listdir(cfg.DATA.INPUTS_PATH) 57 | labels = util.listdir(cfg.DATA.LABELS_PATH) 58 | 59 | # 获取体位信息 60 | f = open("./directions.csv") 61 | dirs = f.readlines() 62 | # print("dirs: ", dirs) 63 | dirs = [x.rstrip("\n") for x in dirs] 64 | dirs = [x.split(",") for x in dirs] 65 | dic = {} 66 | for dir in dirs: 67 | dic[dir[0].strip()] = dir[1].strip() 68 | f.close() 69 | 70 | vol_npz = [] 71 | lab_npz = [] 72 | npz_count = 0 73 | thick = (cfg.TRAIN.THICKNESS - 1) / 2 74 | pbar = tqdm(range(len(labels)), desc="数据处理中") 75 | for i in range(len(labels)): 76 | pbar.set_postfix(filename=labels[i] + " " + volumes[i]) 77 | pbar.update(1) 78 | 79 | print(volumes[i], labels[i]) 80 | 81 | volf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, volumes[i])) 82 | labf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[i])) 83 | 84 | util.save_info(volumes[i], volf.header, cfg.DATA.SUMMARY_FILE) 85 | 86 | volume = volf.get_fdata() 87 | label = labf.get_fdata() 88 | label = label.astype(int) 89 | # plt.imshow(volume[:, :, 0]) 90 | # plt.show() 91 | if dic[volumes[i]] == "2": 92 | volume = np.rot90(volume, 3) 93 | label = np.rot90(label, 3) 94 | else: 95 | volume = np.rot90(volume, 1) 96 | label = np.rot90(label, 1) 97 | 98 | # plt.imshow(volume[:, :, 0]) 99 | # plt.show() 100 | 101 | if cfg.PREP.INTERP: 102 | print("interping") 103 | header = volf.header.structarr 104 | spacing = cfg.PREP.INTERP_PIXDIM 105 | # pixdim 是 ct 三个维度的间距 106 | pixdim = [header["pixdim"][x] for x in range(1, 4)] 107 | for ind in range(3): 108 | if spacing[ind] == -1: # 如果目标spacing为 -1 ,这个维度不进行插值 109 | spacing[ind] = pixdim[ind] 110 | ratio = [x / y for x, y in zip(spacing, pixdim)] 111 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3) 112 | label = scipy.ndimage.interpolation.zoom(label, ratio, order=0) 113 | 114 | if cfg.PREP.WINDOW: 115 | volume = util.windowlize_image(volume, cfg.PREP.WWWC) 116 | 117 | label = util.clip_label(label, cfg.PREP.FRONT) 118 | 119 | if cfg.PREP.CROP: # 裁到只有前景 120 | bb_min, bb_max = get_bbs(label) 121 | label = crop_to_bbs(label, bb_min, bb_max, 0.5)[0] 122 | volume = crop_to_bbs(volume, bb_min, bb_max)[0] 123 | 124 | label = pad_volume(label, [512, 512, 0], 0) # NOTE: 注意标签使用 0 125 | volume = pad_volume(volume, [512, 512, 0], -1024) 126 | print("after padding", volume.shape, label.shape) 127 | 128 | volume = volume.astype(np.float16) 129 | label = label.astype(np.int8) 130 | 131 | crop_size = list(cfg.PREP.SIZE) 132 | for ind in range(3): 133 | if crop_size[ind] == -1: 134 | crop_size[ind] = volume.shape[ind] 135 | volume, label = aug.crop(volume, label, crop_size) 136 | 137 | # 开始切片 138 | if cfg.PREP.PLANE == "xy": 139 | for frame in range(1, volume.shape[2] - 1): 140 | if label[:, :, frame].sum() > cfg.PREP.THRESH: 141 | vol = volume[:, :, frame - thick : frame + thick + 1] 142 | lab = label[:, :, frame] 143 | lab = lab[:, :, np.newaxis] 144 | 145 | vol = np.swapaxes(vol, 0, 2) 146 | lab = np.swapaxes(lab, 0, 2) # [3,512,512],CWH 的顺序 147 | 148 | vol_npz.append(vol.copy()) 149 | lab_npz.append(lab.copy()) 150 | print("{} 片满足,当前共 {}".format(frame, len(vol_npz))) 151 | 152 | if len(vol_npz) == cfg.PREP.BATCH_SIZE or ( 153 | i == (len(labels) - 1) and frame == volume.shape[2] - 1 154 | ): 155 | imgs = np.array(vol_npz) 156 | labs = np.array(lab_npz) 157 | print(imgs.shape) 158 | print(labs.shape) 159 | print("正在存盘") 160 | file_name = "{}_{}_f{}-{}".format( 161 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count 162 | ) 163 | file_path = os.path.join(cfg.DATA.PREP_PATH, file_name) 164 | np.savez(file_path, imgs=imgs, labs=labs) 165 | vol_npz = [] 166 | lab_npz = [] 167 | npz_count += 1 168 | else: 169 | print(volume.shape, label.shape) 170 | for frame in range(1, volume.shape[0] - 1): 171 | if label[frame, :, :].sum() > cfg.PREP.THRESH: 172 | vol = volume[frame - 1 : frame + 2, :, :] 173 | lab = label[frame, :, :] 174 | lab = lab.reshape([1, lab.shape[0], lab.shape[1]]) 175 | 176 | vol_npz.append(vol.copy()) 177 | lab_npz.append(lab.copy()) 178 | 179 | if len(vol_npz) == cfg.PREP.BATCH_SIZE: 180 | vols = np.array(vol_npz) 181 | labs = np.array(lab_npz) 182 | print(vols.shape) 183 | print(labs.shape) 184 | print("正在存盘") 185 | file_name = "{}_{}_f{}-{}".format( 186 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count 187 | ) 188 | file_path = os.path.join(cfg.DATA.Z_PREP_PATH, file_name) 189 | np.savez(file_path, vols=vols, labs=labs) 190 | vol_npz = [] 191 | lab_npz = [] 192 | npz_count += 1 193 | 194 | pbar.close() 195 | 196 | 197 | if __name__ == "__main__": 198 | parse_args() 199 | main() 200 | -------------------------------------------------------------------------------- /medseg/utils/config.py: -------------------------------------------------------------------------------- 1 | import six 2 | from ast import literal_eval 3 | import codecs 4 | import yaml 5 | 6 | # 使用的时候如果直接赋值出去,默认是不可变的,如果需要再赋值一定注意 7 | class PjConfig(dict): 8 | def __init__(self, *args, **kwargs): 9 | super(PjConfig, self).__init__(*args, **kwargs) 10 | self.immutable = False 11 | 12 | def __setattr__(self, key, value, create_if_not_exist=True): 13 | if key in ["immutable"]: 14 | self.__dict__[key] = value 15 | return 16 | 17 | t = self 18 | keylist = key.split(".") 19 | for k in keylist[:-1]: 20 | t = t.__getattr__(k, create_if_not_exist) 21 | 22 | t.__getattr__(keylist[-1], create_if_not_exist) 23 | t[keylist[-1]] = value 24 | 25 | def __getattr__(self, key, create_if_not_exist=True): 26 | if key in ["immutable"]: 27 | return self.__dict__[key] 28 | 29 | if not key in self: 30 | if not create_if_not_exist: 31 | raise KeyError 32 | self[key] = PjConfig() 33 | return self[key] 34 | 35 | def __setitem__(self, key, value): 36 | if self.immutable: 37 | raise AttributeError( 38 | 'Attempted to set "{}" to "{}", but PjConfig is immutable'.format( 39 | key, value 40 | ) 41 | ) 42 | if isinstance(value, six.string_types): 43 | try: 44 | value = literal_eval(value) 45 | except ValueError: 46 | pass 47 | except SyntaxError: 48 | pass 49 | super(PjConfig, self).__setitem__(key, value) 50 | 51 | def update_from_Config(self, other): 52 | if isinstance(other, dict): 53 | other = PjConfig(other) 54 | assert isinstance(other, PjConfig) 55 | diclist = [("", other)] 56 | while len(diclist): 57 | prefix, tdic = diclist[0] 58 | diclist = diclist[1:] 59 | for key, value in tdic.items(): 60 | key = "{}.{}".format(prefix, key) if prefix else key 61 | if isinstance(value, dict): 62 | diclist.append((key, value)) 63 | continue 64 | try: 65 | self.__setattr__(key, value, create_if_not_exist=False) 66 | except KeyError: 67 | raise KeyError("Non-existent config key: {}".format(key)) 68 | self.check() 69 | 70 | def update_from_list(self, config_list): 71 | if len(config_list) % 2 != 0: 72 | raise ValueError( 73 | "Command line options config format error! Please check it: {}".format( 74 | config_list 75 | ) 76 | ) 77 | for key, value in zip(config_list[0::2], config_list[1::2]): 78 | try: 79 | self.__setattr__(key, value, create_if_not_exist=False) 80 | except KeyError: 81 | raise KeyError("Non-existent config key: {}".format(key)) 82 | self.check() 83 | 84 | def update_from_file(self, config_file): 85 | with codecs.open(config_file, "r", "utf-8") as file: 86 | dic = yaml.load(file, Loader=yaml.FullLoader) 87 | self.update_from_Config(dic) 88 | 89 | def set_immutable(self, immutable): 90 | self.immutable = immutable 91 | for value in self.values(): 92 | if isinstance(value, PjConfig): 93 | value.set_immutable(immutable) 94 | 95 | def is_immutable(self): 96 | return self.immutable 97 | 98 | def check(self): 99 | if cfg.PREP.THICKNESS % 2 != 1: 100 | raise ValueError("2.5D预处理厚度 {} 不是奇数".format(cfg.TRAIN.THICKNESS)) 101 | 102 | 103 | cfg = PjConfig() 104 | 105 | """数据集配置""" 106 | # 数据集名称 107 | cfg.DATA.NAME = "lits" 108 | # 输入的2D或3D图像路径 109 | cfg.DATA.INPUTS_PATH = "/home/aistudio/data/scan" 110 | # 标签路径 111 | cfg.DATA.LABELS_PATH = "/home/aistudio/data/label" 112 | # 预处理输出npz路径 113 | cfg.DATA.PREP_PATH = "/home/aistudio/data/preprocess" 114 | # z 方向初始化可以指定一个独立的输出文件路径 115 | cfg.DATA.Z_PREP_PATH = cfg.DATA.PREP_PATH 116 | # 预处理过程中数据信息写到这个文件 117 | cfg.DATA.SUMMARY_FILE = "./{}.csv".format(cfg.DATA.NAME) 118 | 119 | """ 预处理配置 """ 120 | # 预处理进行的平面 121 | cfg.PREP.PLANE = "xy" 122 | # 处理过程中所有比这个数字大的标签都设为前景 123 | cfg.PREP.FRONT = 1 124 | # 是否将数据只 crop 到前景 125 | cfg.PREP.CROP = False 126 | # 是否对数据插值改变大小 127 | cfg.PREP.INTERP = False 128 | # 进行插值的话目标片间间隔是多少,单位mm,-1的维度不会进行插值 129 | cfg.PREP.INTERP_PIXDIM = (-1, -1, 1.0) 130 | # 是否进行窗口化,在预处理阶段不建议做,灵活性太低 131 | cfg.PREP.WINDOW = False 132 | # 窗宽窗位 133 | cfg.PREP.WWWC = (400, 0) 134 | # 丢弃前景数量少于thresh的slice 135 | cfg.PREP.THRESH = 256 136 | # 3D的数据在开始切割之前pad到这个大小,-1的维度会放着不动 137 | cfg.PREP.SIZE = (512, 512, -1) 138 | # 2.5D预处理一片的厚度 139 | cfg.PREP.THICKNESS = 3 140 | # 预处理过程中多少组数据组成一个npz文件 141 | # 可以先跑bs=1,看看一对数据多大;尽量至少将训练数据分入10个npz,否则分训练和验证集的时候会很不准 142 | # 这个值不建议给成 2^n,这样更利于随机打乱数据 143 | cfg.PREP.BATCH_SIZE = 128 144 | 145 | """训练配置""" 146 | cfg.TRAIN.DATA_PATH = "/home/aistudio/data/preprocess" 147 | # 训练数据的数量,用来显示训练进度条和时间估计,如果不知道有多少写-1 148 | cfg.TRAIN.DATA_COUNT = -1 149 | # 预训练权重路径,如果没有写空,有的话会尝试加载 150 | cfg.TRAIN.PRETRAINED_WEIGHT = "" 151 | # 预测裁剪模型保存路径 152 | cfg.TRAIN.INF_MODEL_PATH = "./model/lits/inf" 153 | # 可以继续训练的ckpt模型保存路径 154 | cfg.TRAIN.CKPT_MODEL_PATH = "./model/lits/ckpt" 155 | # 效果最好的模型保存路径 156 | cfg.TRAIN.BEST_MODEL_PATH = "./model/lits/best" 157 | # 训练过程中输入图像大小,不加channel 158 | cfg.TRAIN.INPUT_SIZE = (512, 512) 159 | # 训练过程中用的batch_size 160 | cfg.TRAIN.BATCH_SIZE = 32 161 | # 共训练多少个epoch 162 | cfg.TRAIN.EPOCHS = 20 163 | # 使用的模型结构 164 | cfg.TRAIN.ARCHITECTURE = "res_unet" 165 | # 使用的正则化方法,支持L1,L2,其他一切值都是不加正则化 166 | cfg.TRAIN.REG_TYPE = "L1" 167 | # 正则化的权重 168 | cfg.TRAIN.REG_COEFF = 1e-6 169 | # 梯度下降方法 170 | cfg.TRAIN.OPTIMIZER = "adam" 171 | # 学习率 172 | cfg.TRAIN.LR = [0.003, 0.002, 0.001] 173 | # 学习率变化step 174 | cfg.TRAIN.BOUNDARIES = [10000, 20000] 175 | # Loss 支持ce,dice,miou,wce,focal 176 | cfg.TRAIN.LOSS = ["ce", "dice"] 177 | # 是否使用GPU进行训练 178 | cfg.TRAIN.USE_GPU = False 179 | # 进行验证 180 | cfg.TRAIN.DO_EVAL = False 181 | # 每 snapchost_epoch 做一次eval并保存模型 182 | cfg.TRAIN.SNAPSHOT_BATCH = 500 183 | # 每 disp_epoch 打出一次训练过程 184 | cfg.TRAIN.DISP_BATCH = 10 185 | # VDL log路径 186 | cfg.TRAIN.VDL_LOG = "/home/aistudio/log" 187 | 188 | """ HRNET 设置""" 189 | # HRNET STAGE2 设置 190 | cfg.MODEL.HRNET.STAGE2.NUM_MODULES = 1 191 | cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS = [40, 80] 192 | # HRNET STAGE3 设置 193 | cfg.MODEL.HRNET.STAGE3.NUM_MODULES = 4 194 | cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160] 195 | # HRNET STAGE4 设置 196 | cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3 197 | cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320] 198 | 199 | """数据增强""" 200 | # 不单独为增强操作设做不做的config,不想做概率设成 0,注意CWH 201 | # 每个维度进行翻转增强的概率,CWH 202 | # 是否进行窗口化 203 | cfg.AUG.WINDOWLIZE = True 204 | # 窗宽窗位 205 | cfg.AUG.WWWC = cfg.PREP.WWWC 206 | # 进行翻转数据增强的概率 207 | cfg.AUG.FLIP.RATIO = (0, 0, 0) 208 | # 进行旋转增强的概率 209 | cfg.AUG.ROTATE.RATIO = (0, 0, 0) 210 | # 旋转的角度范围,单位度 211 | cfg.AUG.ROTATE.RANGE = (0, (0, 0), 0) 212 | # 进行缩放的概率 213 | cfg.AUG.ZOOM.RATIO = (0, 0, 0) 214 | # 进行缩放的比例 215 | cfg.AUG.ZOOM.RANGE = ((1, 1), (1, 1), (1, 1)) 216 | # 进行随机crop的目标大小 217 | cfg.AUG.CROP.SIZE = (3, 512, 512) 218 | 219 | """推理配置""" 220 | # 推理的输入数据路径 221 | cfg.INFER.PATH.INPUT = "/home/aistudio/data/inference" 222 | # 推理的结果输出路径 223 | cfg.INFER.PATH.OUTPUT = "/home/aistudio/data/infer_lab" 224 | # 推理的模型权重路径 225 | cfg.INFER.PATH.PARAM = "/home/aistudio/weight/liver/inf" 226 | # 是否使用GPU进行推理 227 | cfg.INFER.USE_GPU = False 228 | # 推理过程中的 batch_size 229 | cfg.INFER.BATCH_SIZE = 128 230 | # 是否进行窗口化,这个和训练过程中的配置应当相同 231 | cfg.INFER.WINDOWLIZE = True 232 | # 窗宽窗位 233 | cfg.INFER.WWWC = cfg.PREP.WWWC 234 | # 是否进行插值 235 | cfg.INFER.DO_INTERP = False 236 | # 如果进行插值,目标的spacing,-1的维度忽略 237 | cfg.INFER.SPACING = [-1, -1, 1] 238 | # 是否进行最大连通块过滤 239 | cfg.INFER.FILTER_LARGES = True 240 | # 推理过程中区分前景和背景的阈值 241 | cfg.INFER.THRESH = 0.5 242 | 243 | """ 测试配置 """ 244 | # 分割结果的路径 245 | cfg.EVAL.PATH.SEG = "/home/aistudio/data/infer_lab" 246 | # 分割GT标签的路径 247 | cfg.EVAL.PATH.GT = "/home/aistudio/data/eval_lab" 248 | # 评估结果存储的文件 249 | cfg.EVAL.PATH.NAME = "eval" 250 | # 测试过程中要计算的指标,包括 251 | # FP,FN,TP,TN(绝对数量) 252 | # Precision,Recall/Sensitivity,Specificity,Accuracy,Kappa 253 | # Dice,IOU/VOE 254 | cfg.EVAL.METRICS = [ 255 | "IOU", 256 | "Dice", 257 | "TP", 258 | "TN", 259 | "Precision", 260 | "Recall", 261 | "Sensitivity", 262 | "Specificity", 263 | "Accuracy", 264 | ] 265 | -------------------------------------------------------------------------------- /medseg/prep_png.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | 4 | import os 5 | import argparse 6 | 7 | import numpy as np 8 | import nibabel as nib 9 | from tqdm import tqdm 10 | import scipy 11 | import cv2 12 | import matplotlib.pyplot as plt 13 | 14 | import utils.util as util 15 | from utils.config import cfg 16 | import utils.util as util 17 | import aug 18 | 19 | np.set_printoptions(threshold=np.inf) 20 | 21 | 22 | """ 23 | 对 3D 体数据进行一些预处理,并保存成npz文件 24 | 每个npz文件包含volume和label两个数组,volume和label各包含n条扫描记录,文件进行压缩 25 | """ 26 | # TODO: 支持更多的影像格式 27 | # TODO: 提供预处理npz gzip选项 28 | # https://stackoverflow.com/questions/54238670/what-is-the-advantage-of-saving-npz-files-instead-of-npy-in-python-regard 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description="数据预处理") 33 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 34 | parser.add_argument("opts", nargs=argparse.REMAINDER) 35 | args = parser.parse_args() 36 | 37 | if args.cfg_file is not None: 38 | cfg.update_from_file(args.cfg_file) 39 | if args.opts: 40 | cfg.update_from_list(args.opts) 41 | 42 | 43 | def main(): 44 | # 1. 创建输出路径,删除非空的summary表格 45 | if cfg.PREP.PLANE == "xy" and not os.path.exists(cfg.DATA.PREP_PATH): 46 | os.makedirs(cfg.DATA.PREP_PATH) 47 | if cfg.PREP.PLANE == "xz" and not os.path.exists(cfg.DATA.Z_PREP_PATH): 48 | os.makedirs(cfg.DATA.Z_PREP_PATH) 49 | 50 | if os.path.exists(cfg.DATA.SUMMARY_FILE) and os.path.getsize(cfg.DATA.SUMMARY_FILE) != 0: 51 | os.remove(cfg.DATA.SUMMARY_FILE) 52 | 53 | volumes = util.listdir(cfg.DATA.INPUTS_PATH) 54 | labels = util.listdir(cfg.DATA.LABELS_PATH) 55 | 56 | # 获取体位信息 57 | f = open("./directions.csv") 58 | dirs = f.readlines() 59 | # print("dirs: ", dirs) 60 | dirs = [x.rstrip("\n") for x in dirs] 61 | dirs = [x.split(",") for x in dirs] 62 | dic = {} 63 | for dir in dirs: 64 | dic[dir[0].strip()] = dir[1].strip() 65 | f.close() 66 | 67 | vol_npz = [] 68 | lab_npz = [] 69 | npz_count = 0 70 | 71 | pbar = tqdm(range(len(labels)), desc="数据处理中") 72 | for i in range(len(labels)): 73 | pbar.set_postfix(filename=labels[i] + " " + volumes[i]) 74 | pbar.update(1) 75 | 76 | print(volumes[i], labels[i]) 77 | 78 | volf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, volumes[i])) 79 | labf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[i])) 80 | 81 | util.save_info(volumes[i], volf.header, cfg.DATA.SUMMARY_FILE) 82 | 83 | volume = volf.get_fdata() 84 | label = labf.get_fdata() 85 | label = label.astype(int) 86 | 87 | # plt.imshow(volume[:, :, 0]) 88 | # plt.show() 89 | 90 | if dic[volumes[i]] == "2": 91 | volume = np.rot90(volume, 3) 92 | label = np.rot90(label, 3) 93 | else: 94 | volume = np.rot90(volume, 1) 95 | label = np.rot90(label, 1) 96 | 97 | # plt.imshow(volume[:, :, 0]) 98 | # plt.show() 99 | 100 | if cfg.PREP.INTERP: 101 | print("interping") 102 | header = volf.header.structarr 103 | spacing = cfg.PREP.INTERP_PIXDIM 104 | # pixdim 是 ct 三个维度的间距 105 | pixdim = [header["pixdim"][x] for x in range(1, 4)] 106 | for ind in range(3): 107 | if spacing[ind] == -1: # 如果目标spacing为 -1 ,这个维度不进行插值 108 | spacing[ind] = pixdim[ind] 109 | ratio = [x / y for x, y in zip(spacing, pixdim)] 110 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3) 111 | label = scipy.ndimage.interpolation.zoom(label, ratio, order=0) 112 | 113 | if cfg.PREP.WINDOW: 114 | volume = util.windowlize_image(volume, cfg.PREP.WWWC) 115 | 116 | label = util.clip_label(label, cfg.PREP.FRONT) 117 | 118 | if cfg.PREP.CROP: # 裁到只有前景 119 | bb_min, bb_max = get_bbs(label) 120 | label = crop_to_bbs(label, bb_min, bb_max, 0.5)[0] 121 | volume = crop_to_bbs(volume, bb_min, bb_max)[0] 122 | 123 | label = pad_volume(label, [512, 512, 0], 0) # NOTE: 注意标签使用 0 124 | volume = pad_volume(volume, [512, 512, 0], -1024) 125 | print("after padding", volume.shape, label.shape) 126 | 127 | volume = volume.astype(np.float16) 128 | label = label.astype(np.int8) 129 | 130 | crop_size = list(cfg.PREP.SIZE) 131 | for ind in range(3): 132 | if crop_size[ind] == -1: 133 | crop_size[ind] = volume.shape[ind] 134 | volume, label = aug.crop(volume, label, crop_size) 135 | 136 | # 开始切片 137 | volume = volume.clip(-200, 200) 138 | volume = (volume + 200) / 400 * 255 139 | volume = volume.astype(np.uint8) 140 | print(volume.dtype) 141 | 142 | if cfg.PREP.PLANE == "xy": 143 | for frame in range(1, volume.shape[2] - 1): 144 | if label[:, :, frame].sum() > cfg.PREP.THRESH: 145 | # vol = volume[:, :, frame - 1 : frame + 2] 146 | lab = label[:, :, frame] 147 | 148 | vol = volume[:, :, frame] 149 | lab = lab * 255 150 | 151 | # print(vol.shape) 152 | # print(lab.shape) 153 | cv2.imwrite( 154 | os.path.join( 155 | cfg.DATA.PREP_PATH, 156 | "imgs", 157 | "lits-{}-{}.png".format( 158 | volumes[i].lstrip("volume-").rstrip(".nii"), frame 159 | ), 160 | ), 161 | vol, 162 | ) 163 | 164 | cv2.imwrite( 165 | os.path.join( 166 | cfg.DATA.PREP_PATH, 167 | "labs", 168 | "lits-{}-{}.png".format( 169 | volumes[i].lstrip("volume-").rstrip(".nii"), frame 170 | ), 171 | ), 172 | lab, 173 | ) 174 | 175 | # volimg = Image.fromarray(vol) 176 | # labimg = Image.fromarray(lab, "L") 177 | # volimg.save( 178 | # os.path.join( 179 | # cfg.DATA.PREP_PATH, 180 | # "imgs", 181 | # "lits-{}-{}.png".format( 182 | # volumes[i].lstrip("volume-").rstrip(".nii"), frame 183 | # ), 184 | # ) 185 | # ) 186 | # labimg.save( 187 | # os.path.join( 188 | # cfg.DATA.PREP_PATH, 189 | # "labs", 190 | # "lits-{}-{}.png".format( 191 | # volumes[i].lstrip("volume-").rstrip(".nii"), frame 192 | # ), 193 | # ) 194 | # ) 195 | # volimg.close() 196 | # labimg.close() 197 | 198 | # input("here") 199 | 200 | else: 201 | print(volume.shape, label.shape) 202 | for frame in range(1, volume.shape[0] - 1): 203 | if label[frame, :, :].sum() > cfg.PREP.THRESH: 204 | vol = volume[frame - 1 : frame + 2, :, :] 205 | lab = label[frame, :, :] 206 | lab = lab.reshape([1, lab.shape[0], lab.shape[1]]) 207 | 208 | vol_npz.append(vol.copy()) 209 | lab_npz.append(lab.copy()) 210 | 211 | if len(vol_npz) == cfg.PREP.BATCH_SIZE: 212 | vols = np.array(vol_npz) 213 | labs = np.array(lab_npz) 214 | print(vols.shape) 215 | print(labs.shape) 216 | print("正在存盘") 217 | file_name = "{}_{}_f{}-{}".format( 218 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count 219 | ) 220 | file_path = os.path.join(cfg.DATA.Z_PREP_PATH, file_name) 221 | np.savez(file_path, vols=vols, labs=labs) 222 | vol_npz = [] 223 | lab_npz = [] 224 | npz_count += 1 225 | 226 | pbar.close() 227 | 228 | 229 | if __name__ == "__main__": 230 | parse_args() 231 | main() 232 | -------------------------------------------------------------------------------- /medseg/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import sys 3 | import os 4 | import argparse 5 | import random 6 | import math 7 | from datetime import datetime 8 | import multiprocessing 9 | 10 | import numpy as np 11 | from tqdm.auto import tqdm 12 | import paddle 13 | import paddle.fluid as fluid 14 | from paddle.fluid.layers import log 15 | from visualdl import LogWriter 16 | 17 | 18 | import utils.util as util 19 | from utils.config import cfg 20 | import loss 21 | import aug 22 | from models.model import create_model 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="训练") 27 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径") 28 | parser.add_argument("--use_gpu", action="store_true", default=False, help="是否用GPU") 29 | parser.add_argument("--do_eval", action="store_true", default=False, help="是否进行测试") 30 | parser.add_argument("opts", nargs=argparse.REMAINDER) 31 | args = parser.parse_args() 32 | 33 | if args.cfg_file is not None: 34 | cfg.update_from_file(args.cfg_file) 35 | if args.opts: 36 | cfg.update_from_list(args.opts) 37 | if args.use_gpu: # 命令行参数只能从false改成true,不能声明false 38 | cfg.TRAIN.USE_GPU = True 39 | if args.do_eval: 40 | cfg.TRAIN.DO_EVAL = True 41 | 42 | cfg.set_immutable(True) 43 | # TODO: 打印cfg配置 44 | 45 | 46 | npz_names = util.listdir(cfg.TRAIN.DATA_PATH) 47 | random.shuffle(npz_names) 48 | 49 | 50 | def data_reader(part_start=0, part_end=8): 51 | # NOTE: 这种分法效率高好写,但是npz很少的时候分得不准。npz文件至少分10个 52 | npz_part = npz_names[int(len(npz_names) * part_start / 10) : int(len(npz_names) * part_end / 10)] 53 | 54 | def reader(): 55 | # BUG: tqdm每次更新都另起一行,此外要单独测试windows上好不好使 56 | if cfg.TRAIN.DATA_COUNT != -1: 57 | pbar = tqdm(total=cfg.TRAIN.DATA_COUNT, desc="训练进度") 58 | for npz_name in npz_part: 59 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz_name)) 60 | imgs = data["imgs"] 61 | labs = data["labs"] 62 | assert len(np.where(labs == 1)[0]) + len(np.where(labs == 0)[0]) == labs.size, "非法的label数值" 63 | if cfg.AUG.WINDOWLIZE: 64 | imgs = util.windowlize_image(imgs, cfg.AUG.WWWC) 65 | else: 66 | imgs = util.windowlize_image(imgs, (4096, 0)) 67 | inds = [x for x in range(imgs.shape[0])] 68 | random.shuffle(inds) 69 | for ind in inds: 70 | if cfg.TRAIN.DATA_COUNT != -1: 71 | pbar.update() 72 | vol = imgs[ind].reshape(cfg.TRAIN.THICKNESS, 512, 512).astype("float32") 73 | lab = labs[ind].reshape(1, 512, 512).astype("int32") 74 | yield vol, lab 75 | # TODO: 标签平滑 76 | # https://medium.com/@lessw/label-smoothing-deep-learning-google-brain-explains-why-it-works-and-when-to-use-sota-tips-977733ef020 77 | 78 | return reader 79 | 80 | 81 | def aug_mapper(data): 82 | vol = data[0] 83 | lab = data[1] 84 | ww, wc = cfg.AUG.WWWC 85 | # NOTE: 注意不要增强第0维,那是厚度的方向 86 | vol, lab = aug.flip(vol, lab, cfg.AUG.FLIP.RATIO) 87 | vol, lab = aug.rotate(vol, lab, cfg.AUG.ROTATE.RANGE, cfg.AUG.ROTATE.RATIO, wc - ww / 2) 88 | vol, lab = aug.zoom(vol, lab, cfg.AUG.ZOOM.RANGE, cfg.AUG.ZOOM.RATIO) 89 | vol, lab = aug.crop(vol, lab, cfg.AUG.CROP.SIZE, wc - ww / 2) 90 | return vol, lab 91 | 92 | 93 | def main(): 94 | train_program = fluid.Program() 95 | train_init = fluid.Program() 96 | 97 | with fluid.program_guard(train_program, train_init): 98 | image = fluid.layers.data(name="image", shape=[cfg.TRAIN.THICKNESS, 512, 512], dtype="float32") 99 | label = fluid.layers.data(name="label", shape=[1, 512, 512], dtype="int32") 100 | train_loader = fluid.io.DataLoader.from_generator( 101 | feed_list=[image, label], 102 | capacity=cfg.TRAIN.BATCH_SIZE * 2, 103 | iterable=True, 104 | use_double_buffer=True, 105 | ) 106 | prediction = create_model(image, 2) 107 | avg_loss = loss.create_loss(prediction, label, 2) 108 | miou = loss.mean_iou(prediction, label, 2) 109 | 110 | # 进行正则化 111 | if cfg.TRAIN.REG_TYPE == "L1": 112 | decay = paddle.fluid.regularizer.L1Decay(cfg.TRAIN.REG_COEFF) 113 | elif cfg.TRAIN.REG_TYPE == "L2": 114 | decay = paddle.fluid.regularizer.L2Decay(cfg.TRAIN.REG_COEFF) 115 | else: 116 | decay = None 117 | 118 | # 选择优化器 119 | lr = fluid.layers.piecewise_decay(boundaries=cfg.TRAIN.BOUNDARIES, values=cfg.TRAIN.LR) 120 | if cfg.TRAIN.OPTIMIZER == "adam": 121 | optimizer = fluid.optimizer.AdamOptimizer(learning_rate=lr, regularization=decay,) 122 | elif cfg.TRAIN.OPTIMIZER == "sgd": 123 | optimizer = fluid.optimizer.SGDOptimizer(learning_rate=lr, regularization=decay) 124 | elif cfg.TRAIN.OPTIMIZE == "momentum": 125 | optimizer = fluid.optimizer.Momentum(momentum=0.9, learning_rate=lr, regularization=decay,) 126 | else: 127 | raise Exception("错误的优化器类型: {}".format(cfg.TRAIN.OPTIMIZER)) 128 | optimizer.minimize(avg_loss) 129 | 130 | places = fluid.CUDAPlace(0) if cfg.TRAIN.USE_GPU else fluid.CPUPlace() 131 | exe = fluid.Executor(places) 132 | exe.run(train_init) 133 | exe_test = fluid.Executor(places) 134 | 135 | test_program = train_program.clone(for_test=True) 136 | compiled_train_program = fluid.CompiledProgram(train_program).with_data_parallel(loss_name=avg_loss.name) 137 | 138 | if cfg.TRAIN.PRETRAINED_WEIGHT != "": 139 | print("Loading paramaters") 140 | fluid.io.load_persistables(exe, cfg.TRAIN.PRETRAINED_WEIGHT, train_program) 141 | 142 | # train_reader = fluid.io.xmap_readers( 143 | # aug_mapper, data_reader(0, 8), multiprocessing.cpu_count()/2, 16 144 | # ) 145 | train_reader = data_reader(0, 8) 146 | train_loader.set_sample_generator(train_reader, batch_size=cfg.TRAIN.BATCH_SIZE, places=places) 147 | test_reader = paddle.batch(data_reader(8, 10), cfg.INFER.BATCH_SIZE) 148 | test_feeder = fluid.DataFeeder(place=places, feed_list=[image, label]) 149 | 150 | writer = LogWriter(logdir="/home/aistudio/log/{}".format(datetime.now())) 151 | 152 | step = 0 153 | best_miou = 0 154 | 155 | for pass_id in range(cfg.TRAIN.EPOCHS): 156 | for train_data in train_loader(): 157 | step += 1 158 | avg_loss_value, miou_value = exe.run( 159 | compiled_train_program, feed=train_data, fetch_list=[avg_loss, miou] 160 | ) 161 | writer.add_scalar(tag="train_loss", step=step, value=avg_loss_value[0]) 162 | writer.add_scalar(tag="train_miou", step=step, value=miou_value[0]) 163 | if step % cfg.TRAIN.DISP_BATCH == 0: 164 | print( 165 | "\tTrain pass {}, Step {}, Cost {}, Miou {}".format( 166 | pass_id, step, avg_loss_value[0], miou_value[0] 167 | ) 168 | ) 169 | 170 | if math.isnan(float(avg_loss_value[0])): 171 | sys.exit("Got NaN loss, training failed.") 172 | 173 | if step % cfg.TRAIN.SNAPSHOT_BATCH == 0 and cfg.TRAIN.DO_EVAL: 174 | test_step = 0 175 | eval_miou = 0 176 | test_losses = [] 177 | test_mious = [] 178 | for test_data in test_reader(): 179 | test_step += 1 180 | preds, test_loss, test_miou = exe_test.run( 181 | test_program, 182 | feed=test_feeder.feed(test_data), 183 | fetch_list=[prediction, avg_loss, miou], 184 | ) 185 | test_losses.append(test_loss[0]) 186 | test_mious.append(test_miou[0]) 187 | if test_step % cfg.TRAIN.DISP_BATCH == 0: 188 | print("\t\tTest Loss: {} , Miou: {}".format(test_loss[0], test_miou[0])) 189 | 190 | eval_miou = np.average(np.array(test_mious)) 191 | writer.add_scalar( 192 | tag="test_miou", step=step, value=eval_miou, 193 | ) 194 | print("Test loss: {} ,miou: {}".format(np.average(np.array(test_losses)), eval_miou)) 195 | ckpt_dir = os.path.join(cfg.TRAIN.CKPT_MODEL_PATH, str(step) + "_" + str(eval_miou)) 196 | fluid.io.save_persistables(exe, ckpt_dir, train_program) 197 | 198 | print("此前最高的测试MIOU是: ", best_miou) 199 | 200 | if step % cfg.TRAIN.SNAPSHOT_BATCH == 0 and eval_miou > best_miou: 201 | best_miou = eval_miou 202 | print("正在保存第 {} step的权重".format(step)) 203 | fluid.io.save_inference_model( 204 | cfg.TRAIN.INF_MODEL_PATH, 205 | feeded_var_names=["image"], 206 | target_vars=[prediction], 207 | executor=exe, 208 | main_program=train_program, 209 | ) 210 | 211 | 212 | if __name__ == "__main__": 213 | args = parse_args() 214 | main() 215 | -------------------------------------------------------------------------------- /tool/infer/2d_diameter.py: -------------------------------------------------------------------------------- 1 | # 用2d平面内的数据计算管径 2 | import argparse 3 | import os 4 | import sys 5 | import math 6 | from multiprocessing import Pool 7 | import time 8 | import random 9 | 10 | from tqdm import tqdm 11 | import numpy as np 12 | import cv2 13 | import nibabel as nib 14 | import scipy.ndimage 15 | from skimage import filters 16 | from skimage.segmentation import flood, flood_fill 17 | import matplotlib.pyplot as plt 18 | 19 | from util import blood_sort 20 | 21 | np.set_printoptions(threshold=sys.maxsize) 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("-i", "--seg_dir", type=str, required=True) 26 | parser.add_argument("-o", "--dia_dir", type=str, required=True) 27 | parser.add_argument("--filter_arch", type=bool, default=True) 28 | parser.add_argument("-d", "--demo", default=False, action="store_true") 29 | args = parser.parse_args() 30 | 31 | 32 | class Polygon: 33 | def __init__(self, points, height, pixdim, shape=[512, 512]): 34 | self.idx = 0 35 | self.shape = shape 36 | self.points = points 37 | random.shuffle(self.points) 38 | self.center = [0, 0] 39 | self.height = height 40 | self.diameters = [] 41 | self.pixdim = pixdim 42 | # 计算边缘位置平均数做圆心 43 | for p in self.points: 44 | self.center[0] += p[0] 45 | self.center[1] += p[1] 46 | self.center[0] /= len(self.points) 47 | self.center[1] /= len(self.points) 48 | self.center[0] = int(self.center[0]) 49 | self.center[1] = int(self.center[1]) 50 | if not self.is_inside(): 51 | self.points = [] 52 | 53 | def is_inside(self): 54 | """去掉一些多边形 55 | 1. 点少于30个 56 | 2. 中心不在图形里 57 | 3. 漫水面积过大,不是封闭图形? # TODO: 这个有用吗 58 | Returns 59 | ------- 60 | bool 61 | 是否保留这个多边形 62 | 63 | """ 64 | if len(self.points) < 30: 65 | return False 66 | label = np.zeros(self.shape, dtype="uint8") 67 | for p in self.points: 68 | label[p[0]][p[1]] = 1 69 | flood_fill( 70 | label, 71 | tuple(self.center), 72 | 1, 73 | selem=[[0, 1, 0], [1, 1, 1], [0, 1, 0]], 74 | in_place=True, 75 | ) 76 | c = self.center 77 | if label[c[0]][c[1]] != 1: 78 | print("!!!!!!!!!中心不再图形里") 79 | return False 80 | if label.sum() > label.size / 2: 81 | return False 82 | return True 83 | 84 | def ang_sort(self): 85 | pass 86 | 87 | def plot(self): 88 | img = np.zeros([512, 512]) 89 | for ind in range(len(self.points)): 90 | img[self.points[ind][0]][self.points[ind][1]] = 1 91 | img[self.center[0]][self.center[1]] = 1 92 | plt.imshow(img) 93 | plt.show() 94 | 95 | def cal_diameters(self, ang_range=[0, np.pi], split=10, acc=0.1): 96 | """用类似二分的方法,平行线夹计算管径. 97 | 98 | y - y0 + d = k ( x - x0 ):d是这根线在y轴上移动的距离 99 | 在线上取x=0, x=1 的 a,b两点,通过判断a,b,points上的各个点p是不是都向同一个方向转,判断直线是不是已经移出了多边形 100 | 101 | Parameters 102 | ---------- 103 | ang_range : list 104 | 直线和x轴角度的范围. 105 | split : int 106 | 在这个范围内,平均测量多少个方向. 107 | pixdim : float 108 | 片子的pixdim,从像素换算到实际的mm. 109 | 110 | Returns 111 | ------- 112 | list 113 | ang_range 角度范围内,split 等分个方向上,平行线夹的管径是多少mm. 114 | 115 | """ 116 | if len(self.points) == 0: 117 | self.diameters = 0 118 | print("[Error] Polygon at height {} contains no point".format(self.height)) 119 | return 120 | 121 | def is_right(a, b, c): 122 | # 向量叉积 123 | x = np.array((b[0] - a[0], b[1] - a[1])) 124 | y = np.array((c[0] - b[0], c[1] - b[1])) 125 | res = np.cross(x, y) 126 | return res >= 0 127 | 128 | # print(self.height) 129 | self.diameters = [] 130 | center = self.center 131 | # y - y0 + d = k ( x - x0 ) 132 | for alpha in np.arange( 133 | ang_range[0], ang_range[1], (ang_range[1] - ang_range[0]) / split 134 | ): 135 | # TODO: 这个step是纵轴截距,结合k算成斜距 136 | if alpha == np.pi / 2: 137 | continue 138 | k = math.tan(alpha) 139 | 140 | def binary_search(step): 141 | d = 0 142 | prev_out = False 143 | while True: 144 | ya = center[1] - d + k * (0 - center[0]) # (0,ya) 145 | yb = center[1] - d + k * (1 - center[0]) # (1,yb) 146 | dir = is_right((0, ya), (1, yb), self.points[0]) 147 | same_dir = True 148 | for p in self.points: 149 | if dir != is_right((0, ya), (1, yb), p): # 在两侧 150 | same_dir = False 151 | break 152 | # print(same_dir) 153 | d_old = d 154 | if same_dir: 155 | if not prev_out: 156 | step /= 2 157 | d -= step 158 | prev_out = True # 进一次可以退多步,连续退多步不降step 159 | else: 160 | d += step 161 | prev_out = False 162 | label = np.zeros(self.shape) 163 | for p in self.points: 164 | label[p[0], p[1]] = 1 165 | 166 | ##### 可视化 #### 167 | if args.demo: 168 | plt.title( 169 | f"Step:{step}, d:{d}, Going {'out' if d_old < d else 'in'}." 170 | ) 171 | temp_label = label 172 | temp_label[self.center[0], self.center[1]] = 1 173 | temp_label = (1 - temp_label).reshape([512, 512, 1]) * 255 174 | plt.imshow( 175 | np.tile( 176 | temp_label, 177 | (1, 1, 3), 178 | ) 179 | ) 180 | x = np.linspace(0, 512, 512) 181 | y = [512 - center[1] - d + k * (t - center[0]) for t in x] 182 | plt.plot(x, y) 183 | plt.show() 184 | #### #### 185 | 186 | if abs(step) < acc / 2: 187 | break 188 | return d 189 | 190 | self.diameters.append( 191 | (binary_search(40) - binary_search(-40)) 192 | * np.abs(np.cos(alpha)) 193 | * self.pixdim 194 | ) 195 | # print(self.diameters) 196 | return self.idx, self.height, self.diameters 197 | 198 | 199 | def dist(a, b): 200 | h = a.height - b.height 201 | ca = a.center 202 | cb = b.center 203 | return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2 + h ** 2) ** 0.5 204 | 205 | 206 | # print(dist(Polygon([[0, 0]], 0, [0, 0]), Polygon([[0, 0]], 1, [1, 2]))) 207 | 208 | 209 | def cal(polygon): 210 | return polygon.cal_diameters() 211 | 212 | 213 | def cal_diameter(seg_path, filter_arch, dia_dir, thresh=0.9): 214 | """计算seg_path这个nii分割文件的所有管径,返回. 215 | 216 | 过程: 217 | 1. 按照血流反向,获取所有圆 218 | 1.1 按照高度分片层,层内找连通块,可能一块可能两块,计算层中心 219 | 1.2 从最下面的中心开始,找最近的没入序列的中心,对中心按照血流方向反向排序 220 | 221 | 2. 用平行线夹计算血管管径 222 | 223 | 224 | Parameters 225 | ---------- 226 | seg_path : str 227 | 分割标签路径. 228 | 229 | Returns 230 | ------- 231 | type 232 | Description of returned object. 233 | 234 | """ 235 | start = int(time.time()) 236 | segf = nib.load(seg_path) 237 | seg_data = segf.get_fdata() 238 | pixdim = segf.header["pixdim"][1] 239 | seg_data[seg_data > thresh] = 1 240 | seg_data = seg_data.astype("uint8") 241 | print(seg_data.shape) 242 | polygons = [] 243 | for height in range(seg_data.shape[2]): 244 | label = seg_data[:, :, height] # 当前片 245 | label = filters.roberts(label) # 只保留边缘 246 | vol, num = scipy.ndimage.label(label, np.ones([3, 3])) # 联通块 247 | for label_idx in range(1, num + 1): 248 | xs, ys = np.where(vol == label_idx) 249 | points = [] 250 | for x, y in zip(xs, ys): 251 | points.append([int(x), int(y)]) 252 | polygons.append(Polygon(points, height, pixdim, label.shape)) 253 | polygons = [p for p in polygons if len(p.points) != 0] 254 | polygons = blood_sort(polygons) 255 | 256 | # pool = Pool(8) 257 | # diameters = [] 258 | # for res in tqdm(pool.imap_unordered(cal, polygons), total=len(polygons)): 259 | # diameters.append(res) 260 | # pool.close() 261 | # pool.join() 262 | 263 | # 顺序进行 264 | for p in tqdm(polygons): 265 | p.cal_diameters() 266 | 267 | diameters = sorted(diameters, key=lambda x: x[0]) 268 | for d, p in zip(diameters, polygons): 269 | p.diameters = d[2] 270 | 271 | print(os.path.join(dia_dir, seg_path.split("/")[-1])) 272 | f = open( 273 | os.path.join(dia_dir, seg_path.split("/")[-1].rstrip(".gz").rstrip(".nii")) 274 | + ".csv", 275 | "w", 276 | ) 277 | print((int(time.time()) - start) / 60) 278 | print((int(time.time()) - start) / 60, end="\n", file=f) 279 | # for d in diameters: 280 | # print(d[1], end=",", file=f) 281 | # for data in d[2]: 282 | # print(data, end=",", file=f) 283 | # print(file=f) 284 | # 285 | for p in polygons: 286 | print(p.height, end=",", file=f) 287 | # print(p.diameters) 288 | if filter_arch: 289 | if np.max(p.diameters) > np.min(p.diameters) * 2: 290 | continue 291 | for d in p.diameters: 292 | print(d, end=",", file=f) 293 | print(end="\n", file=f) 294 | f.close() 295 | 296 | 297 | if __name__ == "__main__": 298 | names = os.listdir(args.seg_dir) 299 | for name in names: 300 | if os.path.exists(os.path.join(args.dia_dir, name.replace(".nii.gz", ".csv"))): 301 | names.remove(name) 302 | print(names) 303 | print("{} patients to measure in total.".format(len(names))) 304 | 305 | start = int(time.time()) 306 | tot = len(names) 307 | 308 | for count, name in enumerate(names): 309 | cal_diameter(os.path.join(args.seg_dir, name), args.filter_arch, args.dia_dir) 310 | print( 311 | "\t\tFinished {}/{}, expected to finish in {} minutes".format( 312 | count, tot, int(time.time() - start) / count * (tot - count) / 60 313 | ) 314 | ) 315 | -------------------------------------------------------------------------------- /medseg/models/hrnet.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import sys 20 | 21 | import paddle 22 | import paddle.fluid as fluid 23 | from paddle.fluid.initializer import MSRA 24 | from paddle.fluid.param_attr import ParamAttr 25 | 26 | from utils.config import cfg 27 | 28 | 29 | def conv_bn_layer( 30 | input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None 31 | ): 32 | conv = fluid.layers.conv2d( 33 | input=input, 34 | num_filters=num_filters, 35 | filter_size=filter_size, 36 | stride=stride, 37 | padding=(filter_size - 1) // 2, 38 | groups=num_groups, 39 | act=None, 40 | param_attr=ParamAttr(initializer=MSRA(), name=name + "_weights"), 41 | bias_attr=False, 42 | ) 43 | bn_name = name + "_bn" 44 | bn = fluid.layers.batch_norm( 45 | input=conv, 46 | param_attr=ParamAttr(name=bn_name + "_scale", initializer=fluid.initializer.Constant(1.0)), 47 | bias_attr=ParamAttr(name=bn_name + "_offset", initializer=fluid.initializer.Constant(0.0)), 48 | moving_mean_name=bn_name + "_mean", 49 | moving_variance_name=bn_name + "_variance", 50 | ) 51 | if if_act: 52 | bn = fluid.layers.relu(bn) 53 | return bn 54 | 55 | 56 | def basic_block(input, num_filters, stride=1, downsample=False, name=None): 57 | residual = input 58 | conv = conv_bn_layer( 59 | input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + "_conv1" 60 | ) 61 | conv = conv_bn_layer( 62 | input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + "_conv2" 63 | ) 64 | if downsample: 65 | residual = conv_bn_layer( 66 | input=input, 67 | filter_size=1, 68 | num_filters=num_filters, 69 | if_act=False, 70 | name=name + "_downsample", 71 | ) 72 | return fluid.layers.elementwise_add(x=residual, y=conv, act="relu") 73 | 74 | 75 | def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None): 76 | residual = input 77 | conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + "_conv1") 78 | conv = conv_bn_layer( 79 | input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + "_conv2" 80 | ) 81 | conv = conv_bn_layer( 82 | input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False, name=name + "_conv3" 83 | ) 84 | if downsample: 85 | residual = conv_bn_layer( 86 | input=input, 87 | filter_size=1, 88 | num_filters=num_filters * 4, 89 | if_act=False, 90 | name=name + "_downsample", 91 | ) 92 | return fluid.layers.elementwise_add(x=residual, y=conv, act="relu") 93 | 94 | 95 | def fuse_layers(x, channels, multi_scale_output=True, name=None): 96 | out = [] 97 | for i in range(len(channels) if multi_scale_output else 1): 98 | residual = x[i] 99 | shape = residual.shape 100 | width = shape[-1] 101 | height = shape[-2] 102 | for j in range(len(channels)): 103 | if j > i: 104 | y = conv_bn_layer( 105 | x[j], 106 | filter_size=1, 107 | num_filters=channels[i], 108 | if_act=False, 109 | name=name + "_layer_" + str(i + 1) + "_" + str(j + 1), 110 | ) 111 | y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width]) 112 | residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) 113 | elif j < i: 114 | y = x[j] 115 | for k in range(i - j): 116 | if k == i - j - 1: 117 | y = conv_bn_layer( 118 | y, 119 | filter_size=3, 120 | num_filters=channels[i], 121 | stride=2, 122 | if_act=False, 123 | name=name 124 | + "_layer_" 125 | + str(i + 1) 126 | + "_" 127 | + str(j + 1) 128 | + "_" 129 | + str(k + 1), 130 | ) 131 | else: 132 | y = conv_bn_layer( 133 | y, 134 | filter_size=3, 135 | num_filters=channels[j], 136 | stride=2, 137 | name=name 138 | + "_layer_" 139 | + str(i + 1) 140 | + "_" 141 | + str(j + 1) 142 | + "_" 143 | + str(k + 1), 144 | ) 145 | residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) 146 | 147 | residual = fluid.layers.relu(residual) 148 | out.append(residual) 149 | return out 150 | 151 | 152 | def branches(x, block_num, channels, name=None): 153 | out = [] 154 | for i in range(len(channels)): 155 | residual = x[i] 156 | for j in range(block_num): 157 | residual = basic_block( 158 | residual, channels[i], name=name + "_branch_layer_" + str(i + 1) + "_" + str(j + 1) 159 | ) 160 | out.append(residual) 161 | return out 162 | 163 | 164 | def high_resolution_module(x, channels, multi_scale_output=True, name=None): 165 | residual = branches(x, 4, channels, name=name) 166 | out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name) 167 | return out 168 | 169 | 170 | def transition_layer(x, in_channels, out_channels, name=None): 171 | num_in = len(in_channels) 172 | num_out = len(out_channels) 173 | out = [] 174 | for i in range(num_out): 175 | if i < num_in: 176 | if in_channels[i] != out_channels[i]: 177 | residual = conv_bn_layer( 178 | x[i], 179 | filter_size=3, 180 | num_filters=out_channels[i], 181 | name=name + "_layer_" + str(i + 1), 182 | ) 183 | out.append(residual) 184 | else: 185 | out.append(x[i]) 186 | else: 187 | residual = conv_bn_layer( 188 | x[-1], 189 | filter_size=3, 190 | num_filters=out_channels[i], 191 | stride=2, 192 | name=name + "_layer_" + str(i + 1), 193 | ) 194 | out.append(residual) 195 | return out 196 | 197 | 198 | def stage(x, num_modules, channels, multi_scale_output=True, name=None): 199 | out = x 200 | for i in range(num_modules): 201 | if i == num_modules - 1 and multi_scale_output == False: 202 | out = high_resolution_module( 203 | out, channels, multi_scale_output=False, name=name + "_" + str(i + 1) 204 | ) 205 | else: 206 | out = high_resolution_module(out, channels, name=name + "_" + str(i + 1)) 207 | 208 | return out 209 | 210 | 211 | def layer1(input, name=None): 212 | conv = input 213 | for i in range(4): 214 | conv = bottleneck_block( 215 | conv, num_filters=64, downsample=True if i == 0 else False, name=name + "_" + str(i + 1) 216 | ) 217 | return conv 218 | 219 | 220 | def high_resolution_net(input, num_classes): 221 | 222 | channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS 223 | channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS 224 | channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS 225 | 226 | num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES 227 | num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES 228 | num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES 229 | 230 | x = conv_bn_layer( 231 | input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name="layer1_1" 232 | ) 233 | x = conv_bn_layer( 234 | input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name="layer1_2" 235 | ) 236 | 237 | la1 = layer1(x, name="layer2") 238 | tr1 = transition_layer([la1], [256], channels_2, name="tr1") 239 | st2 = stage(tr1, num_modules_2, channels_2, name="st2") 240 | tr2 = transition_layer(st2, channels_2, channels_3, name="tr2") 241 | st3 = stage(tr2, num_modules_3, channels_3, name="st3") 242 | tr3 = transition_layer(st3, channels_3, channels_4, name="tr3") 243 | st4 = stage(tr3, num_modules_4, channels_4, name="st4") 244 | 245 | # upsample 246 | shape = st4[0].shape 247 | height, width = shape[-2], shape[-1] 248 | st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width]) 249 | st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width]) 250 | st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width]) 251 | 252 | out = fluid.layers.concat(st4, axis=1) 253 | last_channels = sum(channels_4) 254 | 255 | out = conv_bn_layer( 256 | input=out, filter_size=1, num_filters=last_channels, stride=1, if_act=True, name="conv-2" 257 | ) 258 | out = fluid.layers.conv2d( 259 | input=out, 260 | num_filters=num_classes, 261 | filter_size=1, 262 | stride=1, 263 | padding=0, 264 | act=None, 265 | param_attr=ParamAttr(initializer=MSRA(), name="conv-1_weights"), 266 | bias_attr=False, 267 | ) 268 | 269 | out = fluid.layers.resize_bilinear(out, input.shape[2:]) 270 | 271 | return out 272 | 273 | 274 | def hrnet(input, num_classes): 275 | logit = high_resolution_net(input, num_classes) 276 | return logit 277 | 278 | 279 | if __name__ == "__main__": 280 | image_shape = [-1, 3, 769, 769] 281 | image = fluid.data(name="image", shape=image_shape, dtype="float32") 282 | logit = hrnet(image, 4) 283 | print("logit:", logit.shape) 284 | -------------------------------------------------------------------------------- /medseg/models/deeplabv3p.py: -------------------------------------------------------------------------------- 1 | # Deeplabv3p Network is modified from the following link 2 | # https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/deeplabv3%2B/models.py 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | import paddle.fluid as fluid 8 | 9 | import contextlib 10 | name_scope = "" 11 | 12 | decode_channel = 32 13 | encode_channel = 160 14 | 15 | bn_momentum = 0.99 16 | 17 | op_results = {} 18 | 19 | default_epsilon = 1e-3 20 | default_norm_type = 'bn' 21 | default_group_number = 32 22 | depthwise_use_cudnn = True 23 | 24 | bn_regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0) 25 | depthwise_regularizer = fluid.regularizer.L2DecayRegularizer( 26 | regularization_coeff=0.0) 27 | 28 | 29 | @contextlib.contextmanager 30 | def scope(name): 31 | global name_scope 32 | bk = name_scope 33 | name_scope = name_scope + name + '/' 34 | yield 35 | name_scope = bk 36 | 37 | 38 | def check(data, number): 39 | if type(data) == int: 40 | return [data] * number 41 | assert len(data) == number 42 | return data 43 | 44 | 45 | def clean(): 46 | global op_results 47 | op_results = {} 48 | 49 | 50 | def append_op_result(result, name): 51 | global op_results 52 | op_index = len(op_results) 53 | name = name_scope + name + str(op_index) 54 | op_results[name] = result 55 | return result 56 | 57 | 58 | def conv(*args, **kargs): 59 | if "xception" in name_scope: 60 | init_std = 0.09 61 | elif "logit" in name_scope: 62 | init_std = 0.01 63 | elif name_scope.endswith('depthwise/'): 64 | init_std = 0.33 65 | else: 66 | init_std = 0.06 67 | if name_scope.endswith('depthwise/'): 68 | regularizer = depthwise_regularizer 69 | else: 70 | regularizer = None 71 | 72 | kargs['param_attr'] = fluid.ParamAttr( 73 | name=name_scope + 'weights', 74 | regularizer=regularizer, 75 | initializer=fluid.initializer.TruncatedNormal( 76 | loc=0.0, scale=init_std)) 77 | if 'bias_attr' in kargs and kargs['bias_attr']: 78 | kargs['bias_attr'] = fluid.ParamAttr( 79 | name=name_scope + 'biases', 80 | regularizer=regularizer, 81 | initializer=fluid.initializer.ConstantInitializer(value=0.0)) 82 | else: 83 | kargs['bias_attr'] = False 84 | kargs['name'] = name_scope + 'conv' 85 | return append_op_result(fluid.layers.conv2d(*args, **kargs), 'conv') 86 | 87 | 88 | def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None): 89 | N, C, H, W = input.shape 90 | if C % G != 0: 91 | # print "group can not divide channle:", C, G 92 | for d in range(10): 93 | for t in [d, -d]: 94 | if G + t <= 0: continue 95 | if C % (G + t) == 0: 96 | G = G + t 97 | break 98 | if C % G == 0: 99 | # print "use group size:", G 100 | break 101 | assert C % G == 0 102 | x = fluid.layers.group_norm( 103 | input, 104 | groups=G, 105 | param_attr=param_attr, 106 | bias_attr=bias_attr, 107 | name=name_scope + 'group_norm') 108 | return x 109 | 110 | 111 | def bn(*args, **kargs): 112 | if default_norm_type == 'bn': 113 | with scope('BatchNorm'): 114 | return append_op_result( 115 | fluid.layers.batch_norm( 116 | *args, 117 | epsilon=default_epsilon, 118 | momentum=bn_momentum, 119 | param_attr=fluid.ParamAttr( 120 | name=name_scope + 'gamma', regularizer=bn_regularizer), 121 | bias_attr=fluid.ParamAttr( 122 | name=name_scope + 'beta', regularizer=bn_regularizer), 123 | moving_mean_name=name_scope + 'moving_mean', 124 | moving_variance_name=name_scope + 'moving_variance', 125 | **kargs), 126 | 'bn') 127 | elif default_norm_type == 'gn': 128 | with scope('GroupNorm'): 129 | return append_op_result( 130 | group_norm( 131 | args[0], 132 | default_group_number, 133 | eps=default_epsilon, 134 | param_attr=fluid.ParamAttr( 135 | name=name_scope + 'gamma', regularizer=bn_regularizer), 136 | bias_attr=fluid.ParamAttr( 137 | name=name_scope + 'beta', regularizer=bn_regularizer)), 138 | 'gn') 139 | else: 140 | raise "Unsupport norm type:" + default_norm_type 141 | 142 | 143 | def bn_relu(data): 144 | return append_op_result(fluid.layers.relu(bn(data)), 'relu') 145 | 146 | 147 | def relu(data): 148 | return append_op_result( 149 | fluid.layers.relu( 150 | data, name=name_scope + 'relu'), 'relu') 151 | 152 | 153 | def seperate_conv(input, channel, stride, filter, dilation=1, act=None): 154 | with scope('depthwise'): 155 | input = conv( 156 | input, 157 | input.shape[1], 158 | filter, 159 | stride, 160 | groups=input.shape[1], 161 | padding=(filter // 2) * dilation, 162 | dilation=dilation, 163 | use_cudnn=depthwise_use_cudnn) 164 | input = bn(input) 165 | if act: input = act(input) 166 | with scope('pointwise'): 167 | input = conv(input, channel, 1, 1, groups=1, padding=0) 168 | input = bn(input) 169 | if act: input = act(input) 170 | return input 171 | 172 | 173 | def xception_block(input, 174 | channels, 175 | strides=1, 176 | filters=3, 177 | dilation=1, 178 | skip_conv=True, 179 | has_skip=True, 180 | activation_fn_in_separable_conv=False): 181 | repeat_number = 3 182 | channels = check(channels, repeat_number) 183 | filters = check(filters, repeat_number) 184 | strides = check(strides, repeat_number) 185 | data = input 186 | results = [] 187 | for i in range(repeat_number): 188 | with scope('separable_conv' + str(i + 1)): 189 | if not activation_fn_in_separable_conv: 190 | data = relu(data) 191 | data = seperate_conv( 192 | data, 193 | channels[i], 194 | strides[i], 195 | filters[i], 196 | dilation=dilation) 197 | else: 198 | data = seperate_conv( 199 | data, 200 | channels[i], 201 | strides[i], 202 | filters[i], 203 | dilation=dilation, 204 | act=relu) 205 | results.append(data) 206 | if not has_skip: 207 | return append_op_result(data, 'xception_block'), results 208 | if skip_conv: 209 | with scope('shortcut'): 210 | skip = bn( 211 | conv( 212 | input, channels[-1], 1, strides[-1], groups=1, padding=0)) 213 | else: 214 | skip = input 215 | return append_op_result(data + skip, 'xception_block'), results 216 | 217 | 218 | def entry_flow(data): 219 | with scope("entry_flow"): 220 | with scope("conv1"): 221 | data = conv(data, 32, 3, stride=2, padding=1) 222 | data = bn_relu(data) 223 | with scope("conv2"): 224 | data = conv(data, 32, 3, stride=1, padding=1) 225 | data = bn_relu(data) 226 | with scope("block1"): 227 | data, results1 = xception_block(data, 64, [1, 1, 2]) 228 | with scope("block2"): 229 | data, results2 = xception_block(data, 128, [1, 1, 2]) 230 | with scope("block3"): 231 | data, _ = xception_block(data, 256, [1, 1, 2]) 232 | return data, results1[1], results2[2] 233 | 234 | 235 | def middle_flow(data): 236 | with scope("middle_flow"): 237 | for i in range(8): 238 | with scope("block" + str(i + 1)): 239 | data, _ = xception_block(data, 256, [1, 1, 1], skip_conv=False) 240 | return data 241 | 242 | 243 | def exit_flow(data): 244 | with scope("exit_flow"): 245 | with scope('block1'): 246 | data, _ = xception_block(data, [256, 512, 512], [1, 1, 1]) 247 | with scope('block2'): 248 | data, _ = xception_block( 249 | data, [512, 512, 768], [1, 1, 1], 250 | dilation=2, 251 | has_skip=False, 252 | activation_fn_in_separable_conv=True) 253 | return data 254 | 255 | 256 | def encoder(input): 257 | with scope('encoder'): 258 | channel = 192 259 | 260 | with scope("aspp0"): 261 | aspp0 = bn_relu(conv(input, channel, 1, 1, groups=1, padding=0)) 262 | with scope("aspp1"): 263 | aspp1 = seperate_conv(input, channel, 1, 3, dilation=3, act=relu) 264 | with scope("aspp2"): 265 | aspp2 = seperate_conv(input, channel, 1, 3, dilation=6, act=relu) 266 | with scope("aspp3"): 267 | aspp3 = seperate_conv(input, channel, 1, 3, dilation=12, act=relu) 268 | with scope("concat"): 269 | data = append_op_result( 270 | fluid.layers.concat( 271 | [aspp0, aspp1, aspp2, aspp3], axis=1), 272 | 'concat') 273 | data = bn_relu(conv(data, channel, 1, 1, groups=1, padding=0)) 274 | return data 275 | 276 | 277 | def decoder(encode_data, decode_shortcut): 278 | with scope('decoder'): 279 | with scope('concat'): 280 | decode_shortcut = bn_relu( 281 | conv( 282 | decode_shortcut, decode_channel, 1, 1, groups=1, padding=0)) 283 | encode_data = fluid.layers.resize_bilinear( 284 | encode_data, decode_shortcut.shape[2:]) 285 | encode_data = fluid.layers.concat( 286 | [encode_data, decode_shortcut], axis=1) 287 | append_op_result(encode_data, 'concat') 288 | with scope("separable_conv1"): 289 | encode_data = seperate_conv( 290 | encode_data, encode_channel, 1, 3, dilation=1, act=relu) 291 | with scope("separable_conv2"): 292 | encode_data = seperate_conv( 293 | encode_data, encode_channel, 1, 3, dilation=1, act=relu) 294 | return encode_data 295 | 296 | def decoder2(encode_data, decode_shortcut): 297 | with scope('decoder2'): 298 | with scope('concat2'): 299 | decode_shortcut = bn_relu( 300 | conv( 301 | decode_shortcut, decode_channel // 2, 1, 1, groups=1, padding=0)) 302 | encode_data = fluid.layers.resize_bilinear( 303 | encode_data, decode_shortcut.shape[2:]) 304 | encode_data = fluid.layers.concat( 305 | [encode_data, decode_shortcut], axis=1) 306 | append_op_result(encode_data, 'concat2') 307 | with scope("separable_conv12"): 308 | encode_data = seperate_conv( 309 | encode_data, encode_channel // 2, 1, 3, dilation=1, act=relu) 310 | with scope("separable_conv22"): 311 | encode_data = seperate_conv( 312 | encode_data, encode_channel // 2, 1, 3, dilation=1, act=relu) 313 | return encode_data 314 | 315 | 316 | def deeplabv3p(img, label_number): 317 | global default_epsilon 318 | append_op_result(img, 'img') 319 | with scope('xception_65'): 320 | default_epsilon = 1e-3 321 | # Entry flow 322 | data, decode_shortcut1, decode_shortcut2 = entry_flow(img) 323 | print(data.shape) 324 | # Middle flow 325 | data = middle_flow(data) 326 | print(data.shape) 327 | # Exit flow 328 | data = exit_flow(data) 329 | print(data.shape) 330 | default_epsilon = 1e-5 331 | encode_data = encoder(data) 332 | print(encode_data.shape) 333 | encode_data = decoder(encode_data, decode_shortcut2) 334 | print(encode_data.shape) 335 | encode_data = fluid.layers.resize_bilinear(encode_data, (encode_data.shape[2] * 2, encode_data.shape[3] * 2)) 336 | encode_data = decoder2(encode_data, decode_shortcut1) 337 | print(encode_data.shape) 338 | with scope('logit'): 339 | logit = conv( 340 | encode_data, label_number, 1, stride=1, padding=0, bias_attr=True) 341 | logit = fluid.layers.resize_bilinear(logit, img.shape[2:]) 342 | # logit = fluid.layers.resize_bilinear(logit, (3384, 1020)) 343 | print(logit.shape) 344 | return logit -------------------------------------------------------------------------------- /medseg/utils/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | 包含环境变量和常用函数 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | import math 8 | from scipy import ndimage 9 | 10 | # from utils.config import cfg 11 | 12 | # TODO: 清理函数,去除不需要的,对需要的加上清晰注释 13 | 14 | 15 | def listdir(path): 16 | """展示一个路径下所有文件,排序并去除常见辅助文件. 17 | 18 | Parameters 19 | ---------- 20 | path : type 21 | Description of parameter `path`. 22 | 23 | Returns 24 | ------- 25 | type 26 | Description of returned object. 27 | 28 | """ 29 | dirs = os.listdir(path) 30 | if ".DS_Store" in dirs: 31 | dirs.remove(".DS_Store") 32 | dirs.sort() # 通过一样的sort保持vol和seg的对应 33 | return dirs 34 | 35 | 36 | def save_info(name, header, file_name): 37 | """将扫描header中的一些信息保存入csv. 38 | 39 | Parameters 40 | ---------- 41 | name : str 42 | 扫描文件名. 43 | header : dict 44 | 扫描文件文件头. 45 | file_name : str 46 | csv文件的文件名. 47 | 48 | Returns 49 | ------- 50 | type 51 | Description of returned object. 52 | 53 | """ 54 | 55 | """ 56 | sizeof_hdr : 348 57 | data_type : b'' 58 | db_name : b'' 59 | extents : 0 60 | session_error : 0 61 | regular : b'r' 62 | dim_info : 0 63 | dim : [ 3 512 512 75 1 1 1 1] 64 | intent_p1 : 0.0 65 | intent_p2 : 0.0 66 | intent_p3 : 0.0 67 | intent_code : none 68 | datatype : int16 69 | bitpix : 16 70 | slice_start : 0 71 | pixdim : [-1.00000e+00 7.03125e-01 7.03125e-01 5.00000e+00 0.00000e+00 72 | 1.00000e+00 1.00000e+00 5.22410e+04] 73 | vox_offset : 0.0 74 | scl_slope : nan 75 | scl_inter : nan 76 | slice_end : 0 77 | slice_code : unknown 78 | xyzt_units : 10 79 | cal_max : 0.0 80 | cal_min : 0.0 81 | slice_duration : 0.0 82 | toffset : 0.0 83 | glmax : 255 84 | glmin : 0 85 | descrip : b'TE=0;sec=52241.0000;name=' 86 | aux_file : b'!62ABDOMENNATIVUNDVENS' 87 | qform_code : scanner 88 | sform_code : scanner 89 | quatern_b : 0.0 90 | quatern_c : 1.0 91 | quatern_d : 0.0 92 | qoffset_x : 172.9 93 | qoffset_y : -179.29688 94 | qoffset_z : -368.0 95 | srow_x : [ -0.703125 0. 0. 172.9 ] 96 | srow_y : [ 0. 0.703125 0. -179.29688 ] 97 | srow_z : [ 0. 0. 5. -368.] 98 | intent_name : b'' 99 | magic : b'n+1'`` 100 | """ 101 | file = open(file_name, "a+") 102 | print(name, end=", ", file=file) 103 | print( 104 | header["dim"][1], ",", header["dim"][2], ",", header["dim"][3], end=", ", file=file, 105 | ) 106 | print( 107 | header["pixdim"][1], ",", header["pixdim"][2], ",", header["pixdim"][3], end=", ", file=file, 108 | ) 109 | print(header["bitpix"], " , ", header["datatype"], file=file) 110 | file.close() 111 | 112 | 113 | """ 114 | import nibabel as nib 115 | volf = nib.load('/home/aistudio/data/volume/volume-0.nii') 116 | save_info('v1', volf.header.structarr, 'vol_info.csv') 117 | """ 118 | 119 | 120 | """ 体数据处理 """ 121 | 122 | 123 | def windowlize_image(vol, wwwc): 124 | """对扫描按照wwwc进行硬crop. 125 | 126 | Parameters 127 | ---------- 128 | vol : ndarray 129 | 需要进行窗口化的扫描 130 | ww : int 131 | 窗宽 132 | wc : int 133 | 窗位 134 | 135 | Returns 136 | ------- 137 | ndarray 138 | 经过窗口化的扫描 139 | """ 140 | ww = wwwc[0] 141 | wc = wwwc[1] 142 | wl = wc - ww / 2 143 | wh = wc + ww / 2 144 | vol = vol.clip(wl, wh) 145 | return vol 146 | 147 | 148 | def clip_label(label, category): 149 | # 有时候标签会包含多种标注,一般是0背景,从1开始随着数变大标记的东西变小 150 | # label是ndarray,category是最后成为1的类别号,max是最大的类别号 151 | label[label < category] = 0 152 | label[label >= category] = 1 153 | return label 154 | 155 | 156 | def get_bbs(label): 157 | """求一个标签中所有前景的bb. 158 | 159 | Parameters 160 | ---------- 161 | label : ndarray 162 | 标签. 163 | 164 | Returns 165 | ------- 166 | list 167 | list中每一个前景区域一个[bb_min, bb_max],分别是这个前景块bb低和高两个角的坐标. 168 | 169 | """ 170 | # TODO: 目前实现了一个病灶,需要实现多个 171 | one_indexes = np.array(np.where(label == 1)) 172 | if one_indexes.ndim == 0: 173 | raise Exception("label中没有任何前景") 174 | 175 | bb_min = one_indexes.min(axis=1) 176 | bb_max = one_indexes.max(axis=1) 177 | bb_max = bb_max + 1 178 | return bb_min.reshape(-1, 3), bb_max.reshape(-1, 3) 179 | 180 | 181 | def crop_to_bbs(volume, bbs, padding=0.3): 182 | """将一个扫描的背景mute掉,只留下前景及其周围的区域,支持多个前景块. 183 | 具体做法是创建一个mask,对于bbs中的每个前景块,计算中心位置,按照padding计算保留的块范围(不会超出volume),在mask中设成1。所有块都计算完之后mute掉mask中还是0的所有位置 184 | Parameters 185 | ---------- 186 | volume : ndarray 187 | 扫描. 188 | bbs : list 189 | [[bb_min, bb_max], [bb_min, bb_max], ...]. 190 | padding : 191 | Description of parameter `padding`. 192 | 193 | Returns 194 | ------- 195 | type 196 | Description of returned object. 197 | 198 | """ 199 | 200 | # 将一个体切成一个或者多个包含1的区域的bb 201 | # padding 值是在各个维度上向大和小分别拓展多大的视野,一个数就是都一样,列表可以让不同维度不一样 202 | pd = padding 203 | if isinstance(padding, float): 204 | padding = [] 205 | for i in range(volume.ndim): 206 | padding.append(pd) 207 | 208 | volumes = [] 209 | bb_size = bb_max - bb_min 210 | bb_min = np.maximum(np.floor(bb_min - bb_size * padding), 0).astype("int32") 211 | bb_max = np.minimum(np.ceil(bb_max + bb_size * padding), volume.shape).astype("int32") 212 | 213 | for i in range(bb_min.shape[0]): 214 | volumes.append( 215 | volume[bb_min[i][0] : bb_max[i][0], bb_min[i][1] : bb_max[i][1], bb_min[i][2] : bb_max[i][2],] 216 | ) 217 | return volumes 218 | 219 | 220 | def get_pad_len(volume_shape, pad_size, strict=True): 221 | # 1. 计算每个维度当前长度和目标差多少 222 | margin = [] 223 | for x, y in zip(volume_shape, pad_size): 224 | # 1.1 如果目标 -1 ,那这个维度过 225 | if y == -1: 226 | margin.append(0) 227 | continue 228 | # 1.2 如果当前长度比目标长度还大,报错或者过 229 | if x > y: 230 | if strict: 231 | raise Exception( 232 | "Invalid Crop Size", "数据的大小 {} 应小于 pad_size {}".format(volume_shape, pad_size), 233 | ) 234 | else: 235 | margin.append(0) 236 | continue 237 | # 1.3 如果正常,目标大于当前维度,做差 238 | margin.append(y - x) 239 | # 2. 计算每个维度应该补多少 240 | res = [] 241 | for m, p, v in zip(margin, pad_size, volume_shape): 242 | if m == 0: 243 | # 2.1 margin = 0的略过 244 | res.append([0, 0]) 245 | else: 246 | # 2.2 margin分成两份 247 | half = math.floor(m / 2) 248 | res.append([half, p - v - half]) 249 | 250 | return res 251 | 252 | 253 | # print(get_pad_len([3, 512, 300], [3, -1, 512])) 254 | # print(get_pad_len([512, 512, 3], [512, 512, 3])) 255 | # print(get_pad_len([8, 512, 300], [4, -1, 512])) 256 | # print(get_pad_len([8, 512, 300], [4, -1, 512], False)) 257 | 258 | 259 | def pad_volume(volume, pad_size, pad_value=0, strice=True): 260 | """将volume放在中间,用 pad_value 填充到 pad_size 大小 261 | 每个维度一共包含三种情况: 262 | 1. 正常: pad_size大于实际大小,那就计算差多少,在这个维度的两侧均匀的补上 263 | 2. 忽略: 不希望改变这个维度的大小,pad_size这个维度填 -1 264 | 3. 错误: volume的大小比 pad_size 还大,在 strice=true 模式下这个报错,终止执行;strice=false模式下这个维度忽略 265 | 266 | Parameters 267 | ---------- 268 | volume : type 269 | Description of parameter `volume`. 270 | pad_size : int/list/tuple 271 | 如果是一个int,那么做成一个和volume.ndim维的list,三个维度的大小一样,按照这个pad;如果是list,tuple直接按照这个pad 272 | pad_value : type 273 | Description of parameter `pad_value`. 274 | strice : type 275 | Description of parameter `strice`. 276 | 277 | Returns 278 | ------- 279 | type 280 | 经过pad的数据 281 | 282 | """ 283 | if isinstance(pad_size, int): 284 | pad_size = [pad_size for i in range(volume.ndim)] 285 | margin = get_pad_len(volume.shape, pad_size, strice) 286 | # print(margin) 287 | volume = np.pad(volume, margin, "constant", constant_values=(pad_value)) 288 | # print(volume.shape) 289 | return volume 290 | 291 | 292 | def filter_largest_bb(label, ratio=1.2): 293 | """求最大的连通块bb范围,去掉范围外的所有fp。比只保留最大的连通块更保守. 294 | 295 | Parameters 296 | ---------- 297 | label : ndarray 298 | 分割标签,前景为1. 299 | ratio: float 300 | 最终bb范围是最大联通块bb范围的多少倍,比如1.2相当于周围有0.1的拓展 301 | 302 | Returns 303 | ------- 304 | type 305 | 经过处理的标签. 306 | 307 | """ 308 | # 求最大连通块 309 | vol, num = ndimage.label(label, np.ones([3 for _ in range(label.ndim)])) 310 | maxi = 0 311 | maxnum = 0 312 | for i in range(1, num + 1): 313 | count = vol[vol == i].size 314 | if count > maxnum: 315 | maxi = i 316 | maxnum = count 317 | maxind = np.where(vol == maxi) 318 | # 求最大连通块的bb范围 319 | ind_range = [[np.min(maxind[axis]), np.max(maxind[axis]) + 1] for axis in range(label.ndim)] 320 | ind_len = [r[1] - r[0] for r in ind_range] 321 | ext_ratio = (ratio - 1) / 2 322 | # 求加上拓展的边缘 323 | clip_range = [[r[0] - int(l * ext_ratio), r[1] + int(l * ext_ratio)] for r, l in zip(ind_range, ind_len)] 324 | for ind in range(len(clip_range)): 325 | if clip_range[ind][0] < 0: 326 | clip_range[ind][0] = 0 327 | if clip_range[ind][1] > label.shape[ind]: 328 | clip_range[ind][1] = label.shape[ind] 329 | r = clip_range 330 | # print(r) 331 | # 去掉拓展外的fp 332 | new_lab = np.zeros(label.shape) 333 | # 内部所有前景都保留 334 | new_lab[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]] = label[ 335 | r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1] 336 | ] 337 | return new_lab 338 | 339 | 340 | # filter_largest_bb( 341 | # np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]), 3 342 | # ) 343 | 344 | 345 | def filter_largest_volume(label, ratio=1.2, mode="soft"): 346 | """对输入的一个3D标签进行处理,只保留其中最大的连通块 347 | 348 | Parameters 349 | ---------- 350 | label : ndarray 351 | 3D array:一个分割标签 352 | ratio : float 353 | 分割保留的范围 354 | mode : str 355 | "soft" / "hard" 356 | hard是只保留最大的联通块,soft是保留最大连通块bb内的 357 | 358 | Returns 359 | ------- 360 | type 361 | 只保留最大连通块的标签. 362 | 363 | """ 364 | if mode == "soft": 365 | return filter_largest_bb(label, ratio) 366 | vol, num = ndimage.label(label, np.ones([3, 3, 3])) 367 | maxi = 0 368 | maxnum = 0 369 | for i in range(1, num + 1): 370 | count = vol[vol == i].size 371 | if count > maxnum: 372 | maxi = i 373 | maxnum = count 374 | 375 | vol[vol != maxi] = 0 376 | vol[vol == maxi] = 1 377 | label = vol 378 | return label 379 | 380 | 381 | # vol = np.array([[[0, 0, 1, 0], [0, 0, 0, 0], [0, 1, 1, 0]]]) 382 | # vol = np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]]) 383 | # print(filter_largest_volume(vol, 3, mode="soft")) 384 | 385 | 386 | def save_nii(vol, lab, name="test"): 387 | import nibabel as nib 388 | 389 | vol = vol.astype("int16") 390 | volf = nib.Nifti1Image(vol, np.eye(4)) 391 | labf = nib.Nifti1Image(lab, np.eye(4)) 392 | nib.save(volf, "/home/aistudio/data/temp/{}-vol.nii".format(name)) 393 | nib.save(labf, "/home/aistudio/data/temp/{}-lab.nii".format(name)) 394 | 395 | 396 | def slice_count(): 397 | """数所有的npz总共包含多少slice. 398 | 399 | Returns 400 | ------- 401 | int 402 | 所有npz中slice总数. 403 | 404 | """ 405 | tot = 0 406 | npz_names = listdir(cfg.TRAIN.DATA_PATH) 407 | for npz_name in npz_names: 408 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz_name)) 409 | lab = data["labs"] 410 | tot += lab.shape[0] 411 | return tot 412 | 413 | 414 | # print(slice_count()) 415 | 416 | 417 | def cal_direction(fname, scan, label): 418 | """根据预存信息矫正患者体位. 419 | 420 | Parameters 421 | ---------- 422 | fname : str 423 | 患者文件名. 424 | scan : ndarray 425 | 3D扫描. 426 | label : type 427 | 3D标签. 428 | 429 | Returns 430 | ------- 431 | ndarray, ndarray 432 | 校准后的3D数组. 433 | 434 | """ 435 | f = open("./config/directions.csv") 436 | dirs = f.readlines() 437 | f.close() 438 | # print("dirs: ", dirs) 439 | dirs = [x.rstrip("\n") for x in dirs] 440 | dirs = [x.split(",") for x in dirs] 441 | dic = {} 442 | for dir in dirs: 443 | dic[dir[0].strip()] = dir[1].strip() 444 | dirs = dic 445 | try: 446 | if dirs[fname] == "2": 447 | scan = np.rot90(scan, 3) 448 | label = np.rot90(label, 3) 449 | else: 450 | scan = np.rot90(scan, 1) 451 | label = np.rot90(label, 1) 452 | except KeyError: 453 | pass 454 | return scan, label 455 | -------------------------------------------------------------------------------- /tool/infer/util.bk: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import math 4 | import concurrent.futures 5 | import multiprocessing 6 | 7 | import nibabel as nib 8 | import numpy as np 9 | import cv2 10 | import scipy.ndimage 11 | from tqdm import tqdm 12 | import matplotlib.pyplot as plt 13 | import trimesh 14 | from pypinyin import pinyin, Style 15 | 16 | 17 | logging.basicConfig(level=logging.NOTSET) 18 | 19 | 20 | def filter_largest_volume(label, ratio=1.2, mode="soft"): 21 | """对输入的一个3D标签进行处理,只保留其中最大的连通块 22 | 23 | Parameters 24 | ---------- 25 | label : ndarray 26 | 3D array:一个分割标签 27 | ratio : float 28 | 分割保留的范围 29 | mode : str 30 | "soft" / "hard" 31 | hard是只保留最大的联通块,soft是保留最大连通块bb内的 32 | 33 | Returns 34 | ------- 35 | type 36 | 只保留最大连通块的标签. 37 | 38 | """ 39 | if mode == "soft": 40 | return filter_largest_bb(label, ratio) 41 | vol, num = scipy.ndimage.label(label, np.ones([3, 3, 3])) 42 | maxi = 0 43 | maxnum = 0 44 | for i in range(1, num + 1): 45 | count = vol[vol == i].size 46 | if count > maxnum: 47 | maxi = i 48 | maxnum = count 49 | 50 | vol[vol != maxi] = 0 51 | vol[vol == maxi] = 1 52 | label = vol 53 | return label 54 | 55 | 56 | labels = [] 57 | 58 | # TODO: 添加clip到一个前景类型的功能 59 | def nii2png(scan_path, scan_img_dir, label_path=None, label_img_dir=None, rot=0, wwwc=(400, 0), thresh=None): 60 | """将nii格式的扫描转成png. 61 | 扫描和标签一起处理,支持窗口化,旋转,略过没有前景的片 62 | 63 | Parameters 64 | ---------- 65 | scan_path : str 66 | 扫描nii路径. 67 | scan_img_dir : str 68 | 扫描生成png放到这. 69 | label_path : str 70 | 标签nii路径. 71 | label_img_dir : str 72 | 标签生成png放到这. 73 | rot : int 74 | 进行几次旋转,如果有标签会一起. 75 | wwwc : list/tuple 76 | 进行窗口化的窗宽窗位. 77 | thresh : int 78 | 标签中前景数量达到这个数才生成png,否则略过. 79 | 80 | Returns 81 | ------- 82 | type 83 | Description of returned object. 84 | 85 | """ 86 | scanf = nib.load(scan_path) 87 | scan_data = scanf.get_fdata() 88 | name = os.path.basename(scan_path) 89 | # print(name) 90 | if scan_data.shape[0] == 1024: 91 | print("[WARNNING]", name, "is 1024") 92 | # vol = scipy.ndimage.interpolation.zoom(vol, [0.5, 0.5, 1], order=1 if islabel else 3) 93 | 94 | if label_path: 95 | labelf = nib.load(label_path) 96 | label_data = labelf.get_fdata() 97 | if label_data.shape != scan_data.shape: 98 | print("[ERROR] Scan and image dimension mismatch", name, scan_data.shape, label_data.shape) 99 | 100 | for _ in range(rot): 101 | scan_data = np.rot90(scan_data) 102 | if label_path: 103 | label_data = np.rot90(label_data) 104 | 105 | if not os.path.exists(scan_img_dir): 106 | os.makedirs(scan_img_dir) 107 | if label_path and not os.path.exists(label_img_dir): 108 | os.makedirs(label_img_dir) 109 | 110 | wl, wh = (wwwc[1] - wwwc[0] / 2, wwwc[1] + wwwc[0] / 2) 111 | scan_data = scan_data.astype("float32").clip(wl, wh) 112 | scan_data = (scan_data - wl) / (wh - wl) * 256 113 | scan_data = scan_data.astype("uint8") 114 | 115 | with concurrent.futures.ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor: 116 | for ind in range(1, scan_data.shape[2] - 1): 117 | if label_path: 118 | label_slice = label_data[:, :, ind] 119 | # 如果进行thresh限制而且这一片不达到这个数,那么就直接跳过,label和scan都不存 120 | if thresh and label_slice.sum() < thresh: 121 | continue 122 | file_path = os.path.join(label_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind)) 123 | executor.submit(save_png, label_slice, file_path) 124 | 125 | scan_slice = scan_data[:, :, ind - 1 : ind + 2] 126 | file_path = os.path.join(scan_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind)) 127 | executor.submit(save_png, scan_slice, file_path) 128 | if label_path: 129 | label_slice = label_data[:, :, ind] 130 | file_path = os.path.join(label_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind)) 131 | executor.submit(save_png, label_slice, file_path) 132 | 133 | # input("here") 134 | 135 | 136 | def save_png(slice, file_path): 137 | cv2.imwrite(file_path, slice) 138 | 139 | 140 | def nii2png_single(nii_path, png_folder, rot=1, wwwl=(256, 0), islabel=False, thresh=0): 141 | """将一个nii扫描转换成一系列图片,并进行简单的检查. 142 | # TODO: 检查是否只有一个连通块 143 | # TODO: 检查是否只有一种前景 144 | 145 | Parameters 146 | ---------- 147 | nii_path : str 148 | nii扫描文件的路径. 149 | png_path : type 150 | 图片存在哪个文件夹. 151 | rot : int 152 | 扫描进行几次旋转,摆正体位. 153 | wwwl : (int, int) 154 | 窗宽窗位. 155 | """ 156 | volf = nib.load(nii_path) 157 | nii_name = os.path.basename(nii_path) 158 | vol = volf.get_fdata() 159 | if vol.shape[0] == 1024: 160 | vol = scipy.ndimage.interpolation.zoom(vol, [0.5, 0.5, 1], order=1 if islabel else 3) 161 | for _ in range(rot): 162 | vol = np.rot90(vol) 163 | if not islabel: 164 | # vol = vol + np.random.randn() * 50 - 10 165 | wl, wh = (wwwl[1] - wwwl[0] / 2, wwwl[1] + wwwl[0] / 2) 166 | vol = vol.astype("float32").clip(wl, wh) 167 | vol = (vol - wl) / (wh - wl) * 256 168 | vol = vol.astype("uint8") 169 | if not os.path.exists(png_folder): 170 | os.makedirs(png_folder) 171 | 172 | for ind in range(1, vol.shape[2] - 1): 173 | if islabel: 174 | slice = vol[:, :, ind] 175 | else: 176 | slice = vol[:, :, ind - 1 : ind + 2] 177 | if islabel: 178 | sum = np.sum(slice) 179 | print(sum, thresh) 180 | if sum <= thresh: 181 | continue 182 | slice[slice == 2] = 1 183 | 184 | file_path = os.path.join(png_folder, "{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind)) 185 | # if not islabel: 186 | # if "{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind) not in labels: 187 | # print("{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind)) 188 | # continue 189 | 190 | cv2.imwrite(file_path, slice) 191 | 192 | 193 | def nii2png_folder(nii_folder, png_folder, rot=1, wwwl=(400, 0), subfolder=False, islabel=False, thresh=0): 194 | """将一个文件夹里所有的nii转换成png. 195 | 196 | Parameters 197 | ---------- 198 | nii_folder : str 199 | 放所有nii的文件夹. 200 | png_folder : str 201 | 放所有png的文件夹. 202 | rot : int 203 | 旋转几次. 204 | wwwl : (int, int) 205 | 窗宽窗位. 206 | subfolder : bool 207 | 是否给每一个nii创建单独的文件夹. 208 | 209 | """ 210 | nii_names = os.listdir(nii_folder) 211 | for nii_name in tqdm(nii_names): 212 | if subfolder: 213 | png_folder = os.path.join(png_folder, nii_name) 214 | if len(nii_name) > 12: 215 | nii2png_single(os.path.join(nii_folder, nii_name), png_folder, 1, wwwl, islabel, thresh=thresh) 216 | else: 217 | nii2png_single(os.path.join(nii_folder, nii_name), png_folder, 3, wwwl, islabel, thresh=thresh) 218 | # print(len(nii_name)) 219 | # print(nii_name) 220 | # input("here") 221 | # os.system("rm /home/lin/Desktop/data/aorta/dataset/scan/*") 222 | 223 | 224 | def check_nii_match(scan_dir, label_dir, remove=False): 225 | # TODO: 用片间间隔和大小计算片内方向的边长,太大或者太小报错 226 | """检查两个目录下的扫描和标签是不是对的上. 227 | 228 | Parameters 229 | ---------- 230 | scan_dir : str 231 | 扫描所在路径. 232 | label_dir : str 233 | 标签所在路径. 234 | 235 | Returns 236 | ------- 237 | bool 238 | 比较的结果,对上了返回True,否则False,具体的细节直接打到stdio. 239 | 240 | """ 241 | pass_check = True 242 | scans = os.listdir(scan_dir) 243 | labels = os.listdir(label_dir) 244 | scans = [n for n in scans if n.endswith("nii") or n.endswith("gz")] 245 | labels = [n for n in labels if n.endswith("nii") or n.endswith("gz")] 246 | 247 | scan_names = [n.rstrip(".gz").rstrip(".nii") for n in scans] 248 | label_names = [n.rstrip(".gz").rstrip(".nii") for n in labels] 249 | 250 | if len(scans) != len(labels): 251 | logging.error("Number of scnas({}) and labels ({}) don't match".format(len(scans), len(labels))) 252 | pass_check = False 253 | else: 254 | logging.info("Pass file number check") 255 | 256 | names_match = True 257 | for ind, s in enumerate(scan_names): 258 | if s not in label_names: 259 | logging.error("Scan {} dont have corresponding label".format(s)) 260 | names_match = False 261 | print("removing {}".format(s)) 262 | if remove: 263 | os.remove(os.path.join(scan_dir, scans[ind])) 264 | 265 | for l in label_names: 266 | if l not in scan_names: 267 | logging.error("Label {} dont have corresponding scan".format(l)) 268 | names_match = False 269 | 270 | if names_match: 271 | logging.info("Pass file names check") 272 | else: 273 | pass_check = False 274 | 275 | scans = os.listdir(scan_dir) 276 | labels = os.listdir(label_dir) 277 | scans = [n for n in scans if n.endswith("nii") or n.endswith("gz")] 278 | labels = [n for n in labels if n.endswith("nii") or n.endswith("gz")] 279 | 280 | scan_names = [n.rstrip(".gz").rstrip(".nii") for n in scans] 281 | label_names = [n.rstrip(".gz").rstrip(".nii") for n in labels] 282 | 283 | for scan_name in scans: 284 | scanf = nib.load(os.path.join(scan_dir, scan_name)) 285 | labelf = nib.load(os.path.join(label_dir, scan_name)) 286 | if (scanf.affine == np.eye(4)).all(): 287 | logging.warn("Scan {} have np.eye(4) affine, check the header".format(scan_name)) 288 | if (labelf.affine == np.eye(4)).all(): 289 | logging.warn("Label {} have np.eye(4) affine, check the header".format(scan_name)) 290 | if not (labelf.header["dim"] == scanf.header["dim"]).all(): 291 | logging.error( 292 | "Label and scan dimension mismatch for {}, scan is {}, label is {}".format( 293 | scan_name, scanf.header["dim"][1:4], labelf.header["dim"][1:4] 294 | ) 295 | ) 296 | pass_check = False 297 | return pass_check 298 | 299 | 300 | def inspect_pair(scan_path, label_path): 301 | 302 | # 如果是nii格式 303 | # TODO: 完善 304 | if scan_path.endswith("nii") or scan_path.endswith("gz"): 305 | pass 306 | # TODO: 一对图片放到一个frame 307 | if scan_path.endswith("png"): 308 | scan = cv2.imread(scan_path) 309 | label = cv2.imread(label_path) 310 | plt.imshow(scan) 311 | plt.show() 312 | label = label * 255 313 | plt.imshow(label) 314 | plt.show() 315 | 316 | 317 | def is_right(a, b, c): 318 | a = np.array((b[0] - a[0], b[1] - a[1])) 319 | b = np.array((c[0] - b[0], c[1] - b[1])) 320 | res = np.cross(a, b) 321 | if res >= 0: 322 | return True 323 | return False 324 | 325 | 326 | class Polygon: 327 | points = [] 328 | center = [] 329 | height = 0 330 | base = [] 331 | epsilon = 1e-6 332 | 333 | def __init__(self, p): 334 | if len(p) == 0: 335 | raise RuntimeError("Nan points in polygon") 336 | self.points = p 337 | if len(self.points[0]) == 2: 338 | points = [] 339 | for p in self.points: 340 | points.append(list(p[0])) 341 | unique = [] 342 | for p in points: 343 | if p not in unique: 344 | unique.append(p) 345 | self.points = unique 346 | self.cal_rep() 347 | self.ang_sort() 348 | return 349 | 350 | for ind in range(len(self.points)): 351 | self.points[ind] = list(self.points[ind]) 352 | 353 | self.cal_rep() 354 | self.ang_sort() 355 | try: 356 | self.height = p[0][2] 357 | except: 358 | print(p) 359 | 360 | def cal_rep(self): 361 | """计算中心点和基点. 362 | 两个操作有顺序 363 | Returns 364 | ------- 365 | type 366 | Description of returned object. 367 | 368 | """ 369 | # print("___", np.min(self.points, axis=0)) 370 | # print("---", np.max(self.points, axis=0)) 371 | self.center = list((np.min(self.points, axis=0) + np.max(self.points, axis=0)) / 2) 372 | # print(self.center) 373 | # input("here") 374 | self.points.sort() 375 | self.base = self.points[0] 376 | del self.points[0] 377 | 378 | def ang_sort(self): 379 | """对所有的点进行极角排序. 380 | 381 | Returns 382 | ------- 383 | type 384 | Description of returned object. 385 | 386 | """ 387 | 388 | def cmp(a): 389 | return math.atan((a[1] - self.base[1]) / (a[0] - self.base[0] + self.epsilon)) 390 | 391 | self.points.sort(key=cmp, reverse=True) 392 | 393 | def cal_size(self): 394 | """给一个数组的点,求它构成的多边形的面积. 395 | 注意这里没有做极角排序,点本身需要满足顺时针或者逆时针顺序 396 | 397 | Parameters 398 | ---------- 399 | points : type 400 | Description of parameter `points`. 401 | 402 | Returns 403 | ------- 404 | type 405 | Description of returned object. 406 | 407 | """ 408 | points = self.points 409 | tot = 0 410 | p = self.base 411 | for ind in range(0, len(points) - 1): 412 | a = [t2 - t1 for t1, t2 in zip(p, points[ind])] 413 | b = [t2 - t1 for t1, t2 in zip(p, points[ind + 1])] 414 | a = np.array(a) 415 | b = np.array(b) 416 | # b = b.reshape(b.size, 1) 417 | # print(a, b) 418 | res = np.cross(a, b) 419 | tot += (res[0] ** 2 + res[1] ** 2 + res[2] ** 2) ** (1 / 2) / 2 420 | return tot 421 | 422 | def to_2d(self): 423 | def rot_to_horizontal(p): 424 | """将一个点旋转到水平面上. 425 | 426 | Parameters 427 | ---------- 428 | p : type 429 | Description of parameter `p`. 430 | 431 | Returns 432 | ------- 433 | type 434 | Description of returned object. 435 | 436 | """ 437 | # TODO: 在基本和y轴平行的时候会有div0错误 438 | epsilon = 1e-6 439 | if p[0] == 0 and p[1] == 0 and p[2] == 0: 440 | return [0, 0] 441 | x = p[0] 442 | y = p[1] 443 | z = p[2] 444 | angle = math.atan(z / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon)) 445 | unit = ( 446 | y / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon), 447 | -x / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon), 448 | 0, 449 | ) 450 | matrix = trimesh.transformations.rotation_matrix(-angle, unit, (0, 0, 0)) 451 | p.append(1) 452 | p = np.array(p).reshape([1, 4]) 453 | p = p.transpose() 454 | res = np.dot(matrix, p) 455 | return [float(res[0]), float(res[1])] 456 | 457 | self.points = [[p[0] - self.base[0], p[1] - self.base[1], p[2] - self.base[2]] for p in self.points] 458 | # print("+_+", self.center) 459 | self.center = [b - a for a, b in zip(self.base, self.center)] 460 | self.center = rot_to_horizontal(self.center) 461 | # print("_+_", self.center) 462 | self.points = [rot_to_horizontal(p) for p in self.points] 463 | self.base = [0, 0] 464 | 465 | def cal_diameter(self, ang_range=[0, np.pi], split=30, pixdim=1, step=1): 466 | """计算这个多边形的直径 467 | 1. 将所有点旋转到一个平面内 468 | 2. 将半个圆周分成split份,在center做两条线,分别往向上和线下的方向运动 469 | 3. 这个线第一次让所有多边形端点都在线一侧停止运动 470 | 4. 计算两根线距离,作为直径 471 | 472 | Parameters 473 | ---------- 474 | ang_range : type 475 | Description of parameter `ang_range`. 476 | split : type 477 | Description of parameter `split`. 478 | pixdim : type 479 | Description of parameter `pixdim`. 480 | step : type 481 | Description of parameter `step`. 482 | 483 | Returns 484 | ------- 485 | type 486 | Description of returned object. 487 | 488 | """ 489 | self.to_2d() 490 | # self.plot_2d() 491 | center = self.center 492 | diameters = [self.height] 493 | for alpha in np.arange(ang_range[0], ang_range[1], (ang_range[1] - ang_range[0]) / split): 494 | # TODO: 如果这个线是垂直的 495 | if alpha == np.pi / 2: 496 | continue 497 | k = math.tan(alpha) 498 | # print(alpha) 499 | 500 | d = 0 501 | d1 = 0 502 | d2 = 0 503 | while True: 504 | # print("+:", d) 505 | y0 = center[1] - d + k * (0 - center[0]) 506 | y1 = center[1] - d + k * (1 - center[0]) 507 | dir = is_right((0, y0), (1, y1), self.points[0]) 508 | same_dir = True 509 | for p in self.points: 510 | tmp = is_right((0, y0), (1, y1), p) 511 | if tmp != dir: 512 | same_dir = False 513 | break 514 | if same_dir: 515 | d1 = d 516 | break 517 | d += step 518 | 519 | d = 0 520 | while True: 521 | # print("-:", d) 522 | y0 = center[1] - d + k * (0 - center[0]) 523 | y1 = center[1] - d + k * (1 - center[0]) 524 | dir = is_right((0, y0), (1, y1), self.points[0]) 525 | same_dir = True 526 | for p in self.points: 527 | tmp = is_right((0, y0), (1, y1), p) 528 | if tmp != dir: 529 | same_dir = False 530 | break 531 | if same_dir: 532 | d2 = d 533 | break 534 | d += step 535 | diameters.append((d1 + d2) * np.abs(np.cos(alpha)) * pixdim) 536 | 537 | return diameters 538 | 539 | def plot_2d(self): 540 | plt.scatter([p[0] for p in self.points], [p[1] for p in self.points]) 541 | # plt.plot([self.center[0]], [self.center[1]]) 542 | plt.show() 543 | 544 | 545 | """ 546 | y-y0=k(x-x0)+b 547 | x=0, y=y0+k(x-x0)+b 548 | """ 549 | # po = Polygon([[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]]) 550 | # print(po.cal_diameter(split=2, step=0.1)) 551 | 552 | 553 | def dist(a, b): 554 | res = 0 555 | for p, q in zip(a, b): 556 | res += (p - q) ** 2 557 | res = res ** (1 / 2) 558 | return res 559 | 560 | 561 | # print(dist([1, 0, 1], [0, 0, 3])) 562 | 563 | 564 | def filter_polygon(points, center, thresh=1): 565 | """给一堆点,不一定是一个多边形,返回中心离center比较近的那个多边形. 566 | 567 | Parameters 568 | ---------- 569 | points : type 570 | Description of parameter `points`. 571 | 572 | Returns 573 | ------- 574 | type 575 | Description of returned object. 576 | 577 | """ 578 | # TODO: 点需要极角排序 579 | 580 | curr_point = points[-1] 581 | polygons = [[]] 582 | p_ind = 0 583 | while len(points) != 0: 584 | found = False 585 | for ind in range(len(points) - 1, -1, -1): 586 | if dist(curr_point, points[ind]) < thresh: 587 | curr_point = points[ind] 588 | del points[ind] 589 | polygons[p_ind].append(curr_point) 590 | found = True 591 | break 592 | if not found: 593 | polygons.append([]) 594 | p_ind += 1 595 | curr_point = points[-1] 596 | if center == "all": 597 | return polygons 598 | centers = [] 599 | for polygon in polygons: 600 | ps = np.array(polygon) 601 | centers.append(np.mean(ps, axis=0)) 602 | dists = [dist(p, center) for p in centers] 603 | min_ind = np.argmin(dists) 604 | print(min_ind) 605 | return list(polygons[min_ind]) 606 | 607 | 608 | # print(filter_polygon([[0, 0, 0], [1, 1, 1], [4, 4, 4], [5, 5, 5]], [0, 0, 0])) 609 | 610 | 611 | def sort_line(polygons): 612 | """给一个list的多边形,找到一个开头,之后从这个开头开始dfs的顺序给序列排序. 613 | 614 | Parameters 615 | ---------- 616 | points : type 617 | Description of parameter `points`. 618 | dist : type 619 | Description of parameter `dist`. 620 | 621 | Returns 622 | ------- 623 | type 624 | Description of returned object. 625 | 626 | """ 627 | # 从最低点开始 628 | polygons.sort(key=lambda a: [a.center[2], a.center[1], a.center[0]], reverse=True) 629 | curr_point = polygons[-1].center 630 | ordered = [] 631 | while len(polygons) != 0: 632 | # 找最近的点 633 | min_dist = dist(curr_point, polygons[-1].center) 634 | min_ind = len(polygons) - 1 635 | for ind in range(len([polygons])): 636 | curr_dist = dist(polygons[ind].center, curr_point) 637 | if curr_dist < min_dist: 638 | min_dist = curr_dist 639 | min_ind = ind 640 | curr_point = polygons[min_ind].center 641 | ordered.append(polygons[min_ind]) 642 | polygons.pop(min_ind) 643 | return ordered 644 | 645 | 646 | # print(sort_line([[0, 0, 0], [0, 0, 4], [0, 0, 3], [0, 0, 2], [0, 0, 6], [0, 1, 5.1], [0, 1, 6], [0, 0, 5]])) 647 | 648 | 649 | def to_pinyin(name, nonum=False): 650 | new_name = "" 651 | for ch in name: 652 | if u"\u4e00" <= ch <= u"\u9fff": 653 | new_name += pinyin(ch, style=Style.NORMAL)[0][0] 654 | else: 655 | # if nonum and ("0" <= ch <= "9" or ch == "_"): 656 | # continue 657 | new_name += ch 658 | return new_name 659 | --------------------------------------------------------------------------------