├── medseg
├── __init__.py
├── models
│ ├── __init__.py
│ ├── model.py
│ ├── unet.py
│ ├── unet_plain.py
│ ├── res_unet.py
│ ├── hrnet.py
│ └── deeplabv3p.py
├── utils
│ ├── __init__.py
│ ├── config.py
│ └── util.py
├── prep_2d.py
├── loss.py
├── eval.py
├── infer.py
├── vis.py
├── aug.py
├── prep_3d.py
├── prep_png.py
└── train.py
├── tool
├── infer
│ ├── util.py
│ ├── connected.py
│ ├── merge.py
│ ├── vote.py
│ ├── aorta.py
│ ├── 3d_diameter.py
│ ├── slice2nii.py
│ ├── 2d_diameter.py
│ └── util.bk
├── train
│ ├── util.py
│ ├── vis.py
│ ├── mhd2nii.py
│ ├── gen_list.py
│ ├── folder_split.py
│ ├── resize.py
│ ├── to_slice.py
│ ├── slice_mp.py
│ └── dataset_scan.py
├── SimHei.ttf
├── to_pinyin.py
├── README.md
├── flood_fill.py
└── zip_dataset.py
├── .gitattributes
├── requirements.txt
├── config
├── lits-zprep.yaml
├── hrnet_optic.yaml
├── unet_optic.yaml
├── deeplab_optic.yaml
├── resunet_optic.yaml
├── eval.yaml
├── resunet_aorta.yaml
├── test.yaml
└── resunet_lits.yaml
├── .gitignore
├── README.md
└── README_en.md
/medseg/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/medseg/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/medseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tool/infer/util.py:
--------------------------------------------------------------------------------
1 | ../util.py
--------------------------------------------------------------------------------
/tool/train/util.py:
--------------------------------------------------------------------------------
1 | ../util.py
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ttf filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nibabel
2 | scikit-image
3 | tqdm
4 | numpy
5 | visualdl
6 | medpy
7 | pypinyin
8 | trimesh
9 | SimpleITK
10 | paddleseg
11 |
--------------------------------------------------------------------------------
/config/lits-zprep.yaml:
--------------------------------------------------------------------------------
1 | PREP:
2 | PLANE: 'xz'
3 | FRONT: 1
4 | SIZE: [-1, 512,512]
5 | THRESH: 512
6 | INTERP: True
7 | INTERP_PIXDIM: [-1, -1, 1]
8 |
--------------------------------------------------------------------------------
/tool/SimHei.ttf:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:fd952619133560ec32c44d9b5f95a795d273db6aafa507edbcb2686c26f2c203
3 | size 10061591
4 |
--------------------------------------------------------------------------------
/config/hrnet_optic.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "optic"
3 | INPUTS_PATH: "/home/aistudio/optic/image"
4 | LABELS_PATH: "/home/aistudio/optic/label"
5 | PREP_PATH: "/home/aistudio/optic/prep"
6 | TRAIN:
7 | ARCHITECTURE: "hrnet"
8 | BATCH_SIZE: 8
9 | SNAPSHOT_BATCH: 30
10 | DISP_BATCH: 1
11 |
--------------------------------------------------------------------------------
/config/unet_optic.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "optic"
3 | INPUTS_PATH: "/home/aistudio/optic/image"
4 | LABELS_PATH: "/home/aistudio/optic/label"
5 | PREP_PATH: "/home/aistudio/optic/prep"
6 | TRAIN:
7 | ARCHITECTURE: "res_unet"
8 | BATCH_SIZE: 8
9 | SNAPSHOT_BATCH: 30
10 | DISP_BATCH: 1
11 |
--------------------------------------------------------------------------------
/config/deeplab_optic.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "optic"
3 | INPUTS_PATH: "/home/aistudio/optic/image"
4 | LABELS_PATH: "/home/aistudio/optic/label"
5 | PREP_PATH: "/home/aistudio/optic/prep"
6 | TRAIN:
7 | ARCHITECTURE: "deeplabv3p"
8 | BATCH_SIZE: 8
9 | SNAPSHOT_BATCH: 30
10 | DISP_BATCH: 1
11 |
--------------------------------------------------------------------------------
/config/resunet_optic.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "optic"
3 | INPUTS_PATH: "/home/aistudio/optic/image"
4 | LABELS_PATH: "/home/aistudio/optic/label"
5 | PREP_PATH: "/home/aistudio/optic/prep"
6 | TRAIN:
7 | ARCHITECTURE: "res_unet"
8 | BATCH_SIZE: 8
9 | SNAPSHOT_BATCH: 30
10 | DISP_BATCH: 1
11 |
--------------------------------------------------------------------------------
/config/eval.yaml:
--------------------------------------------------------------------------------
1 | EVAL:
2 | PATH:
3 | SEG: "/data/aorta/300/label/nii/aunet-best"
4 | GT: "/data/aorta/300/test/label"
5 | NAME: "aunet-best"
6 | METRICS: [
7 | "IOU",
8 | "Dice",
9 | "TP",
10 | "TN",
11 | "Precision",
12 | "Recall",
13 | "Sensitivity",
14 | "Specificity",
15 | "Accuracy",
16 | "Dice",
17 | "IOU",
18 | "Assd",
19 | "Ravd",
20 | ]
21 |
--------------------------------------------------------------------------------
/config/resunet_aorta.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "lits"
3 | INPUTS_PATH: "/home/aistudio/data/volume"
4 | LABELS_PATH: "/home/aistudio/data/label"
5 | PREP_PATH: "/home/aistudio/data/preprocess"
6 | PREP:
7 | FRONT: 1
8 | WINDOW: False
9 | CROP: False
10 | THRESH: 128
11 | SIZE: (512, 512, -1)
12 | BATCH_SIZE: 16
13 | TRAIN:
14 | ARCHITECTURE: "res_unet"
15 | BATCH_SIZE: 30
16 | SNAPSHOT_BATCH: 300
17 | DISP_BATCH: 10
18 |
19 | AUG:
20 | ROTATE:
21 | RATIO: (0, 0.3, 0)
22 | RANGE: (0,(-10,10),0)
23 | WINDOWLIZE: True
24 | WWWC: (400, 0)
25 |
--------------------------------------------------------------------------------
/tool/infer/connected.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | from tqdm import tqdm
5 | import nibabel as nib
6 | import numpy as np
7 |
8 | from util import filter_largest_volume
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("-i", "--in_dir", type=str, required=True)
12 | parser.add_argument("-o", "--out_dir", type=str, required=True)
13 | args = parser.parse_args()
14 |
15 | names = os.listdir(args.in_dir)
16 |
17 | for name in tqdm(names):
18 | segf = nib.load(os.path.join(args.in_dir, name))
19 | header = segf.header
20 | data = segf.get_fdata()
21 | data = filter_largest_volume(data, mode="hard")
22 | newf = nib.Nifti1Image(data, segf.affine, header)
23 | nib.save(newf, os.path.join(args.out_dir, name))
24 |
--------------------------------------------------------------------------------
/tool/train/vis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | for k, v in os.environ.items():
9 | if k.startswith("QT_") and "cv2" in v:
10 | del os.environ[k]
11 |
12 |
13 | base_dir = "/home/lin/Desktop/data/lits/img"
14 | img_dir = osp.join(base_dir, "JPEGImages")
15 | lab_dir = osp.join(base_dir, "Annotations")
16 |
17 | for f in os.listdir(lab_dir)[:30]:
18 | img = osp.join(img_dir, f)
19 | lab = osp.join(lab_dir, f)
20 |
21 | img = cv2.imread(img)
22 | lab = cv2.imread(lab, cv2.IMREAD_UNCHANGED)
23 | print(lab.sum())
24 | lab = lab * 255
25 | # lab = lab.reshape([512, 512, 1])
26 | print(img.shape, lab.shape)
27 |
28 | fig = plt.figure(figsize=(10, 10))
29 | fig.add_subplot(1, 2, 1)
30 | plt.imshow(img)
31 | fig.add_subplot(1, 2, 2)
32 | plt.imshow(lab)
33 | plt.show()
34 |
--------------------------------------------------------------------------------
/tool/to_pinyin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import argparse
5 | from tqdm import tqdm
6 |
7 | import util
8 |
9 |
10 | parser = argparse.ArgumentParser("zip_dataset")
11 | parser.add_argument("--in_dir", type=str)
12 | parser.add_argument("--out_dir", type=str)
13 | args = parser.parse_args()
14 |
15 |
16 | def to_pinyin(name, nonum=False):
17 | new_name = ""
18 | for ch in name:
19 | if u"\u4e00" <= ch <= u"\u9fff":
20 | new_name += pinyin(ch, style=Style.NORMAL)[0][0]
21 | else:
22 | # if nonum and ("0" <= ch <= "9" or ch == "_"):
23 | # continue
24 | new_name += ch
25 | return new_name
26 |
27 |
28 | def main():
29 | files = os.listdir(args.in_dir)
30 | for f in tqdm(files):
31 | shutil.copy(
32 | os.path.join(args.in_dir, f),
33 | os.path.join(args.out_dir, to_pinyin(f)),
34 | )
35 | # input("here")
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/tool/train/mhd2nii.py:
--------------------------------------------------------------------------------
1 | # 将mhd的扫描转换为nii格式
2 | import os
3 | import argparse
4 |
5 | import SimpleITK as sitk
6 | import nibabel as nib
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--mhd_dir", type=str, required=True)
12 | parser.add_argument("--nii_dir", type=str, required=True)
13 | parser.add_argument("--rot", type=int, default=0)
14 | args = parser.parse_args()
15 |
16 | if not os.path.exists(args.nii_dir):
17 | os.makedirs(args.nii_dir)
18 |
19 | # TODO: 添加多进程
20 | for fname in tqdm(os.listdir(args.mhd_dir)):
21 | if not fname.endswith(".mhd"):
22 | continue
23 | scan = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(args.mhd_dir, fname)))
24 | scan = scan.swapaxes(0, 1).swapaxes(1, 2)
25 | for _ in range(args.rot):
26 | scan = np.rot90(scan)
27 |
28 | # TODO: 研究mhd/raw格式是否带有更多头文件信息
29 | new_scan = nib.Nifti1Image(scan, np.eye(4))
30 | nib.save(new_scan, os.path.join(args.nii_dir, fname.replace("mhd", "nii.gz")))
31 |
--------------------------------------------------------------------------------
/tool/infer/merge.py:
--------------------------------------------------------------------------------
1 | # 将肝脏和肿瘤的分割标签合并
2 | import nibabel as nib
3 | import os
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 |
8 | liver_path = "/home/aistudio/data/liver"
9 | tumor_path = "/home/aistudio/data/tumor"
10 | merge_path = "/home/aistudio/data/merge"
11 |
12 |
13 | assert len(os.listdir(liver_path)) == len(os.listdir(tumor_path)), "肝脏和肿瘤的分割标签数量不想等"
14 |
15 | for liver_fname in tqdm(os.listdir(liver_path)):
16 | liverf = nib.load(os.path.join(liver_path, liver_fname))
17 | tumorf = nib.load(os.path.join(tumor_path, liver_fname))
18 |
19 | liver = liverf.get_fdata()
20 | tumor = tumorf.get_fdata()
21 |
22 | assert len(np.where(tumor == 1)[0]) == 0, "肿瘤中包含标签为1的前景,是不是肝脏和肿瘤数据弄反了"
23 | assert len(np.where(liver == 2)[0]) == 0, "肝脏中包含标签为2的前景,是不是肝脏和肿瘤数据弄反了"
24 |
25 | print("肝体{}, 肿瘤{}".format(liver.sum(), tumor.sum()))
26 |
27 | tumor = tumor / 2
28 | liver += tumor
29 |
30 | merge_file = nib.Nifti1Image(liver, liverf.affine)
31 | nib.save(merge_file, os.path.join(merge_path, liver_fname))
32 |
--------------------------------------------------------------------------------
/config/test.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "lits"
3 | INPUTS_PATH: "/home/aistudio/data/scan_temp"
4 | LABELS_PATH: "/home/aistudio/data/label_temp"
5 | PREP_PATH: "/home/aistudio/data/preprocess"
6 | PREP:
7 | FRONT: 1
8 | WINDOW: False
9 | CROP: False
10 | THRESH: 1024
11 | SIZE: (512, 512, -1)
12 | BATCH_SIZE: 16
13 | TRAIN:
14 | ARCHITECTURE: "unet_plain"
15 | BATCH_SIZE: 1
16 | SNAPSHOT_BATCH: 600
17 | DISP_BATCH: 1
18 | REG_TYPE: L1
19 | REG_COEFF: 1e-6
20 | AUG:
21 | ROTATE:
22 | RATIO: (0, 0.2, 0)
23 | RANGE: (0,(-15, -16),0)
24 | FLIP:
25 | RATIO: (0, 0, 0)
26 | WINDOWLIZE: True
27 | WWWC: (400, 0)
28 | ZOOM:
29 | RATIO: (0, 0.3, 0.3)
30 | RANGE: (0, (0.8, 1.0), (0.8, 1.0))
31 | INFER:
32 | PATH:
33 | INPUT: "/home/aistudio/data/inference"
34 | OUTPUT: "/home/aistudio/data/infer_lab"
35 | PARAM: "/home/aistudio/param/liver_inf"
36 | BATCH_SIZE: 196
37 | EVAL:
38 | PATH:
39 | SEG: "/home/aistudio/data/infer_lab"
40 | GT: "/home/aistudio/data/label"
41 | METRICS: [
42 | "IOU",
43 | "Dice",
44 | "TP",
45 | "TN",
46 | "Precision",
47 | "Recall",
48 | "Sensitivity",
49 | "Specificity",
50 | "Accuracy",
51 | ]
52 |
--------------------------------------------------------------------------------
/medseg/models/model.py:
--------------------------------------------------------------------------------
1 | from utils.config import cfg
2 | from models.res_unet import res_unet
3 | from models.unet_plain import unet_plain
4 | from models.hrnet import hrnet
5 | from models.deeplabv3p import deeplabv3p
6 |
7 | # models = {
8 | # "unet_simple": unet_simple(volume, 2, [512, 512]),
9 | # "res_unet": unet_base(volume, 2, [512, 512]),
10 | # "deeplabv3": deeplabv3p(volume, 2),
11 | # "hrnet": hrnet(volume, 2),
12 | # }
13 |
14 |
15 | def create_model(input, num_class=2):
16 | """构建训练模型.
17 |
18 | Parameters
19 | ----------
20 | input : paddle.data
21 | 输入的 placeholder.
22 | num_class : int
23 | 输出分类有几类.
24 |
25 | Returns
26 | -------
27 | type
28 | 构建好的模型.
29 |
30 | """
31 | if cfg.TRAIN.ARCHITECTURE == "unet_plain":
32 | return unet_plain(input, num_class, cfg.TRAIN.INPUT_SIZE)
33 | if cfg.TRAIN.ARCHITECTURE == "res_unet":
34 | return res_unet(input, num_class, cfg.TRAIN.INPUT_SIZE)
35 | if cfg.TRAIN.ARCHITECTURE == "hrnet":
36 | return hrnet(input, num_class)
37 | if cfg.TRAIN.ARCHITECTURE == "deeplabv3p":
38 | return deeplabv3p(input, num_class)
39 | raise Exception("错误的网络类型: {}".format(cfg.TRAIN.ARCHITECTURE))
40 |
--------------------------------------------------------------------------------
/config/resunet_lits.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | NAME: "lits"
3 | INPUTS_PATH: "/home/aistudio/data/scan"
4 | LABELS_PATH: "/home/aistudio/data/label"
5 | PREP_PATH: "/home/aistudio/data/preprocess"
6 | PREP:
7 | FRONT: 1
8 | WINDOW: False
9 | CROP: False
10 | THRESH: 0
11 | SIZE: (512, 512, -1)
12 | BATCH_SIZE: 20
13 | THICKNESS: 5
14 |
15 | TRAIN:
16 | ARCHITECTURE: "res_unet"
17 | BATCH_SIZE: 32
18 | EPOCHS: 30
19 | SNAPSHOT_BATCH: 600
20 | DISP_BATCH: 20
21 | OPTIMIZER: "adam"
22 | REG_TYPE: L1
23 | REG_COEFF: 3e-6
24 | BOUNDARIES: [80000, 20000]
25 | LR: [0.002, 0.001, 0.0005]
26 |
27 | AUG:
28 | ROTATE:
29 | RATIO: (0, 0, 0)
30 | RANGE: (0,(-15, -16),0)
31 | FLIP:
32 | RATIO: (0, 0, 0)
33 | WINDOWLIZE: True
34 | WWWC: (400, 0)
35 | ZOOM:
36 | RATIO: (0, 0.3, 0.3)
37 | RANGE: (0, (0.8, 1.1), (0.8, 1.1))
38 |
39 | INFER:
40 | PATH:
41 | INPUT: "/home/aistudio/data/inference"
42 | OUTPUT: "/home/aistudio/data/infer_lab"
43 | PARAM: "/home/aistudio/liverSeg/model/lits/inf"
44 | BATCH_SIZE: 64
45 | THRESH: 0.5
46 | FILTER_LARGES: True
47 | WWWC: (400, 0)
48 |
49 | EVAL:
50 | PATH:
51 | SEG: "/home/aistudio/data/infer_lab"
52 | GT: "/home/aistudio/data/label"
53 | METRICS: [
54 | "IOU",
55 | "Dice",
56 | "TP",
57 | "TN",
58 | "Precision",
59 | "Recall",
60 | "Sensitivity",
61 | "Specificity",
62 | "Accuracy",
63 | ]
64 |
--------------------------------------------------------------------------------
/tool/README.md:
--------------------------------------------------------------------------------
1 | 对数据进行预处理和后处理的一些有用的脚本,功能介绍在[../README.md]中有写
2 |
3 | # 数据格式转换
4 | 写代码角度讲nii格式一般用起来比较方便。dcm包含的信息多,但是一个文件夹不一定就是一个序列;一层一个文件的保存方式下,进行大量文件读写I/O效率也不高。这个项目基本都用nii格式,记录一些格式转换方法。
5 |
6 | [dcm2niix](https://github.com/rordenlab/dcm2niix)可以将dcm转换nii,下面是一个命令行转换的例子。
7 | ```shell
8 | dcm2niix -f 输出文件名,支持填入多种扫描里的信息 -d 9 -c 在输出nii文件中写注释 dcm文件夹
9 | ```
10 | train目录下的 [mhd2nii.py](./train/mhd2nii.py) 可以将一个目录下的mhd格式扫描转成nii。
11 |
12 | # 数据标注
13 | 标注数据的时候ITK-snap用起来很方便,用好命令行参数可以自动进行文件打开,节省时间。
14 | ```shell
15 | count=0 ;
16 | tot=`ls -l | wc -l`
17 | for f in `ls`;
18 | do count=`expr $count + 1`;
19 | echo $count / $tot;
20 | echo $f;
21 | echo -e "\n";
22 | itksnap -s /path/to/label/${f} -g /path/to/scan/${f} --geometry 1920x1080+0+0;
23 | done
24 | ```
25 |
26 | beep函数可以用扬声器测试程序发出一声beep,结合定时可以更好掌握标注时间
27 | ```shell
28 | beep1()
29 | {
30 | ( \speaker-test --frequency $1 --test sine )&
31 | pid=$!
32 | \sleep ${2}s
33 | \kill -9 $pid
34 | }
35 |
36 | beep()
37 | {
38 | beep1 350 0.2
39 | beep1 350 0.2
40 | beep1 350 0.2
41 | beep1 350 0.4
42 | sleep 1
43 | }
44 |
45 | count=0 ;
46 | for f in `ls`;
47 | do
48 | # 四个 beep
49 | (sleep 2m; beep;) &
50 | pid2=$!
51 | (sleep 4m; beep; beep;) &
52 | pid4=$!
53 | (sleep 6m; beep; beep; beep;) &
54 | pid6=$!
55 | (sleep 8m; beep; beep; beep; beep; ) &
56 | pid8=$!
57 |
58 | # 计数
59 | count=`expr $count + 1`;
60 | echo $count / `ls -l | wc -l`;
61 | echo ${f};
62 | echo -e "\n";
63 |
64 | # 打开扫描和标签
65 | itksnap -s ./${f} -g ../nii/${f} --geometry 1920x1080+0+0;
66 |
67 | # 文件归档
68 | # cp ./${f} ../manual-label/
69 | mv ./${f} ../manual-fin../ished/
70 |
71 | # 关闭没响的beep
72 | kill -9 $pid2
73 | kill -9 $pid4
74 | kill -9 $pid6
75 | kill -9 $pid8
76 | done
77 | ```
78 |
--------------------------------------------------------------------------------
/tool/infer/vote.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import numpy as np
3 | from multiprocessing import Pool
4 | from multiprocessing import cpu_count
5 | import os
6 |
7 |
8 | def listdir(path):
9 | dirs = os.listdir(path)
10 | if ".DS_Store" in dirs:
11 | dirs.remove(".DS_Store")
12 | if "checkpoint" in dirs:
13 | dirs.remove("checkpoint")
14 |
15 | dirs.sort() # 通过一样的sort保持vol和seg的对应
16 | return dirs
17 |
18 |
19 | voters_base = "/home/aistudio/data/voting"
20 | voter_paths = ["voter1", "voter2", "voter3"]
21 | voter_paths = [os.path.join(voters_base, path) for path in voter_paths]
22 |
23 |
24 | def voting(fname):
25 | voter0 = nib.load(os.path.join(voter_paths[0], fname))
26 | merged = voter0.get_fdata()
27 | print("voting {}".format(fname))
28 | # print(merged.sum())
29 | for ind in range(1, len(voter_paths)):
30 | voterf = nib.load(os.path.join(voter_paths[ind], fname))
31 | merged += voterf.get_fdata()
32 | merged[merged > len(voter_paths) / 2] = 1
33 | merged[merged != 1] = 0
34 |
35 |
36 | def main():
37 | voter_names = [listdir(path) for path in voter_paths]
38 | print(voter_names)
39 | for data_ind in range(len(voter_names[0])):
40 | for voter_ind in range(1, len(voter_names)):
41 | # print(data_ind, voter_ind)
42 | assert voter_names[voter_ind][data_ind] == voter_names[0][data_ind], "第 {} 组数据,{} 和 {} 名称不相同".format(
43 | data_ind, voter_names[voter_ind][data_ind], voter_names[0][data_ind]
44 | )
45 | # for fname in voter_names[0]:
46 | # voting(fname)
47 | with Pool(cpu_count()) as p:
48 | p.map(voting, voter_names[0])
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
--------------------------------------------------------------------------------
/tool/train/gen_list.py:
--------------------------------------------------------------------------------
1 | # 按照文件名排序,生成文件列表
2 | import os
3 | import os.path as osp
4 | import argparse
5 | import logging
6 |
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | import util
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--base_dir", type=str, help="数据集路径", default=None)
14 | parser.add_argument(
15 | "--img_fdr",
16 | type=str,
17 | help="扫描文件夹",
18 | default="JPEGImages",
19 | )
20 | parser.add_argument(
21 | "--lab_fdr",
22 | type=str,
23 | help="标签文件路径",
24 | default="Annotations",
25 | )
26 | parser.add_argument("-d", "--delimiter", type=str, help="分隔符", default=" ")
27 | parser.add_argument(
28 | "-s",
29 | "--split",
30 | nargs=3,
31 | help="训练/验证/测试划分比例,比如 7 2 1",
32 | default=["7", "2", "1"],
33 | )
34 | args = parser.parse_args()
35 |
36 | imgs = util.listdir(osp.join(args.base_dir, args.img_fdr))
37 | labs = util.listdir(osp.join(args.base_dir, args.lab_fdr))
38 | assert len(imgs) == len(
39 | labs
40 | ), f"Scan slice number ({len(imgs)}) isn't equal to label slice number({len(labs)})"
41 |
42 | names = [[i, l] for i, l in zip(imgs, labs)]
43 | file_names = ["train_list.txt", "eval_list.txt", "test_list.txt"]
44 | split = util.toint(args.split)
45 | tot = np.sum(split)
46 | split = [int(s / tot * len(names)) for s in split]
47 | print(f"Train/Eval/Test split is {split}")
48 | split[1] += split[0]
49 | split[2] = tot
50 | part = 0
51 | f = open(osp.join(args.base_dir, file_names[part]), "w")
52 | for idx, (img, lab) in enumerate(names):
53 | if idx == split[part] and idx != len(names) - 1:
54 | f.close()
55 | part += 1
56 | f = open(osp.join(args.base_dir, file_names[part]), "w")
57 |
58 | print(
59 | "{:s}{:s}{:s}".format(
60 | osp.join(args.img_fdr, img),
61 | args.delimiter,
62 | osp.join(args.lab_fdr, lab),
63 | ),
64 | file=f,
65 | )
66 |
--------------------------------------------------------------------------------
/tool/infer/aorta.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import math
5 | from multiprocessing import Pool
6 | import time
7 | import random
8 |
9 | import numpy as np
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--in_dir", type=str, default="./train-diameter")
14 | args = parser.parse_args()
15 |
16 |
17 | def score(data, split=0.2):
18 | if len(data) < 10:
19 | return -1, -1, -1, -1, -1
20 |
21 | mins = []
22 | # print("\n\n\n\n\n", data)
23 | for idx in range(len(data)):
24 | if len(data[idx]) == 1:
25 | continue
26 | mins.append(np.min(data[idx][1:]))
27 | # print(mins)
28 | abdomin = np.max(mins[: int(len(mins) * split)])
29 | chest = np.max(mins[int(len(mins) * split) :])
30 | if abdomin > 30:
31 | abdomin += 20
32 | res = max(abdomin, chest)
33 | if 40 < res < 50:
34 | cat1 = 1
35 | elif res > 50:
36 | cat1 = 2
37 | else:
38 | cat1 = 0
39 | if res > 39.5:
40 | cat2 = 1
41 | else:
42 | cat2 = 0
43 | if res > 50:
44 | cat3 = 1
45 | else:
46 | cat3 = 0
47 | # print(abdomin, chest, cat1, cat2, cat3)
48 | return abdomin, chest, cat1, cat2, cat3
49 |
50 |
51 | if __name__ == "__main__":
52 | names = os.listdir(args.in_dir)
53 | for name in names:
54 | path = os.path.join(args.in_dir, name)
55 | with open(path, "r") as f:
56 | data = f.readlines()
57 | data = [d.split(",") for d in data[1:]]
58 | for i in range(len(data)):
59 | for j in range(len(data[i])):
60 | try:
61 | data[i][j] = float(data[i][j])
62 | except:
63 | del data[i][j]
64 | print(name.split("_")[0], end="\t")
65 | for d in score(data):
66 | print(d, end="\t")
67 | print()
68 |
--------------------------------------------------------------------------------
/medseg/prep_2d.py:
--------------------------------------------------------------------------------
1 | # 将图片组batch,存成和3D预处理一样的npz格式
2 | import os
3 |
4 | import cv2
5 | import argparse
6 | import numpy as np
7 |
8 | import utils.util as util
9 | from utils.config import cfg
10 | import aug
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="数据预处理")
15 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
16 | parser.add_argument("opts", nargs=argparse.REMAINDER)
17 | args = parser.parse_args()
18 |
19 | if args.cfg_file is not None:
20 | cfg.update_from_file(args.cfg_file)
21 | if args.opts:
22 | cfg.update_from_list(args.opts)
23 |
24 |
25 | def main():
26 | images = util.listdir(cfg.DATA.INPUTS_PATH)
27 | labels = util.listdir(cfg.DATA.LABELS_PATH)
28 | print(images)
29 |
30 | if not os.path.exists(cfg.DATA.PREP_PATH):
31 | os.makedirs(cfg.DATA.PREP_PATH)
32 | npz_count = 0
33 | img_npz = []
34 | lab_npz = []
35 | for ind in range(len(images)):
36 | img = cv2.imread(os.path.join(cfg.DATA.INPUTS_PATH, images[ind]))
37 | lab = cv2.imread(os.path.join(cfg.DATA.LABELS_PATH, labels[ind]))
38 | img = img.swapaxes(0, 2)
39 | lab = lab.swapaxes(0, 2)
40 |
41 | lab = lab[0] / 255
42 | lab = lab[np.newaxis, :, :]
43 |
44 | img, lab = aug.crop(img, lab, size=[3, 512, 512])
45 |
46 | img_npz.append(img)
47 | lab_npz.append(lab)
48 |
49 | print(img.shape, lab.shape)
50 |
51 | if len(img_npz) == cfg.PREP.BATCH_SIZE or ind == len(images) - 1:
52 | imgs = np.array(img_npz)
53 | labs = np.array(lab_npz)
54 | file_name = "{}-{}".format(cfg.DATA.NAME, npz_count)
55 | file_path = os.path.join(cfg.DATA.PREP_PATH, file_name)
56 | np.savez(file_path, imgs=imgs, labs=labs)
57 | img_npz = []
58 | lab_npz = []
59 | npz_count += 1
60 |
61 |
62 | if __name__ == "__main__":
63 | parse_args()
64 | main()
65 |
--------------------------------------------------------------------------------
/tool/flood_fill.py:
--------------------------------------------------------------------------------
1 | """
2 | 对分割标签进行漫水填充
3 | """
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from skimage.morphology import flood_fill
7 | import numpy as np
8 | import nibabel as nib
9 | import queue
10 | import os
11 |
12 | # TODO: -60以下前景的要去掉
13 |
14 | lab_path = "/home/lin/Desktop/data/ann/flood/"
15 | fill_path = "/home/lin/Desktop/data/ann/flooded/"
16 | processed = 0
17 | lab_names = os.listdir(lab_path)
18 | for lab_name in lab_names:
19 | print("--------")
20 | print(lab_name)
21 |
22 | if os.path.exists(os.path.join(fill_path, lab_name)):
23 | print("[skp]跳过已经手工refine过的标签")
24 | continue
25 |
26 | labf = nib.load(os.path.join(lab_path, lab_name))
27 | lab = labf.get_fdata()
28 |
29 | # print(np.where(lab == 2))
30 | if len(np.where(lab == 2)[0]) == 0:
31 | print("[skp]跳过没有lab2的标签")
32 | continue
33 |
34 | print(lab.shape)
35 | processed += 1
36 |
37 | # 第 2 个维度是层数,从0开始,和itk里面下表是 i 对应 i + 1
38 | # 片内是 itk 显示的片子顺时针转90度
39 | # 第 0 个维度在数组中是上到下,在itk中是左到右
40 | # 第 1 个维度在数组中是左到右,在itk中是上到下
41 |
42 | for sli_ind in range(lab.shape[2]):
43 |
44 | seeds = np.where(lab[:, :, sli_ind] == 2)
45 | # print(seeds)
46 |
47 | # plt.imshow(lab[:, :, sli_ind])
48 | # plt.show()
49 |
50 | for i in range(len(seeds[0])):
51 | lab[seeds[0][i], seeds[1][i], sli_ind] = 0
52 | slice = lab[:, :, sli_ind]
53 | flood_fill(
54 | slice,
55 | (seeds[1][i], seeds[0][i]),
56 | 1,
57 | selem=[[0, 1, 0], [1, 0, 1], [0, 1, 0]],
58 | inplace=True,
59 | )
60 |
61 | lab[seeds[0][i], seeds[1][i], sli_ind] = 1
62 |
63 | # plt.imshow(lab[:, :, sli_ind])
64 | # plt.show()
65 |
66 | filled = nib.Nifti1Image(lab, np.eye(4))
67 | nib.save(filled, os.path.join(fill_path, lab_name))
68 | print("\t处理完成")
69 |
70 | print("共: ", len(lab_names), "处理了: ", processed)
71 |
--------------------------------------------------------------------------------
/tool/train/folder_split.py:
--------------------------------------------------------------------------------
1 | # 原来的数据和标签分别在一个目录里,进行随机split之后按照pdseg的目录结构放
2 | import os
3 | import argparse
4 | import random
5 |
6 | import util
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--dst_dir", type=str, required=True)
11 | parser.add_argument("--img_folder", type=str, required=True)
12 | parser.add_argument("--lab_folder", type=str, required=True)
13 | args = parser.parse_args()
14 |
15 |
16 | def mv(curr, dest):
17 | print(os.path.join(dest.rstrip(dest.split("/")[-1])))
18 | if not os.path.exists(os.path.join(dest.rstrip(dest.split("/")[-1]))):
19 | os.makedirs(os.path.join(dest.rstrip(dest.split("/")[-1])))
20 | os.rename(curr, dest)
21 |
22 |
23 | def split_dataset(
24 | img_folder,
25 | lab_folder,
26 | dst_dir,
27 | split=[8, 2, 0],
28 | folders=["imgs", "annotations"],
29 | sub_folders=["train", "val", "test"],
30 | ):
31 | # 1. 创建目标目录
32 | for fd1 in folders:
33 | for fd2 in sub_folders:
34 | dir = os.path.join(dst_dir, fd1, fd2)
35 | if not os.path.exists(dir):
36 | os.makedirs(dir)
37 |
38 | # 2. 获取图像和标签文件名,打乱
39 | img_names = util.listdir(img_folder)
40 | lab_names = util.listdir(lab_folder)
41 | names = [[i, l] for i, l in zip(img_names, lab_names)]
42 | random.shuffle(names)
43 |
44 | for idx in range(10):
45 | print(names[idx])
46 |
47 | # 3. 计算划分点
48 | split.insert(0, 0)
49 | for ind in range(1, len(split)):
50 | split[ind] += split[ind - 1]
51 |
52 | split = [x / split[-1] for x in split]
53 | split = [int(len(img_names) * split[ind]) for ind in range(4)]
54 | print(split)
55 |
56 | # 4. 进行移动
57 | for part in range(3):
58 | print(f"正在处理{sub_folders[part]}")
59 | for idx in range(split[part], split[part + 1]):
60 | img, lab = names[idx]
61 | mv(
62 | os.path.join(img_folder, img),
63 | os.path.join(dst_dir, folders[0], sub_folders[part], img),
64 | )
65 | mv(
66 | os.path.join(lab_folder, lab),
67 | os.path.join(dst_dir, folders[1], sub_folders[part], lab),
68 | )
69 |
70 |
71 | if __name__ == "__main__":
72 | split_dataset(
73 | img_folder=args.img_folder,
74 | lab_folder=args.lab_folder,
75 | dst_dir=args.dst_dir,
76 | )
77 |
--------------------------------------------------------------------------------
/medseg/loss.py:
--------------------------------------------------------------------------------
1 | import paddle
2 | import paddle.fluid as fluid
3 | from paddle.fluid.layers import log
4 |
5 |
6 | def mean_iou(pred, label, num_classes=2):
7 | """
8 | 计算miou
9 | """
10 | pred = fluid.layers.argmax(pred, axis=1)
11 | pred = fluid.layers.cast(pred, "int32")
12 | label = fluid.layers.cast(label, "int32")
13 | miou, wrong, correct = fluid.layers.mean_iou(pred, label, num_classes)
14 | return miou
15 |
16 |
17 | def weighed_binary_cross_entropy(y, y_predict, beta=2, epsilon=1e-6):
18 | """
19 | 返回 wce loss
20 | beta标记的是希望positive类给到多少的权重,如果positive少,beta给大于1相当与比0的类更重视
21 | """
22 | y = fluid.layers.clip(y, epsilon, 1 - epsilon)
23 | y_predict = fluid.layers.clip(y_predict, epsilon, 1 - epsilon)
24 |
25 | ylogp = fluid.layers.elementwise_mul(y, log(y_predict))
26 | betas = fluid.layers.fill_constant(ylogp.shape, "float32", beta)
27 | ylogp = fluid.layers.elementwise_mul(betas, ylogp)
28 |
29 | ones = fluid.layers.fill_constant(y_predict.shape, "float32", 1)
30 | ylogp = fluid.layers.elementwise_add(
31 | ylogp, elementwise_mul(elementwise_sub(ones, y), log(elementwise_sub(ones, y_predict)))
32 | )
33 |
34 | zeros = fluid.layers.fill_constant(y_predict.shape, "float32", 0)
35 | return fluid.layers.elementwise_sub(zeros, ylogp)
36 |
37 |
38 | def focal_loss(y_predict, y, alpha=0.85, gamma=2, epsilon=1e-6):
39 | """
40 | alpha 变大,对前景类惩罚变大,更加重视
41 | gamma 变大,对信心大的例子更加忽略,学习难的例子
42 | """
43 | y = fluid.layers.clip(y, epsilon, 1 - epsilon)
44 | y_predict = fluid.layers.clip(y_predict, epsilon, 1 - epsilon)
45 |
46 | return -1 * (
47 | alpha * fluid.layers.pow((1 - y_predict), gamma) * y * log(y_predict)
48 | + (1 - alpha) * fluid.layers.pow(y_predict, gamma) * (1 - y) * log(1 - y_predict)
49 | )
50 |
51 |
52 | def create_loss(predict, label, num_classes=2):
53 | predict = fluid.layers.transpose(predict, perm=[0, 2, 3, 1])
54 | predict = fluid.layers.reshape(predict, shape=[-1, num_classes])
55 | predict = fluid.layers.softmax(predict)
56 | label = fluid.layers.reshape(label, shape=[-1, 1])
57 | label = fluid.layers.cast(label, "int64")
58 | dice_loss = fluid.layers.dice_loss(predict, label)
59 |
60 | # label = fluid.layers.cast(label, "int64")
61 |
62 | ce_loss = fluid.layers.cross_entropy(predict, label)
63 | # focal = focal_loss(predict, label)
64 |
65 | return fluid.layers.reduce_mean(ce_loss + dice_loss)
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # project cache
2 | lits.csv
3 | test.py
4 | *.csv
5 | summary.txt
6 |
7 | # BOS
8 | bos_conf.py
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
--------------------------------------------------------------------------------
/tool/train/resize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | from multiprocessing import Pool
5 |
6 | import nibabel as nib
7 | import numpy as np
8 | import scipy.ndimage
9 |
10 | import util
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--scan_dir", type=str, required=True, help="扫描或标签路径")
14 | parser.add_argument("--out_dir", type=str, required=True, help="resize后输出路径")
15 | parser.add_argument("--size", nargs=2, default=[512, 512], help="resize目标大小")
16 | parser.add_argument("-t", "--thickness", type=float, default=None, help="统一的层间距")
17 | parser.add_argument(
18 | "-l",
19 | "--is_label",
20 | default=False,
21 | action="store_true",
22 | help="是否是标签,如果是标签会使用零阶插值",
23 | )
24 | args = parser.parse_args()
25 |
26 | # TODO: 区分标签和扫描
27 | def resize(path):
28 | name = osp.basename(path)
29 | scanf = nib.load(path)
30 | header = scanf.header.copy()
31 | scan_data = scanf.get_fdata()
32 | old_pixdim = scanf.header.copy()["pixdim"]
33 | old_shape = scan_data.shape
34 | if scan_data.shape[:2] != args.size:
35 | scale = [t / c for c, t in zip(scan_data.shape[:2], args.size)]
36 | s = scale
37 | header["pixdim"][1] /= s[0]
38 | header["pixdim"][2] /= s[1]
39 | else:
40 | scale = [1, 1]
41 |
42 | if args.thickness and header["pixdim"][3] != args.thickness:
43 | scale.append(header["pixdim"][3] / args.thickness)
44 | header["pixdim"][3] = args.thickness
45 | else:
46 | scale.append(1)
47 |
48 | if scale != [1, 1, 1]:
49 | s = scale
50 | scan_data = scipy.ndimage.interpolation.zoom(
51 | scan_data, (s[0], s[1], s[2]), order=0 if args.is_label else 3
52 | )
53 | # if args.is_label:
54 | # scan_data = scan_data.astype("uint8")
55 |
56 | newf = nib.Nifti1Image(scan_data.astype(np.float32), scanf.affine, header)
57 | nib.save(newf, osp.join(args.out_dir, name))
58 | print(
59 | name,
60 | ":",
61 | old_pixdim[1:4],
62 | old_shape,
63 | scale,
64 | header["pixdim"][1:4],
65 | scan_data.shape,
66 | )
67 |
68 |
69 | if __name__ == "__main__":
70 | args.size = util.toint(args.size)
71 | names = os.listdir(args.scan_dir)
72 | names = [
73 | osp.join(args.scan_dir, n)
74 | for n in names
75 | if n.endswith("nii") or n.endswith("nii.gz")
76 | ]
77 | if not osp.exists(args.out_dir):
78 | os.makedirs(args.out_dir)
79 |
80 | p = Pool(8)
81 | p.map(resize, names)
82 |
83 | # for name in names:
84 | # to_512(name)
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # medSeg
2 | 中文 | [English](./README_en.md)
3 |
4 | 仿照百度[PaddleSeg](https://github.com/paddlepaddle/paddleseg)结构实现的一个医学影像方向分割任务开发套件。主要目标是实现多种2D和3D网络(3D网络还在开发中),多种loss和数据增强策略。目前项目还在开发中,但是已经能在肝脏分割场景下做到 .94 的IOU。2.5D P-Unet项目基于这个套件实现。开发计划见[Project](https://github.com/davidlinhl/medSeg/projects/1)
5 |
6 | ## 项目结构
7 | #### medseg 项目主体
8 |
9 | - prep_3d.py prep_2d.py 对3D扫描或2D切片进行推理
10 | - loss.py 定义loss
11 | - models 定义模型
12 | - aug.py 定义数据增强方法
13 | - train.py 训练网络
14 | - infer.py 用训练完的模型进行推理
15 | - vis.py 对数据进行可视化
16 | - eval.py 对分割结果进行评估,支持基本所有医学影像常用2d/3d metric
17 |
18 | #### tool 工具脚本
19 | tool中提供了一些实用的工具脚本,[train](./train)目录下主要用于训练前的预处理,[infer](./infer)目录下的主要用于推理和后处理。
20 |
21 | - train
22 | - mhd2nii.py : 将mhd格式文件转成nii
23 | - resize.py : 将nii格式的扫描或标签转成512大小
24 | - dataset_scan.py : 生成数据集总览,包括强度分布和归一化需要的平均中位数
25 | - to_slice.py : 将3D扫描和标签转成2D的切片,实测多线程提速1倍左右
26 | - vis.py : 随机抽取切片结果或3D序列中的片进行可视化
27 | - gen_list.py : 生成数据文件列表,按照比例划分训练/验证/测试集
28 | - folder_split.py : 将整个数据集随机划分成训练,验证和测试集
29 | - infer
30 | - 2d_diameter.py : 在切片内测量分割标签中的血管直径
31 | - zip_dataset.py : 将一个路径下的文件打包,每个压缩包不超过指定大小。和分包zip不同的是每个压缩包都是单独的包,都可以解压实例测
32 | - flood_fill.py : 对分割标签进行漫水填充
33 | - to_pinyin.py : 将中文文件名转拼音
34 |
35 | #### config 配置文件
36 | 所有配置参考[config.py](https://github.com/davidlinhl/medSeg/blob/master/medseg/utils/config.py)
37 |
38 | ## 使用方法
39 | ### 配置环境
40 | 安装环境依赖:
41 | ```shell
42 | pip install -r requirements.txt
43 | ```
44 | paddle框架的安装参考[paddle官网](https://www.paddlepaddle.org.cn/)
45 | 如果进行训练需要有数据,目前项目主要面向lits调试,在aistudio上可以找到。[训练集](https://aistudio.baidu.com/aistudio/datasetDetail/10273) [测试集](https://aistudio.baidu.com/aistudio/datasetDetail/10292)
46 |
47 | 数据集下载,解压之后将所有的训练集volume放到一个文件夹,所有的训练集label放到一个文件夹,测试集volume放到一个文件夹。修改 lits.yaml 中对应的路径。
48 |
49 | ### 预处理
50 | 配置完毕需要首先进行数据预处理,这里主要是将数据统一成npz格式,方便后续训练。也可以在这一步结合一些预处理步骤对3D CT数据可以做窗口化,3D旋转。
51 | ```shell
52 | python medseg/prep_3d.py -c config/lits.yaml
53 | ```
54 | ### 训练
55 | 网络用预处理后的数据进行训练,训练提供一些参数,-h 可以显示。如果用的是cpu版本的paddle,不要添加 --use_gpu 参数。
56 | ```shell
57 | python medseg/train.py -c config/lits.yaml --use_gpu --do_eval
58 | ```
59 | ### 预测
60 | 最后一步是用训练好的网络进行预测,要配置好模型权重的路径,按照上一步实际输出的路径进行修改。代码会读取inference路径下所有的nii逐个进行预测。目前支持的数据格式有 .nii, .nii.gz。
61 | ```shell
62 | python infer.py -c config/lits.yaml --use_gpu
63 | ```
64 |
65 |
66 | 这个项目在aistudio中有完整的环境,fork项目可以直接运行,[项目地址](https://aistudio.baidu.com/aistudio/projectdetail/250994)中运行
67 |
68 | # 更新日志
69 | * 2020.6.21
70 | **v0.2.0**
71 | * 项目整体修改为使用配置文件,丰富功能,增加可视化
72 |
73 | * 2020.5.1
74 | **v0.1.0**
75 | * 在lits数据集上跑通预处理,训练,预测脚本
76 |
77 | 如项目使用中有任何问题,欢迎加入 Aistudio医学兴趣组,在群中提问,也可以和更多大佬一起学习和进步。
78 |
79 |
80 |

81 |
82 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | # MedSeg
2 | English | [简体中文](./README_cn.md)
3 |
4 | Medical image segmentation toolkit based on PaddlePaddle framework. Our target is to implement various 2D and 3D model architectures, various loss function and data augmentation methods. This is still a work in progress but has achieved promising results on liver segmentation and aorta segmentation. The development plans can be seen in the [Project](https://github.com/davidlinhl/medSeg/projects/1)
5 |
6 | ## Project Structure
7 | Currently this project contains only 2D segmentation models. The structure is as follows.
8 |
9 | - medseg: Promary code
10 | - train.py: Training pipeline
11 | - aug.py: Data augmentation
12 | - loss.py: Various model loss
13 | - eval.py: Various metrics to evaluate segmentation result
14 | - vis.py: Visualize results
15 | - models: Currently only 2D models
16 | - tool: Useful scripts
17 | - train: Tools for converting scan format, generating 2D slices from scan etc.
18 | - infer: Tools used after inference, merging segmentation results from slices, voting model fusion etc.
19 | - config: Training configurations
20 | All configurations can be found in [utils/config.py](https://github.com/davidlinhl/medSeg/blob/master/medseg/utils/config.py)
21 |
22 | ## Usage
23 | ### Environment Set Up
24 | Install project dependencies with
25 | ```shell
26 | pip install -r requirements.txt
27 | ```
28 | Instructions for installing PaddlePaddle-GPU can be found on PaddlePaddle's [official home page](https://www.paddlepaddle.org.cn/)
29 |
30 | ### Preprocess
31 | Preprocess 3D scans either into 2D slices or 3D patches. Applying WWWC or other slice-wise augmentation can also be done here.
32 | ```shell
33 | python medseg/prep_3d.py -c config/lits.yaml
34 | ```
35 |
36 | ### Training
37 | The training script contains several choices. Run with -h command to see details about them. If u r training with CPU only, don't include the --use_gpu command.
38 | ```shell
39 | python medseg/train.py -c config/lits.yaml --use_gpu --do_eval
40 | ```
41 |
42 | ### Inference
43 | The last step is doing inference with previously trained model. The script would perform inference on all data under specified path and perform inference. Currently supports nii format only.
44 | ```shell
45 | python medseg/infer.py -c config/lits.yaml --use_gpu
46 | ```
47 |
48 | ### Evaluation and Else
49 | After getting inference results, you may want to know how well the model performs. We have an evaluation script with multiple metrics implemented with medpy.
50 | ```shell
51 | python medseg/eval.py -c config/eval.yaml
52 | ```
53 | For aorta, specifically, we also have scripts for measuring blood vessel diameter and reporting aorta aneurysm.
54 | ```shell
55 | python tool/infer/2d_diameter.py
56 | python tool/infer/aorta.py
57 | ```
58 | These two scripts combined can calculate aorta diameter and report aorta aneurysm based on it.
59 |
60 | We have an [Aistudio](https://aistudio.baidu.com/aistudio/projectdetail/250994) project that have all the data and environment ready.
61 |
62 | Should u have any question in using this toolkit, u can contact the developer at linhandev@qq.com
63 |
--------------------------------------------------------------------------------
/tool/train/to_slice.py:
--------------------------------------------------------------------------------
1 | """
2 | 将nii和标签批量转成2D切片(png/npy)
3 | 要求扫描和标签按照字典序排序相同(文件名相同,拓展名不同就可以满足这个)
4 | """
5 |
6 | import os
7 | import os.path as osp
8 | import argparse
9 | import logging
10 | from tqdm import tqdm
11 |
12 | import util
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--scan_dir", type=str, help="扫描文件路径", required=True)
16 | parser.add_argument("--label_dir", type=str, help="标签文件路径", default=None)
17 | parser.add_argument("--out_dir", type=str, help="数据集输出路径", required=True)
18 | parser.add_argument(
19 | "--thick",
20 | type=int,
21 | help="切片厚度,默认3。如果是保存成png格式必须为3",
22 | default=3,
23 | )
24 | parser.add_argument(
25 | "-t",
26 | "--thresh",
27 | type=int,
28 | help="前景像素数量大于这个数才包含到数据集里,否则这个slice跳过",
29 | default=None,
30 | )
31 | parser.add_argument(
32 | "-s",
33 | "--size",
34 | nargs=2,
35 | help="输出片的大小,不声明这个参数不进行任何插值,否则扫描3阶插值,标签0阶缩放到这个大小",
36 | default=None,
37 | )
38 | parser.add_argument("--wwwc", nargs=2, help="窗宽窗位", default=["1000", "0"])
39 | parser.add_argument(
40 | "-r",
41 | "--rot",
42 | type=int,
43 | help="逆时针90度转多少次,可以为负",
44 | default=0,
45 | ) # TODO: 用库做体位校正
46 | parser.add_argument("-f", "--front", type=int, help="如果标签有多种前景,要保留的前景值", default=None)
47 | parser.add_argument(
48 | "-fm",
49 | "--front_mode",
50 | type=str,
51 | help="多个前景保留一个的策略。stack:把大于front的标签都设成front,小于front的标签设成背景。single:只保留front,其他的都设成背景",
52 | default=None,
53 | )
54 | parser.add_argument(
55 | "-itv",
56 | "--interval",
57 | type=int,
58 | help="每隔这个数量取1片,重建层间距很小,片层很多的时候可以用这个跳过一些片",
59 | default=1,
60 | )
61 | parser.add_argument("-c", "--check", default=False, action="store_true", help="是否检查数据集")
62 | parser.add_argument("--ext", type=str, help="文件保存的拓展名,不带点", default="png")
63 | parser.add_argument("--transpose", type=bool, default=False, help="是否调整数据维度顺序")
64 | parser.add_argument("--prefix", type=str, help="文件保存的前缀", default=None)
65 | args = parser.parse_args()
66 |
67 | logging.basicConfig(
68 | level=logging.DEBUG,
69 | format="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s",
70 | )
71 |
72 | if args.thick % 2 != 1:
73 | logging.error(
74 | f"The thickkess argument {args.thick} is not odd, plz use an odd number."
75 | )
76 | exit()
77 |
78 | if args.ext == "png" and args.thick not in [1, 3]:
79 | logging.error(
80 | f"Can't save {args.thick} channel image with png format. Png format only supports 1 or 3 channels. Switch to save in npy format instead."
81 | )
82 | exit()
83 |
84 |
85 | # TODO: 完善对扫描和标签的检查
86 | if args.check:
87 | util.check_nii_match(args.scan_dir, args.label_dir)
88 |
89 | scans = util.listdir(args.scan_dir)
90 | if args.label_dir is not None:
91 | labels = util.listdir(args.label_dir)
92 | assert len(labels) == len(
93 | scans
94 | ), f"The number of labels {len(labels)} is not equal to number of scans {len(scans)}"
95 | logging.info("Discovered scan/label pairs:")
96 | for s, l in zip(scans, labels):
97 | logging.info(f" {s} \t {l}")
98 | cmd = input(
99 | f"""Totally {len(scans)} pairs, plz check for any mismatch.
100 | Input Y/y to continue, input anything else to stop: """
101 | )
102 | else:
103 | cmd = input(
104 | f"""Totally {len(scans)} scans.
105 | Input Y/y to continue, input anything else to stop: """
106 | )
107 | labels = [None for _ in range(len(scans))]
108 | if cmd.lower() != "y":
109 | exit("Exit on user command")
110 |
111 | progress = tqdm(range(len(scans)))
112 |
113 | for scan, label in zip(scans, labels):
114 | progress.set_description(f"Processing {osp.basename(scan)}")
115 | util.slice_med(
116 | osp.join(args.scan_dir, scan),
117 | osp.join(args.out_dir, "JPEGImages"),
118 | osp.join(args.label_dir, label) if args.label_dir else None,
119 | osp.join(args.out_dir, "Annotations") if args.label_dir else None,
120 | args.thick,
121 | rot=args.rot,
122 | wwwc=util.toint(args.wwwc),
123 | thresh=args.thresh,
124 | front=args.front,
125 | front_mode=args.front_mode,
126 | itv=args.interval,
127 | ext=args.ext,
128 | transpose=args.transpose,
129 | prefix=args.prefix,
130 | )
131 | progress.update(n=1)
132 |
--------------------------------------------------------------------------------
/tool/train/slice_mp.py:
--------------------------------------------------------------------------------
1 | """
2 | 将nii和标签批量转成2D切片(png/npy)
3 | 要求扫描和标签按照字典序排序相同(文件名相同,拓展名不同就可以满足这个)
4 | """
5 | import multiprocessing
6 | import os
7 | import os.path as osp
8 | import argparse
9 | import logging
10 | from tqdm import tqdm
11 | import functools
12 |
13 | import util
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--scan_dir", type=str, help="扫描文件路径", required=True)
17 | parser.add_argument("--label_dir", type=str, help="标签文件路径", default=None)
18 | parser.add_argument("--out_dir", type=str, help="数据集输出路径", required=True)
19 | parser.add_argument(
20 | "--thick",
21 | type=int,
22 | help="切片厚度,默认3。如果是保存成png格式必须为3",
23 | default=3,
24 | )
25 | parser.add_argument(
26 | "-t",
27 | "--thresh",
28 | type=int,
29 | help="前景像素数量大于这个数才包含到数据集里,否则这个slice跳过",
30 | default=None,
31 | )
32 | parser.add_argument(
33 | "-s",
34 | "--size",
35 | nargs=2,
36 | help="输出片的大小,不声明这个参数不进行任何插值,否则扫描3阶插值,标签0阶缩放到这个大小",
37 | default=None,
38 | )
39 | parser.add_argument("--wwwc", nargs=2, help="窗宽窗位", default=["1000", "0"])
40 | parser.add_argument(
41 | "-r",
42 | "--rot",
43 | type=int,
44 | help="逆时针90度转多少次,可以为负",
45 | default=0,
46 | ) # TODO: 用库做体位校正
47 | parser.add_argument("-f", "--front", type=int, help="如果标签有多种前景,要保留的前景值", default=None)
48 | parser.add_argument(
49 | "-fm",
50 | "--front_mode",
51 | type=str,
52 | help="多个前景保留一个的策略。stack:把大于front的标签都设成front,小于front的标签设成背景。single:只保留front,其他的都设成背景",
53 | default=None,
54 | )
55 | parser.add_argument(
56 | "-itv",
57 | "--interval",
58 | type=int,
59 | help="每隔这个数量取1片,重建层间距很小,片层很多的时候可以用这个跳过一些片",
60 | default=1,
61 | )
62 | parser.add_argument("-c", "--check", default=False, action="store_true", help="是否检查数据集")
63 | parser.add_argument("--ext", type=str, help="文件保存的拓展名,不带点", default="png")
64 | parser.add_argument(
65 | "--transpose", default=False, action="store_true", help="是否调整数据维度顺序"
66 | )
67 | parser.add_argument("--prefix", type=str, help="文件保存的前缀", default="")
68 | parser.add_argument("-p", "--process", type=int, help="进程数", default=2)
69 |
70 | args = parser.parse_args()
71 |
72 | logging.basicConfig(
73 | level=logging.DEBUG,
74 | format="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s",
75 | )
76 |
77 | if args.thick % 2 != 1:
78 | logging.error(
79 | f"The thickkess argument {args.thick} is not odd, plz use an odd number."
80 | )
81 | exit()
82 |
83 | if args.ext == "png" and args.thick not in [1, 3]:
84 | logging.error(
85 | f"Can't save {args.thick} channel image with png format. Png format only supports 1 or 3 channels. Switch to save in npy format instead."
86 | )
87 | exit()
88 |
89 |
90 | # TODO: 完善对扫描和标签的检查
91 | if args.check:
92 | util.check_nii_match(args.scan_dir, args.label_dir)
93 |
94 | scans = util.listdir(args.scan_dir)
95 | if args.label_dir is not None:
96 | labels = util.listdir(args.label_dir)
97 | assert len(labels) == len(
98 | scans
99 | ), f"The number of labels {len(labels)} is not equal to number of scans {len(scans)}"
100 | logging.info("Discovered scan/label pairs:")
101 | for s, l in zip(scans, labels):
102 | logging.info(f" {s} \t {l}")
103 | cmd = input(
104 | f"""Totally {len(scans)} pairs, plz check for any mismatch.
105 | Input Y/y to continue, input anything else to stop: """
106 | )
107 | else:
108 | cmd = input(
109 | f"""Totally {len(scans)} scans.
110 | Input Y/y to continue, input anything else to stop: """
111 | )
112 | labels = [None for _ in range(len(scans))]
113 | if cmd.lower() != "y":
114 | exit("Exit on user command")
115 |
116 | pool = multiprocessing.Pool(processes=args.process)
117 | tasks = []
118 | for scan, label in zip(scans, labels):
119 | tasks.append(
120 | (
121 | osp.join(args.scan_dir, scan),
122 | osp.join(args.out_dir, "JPEGImages"),
123 | osp.join(args.label_dir, label) if args.label_dir else None,
124 | osp.join(args.out_dir, "Annotations") if args.label_dir else None,
125 | args.thick,
126 | args.rot,
127 | util.toint(args.wwwc),
128 | args.thresh,
129 | args.front,
130 | args.front_mode,
131 | args.interval,
132 | None,
133 | args.ext,
134 | args.transpose,
135 | args.prefix,
136 | )
137 | )
138 | pool.starmap(util.slice_med, tasks)
139 |
140 | pool.close()
141 | pool.join()
142 |
--------------------------------------------------------------------------------
/medseg/models/unet.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 | import contextlib
20 | import paddle
21 | import paddle.fluid as fluid
22 | from utils.config import cfg
23 | from models.libs.model_libs import scope, name_scope
24 | from models.libs.model_libs import bn, bn_relu, relu
25 | from models.libs.model_libs import conv, max_pool, deconv
26 |
27 |
28 | def double_conv(data, out_ch):
29 | param_attr = fluid.ParamAttr(
30 | name='weights',
31 | regularizer=fluid.regularizer.L2DecayRegularizer(
32 | regularization_coeff=0.0),
33 | initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.33))
34 | with scope("conv0"):
35 | data = bn_relu(
36 | conv(data, out_ch, 3, stride=1, padding=1, param_attr=param_attr))
37 | with scope("conv1"):
38 | data = bn_relu(
39 | conv(data, out_ch, 3, stride=1, padding=1, param_attr=param_attr))
40 | return data
41 |
42 |
43 | def down(data, out_ch):
44 | # 下采样:max_pool + 2个卷积
45 | with scope("down"):
46 | data = max_pool(data, 2, 2, 0)
47 | data = double_conv(data, out_ch)
48 | return data
49 |
50 |
51 | def up(data, short_cut, out_ch):
52 | # 上采样:data上采样(resize或deconv), 并与short_cut concat
53 | param_attr = fluid.ParamAttr(
54 | name='weights',
55 | regularizer=fluid.regularizer.L2DecayRegularizer(
56 | regularization_coeff=0.0),
57 | initializer=fluid.initializer.XavierInitializer(),
58 | )
59 | with scope("up"):
60 | if cfg.MODEL.UNET.UPSAMPLE_MODE == 'bilinear':
61 | data = fluid.layers.resize_bilinear(data, short_cut.shape[2:])
62 | else:
63 | data = deconv(
64 | data,
65 | out_ch // 2,
66 | filter_size=2,
67 | stride=2,
68 | padding=0,
69 | param_attr=param_attr)
70 | data = fluid.layers.concat([data, short_cut], axis=1)
71 | data = double_conv(data, out_ch)
72 | return data
73 |
74 |
75 | def encode(data):
76 | # 编码器设置
77 | short_cuts = []
78 | with scope("encode"):
79 | with scope("block1"):
80 | data = double_conv(data, 64)
81 | short_cuts.append(data)
82 | with scope("block2"):
83 | data = down(data, 128)
84 | short_cuts.append(data)
85 | with scope("block3"):
86 | data = down(data, 256)
87 | short_cuts.append(data)
88 | with scope("block4"):
89 | data = down(data, 512)
90 | short_cuts.append(data)
91 | with scope("block5"):
92 | data = down(data, 512)
93 | return data, short_cuts
94 |
95 |
96 | def decode(data, short_cuts):
97 | # 解码器设置,与编码器对称
98 | with scope("decode"):
99 | with scope("decode1"):
100 | data = up(data, short_cuts[3], 256)
101 | with scope("decode2"):
102 | data = up(data, short_cuts[2], 128)
103 | with scope("decode3"):
104 | data = up(data, short_cuts[1], 64)
105 | with scope("decode4"):
106 | data = up(data, short_cuts[0], 64)
107 | return data
108 |
109 |
110 | def get_logit(data, num_classes):
111 | # 根据类别数设置最后一个卷积层输出
112 | param_attr = fluid.ParamAttr(
113 | name='weights',
114 | regularizer=fluid.regularizer.L2DecayRegularizer(
115 | regularization_coeff=0.0),
116 | initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
117 | with scope("logit"):
118 | data = conv(
119 | data, num_classes, 3, stride=1, padding=1, param_attr=param_attr)
120 | return data
121 |
122 |
123 | def unet(input, num_classes):
124 | # UNET网络配置,对称的编码器解码器
125 | encode_data, short_cuts = encode(input)
126 | decode_data = decode(encode_data, short_cuts)
127 | logit = get_logit(decode_data, num_classes)
128 | return logit
129 |
130 |
131 | if __name__ == '__main__':
132 | image_shape = [3, 320, 320]
133 | image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
134 | logit = unet(image, 4)
135 | print("logit:", logit.shape)
136 |
--------------------------------------------------------------------------------
/tool/infer/3d_diameter.py:
--------------------------------------------------------------------------------
1 | # 用mesh进行3d重建,之后计算管径
2 | import argparse
3 | import os
4 |
5 | import skimage.measure
6 | import scipy.ndimage
7 | import numpy as np
8 | import nibabel as nib
9 | import trimesh
10 | from util import filter_polygon, sort_line, Polygon
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | "--in_dir", type=str, default="/home/lin/Desktop/aorta/private/label/test"
15 | )
16 | parser.add_argument("--out_dir", type=str, default="./img")
17 | args = parser.parse_args()
18 |
19 |
20 | vol_names = os.listdir(args.in_dir)
21 | for vol_name in vol_names:
22 | # 1. 获取需要测量的标签数据,插值成1024分辨率提升精度
23 | volf = nib.load(os.path.join(args.in_dir, vol_name))
24 | vol = volf.get_fdata()
25 | print(vol.shape)
26 | # if vol.shape[0] < 1024:
27 | # vol = scipy.ndimage.interpolation.zoom(vol, (2, 2, 1), order=3)
28 | # 2. 进行3D重建
29 | verts, faces, normals, values = skimage.measure.marching_cubes(vol)
30 | # print(verts)
31 | verts = [[v[0], v[1], v[2] * 10] for v in verts]
32 | # print(verts)
33 | # 将bb左下角放到原点
34 | min = np.min(verts, axis=0)
35 | print(min)
36 | verts = verts - min
37 | # print(verts)
38 | mesh = trimesh.Trimesh(vertices=verts, faces=faces)
39 |
40 | # print(mesh.is_watertight)
41 | # print(mesh.bounds)
42 |
43 | # mesh.export("aorta.stl")
44 | # mesh = trimesh.load("./aorta.stl")
45 |
46 | scene = trimesh.Scene()
47 | scene.add_geometry(mesh)
48 | axis = trimesh.creation.axis(origin_size=10, origin_color=[1.0, 0, 0])
49 | scene.add_geometry(axis)
50 | scene.show()
51 |
52 | # 3. 获取测量路径,在血管壁上取一条线的点
53 | # 3.1 查所有不同的高度,作为一个片曾
54 | heights = []
55 | for v in verts:
56 | if v[2] not in heights:
57 | heights.append(v[2])
58 | print("heights", heights)
59 |
60 | slices = [[] for _ in range(len(heights))]
61 | for ind, h in enumerate(heights):
62 | for v in verts:
63 | if v[2] == h:
64 | slices[ind].append(v)
65 |
66 | # 3.2 算所有片曾的圆心
67 | centers = []
68 | polygons = []
69 | for ind in range(len(heights)):
70 | res = filter_polygon(slices[ind], "all", 15)
71 | for poly in res:
72 | polygons.append(Polygon(poly))
73 |
74 | polygons = sort_line(polygons)
75 |
76 | # center_cloud = trimesh.points.PointCloud([a.center for a in polygons], [0, 255, 0, 100])
77 | # scene.add_geometry(center_cloud)
78 | #
79 | # base_cloud = trimesh.points.PointCloud([a.base for a in polygons], [255, 0, 0, 100])
80 | # scene.add_geometry(base_cloud)
81 | # scene.show()
82 |
83 | # 4. 找每个一base和血管相交,最小的圆
84 | for polygon in polygons[10:]:
85 | base = polygon.base
86 | last_prin = [0, 0, 1]
87 | last_size = 65535
88 | diameters = []
89 | while True:
90 | stride = 0.01
91 | tweak = [
92 | [0, 0, stride],
93 | [0, 0, -stride],
94 | [0, stride, 0],
95 | [0, -stride, 0],
96 | [stride, 0, 0],
97 | [-stride, 0, 0],
98 | ]
99 | new_prins = np.array(last_prin) + np.array(tweak)
100 | print(new_prins)
101 | sizes = []
102 |
103 | for prin in new_prins:
104 | print(prin)
105 | lines = trimesh.intersections.mesh_plane(mesh, prin, base)
106 | size = Polygon(lines).cal_size()
107 | if size < 5:
108 | size = 65535
109 | sizes.append(size)
110 | print(sizes)
111 |
112 | min_ind = np.array(sizes).argmin()
113 | if last_size > sizes[min_ind]:
114 | last_prin = new_prins[min_ind]
115 | last_size = sizes[min_ind]
116 | print("+_+_+_+_", last_size)
117 | else:
118 | break
119 |
120 | lines = trimesh.intersections.mesh_plane(mesh, last_prin, base)
121 | min_polygon = Polygon(lines)
122 | diameters.append(min_polygon.cal_diameter())
123 | center_cloud = trimesh.points.PointCloud(polygon.points, [0, 255, 0, 100])
124 | scene.add_geometry(center_cloud)
125 | center_cloud = trimesh.points.PointCloud(min_polygon.points, [255, 0, 0, 100])
126 | scene.add_geometry(center_cloud)
127 | scene.show()
128 | print("get_min")
129 | input("here")
130 |
131 | # points = []
132 | # for p in lines:
133 | # points.append([p[0][0], p[0][1], p[0][2]])
134 | # points = ang_sort(points)
135 | # print(points)
136 | # points = filter_polygon(points, [0, 0, 0], 10)
137 |
138 | # points = trimesh.points.PointCloud(points, [200, 200, 250, 100])
139 |
140 | # faces = []
141 | # for ind in range(len(points)):
142 | # faces.append([ind, int((ind + len(points) / 2) % len(points)), int((ind + len(points) / 2 + 1) % len(points))])
143 | # plane = trimesh.creation.extrude_triangulation(points, [[ind, ind + 1, ind + 2] for ind in range(20)], height=-1)
144 |
--------------------------------------------------------------------------------
/medseg/eval.py:
--------------------------------------------------------------------------------
1 | # 在验证集上对模型的多种指标进行评估
2 | import os
3 | from multiprocessing import Pool, cpu_count
4 | from datetime import datetime
5 | import argparse
6 |
7 | from medpy import metric
8 | import nibabel as nib
9 | from tqdm import tqdm
10 | import scipy.ndimage
11 |
12 | import utils.util as util
13 | from utils.config import cfg
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description="预测")
18 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
19 | parser.add_argument("opts", nargs=argparse.REMAINDER)
20 | args = parser.parse_args()
21 |
22 | if args.cfg_file is not None:
23 | cfg.update_from_file(args.cfg_file)
24 | if args.opts:
25 | cfg.update_from_list(args.opts)
26 |
27 | cfg.set_immutable(True)
28 |
29 |
30 | def main():
31 | headers = []
32 | # TODO: 研究一个更简洁的保持第一行和下面的数顺序一致的方法
33 | metrics = [
34 | "FP",
35 | "FN",
36 | "TP",
37 | "TN",
38 | "Precision",
39 | "Recall",
40 | "Sensitivity",
41 | "Specificity",
42 | "Accuracy",
43 | "Kappa",
44 | "Dice",
45 | "IOU",
46 | # 3D
47 | "Assd",
48 | "Ravd",
49 | ]
50 | for m in metrics:
51 | if m in cfg.EVAL.METRICS:
52 | headers.append(m)
53 |
54 | preds = util.listdir(cfg.EVAL.PATH.SEG)
55 | labels = util.listdir(cfg.EVAL.PATH.GT)
56 |
57 | f = open(cfg.EVAL.PATH.NAME + "-" + str(datetime.now()) + ".csv", "w")
58 | print("文件", end=",", file=f)
59 | for ind, header in enumerate(headers):
60 | print(header, end="," if ind != len(headers) - 1 else "\n", file=f)
61 | with Pool(cpu_count()) as p:
62 | res = p.map(calculate, [(preds[idx], labels[idx]) for idx in range(len(preds))])
63 | for pred, r in res:
64 | print(pred, end=",", file=f)
65 | for ind, x in enumerate(r):
66 | print(x, end="," if ind != len(headers) - 1 else "\n", file=f)
67 | f.close()
68 |
69 |
70 | # def write_res(pred, res):
71 | # f = open(cfg.EVAL.PATH.RESULT + "-" + str(datetime.now()) + ".csv", "w+")
72 | # print(pred, end=",", file=f)
73 | # for ind, x in enumerate(res[ind]):
74 | # print(x, end="," if ieval.csv-2020-12-16 18:59:40.731539.csvnd != len(headers) - 1 else "\n", file=f)
75 | # print("\n", file=f)
76 | # f.close()
77 |
78 |
79 | def calculate(input):
80 | pred_name = input[0]
81 | lab_name = input[1]
82 |
83 | predf = nib.load(os.path.join(cfg.EVAL.PATH.SEG, pred_name))
84 | labf = nib.load(os.path.join(cfg.EVAL.PATH.GT, lab_name))
85 |
86 | pred = predf.get_fdata()
87 | lab = labf.get_fdata()
88 | print(pred_name, lab_name, pred.shape, lab.shape)
89 | if lab.shape[0] != pred.shape[0]:
90 | ratio = [a / b for a, b in zip(pred.shape, lab.shape)]
91 | lab = scipy.ndimage.interpolation.zoom(lab, ratio, order=1)
92 | print("插值后大小: ", pred.shape, lab.shape)
93 | assert pred.shape == lab.shape, "分割结果和GT大小不同: {},{}, {}".format(
94 | pred.shape, lab.shape, preds[ind]
95 | )
96 |
97 | temp = []
98 | if "FP" in cfg.EVAL.METRICS:
99 | # fp = metric.binary.obj_fpr(pred, lab)
100 | # temp.append(fp)
101 | pass
102 |
103 | if "FN" in cfg.EVAL.METRICS:
104 | pass
105 |
106 | if "TP" in cfg.EVAL.METRICS:
107 | tp = metric.binary.true_positive_rate(pred, lab)
108 | temp.append(tp)
109 |
110 | if "TN" in cfg.EVAL.METRICS:
111 | tn = metric.binary.true_negative_rate(pred, lab)
112 | temp.append(tn)
113 |
114 | if "Precision" in cfg.EVAL.METRICS:
115 | prec = metric.binary.precision(pred, lab)
116 | temp.append(prec)
117 |
118 | if "Recall" in cfg.EVAL.METRICS:
119 | rec = metric.binary.recall(pred, lab)
120 | temp.append(rec)
121 |
122 | if "Sensitivity" in cfg.EVAL.METRICS:
123 | rec = metric.binary.sensitivity(pred, lab) # same as recall
124 | temp.append(rec)
125 |
126 | if "Specificity" in cfg.EVAL.METRICS:
127 | spec = metric.binary.specificity(pred, lab)
128 | temp.append(spec)
129 |
130 | if "Accuracy" in cfg.EVAL.METRICS:
131 | tp = metric.binary.true_positive_rate(pred, lab)
132 | tn = metric.binary.true_negative_rate(pred, lab)
133 | acc = (tp + tn) / 2
134 | temp.append(acc)
135 |
136 | if "Kappa" in cfg.EVAL.METRICS:
137 | pass
138 |
139 | if "Dice" in cfg.EVAL.METRICS:
140 | dice = metric.dc(pred, lab)
141 | temp.append(dice)
142 |
143 | if "IOU" in cfg.EVAL.METRICS:
144 | iou = metric.binary.jc(pred, lab)
145 | temp.append(iou)
146 |
147 | if "Assd" in cfg.EVAL.METRICS:
148 | assd = metric.binary.assd(pred, lab)
149 | temp.append(assd)
150 |
151 | if "Ravd" in cfg.EVAL.METRICS:
152 | ravd = metric.binary.ravd(pred, lab)
153 | temp.append(ravd)
154 |
155 | return pred_name, temp
156 |
157 |
158 | # TODO: 绘制箱须图
159 | # https://matplotlib.org/gallery/pyplots/boxplot_demo_pyplot.html#sphx-glr-gallery-pyplots-boxplot-demo-pyplot-py
160 |
161 | if __name__ == "__main__":
162 | parse_args()
163 | main()
164 |
--------------------------------------------------------------------------------
/tool/infer/slice2nii.py:
--------------------------------------------------------------------------------
1 | """
2 | 将切片的推理结果合起来
3 | 默认推理结果是 name-idx.png 格式,格式不同的话修改 get_name 函数
4 | """
5 | import os
6 | import os.path as osp
7 | import concurrent
8 | from time import sleep
9 | from queue import Queue
10 | import argparse
11 | import multiprocessing
12 | from multiprocessing import Pool
13 |
14 | import nibabel as nib
15 | import cv2
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 | from scipy import ndimage
19 | from tqdm import tqdm
20 |
21 | from util import to_pinyin
22 | import util
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | "--scan_dir",
28 | type=str,
29 | required=True,
30 | help="扫描路径,会去找头文件信息",
31 | )
32 | parser.add_argument(
33 | "--seg_dir",
34 | type=str,
35 | required=True,
36 | help="nii分割标签输出路径",
37 | )
38 | parser.add_argument(
39 | "--png_dir",
40 | type=str,
41 | required=True,
42 | help="png格式分割推理结果路径",
43 | )
44 | parser.add_argument(
45 | "--rot",
46 | type=int,
47 | default=0,
48 | help="对结果进行几次旋转",
49 | )
50 | parser.add_argument(
51 | "--filter",
52 | default=False,
53 | action="store_true",
54 | help="是否过滤最大连通块",
55 | )
56 | parser.add_argument(
57 | "--percent",
58 | type=str,
59 | default=None,
60 | help="最大连通块占所有前景标签比例,可以估计分割结果质量,不写不进行统计",
61 | )
62 | args = parser.parse_args()
63 |
64 |
65 | def get_name(name):
66 | """从文件名中解析扫描序列的名字和片层下标
67 |
68 | Parameters
69 | ----------
70 | name : str
71 | 片层标签文件名
72 |
73 | Returns
74 | -------
75 | str, int
76 | 序列名,用于group一个序列的所有推理结果
77 | 序列下标,表示这个片层在序列中的下标
78 |
79 | """
80 | # TODO: rfind
81 | pos = -name[::-1].find("-") - 1 # 找到最后一个 -
82 | # print(name, name[:pos], int(name[pos + 1 :].split(".")[0]))
83 | return name[:pos], int(name[pos + 1 :].split(".")[0])
84 |
85 |
86 | # TODO: 多进程不是多线程
87 | class ThreadPool(concurrent.futures.ThreadPoolExecutor):
88 | def __init__(self, maxsize=6, *args, **kwargs):
89 | super(ThreadPool, self).__init__(*args, **kwargs)
90 | self._work_queue = Queue(maxsize=maxsize)
91 |
92 |
93 | # 检查文件匹配情况
94 | img_names = util.listdir(args.png_dir, sort=False)
95 | patient_names = []
96 | for n in img_names:
97 | n, _ = get_name(n)
98 | if n not in patient_names:
99 | patient_names.append(n)
100 | patient_names = [n + ".nii.gz" for n in patient_names]
101 |
102 | nii_names_set = set(os.listdir(args.scan_dir))
103 | patient_names_set = set(patient_names)
104 | for n in patient_names_set - nii_names_set:
105 | print(n, "dont have nii")
106 | for n in nii_names_set - patient_names_set:
107 | print(n, "dont have segmentation result")
108 | patient_names.sort()
109 | print(patient_names)
110 |
111 | input("Press any key to start!")
112 | if args.percent:
113 | percent_file = open(args.percent, "a+")
114 | if not os.path.exists(args.seg_dir):
115 | os.makedirs(args.seg_dir)
116 |
117 |
118 | def run(patient):
119 | if osp.exists(osp.join(args.seg_dir, patient)):
120 | print(patient, "already finished, skipping")
121 | return
122 |
123 | patient_imgs = [n for n in img_names if get_name(n)[0] == patient.split(".")[0]]
124 | patient_imgs.sort(key=lambda n: int(get_name(n)[1]))
125 | # print(patient, patient_imgs, len(patient_imgs))
126 | label = cv2.imread(
127 | os.path.join(args.png_dir, patient_imgs[0]), cv2.IMREAD_UNCHANGED
128 | )
129 | s = label.shape
130 | label_data = np.zeros([s[0], s[1], len(patient_imgs)], dtype="uint8")
131 |
132 | try:
133 | # print(os.path.join(args.scan_dir, patient))
134 | scanf = nib.load(os.path.join(args.scan_dir, patient))
135 | scan_header = scanf.header
136 | except:
137 | print(f"[ERROR] {patient}'s scan is not found! Skipping {patient}")
138 | return
139 | # scanf = nib.load(os.path.join(args.scan_dir, "张金华_20201024213424575a.nii"))
140 | # scan_header = scanf.header
141 |
142 | for img_name in patient_imgs:
143 | img = cv2.imread(os.path.join(args.png_dir, img_name), cv2.IMREAD_UNCHANGED)
144 | ind = int(get_name(img_name)[1])
145 | label_data[:, :, ind] = img
146 |
147 | save_nii(
148 | label_data,
149 | scanf.affine,
150 | scan_header,
151 | os.path.join(args.seg_dir, patient),
152 | ) # BUG: 貌似会出现最后一两个进程卡住,无法保存的情况
153 |
154 |
155 | if args.percent:
156 | percent_file.close()
157 |
158 |
159 | def save_nii(label_data, affine, header, dir):
160 | print("++++", dir)
161 | print(label_data.shape)
162 | label_data = np.rot90(label_data, args.rot, axes=(0, 1))
163 | label_data = np.transpose(label_data, [1, 0, 2])
164 | if args.filter:
165 | tot = label_data.sum()
166 | label_data = util.filter_largest_volume(label_data, mode="hard")
167 | largest = label_data.sum()
168 | if args.percent:
169 | print(osp.basename(dir), largest / tot, file=percent_file)
170 | percent_file.flush()
171 | newf = nib.Nifti1Image(label_data.astype(np.float64), affine, header)
172 | nib.save(newf, dir)
173 | print("--------", "finish", dir)
174 |
175 |
176 | print(patient_names)
177 | with Pool(multiprocessing.cpu_count()) as p:
178 | p.map(run, patient_names)
179 |
--------------------------------------------------------------------------------
/medseg/models/unet_plain.py:
--------------------------------------------------------------------------------
1 | # Author: Jingxiao Gu
2 | # Baidu Account: Seigato
3 | # Description: Unet Base Network for Lane Segmentation Competition
4 | # 80 unet
5 |
6 |
7 | import paddle.fluid as fluid
8 | from paddle.fluid.initializer import MSRA
9 | from paddle.fluid.param_attr import ParamAttr
10 |
11 |
12 | def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, act=None, bn=True, bias_attr=False):
13 | conv = fluid.layers.conv2d(
14 | input=input,
15 | num_filters=num_filters,
16 | filter_size=filter_size,
17 | stride=stride,
18 | padding=(filter_size - 1) // 2,
19 | groups=groups,
20 | act=None,
21 | bias_attr=bias_attr,
22 | param_attr=ParamAttr(initializer=MSRA()),
23 | )
24 | if bn == True:
25 | conv = fluid.layers.batch_norm(input=conv, act=act)
26 | return conv
27 |
28 |
29 | def conv_layer(input, num_filters, filter_size, stride=1, groups=1, act=None):
30 | conv = fluid.layers.conv2d(
31 | input=input,
32 | num_filters=num_filters,
33 | filter_size=filter_size,
34 | stride=stride,
35 | padding=(filter_size - 1) // 2,
36 | groups=groups,
37 | act=act,
38 | bias_attr=ParamAttr(initializer=MSRA()),
39 | param_attr=ParamAttr(initializer=MSRA()),
40 | )
41 | return conv
42 |
43 |
44 | def shortcut(input, ch_out, stride):
45 | ch_in = input.shape[1]
46 | if ch_in != ch_out or stride != 1:
47 | return conv_bn_layer(input, ch_out, 1, stride)
48 | else:
49 | return input
50 |
51 |
52 | def bottleneck_block(input, num_filters, stride):
53 | conv_bn = conv_bn_layer(input=input, num_filters=num_filters, filter_size=1, act="relu")
54 | conv_bn = conv_bn_layer(input=conv_bn, num_filters=num_filters, filter_size=3, stride=stride, act=None)
55 | short_bn = shortcut(input, num_filters, stride)
56 | return fluid.layers.elementwise_add(x=short_bn, y=conv_bn, act="relu")
57 |
58 |
59 | def encoder_block(input, encoder_depths, encoder_filters, block):
60 | conv_bn = input
61 | for i in range(encoder_depths[block]):
62 | conv_bn = bottleneck_block(
63 | input=conv_bn, num_filters=encoder_filters[block], stride=2 if i == 0 and block != 0 else 1,
64 | )
65 | print("| Encoder Block", block, conv_bn.shape)
66 | return conv_bn
67 |
68 |
69 | def decoder_block(input, concat_input, decoder_depths, decoder_filters, block):
70 | deconv_bn = input
71 | deconv_bn = fluid.layers.resize_bilinear(
72 | input=deconv_bn, out_shape=(deconv_bn.shape[2] * 2, deconv_bn.shape[3] * 2)
73 | )
74 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1)
75 |
76 | concat_input = conv_bn_layer(
77 | input=concat_input, num_filters=concat_input.shape[1] // 2, filter_size=1, act="relu"
78 | )
79 |
80 | deconv_bn = fluid.layers.concat([deconv_bn, concat_input], axis=1)
81 | for i in range(decoder_depths[block]):
82 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1)
83 | print("| Decoder Block", block, deconv_bn.shape)
84 | return deconv_bn
85 |
86 |
87 | def unet_plain(img, label_number, img_size):
88 | print("| Build Custom-Designed Resnet-Unet:")
89 | encoder_depth = [3, 4, 6, 4]
90 | encoder_filters = [64, 128, 256, 512]
91 | decoder_depth = [4, 3, 3, 2]
92 | decoder_filters = [256, 128, 64, 32]
93 | print("| Input Image Data", img.shape)
94 | """
95 | Encoder
96 | """
97 | # Start Conv
98 | start_conv = conv_bn_layer(input=img, num_filters=32, filter_size=3, stride=2, act="relu")
99 | start_conv = conv_bn_layer(input=start_conv, num_filters=32, filter_size=3, stride=1, act="relu")
100 | start_pool = fluid.layers.pool2d(
101 | input=start_conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type="max"
102 | )
103 | print("| Start Convolution", start_conv.shape)
104 |
105 | conv0 = encoder_block(start_pool, encoder_depth, encoder_filters, block=0)
106 | conv1 = encoder_block(conv0, encoder_depth, encoder_filters, block=1)
107 | conv2 = encoder_block(conv1, encoder_depth, encoder_filters, block=2)
108 | conv3 = encoder_block(conv2, encoder_depth, encoder_filters, block=3)
109 |
110 | """
111 | Decoder
112 | """
113 | decode_conv1 = decoder_block(conv3, conv2, decoder_depth, decoder_filters, block=0)
114 | decode_conv2 = decoder_block(decode_conv1, conv1, decoder_depth, decoder_filters, block=1)
115 | decode_conv3 = decoder_block(decode_conv2, conv0, decoder_depth, decoder_filters, block=2)
116 | decode_conv4 = decoder_block(decode_conv3, start_conv, decoder_depth, decoder_filters, block=3)
117 |
118 | """
119 | Output Coder
120 | """
121 | decode_conv5 = fluid.layers.resize_bilinear(input=decode_conv4, out_shape=img_size)
122 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=32, stride=1)
123 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=16, stride=1)
124 | logit = conv_layer(input=decode_conv5, num_filters=label_number, filter_size=1, act=None)
125 | print("| Output Predictions:", logit.shape)
126 | # logit = fluid.layers.resize_bilinear(input=logit, out_shape=(3384, 1020))
127 | print("| Final Predictions:", logit.shape)
128 |
129 | return logit
130 |
--------------------------------------------------------------------------------
/medseg/models/res_unet.py:
--------------------------------------------------------------------------------
1 | # Author: Jingxiao Gu
2 | # Baidu Account: Seigato
3 | # Description: Unet Simple Network for Lane Segmentation Competition
4 |
5 | import paddle.fluid as fluid
6 | from paddle.fluid.initializer import MSRA
7 | from paddle.fluid.param_attr import ParamAttr
8 |
9 | # 测试:所有的relu都换成leaky_relu试一下
10 |
11 |
12 | def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, act=None, bn=True, bias_attr=False):
13 | conv = fluid.layers.conv2d(
14 | input=input,
15 | num_filters=num_filters,
16 | filter_size=filter_size,
17 | stride=stride,
18 | padding=(filter_size - 1) // 2, # same
19 | groups=groups,
20 | act=act,
21 | bias_attr=bias_attr,
22 | param_attr=ParamAttr(initializer=MSRA()),
23 | )
24 | if bn == True:
25 | conv = fluid.layers.batch_norm(input=conv, act=act)
26 | return conv
27 |
28 |
29 | def conv_layer(input, num_filters, filter_size, stride=1, groups=1, act=None):
30 | conv = fluid.layers.conv2d(
31 | input=input,
32 | num_filters=num_filters,
33 | filter_size=filter_size,
34 | stride=stride,
35 | padding=(filter_size - 1) // 2,
36 | groups=groups,
37 | act=act,
38 | bias_attr=ParamAttr(initializer=MSRA()),
39 | param_attr=ParamAttr(initializer=MSRA()),
40 | )
41 | return conv
42 |
43 |
44 | def shortcut(input, ch_out, stride):
45 | ch_in = input.shape[1]
46 | if ch_in != ch_out or stride != 1:
47 | return conv_bn_layer(input, ch_out, 1, stride)
48 | else:
49 | return input
50 |
51 |
52 | def bottleneck_block(input, num_filters, stride):
53 | conv_bn = conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act="leaky_relu")
54 | conv_bn = conv_bn_layer(input=conv_bn, num_filters=num_filters, filter_size=3, stride=stride, act=None)
55 | short_bn = shortcut(input, num_filters, stride)
56 | return fluid.layers.elementwise_add(x=short_bn, y=conv_bn, act="leaky_relu")
57 |
58 |
59 | def encoder_block(input, encoder_depths, encoder_filters, block):
60 | conv_bn = input
61 | for i in range(encoder_depths[block]):
62 | conv_bn = bottleneck_block(
63 | input=conv_bn, num_filters=encoder_filters[block], stride=2 if i == 0 and block != 0 else 1,
64 | )
65 | print("| Encoder Block", block, conv_bn.shape)
66 | return conv_bn
67 |
68 |
69 | def decoder_block(input, concat_input, decoder_depths, decoder_filters, block):
70 | deconv_bn = input
71 | deconv_bn = fluid.layers.resize_bilinear(
72 | input=deconv_bn, out_shape=(deconv_bn.shape[2] * 2, deconv_bn.shape[3] * 2)
73 | )
74 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1)
75 |
76 | concat_input = conv_bn_layer(
77 | input=concat_input, num_filters=concat_input.shape[1] // 2, filter_size=1, act="leaky_relu"
78 | )
79 |
80 | deconv_bn = fluid.layers.concat([deconv_bn, concat_input], axis=1)
81 | for i in range(decoder_depths[block]):
82 | deconv_bn = bottleneck_block(input=deconv_bn, num_filters=decoder_filters[block], stride=1)
83 | print("| Decoder Block", block, deconv_bn.shape)
84 | return deconv_bn
85 |
86 |
87 | def res_unet(img, label_number, img_size):
88 | print("| Build Custom-Designed Resnet-Unet:")
89 | encoder_depth = [3, 4, 5, 3]
90 | encoder_filters = [64, 128, 256, 512]
91 | decoder_depth = [2, 3, 3, 2]
92 | decoder_filters = [256, 128, 64, 32]
93 | print("| Input Image Data", img.shape)
94 | """
95 | Encoder
96 | """
97 | # Start Conv
98 | start_conv = conv_bn_layer(input=img, num_filters=32, filter_size=3, stride=2, act="leaky_relu")
99 | start_conv = conv_bn_layer(input=start_conv, num_filters=32, filter_size=3, stride=1, act="leaky_relu")
100 | start_pool = fluid.layers.pool2d(
101 | input=start_conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type="max"
102 | )
103 | print("| Start Convolution", start_conv.shape)
104 |
105 | conv0 = encoder_block(start_pool, encoder_depth, encoder_filters, block=0)
106 | conv1 = encoder_block(conv0, encoder_depth, encoder_filters, block=1)
107 | conv2 = encoder_block(conv1, encoder_depth, encoder_filters, block=2)
108 | conv3 = encoder_block(conv2, encoder_depth, encoder_filters, block=3)
109 |
110 | """
111 | Decoder
112 | """
113 | decode_conv1 = decoder_block(conv3, conv2, decoder_depth, decoder_filters, block=0)
114 | decode_conv2 = decoder_block(decode_conv1, conv1, decoder_depth, decoder_filters, block=1)
115 | decode_conv3 = decoder_block(decode_conv2, conv0, decoder_depth, decoder_filters, block=2)
116 | decode_conv4 = decoder_block(decode_conv3, start_conv, decoder_depth, decoder_filters, block=3)
117 |
118 | """
119 | Output Coder
120 | """
121 | print("+_+_+")
122 | print(decode_conv4.shape)
123 | print(img_size)
124 | print("+_+_+")
125 | decode_conv5 = fluid.layers.resize_bilinear(input=decode_conv4, out_shape=img_size)
126 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=32, stride=1)
127 | decode_conv5 = bottleneck_block(input=decode_conv5, num_filters=16, stride=1)
128 | logit = conv_layer(input=decode_conv5, num_filters=label_number, filter_size=1, act=None)
129 | print("| Output Predictions:", logit.shape)
130 | # logit = fluid.layers.resize_bilinear(input=logit, out_shape=(3384, 1020))
131 | print("| Final Predictions:", logit.shape)
132 |
133 | return logit
134 |
--------------------------------------------------------------------------------
/tool/train/dataset_scan.py:
--------------------------------------------------------------------------------
1 | """
2 | 扫描数据集,生成summary
3 | """
4 | # TODO: 添加窗宽窗位
5 | # TODO: check大小是否一样
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 |
11 | import nibabel as nib
12 | import numpy as np
13 | from matplotlib import pyplot as plt
14 | from matplotlib.font_manager import FontProperties
15 | from tqdm import tqdm
16 |
17 | import util
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | "-s",
22 | "--scan_dir",
23 | type=str,
24 | required=True,
25 | help="扫描路径",
26 | )
27 | parser.add_argument("-l", "--label_dir", type=str, help="标签路径")
28 | parser.add_argument("-p", "--plt_dir", type=str, help="强度分布输出路径,不写不进行绘制", default=None)
29 | parser.add_argument(
30 | "--wwwc",
31 | nargs=2,
32 | default=None,
33 | help="窗宽窗位,不写不进行窗宽窗位处理",
34 | )
35 | parser.add_argument(
36 | "--skip",
37 | nargs=2,
38 | default=["", ""],
39 | help="在检查文件配对的时候,扫描和标签文件开头略过两个字符串,用于匹配 scan-1.nii 和 label-1.nii 这种情况",
40 | )
41 | args = parser.parse_args()
42 |
43 | font = FontProperties(fname="../SimHei.ttf", size=16)
44 |
45 | util.check_nii_match(args.scan_dir, args.label_dir, args.skip)
46 |
47 | scans = util.listdir(args.scan_dir)
48 | labels = util.listdir(args.label_dir)
49 | assert len(scans) == len(labels), "扫描和标签数量不相等"
50 |
51 | print(f"数据集中共{len(scans)}组扫描和标签,对应情况:")
52 | for idx in range(len(scans)):
53 | print(f"{scans[idx]} \t {labels[idx]}")
54 |
55 | if not osp.exists(osp.join(args.plt_dir, "scan")):
56 | os.makedirs(osp.join(args.plt_dir, "scan"))
57 | if not osp.exists(osp.join(args.plt_dir, "label")):
58 | os.makedirs(osp.join(args.plt_dir, "label"))
59 |
60 | pixdims = []
61 | shapes = []
62 | norms = []
63 |
64 | pbar = tqdm(range(len(scans)), desc="正在统计")
65 | for idx in range(len(scans)):
66 | pbar.set_postfix(filename=scans[idx].split(".")[0])
67 | pbar.update(1)
68 |
69 | scanf = nib.load(os.path.join(args.scan_dir, scans[idx]))
70 | labelf = nib.load(os.path.join(args.label_dir, labels[idx]))
71 |
72 | header = scanf.header.structarr
73 | shape = scanf.header.get_data_shape()
74 | shapes.append([shape[0], shape[1], shape[2]])
75 | pixdims.append(header["pixdim"][1:4])
76 |
77 | scan = scanf.get_fdata()
78 | norms.append([scan.min(), np.median(scan), scan.max()])
79 |
80 | if args.plt_dir:
81 | scan_plt = osp.join(args.plt_dir, "scan")
82 | label_plt = osp.join(args.plt_dir, "label")
83 |
84 | scan = scan.reshape([scan.size])
85 |
86 | plt.title(scans[idx].split(".")[0], fontproperties=font)
87 | plt.xlabel(
88 | "size:[{},{},{}] pixdims:[{},{},{}] ".format(
89 | shape[0],
90 | shape[1],
91 | shape[2],
92 | header["pixdim"][1],
93 | header["pixdim"][2],
94 | header["pixdim"][3],
95 | )
96 | )
97 | nums, bins, patchs = plt.hist(scan, bins=1000)
98 | plt.savefig(osp.join(scan_plt, scans[idx].split(".")[0] + ".png"))
99 | plt.close()
100 |
101 | file = open(osp.join(scan_plt, f"{scans[idx].split('.')[0]}.txt"), "w")
102 | print("--------- {} --------".format(scans[idx]), file=file)
103 |
104 | sum = 0
105 | for num in nums:
106 | sum += num
107 | nowsum = 0
108 | for i in range(0, len(nums)):
109 | nowsum += nums[i]
110 | print(
111 | "[{:<10f},{:<10f}] : {:>10} percentage : {}".format(
112 | bins[i], bins[i + 1], nums[i], nowsum / sum
113 | ),
114 | file=file,
115 | )
116 | file.close()
117 |
118 | label = labelf.get_fdata()
119 | label = np.reshape(label, [label.size])
120 | plt.title(
121 | f"{scans[idx].split('.')[0]} [{np.min(label)},{np.max(label)}]",
122 | fontproperties=font,
123 | )
124 | plt.xlabel(
125 | "size:[{},{},{}] pixdims:[{},{},{}] ".format(
126 | shape[0],
127 | shape[1],
128 | shape[2],
129 | header["pixdim"][1],
130 | header["pixdim"][2],
131 | header["pixdim"][3],
132 | )
133 | )
134 | nums, bins, patchs = plt.hist(label, bins=5)
135 | plt.savefig(osp.join(label_plt, scans[idx].rstrip(".nii") + ".png"))
136 | plt.close()
137 |
138 | file = open(
139 | os.path.join(label_plt, "{}.txt".format(scans[idx].split(".")[0])),
140 | "w",
141 | )
142 | print("--------- {} --------".format(scans[idx]), file=file)
143 |
144 | sum = 0
145 | for num in nums:
146 | sum += num
147 | nowsum = 0
148 | for i in range(0, len(nums)):
149 | nowsum += nums[i]
150 | print(
151 | "[{:<10f},{:<10f}] : {:>10} percentage : {}".format(
152 | bins[i], bins[i + 1], nums[i], nowsum / sum
153 | ),
154 | file=file,
155 | )
156 |
157 | pbar.close()
158 |
159 | spacing = np.median(pixdims, axis=0)
160 | size = np.median(shapes, axis=0)
161 |
162 | print(norms)
163 | norm = []
164 | norms = np.array(norms)
165 | norm.append(np.min(norms[:, 0]))
166 | norm.append(np.median(norms[:, 1]))
167 | norm.append(np.max(norms[:, 2]))
168 |
169 | print(spacing, size, norm)
170 |
171 | file = open("./summary.txt", "w")
172 |
173 | print("spacing", file=file)
174 | for dat in spacing:
175 | print(dat, file=file)
176 |
177 | print("\nsize", file=file)
178 | for dat in size:
179 | print(dat, file=file)
180 |
181 | print("\nnorm", file=file)
182 | for dat in norm:
183 | print(dat, file=file)
184 | file.close()
185 |
--------------------------------------------------------------------------------
/medseg/infer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import time
4 | import os
5 | from multiprocessing import Process, Queue
6 |
7 | from tqdm import tqdm
8 | import numpy as np
9 | import cv2
10 | import nibabel as nib
11 | import paddle
12 | from paddle import fluid
13 |
14 | import utils.util as util
15 | from utils.config import cfg
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description="预测")
20 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
21 | parser.add_argument("--use_gpu", action="store_true", default=False, help="使用GPU推理")
22 | parser.add_argument("opts", nargs=argparse.REMAINDER)
23 | args = parser.parse_args()
24 |
25 | if args.cfg_file is not None:
26 | cfg.update_from_file(args.cfg_file)
27 | if args.opts:
28 | cfg.update_from_list(args.opts)
29 | if args.use_gpu: # 命令行参数只能从false改成true,不能声明false
30 | cfg.TRAIN.USE_GPU = True
31 |
32 | cfg.set_immutable(True)
33 |
34 |
35 | def read_data(file_path, q):
36 | """读取数据并进行预处理.
37 |
38 | Parameters
39 | ----------
40 | file_path : str
41 | 需要读取的数据文件名.
42 | q : queue
43 | 读入的数据放入这个队列.
44 |
45 | """
46 | print("Start Reading: ", file_path)
47 | volf = nib.load(file_path)
48 | volume = np.array(volf.get_fdata())
49 |
50 | if cfg.INFER.WINDOWLIZE:
51 | volume = util.windowlize_image(volume, cfg.INFER.WWWC)
52 | if cfg.INFER.DO_INTERP:
53 | header = volf.header.structarr
54 | # pixdim 是这套 ct 三个维度的间距
55 | pixdim = [header["pixdim"][ind] for ind in range(1, 4)]
56 | spacing = list(cfg.INFER.SPACING)
57 | for ind in range(3):
58 | if spacing[ind] == -1:
59 | spacing[ind] = pixdim[ind]
60 | ratio = [pixdim[0] / spacing[0], pixdim[1] / spacing[1], pixdim[2] / spacing[2]]
61 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3)
62 | q.put([volume, volf.affine])
63 | print("Finish Reading: ", file_path)
64 |
65 |
66 | def post_process(fpath, inference, affine):
67 | """这个函数对数据进行后处理和存盘,防止GPU空等.
68 |
69 | Parameters
70 | ----------
71 | fpath : str
72 | 要保存的文件名称.
73 | inference : ndarray
74 | 输出的分割标签数组.
75 | """
76 | print("Start Postprocess: ", fpath)
77 | inference[inference >= cfg.INFER.THRESH] = 1
78 | inference[inference < cfg.INFER.THRESH] = 0
79 | if cfg.INFER.FILTER_LARGES:
80 | inference = util.filter_largest_volume(inference, 1, "soft")
81 | if cfg.INFER.DO_INTERP:
82 | ratio = [1 / x for x in ratio]
83 | inference = scipy.ndimage.interpolation.zoom(inference, ratio, order=3)
84 |
85 | inference = inference.astype("int8")
86 | inference_file = nib.Nifti1Image(inference, affine)
87 | inferece_path = fpath
88 | nib.save(inference_file, inferece_path)
89 | print("Finish POstprocess: ", fpath)
90 |
91 |
92 | def main():
93 | places = fluid.CUDAPlace(0) if cfg.TRAIN.USE_GPU else fluid.CPUPlace()
94 | exe = fluid.Executor(places)
95 |
96 | infer_exe = fluid.Executor(places)
97 | inference_scope = fluid.core.Scope()
98 | vol_queue = Queue()
99 |
100 | if not os.path.exists(cfg.INFER.PATH.OUTPUT):
101 | os.makedirs(cfg.INFER.PATH.OUTPUT)
102 |
103 | with fluid.scope_guard(inference_scope):
104 | [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(
105 | cfg.INFER.PATH.PARAM, infer_exe
106 | )
107 |
108 | inf_volumes = os.listdir(cfg.INFER.PATH.INPUT)
109 | for vol_ind in tqdm(range(len(inf_volumes)), position=0):
110 | inf_volume = inf_volumes[vol_ind]
111 | if vol_queue.empty():
112 | read_data(os.path.join(cfg.INFER.PATH.INPUT, inf_volumes[vol_ind]), vol_queue)
113 |
114 | volume, affine = vol_queue.get()
115 | print(volume.shape)
116 |
117 | # 这里异步调一个读取数据,必须在get后面,队列是有锁的
118 | if vol_ind != len(inf_volumes):
119 | p = Process(
120 | target=read_data,
121 | args=(os.path.join(cfg.INFER.PATH.INPUT, inf_volumes[vol_ind + 1]), vol_queue),
122 | )
123 | p.start()
124 | # read_data(inf_path, vol_queue)
125 |
126 | inference = np.zeros(volume.shape)
127 |
128 | batch_size = cfg.INFER.BATCH_SIZE
129 | ind = 0
130 | flag = True
131 | pbar = tqdm(total=volume.shape[2] - 2, position=1, leave=False)
132 | pbar.set_postfix(file=inf_volume)
133 | while flag:
134 | batch_data = []
135 | for j in range(0, batch_size):
136 | ind = ind + 1
137 | pbar.update(1)
138 | data = volume[:, :, ind - 1 : ind + 2]
139 | data = data.swapaxes(0, 2).reshape([3, data.shape[1], data.shape[0]]).astype("float32")
140 | batch_data.append(data)
141 |
142 | if ind == volume.shape[2] - 2:
143 | flag = False
144 | pbar.refresh()
145 | break
146 | batch_data = np.array(batch_data)
147 |
148 | result = infer_exe.run(
149 | inference_program, feed={feed_target_names[0]: batch_data}, fetch_list=fetch_targets,
150 | )
151 |
152 | result = np.array(result)
153 | result = result.reshape([-1, 2, 512, 512])
154 |
155 | ii = ind
156 | for j in range(result.shape[0] - 1, -1, -1):
157 | resp = result[j, 1, :, :].reshape([512, 512])
158 | resp = resp.swapaxes(0, 1)
159 | inference[:, :, ii] = resp
160 | ii = ii - 1
161 |
162 | # 这里调用多进程后处理和存储
163 | p = Process(
164 | target=post_process,
165 | args=(
166 | os.path.join(cfg.INFER.PATH.OUTPUT, inf_volume).replace("volume", "segmentation"),
167 | inference,
168 | affine,
169 | ),
170 | )
171 | p.start()
172 | # post_process(
173 | # os.path.join(cfg.INFER.PATH.OUTPUT, inf_volume).replace("volume", "segmentation"),
174 | # inference,
175 | # affine,
176 | # )
177 | pbar.close()
178 |
179 |
180 | if __name__ == "__main__":
181 | parse_args()
182 | main()
183 |
--------------------------------------------------------------------------------
/medseg/vis.py:
--------------------------------------------------------------------------------
1 | """
2 | 对内存中的ndarray,npz,nii进行可视化
3 | """
4 | import sys
5 | import os
6 | import argparse
7 | import time
8 |
9 | import SimpleITK as sitk
10 | import nibabel as nib
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 |
14 | from utils.config import cfg
15 | import utils.util as util
16 | import train
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description="数据预处理")
21 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
22 | parser.add_argument("opts", nargs=argparse.REMAINDER)
23 | args = parser.parse_args()
24 |
25 | if args.cfg_file is not None:
26 | cfg.update_from_file(args.cfg_file)
27 | if args.opts:
28 | cfg.update_from_list(args.opts)
29 |
30 |
31 | def show_slice(vol, lab):
32 | """展示一个2.5D的数据对.
33 |
34 | Parameters
35 | ----------
36 | vol : ndarray
37 | 2.5D的扫描slice.
38 | lab : ndarray
39 | 1片分割标签.
40 | """
41 | if vol.shape[0] <= 3: # CWH 需要转换WHC
42 | vol = vol.swapaxes(0, 2)
43 | lab = lab.swapaxes(0, 2)
44 | if lab.ndim == 2:
45 | lab = lab[:, :, np.newaxis]
46 | if len(vol.shape) == 3 and vol.shape[2] == 3: # 如果输入是3 channel的取中间一片
47 | vol = vol[:, :, 1]
48 | if len(vol.shape) == 2:
49 | vol = vol[:, :, np.newaxis]
50 |
51 | vol = np.tile(vol, (1, 1, 3))
52 | lab = np.tile(lab, (1, 1, 3))
53 | print("vis shape", vol.shape, lab.shape)
54 | vmax = vol.max()
55 | vmin = vol.min()
56 | vol = (vol - vmin) / (vmax - vmin) * 255
57 | lab = lab * 255
58 |
59 | vol = vol.astype("uint8")
60 | lab = lab.astype("uint8")
61 |
62 | plt.figure(figsize=(15, 15))
63 | plt.subplot(121)
64 | plt.imshow(vol)
65 | plt.subplot(122)
66 | plt.imshow(lab)
67 | plt.show()
68 | plt.close()
69 |
70 |
71 | def show_nii():
72 | scans = util.listdir(cfg.DATA.INPUTS_PATH)
73 | labels = util.listdir(cfg.DATA.LABELS_PATH)
74 | records = []
75 | for ind in range(len(scans)):
76 | # for ind in range(3):
77 | print(scans[ind], labels[ind])
78 |
79 | scanf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, scans[ind]))
80 | labelf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[ind]))
81 | scan = scanf.get_fdata()
82 | scan = util.windowlize_image(scan, cfg.PREP.WWWC)
83 | label = labelf.get_fdata()
84 | scan, label = util.cal_direction(scans[ind], scan, label)
85 | print(scan.shape)
86 | print(label.shape)
87 | sli_ind = int(scan.shape[2] / 6)
88 | # for sli_ind in range(vol.shape[2]):
89 | show_slice(scan[:, :, sli_ind * 2], label[:, :, sli_ind * 2])
90 | show_slice(scan[:, :, sli_ind * 3], label[:, :, sli_ind * 3])
91 | show_slice(scan[:, :, sli_ind * 4], label[:, :, sli_ind * 4])
92 | t = input("是否左右翻转: ")
93 | records.append([scans[ind], t])
94 | time.sleep(1)
95 | print(records)
96 |
97 | f = open("./flip.csv", "w")
98 | for record in records:
99 | print(record[0] + "," + record[1], file=f)
100 | f.close()
101 | # 1 rot 1 次,2 rot 3 次
102 | # 0 左右不 flip, 1 左右 flip
103 |
104 |
105 | def show_npz():
106 | """对训练数据npz进行可视化.
107 |
108 | """
109 | for npz in os.listdir(cfg.TRAIN.DATA_PATH):
110 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz))
111 | vol = data["imgs"]
112 | lab = data["labs"]
113 | # for ind in range(vol.shape[0]):
114 | # show_slice(vol[ind], lab[ind])
115 | show_slice(vol[0], lab[0])
116 | show_slice(vol[vol.shape[0] - 1], lab[vol.shape[0] - 1])
117 |
118 |
119 | def show_aug():
120 | """在读取npz基础上做aug之后展示.
121 |
122 | """
123 | for npz in os.listdir(cfg.TRAIN.DATA_PATH):
124 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz))
125 | vol = data["imgs"]
126 | lab = data["labs"]
127 | vol = vol.astype("float32")
128 | lab = lab.astype("int32")
129 | if cfg.AUG.WINDOWLIZE:
130 | vol = util.windowlize_image(vol, cfg.AUG.WWWC) # 肝脏常用
131 | # for ind in range(vol.shape[0]):
132 | vol_slice, lab_slice = train.aug_mapper([vol[0], lab[0]])
133 | show_slice(vol_slice, lab_slice)
134 |
135 | vol_slice, lab_slice = train.aug_mapper([vol[vol.shape[0] - 1], lab[vol.shape[0] - 1]])
136 | show_slice(vol_slice, lab_slice)
137 |
138 |
139 | if __name__ == "__main__":
140 | parse_args()
141 | show_nii()
142 | # show_npz()
143 | # show_aug()
144 |
145 |
146 | # import os
147 | # import matplotlib.pyplot as plt
148 | # from nibabel.orientations import aff2axcodes
149 |
150 | # vols = "/home/aistudio/data/volume"
151 | # for voln in os.listdir(vols):
152 | # print("--------")
153 | # print(voln)
154 | #
155 | # volf = sitk.ReadImage(os.path.join(vols, voln))
156 | #
157 | # vold = sitk.GetArrayFromImage(volf)
158 | # print(vold.shape)
159 | # vold[500:512, 250:260, 0] = 2048
160 | #
161 | # plt.imshow(vold[0, :, :])
162 | # plt.show()
163 |
164 | # vols = "/home/aistudio/data/volume"
165 | # directions = []
166 | # for voln in os.listdir(vols):
167 | # print("--------")
168 | # print(voln)
169 | #
170 | # volf = nib.load(os.path.join(vols, voln))
171 | # print(volf.affine)
172 | # print("codes", aff2axcodes(volf.affine))
173 | #
174 | # vold = volf.get_fdata()
175 | # vold[500:512, 250:260, 0] = 2048
176 | #
177 | # plt.imshow(vold[:, :, 0])
178 | # plt.show()
179 | #
180 | # cmd = input("direction: ")
181 | # if cmd == "a":
182 | # dir = [voln, 1] # 床在左边
183 | # else:
184 | # dir = [voln, 2] # 床在右边
185 | # directions.append(dir)
186 | #
187 | #
188 | # f = open("./directions.csv", "w")
189 | # for dir in directions:
190 | # print(dir[0], ",", dir[1], file=f)
191 | # f.close()
192 |
193 |
194 | # vols = "/home/aistudio/data/volume"
195 | #
196 | #
197 | # # 获取体位信息
198 | # f = open("./directions.csv")
199 | # dirs = f.readlines()
200 | # print("dirs: ", dirs)
201 | # dirs = [x.rstrip("\n") for x in dirs]
202 | # dirs = [x.split(",") for x in dirs]
203 | # dic = {}
204 | # for dir in dirs:
205 | # dic[dir[0].strip()] = dir[1].strip()
206 | # f.close()
207 | #
208 | # print(dic)
209 | #
210 | #
211 | # for voln in os.listdir(vols):
212 | # print("--------")
213 | # print(voln)
214 | #
215 | # volf = nib.load(os.path.join(vols, voln))
216 | # vold = volf.get_fdata()
217 | # print(dic[voln])
218 | # if dic[voln] == "2":
219 | # vold = np.rot90(vold, 3)
220 | # else:
221 | # vold = np.rot90(vold, 1)
222 | #
223 | # plt.imshow(vold[:, :, 50])
224 | # plt.show()
225 |
--------------------------------------------------------------------------------
/tool/zip_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """
3 | 将一个目录下的所有文件和文件夹按照原来的文件结构打包,每个包不超过指定大小
4 | """
5 |
6 | # TODO 测试压缩后删除原文件功能
7 | # TODO: 修改成按照zip文件的大小显示进度条
8 |
9 | import zipfile
10 | import os
11 | from tqdm import tqdm
12 | import platform
13 | import argparse
14 | import logging
15 |
16 |
17 | def get_args():
18 | parser = argparse.ArgumentParser("zip_dataset")
19 | parser.add_argument(
20 | "-i", "--dataset_dir", type=str, required=True, help="[必填] 需要压缩的数据集路径,所有文件所在的文件夹"
21 | )
22 | parser.add_argument(
23 | "-o",
24 | "--zip_dir",
25 | type=str,
26 | required=True,
27 | help="[必填] 压缩后的压缩包保存路径,如果有条件可以和待压缩数据放到不同硬件上,能加快一点速度",
28 | )
29 | parser.add_argument("--size", type=float, default=10.0, help="[可选] 压缩文件过程中每个包不超过这个大小,G为单位")
30 | parser.add_argument(
31 | "-m", "--method", type=str, default="zip", help="[可选] 压缩方法,可选的有:store(只打包不压缩),zip,bz2,lzma",
32 | )
33 | parser.add_argument(
34 | "-v", "--verbos", action="store_true", default=False, help="[可选] 执行过程中显示详细信息"
35 | )
36 | parser.add_argument("--debug", action="store_true", default=False, help="[可选] 执行过程中显示详细信息")
37 | parser.add_argument(
38 | "-d",
39 | "--delete",
40 | action="store_true",
41 | default=False,
42 | help="[慎用] 在压缩完文件后删除对应的原文件,如果盘上空间不够压完就删掉原文件不会炸盘。!!慎用此功能!!",
43 | )
44 |
45 | return parser.parse_args()
46 |
47 |
48 | def do_zip(args):
49 | # 1. 参数校验和设置
50 | if args.verbos:
51 | level = logging.INFO
52 | elif args.debug:
53 | level = logging.DEBUG
54 | else:
55 | level = logging.CRITICAL
56 | logging.basicConfig(format="%(asctime)s : %(message)s", level=level)
57 |
58 | # TODO: 不同压缩方法不用后缀
59 | methods = {
60 | "zip": zipfile.ZIP_DEFLATED,
61 | "store": zipfile.ZIP_STORED,
62 | "bz2": zipfile.ZIP_BZIP2,
63 | "lzma": zipfile.ZIP_LZMA,
64 | }
65 | try:
66 | mode = methods[args.method]
67 | except KeyError:
68 | raise RuntimeError("mode 参数 {} 不合法".format(mode))
69 |
70 | # if args.method == "zip":
71 | # mode = zipfile.ZIP_DEFLATED
72 | # elif args.method == "store":
73 | # mode = zipfile.ZIP_STORED
74 | # elif args.method == "bz2":
75 | # mode = zipfile.ZIP_BZIP2
76 | # elif argms.mode == "lzma":
77 | # mode = zipfile.ZIP_LZMA
78 | # else:
79 | # raise RuntimeError("mode参数 {} 不合法".format(mode))
80 |
81 | # 2. 确认路径和数据集名
82 | print("\n", os.listdir(args.dataset_dir)[:10], "\n")
83 | print("以上是您指定的待压缩路径 {} 下的前10个文件(夹),请确定该路径是否正确".format(args.dataset_dir))
84 | cmd = input("如果 是 请输入 y/Y ,按其他任意键退出执行: ")
85 | if cmd != "y" and cmd != "Y":
86 | logging.error("用户退出执行")
87 | exit(0)
88 |
89 | if not os.path.exists(args.zip_dir): # 如果zip输出路径不存在创建它
90 | logging.info("创建zip文件夹: {}".format(args.zip_dir))
91 | os.makedirs(args.zip_dir)
92 |
93 | dataset_name = os.path.basename(args.dataset_dir.rstrip("\\").rstrip("/")) # 文件夹名做数据集名
94 | logging.info("默认数据集名称为: {}".format(dataset_name))
95 | print("默认数据集名称为: {}".format(dataset_name))
96 | cmd = input("确认使用该名称请输入 y/Y,如想使用其他名称请输入: ")
97 | if cmd != "y" and cmd != "Y":
98 | dataset_name = cmd
99 | logging.info("最终使用数据集名称为: {}".format(dataset_name))
100 |
101 | # 3. 制作当前压缩包名,创建压缩包文件
102 | zip_num = 1
103 | curr_name = "{}-{}.zip".format(dataset_name, zip_num)
104 | curr_zip_path = os.path.join(args.zip_dir, curr_name)
105 | f = zipfile.ZipFile(curr_zip_path, "a", mode)
106 |
107 | files_list = [] # 用来存储待压缩的文件路径和在压缩包中的路径,存到一定数量之后一起往包里压,避免频繁查看当前压缩包大小
108 | list_size = 0 # 当前 files_list 中文件总大小, 单位B
109 | zip_tot_size = args.size * 1000 ** 3 # 每个压缩包不超过这个大小,单位B。因为有的设备上是按照1k换算的,所以保险用1B*1000^3做1G
110 | zip_left_size = zip_tot_size # 当前压缩包离最大大小还有多少
111 |
112 | """
113 | 压缩的整体策略是
114 | 1. 将文件路径添加进 files_list,直到列表中再添加一个文件就会超过压缩包离最大限制的空间 zip_left_size。
115 | 因为files_list文件的文件计算的是没压缩的大小,所以这些文件都加进压缩包大小一定不会超过限制。
116 | 2. 将 files_list 中的文件都加入压缩包,检查当前压缩包的大小,决定是否开新压缩包。 之后继续制作 files_list
117 | 3. 经过多次 2 的添加压缩包大小会接近最大限制,开新压缩包的条件是当前压缩包的大小加上当前准备压缩的文件大小超过了最大限制。
118 | 这里可能会有一点浪费,但是这个文件没实际压进包没法知道是不是会超限制,所以就直接保守开新包了。这里在压缩率比较高包还很小的时候可能会不停的开新包。
119 | """
120 | for dirpath, dirnames, filenames in os.walk(os.path.join(args.dataset_dir)):
121 | for filename in filenames:
122 | # 获取当前文件大小,判断列表中加入这个文件是否超过限制
123 | curr_file_size = os.path.getsize(os.path.join(dirpath, filename))
124 | logging.debug("Name: {}, Size: {}".format(filename, curr_file_size / 1024 ** 2))
125 | list_size += curr_file_size
126 |
127 | if list_size >= zip_left_size: # 如果当前列表中未压缩文件的大小大于zip包能装的大小,那么开始压包
128 | logging.info("当前列表中文件大小是: {} M ".format(list_size / 1024 ** 2))
129 | logging.info("当前压缩包剩余大小: {} M".format(zip_left_size / 1024 ** 2))
130 |
131 | logging.critical("正在将 {} 个文件写入压缩包".format(len(files_list)))
132 | logging.debug("前三个文件是: {}".format(str(files_list[:3])))
133 | logging.debug("最后三个文件是: {}".format(str(files_list[-3:])))
134 | # 将列表里所有的文件写入zip
135 | for pair in tqdm(files_list, ascii=True):
136 | f.write(pair[0], pair[1])
137 | if args.delete:
138 | os.remove(pair[0])
139 |
140 | files_list = [] # 写入完成,清空列表
141 | # 循环中这个pass的文件是在if之后加入列表的,所以列表文件的大小直接就是这个pass文件的大小,下面就添加了
142 | list_size = curr_file_size
143 |
144 | curr_zip_size = os.path.getsize(curr_zip_path)
145 | logging.info("当前压缩包的大小是: {} M\n".format(curr_zip_size / 1024 ** 2))
146 |
147 | if curr_zip_size + curr_file_size > zip_tot_size: # 如果加入这个文件压缩包就超大小限制了就开新的压缩包
148 | f.close()
149 | zip_num += 1
150 | curr_name = "{}-{}.zip".format(dataset_name, zip_num)
151 | curr_zip_path = os.path.join(args.zip_dir, curr_name)
152 | f = zipfile.ZipFile(curr_zip_path, "a", mode)
153 | logging.critical("\n\n\n正创建新的压缩包: {} ".format(curr_name))
154 | zip_left_size = zip_tot_size
155 | else:
156 | zip_left_size = zip_tot_size - curr_zip_size
157 |
158 | # 第一个是文件路径,第二个是压缩包中的路径,压缩包中保存原来文件夹的结构
159 | files_list.append(
160 | [
161 | os.path.join(dirpath, filename),
162 | os.path.join(dataset_name, dirpath[len(args.dataset_dir) + 1 :], filename),
163 | ]
164 | )
165 |
166 | # 最后一个压缩包一般都不会到限制的大小触发写入,剩下的所有文件写入最后一个压缩包
167 | if len(files_list) != 0:
168 | logging.critical("正在将 {} 个文件写入最后一个压缩包".format(len(files_list)))
169 | for pair in tqdm(files_list, ascii=True):
170 | f.write(pair[0], pair[1])
171 | files_list = []
172 | list_size = 0
173 | f.close()
174 |
175 | logging.critical("压缩结束,共 {} 个压缩包".format(zip_num))
176 |
177 |
178 | if __name__ == "__main__":
179 | args = get_args()
180 | do_zip(args)
181 |
--------------------------------------------------------------------------------
/medseg/aug.py:
--------------------------------------------------------------------------------
1 | # 对数据进行增强
2 | # 1. 需要能处理numpy数组,不只是图片
3 | # 2. 要能处理2d和3d的情况
4 | # 3. 所有的操作执行不执行要概率控制
5 | # 4. 首先实现一个能用的版本,逐步实现都用基础矩阵操作不调包
6 | # 5. 所有的函数的默认参数都是调用不做任何变化
7 |
8 | import random
9 | import math
10 | import time
11 |
12 | import cv2
13 | import numpy as np
14 | from scipy.ndimage.interpolation import map_coordinates
15 | from scipy.ndimage.filters import gaussian_filter
16 | import scipy.ndimage
17 | import matplotlib.pyplot as plt
18 | import skimage.io
19 |
20 | from utils.util import pad_volume
21 |
22 | random.seed(time.time())
23 |
24 |
25 | def flip(volume, label=None, chance=(0, 0, 0)):
26 | """.
27 |
28 | Parameters
29 | ----------
30 | volume : type
31 | Description of parameter `volume`.
32 | label : type
33 | Description of parameter `label`.
34 | chance : type
35 | Description of parameter `chance`.
36 |
37 | Returns
38 | -------
39 | flip(volume, label=None,
40 | Description of returned object.
41 |
42 | """
43 | if random.random() < chance[0]:
44 | volume = volume[::-1, :, :]
45 | if label is not None:
46 | label = label[::-1, :, :]
47 | if random.random() < chance[1]:
48 | volume = volume[:, ::-1, :]
49 | if label is not None:
50 | label = label[:, ::-1, :]
51 | if random.random() < chance[2]:
52 | volume = volume[:, :, ::-1]
53 | if label is not None:
54 | label = label[:, :, ::-1]
55 |
56 | if label is not None:
57 | return volume, label
58 | return volume
59 |
60 |
61 | # x,y,z 任意角度旋转,背景填充,mirror,0,extend
62 | # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab 需要研究
63 | def rotate(volume, label=None, angel=([0, 0], [0, 0], [0, 0]), chance=(0, 0, 0), cval=0):
64 | """ 按照指定象限旋转
65 | angel:是角度不是弧度
66 | """
67 |
68 | for axes in range(3):
69 | if random.random() < chance[axes]:
70 | rand_ang = angel[axes][0] + random.random() * (angel[axes][1] - angel[axes][0])
71 | volume = scipy.ndimage.rotate(
72 | volume,
73 | rand_ang,
74 | axes=(axes, (axes + 1) % 3),
75 | reshape=False,
76 | mode="constant",
77 | cval=cval,
78 | )
79 | if label is not None:
80 | label = scipy.ndimage.rotate(
81 | label, rand_ang, axes=(axes, (axes + 1) % 3), reshape=False,
82 | )
83 | if label is not None:
84 | return volume, label
85 | return volume
86 |
87 |
88 | # 缩放大小, vol 是三阶, lab 是插值, 给的是目标大小
89 | def zoom(volume, label=None, ratio=[(1, 1), (1, 1), (1, 1)], chance=(0, 0, 0)):
90 | ratio = list(ratio)
91 | chance = list(chance)
92 | for axes in range(3):
93 | if random.random() < chance[axes]: # 如果随机超过做zoom的概率,那就是不做缩放
94 | ratio[axes] = ratio[axes][0] + random.random() * (ratio[axes][1] - ratio[axes][0])
95 | else:
96 | ratio[axes] = 1
97 | volume = scipy.ndimage.zoom(volume, ratio, order=3, mode="constant")
98 | if label is not None:
99 | label = scipy.ndimage.zoom(label, ratio, order=3, mode="constant")
100 | return volume, label
101 | return volume
102 |
103 |
104 | def crop(volume, label=None, size=[3, 512, 512], pad_value=0):
105 | """在随机位置裁剪出一个指定大小的体积
106 | 每个维度都有输入图片更大或者size更大两种情况:
107 | - 如果输入图片更大,保证不会裁剪出图片,位置随机;
108 | - 如果size更大,只进行pad操作,体积在正中间.
109 | 对于标签,标签中是1的维度不会进行pad;不是1的和volume都一样
110 | Parameters
111 | ----------
112 | volume : np.ndarray
113 | Description of parameter `volume`.
114 | label : np.ndarray
115 | Description of parameter `label`.
116 | size : 需要裁出的体积,list
117 | Description of parameter `size`.
118 | pad_value : int
119 | volume用pad_value填充,标签默认用0填充.
120 |
121 | Returns
122 | -------
123 | type
124 | size大小的ndarray.
125 |
126 | """
127 | # 1. 先pad一手,让数据至少size大
128 | volume = pad_volume(volume, size, pad_value, False)
129 | if label is not None:
130 | lab_size = list(size)
131 | for ind, s in enumerate(label.shape):
132 | if s == 1: # 是1的维度都不动
133 | lab_size[ind] = -1
134 | label = pad_volume(label, lab_size, 0, False)
135 | # 2.随机一个裁剪范围起点,之后进行crop裁剪
136 | crop_low = [int(random.random() * (x - y)) for x, y in zip(volume.shape, size)]
137 | r = [[l, l + s] for l, s in zip(crop_low, size)]
138 | volume = volume[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]]
139 | if label is not None:
140 | for ind in range(3):
141 | if label.shape[ind] == 1:
142 | r[ind][0] = 0
143 | r[ind][1] = 1
144 | label = label[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]]
145 | return volume, label
146 | return volume
147 |
148 |
149 | def elastic_transform(image, alpha, sigma, random_state=None):
150 | if random_state is None:
151 | random_state = np.random.RandomState(None)
152 | shape = image.shape
153 | print("randomstate", random_state.rand(*shape) * 2 - 1, "end")
154 | dx = (
155 | gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
156 | )
157 | print(dx.shape)
158 | dy = (
159 | gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
160 | )
161 | dz = np.zeros_like(dx)
162 |
163 | x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]))
164 | indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
165 |
166 | distored_image = map_coordinates(image, indices, order=1, mode="reflect")
167 | return distored_image.reshape(image.shape)
168 |
169 |
170 | # TODO: 增加随机shift的增强,这个针对显影剂,随机给整个输入加上一个值,肝脏是-80到0的正态
171 | # TODO: 增加随机噪音的增强
172 | # TODO: 增加图片随机组合的增强
173 | # TODO: 全部流程放进一个函数
174 |
175 | # cat = skimage.io.imread("~/Desktop/cat.jpg")
176 | # checker = skimage.io.imread("~/Desktop/checker.png")
177 | # img = np.ones([10, 10])
178 | #
179 | # img[0:5, 0:5] = 0
180 | #
181 | #
182 | # plt.imshow(cat)
183 | # plt.show()
184 | # if img.ndim == 2:
185 | # img = img[:, :, np.newaxis]
186 |
187 | # print(img.shape)
188 | # img, lab = flip(cat, cat, (1, 1, 0))
189 | # plt.imshow(img)
190 | # plt.show()
191 | #
192 | # img, lab = rotate(cat, cat, ([-45, 45], 0, [0, 0]), (1, 0, 0))
193 | # plt.imshow(img)
194 | # plt.show()
195 | #
196 | # img, lab = zoom(cat, cat, [(0.2, 0.3), (0.7, 0.8), (0.9, 1)], (0.5, 1, 0))
197 | # plt.imshow(img)
198 | # plt.show()
199 |
200 | # print(cat.shape)
201 | # img, label = crop(cat, cat, [400, 500, 3])
202 | # plt.imshow(img)
203 | # plt.show()
204 | #
205 | # plt.imshow(checker)
206 | # plt.show()
207 | # checker = checker[:, :, np.newaxis]
208 | # img = crop(cat, None, [512, 512, 3])
209 | # img = elastic_transform(checker, 900, 8)
210 | # img = img.reshape(img.shape[0], img.shape[1])
211 | # plt.imshow(img)
212 | # plt.show()
213 |
214 | #
215 | # print(img)
216 | # if img.shape[2] == 1:
217 | # img = img.reshape(img.shape[0], img.shape[1])
218 | # plt.imshow(img)
219 | # plt.show()
220 |
221 | # plt.imshow(lab)
222 | # plt.show()
223 |
224 |
225 | """
226 | paddle clas 增广策略
227 |
228 | 图像变换类:
229 | 旋转
230 | 色调
231 | 背景模糊
232 | 透明度变换
233 | 饱和度变换
234 |
235 | 图像裁剪类:
236 | 遮挡
237 |
238 | 图像混叠:
239 | 两幅图一定权重直接叠
240 | 一张图切一部分放到另一张图
241 | """
242 |
--------------------------------------------------------------------------------
/medseg/prep_3d.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 |
4 | import os
5 |
6 |
7 | import numpy as np
8 | import nibabel as nib
9 | from tqdm import tqdm
10 | import scipy
11 | import matplotlib.pyplot as plt
12 |
13 | import utils.util as util
14 | from utils.config import cfg
15 | import utils.util as util
16 |
17 | import argparse
18 | import aug
19 |
20 | # import vis
21 |
22 | np.set_printoptions(threshold=np.inf)
23 |
24 |
25 | """
26 | 对 3D 体数据进行一些预处理,并保存成npz文件
27 | 每个npz文件包含volume和label两个数组,volume和label各包含n条扫描记录,文件进行压缩
28 | """
29 | # TODO: 支持更多的影像格式
30 | # TODO: 提供预处理npz gzip选项
31 | # https://stackoverflow.com/questions/54238670/what-is-the-advantage-of-saving-npz-files-instead-of-npy-in-python-regard
32 |
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser(description="数据预处理")
36 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
37 | parser.add_argument("opts", nargs=argparse.REMAINDER)
38 | args = parser.parse_args()
39 |
40 | if args.cfg_file is not None:
41 | cfg.update_from_file(args.cfg_file)
42 | if args.opts:
43 | cfg.update_from_list(args.opts)
44 |
45 |
46 | def main():
47 | # 1. 创建输出路径,删除非空的summary表格
48 | if cfg.PREP.PLANE == "xy" and not os.path.exists(cfg.DATA.PREP_PATH):
49 | os.makedirs(cfg.DATA.PREP_PATH)
50 | if cfg.PREP.PLANE == "xz" and not os.path.exists(cfg.DATA.Z_PREP_PATH):
51 | os.makedirs(cfg.DATA.Z_PREP_PATH)
52 |
53 | if os.path.exists(cfg.DATA.SUMMARY_FILE) and os.path.getsize(cfg.DATA.SUMMARY_FILE) != 0:
54 | os.remove(cfg.DATA.SUMMARY_FILE)
55 |
56 | volumes = util.listdir(cfg.DATA.INPUTS_PATH)
57 | labels = util.listdir(cfg.DATA.LABELS_PATH)
58 |
59 | # 获取体位信息
60 | f = open("./directions.csv")
61 | dirs = f.readlines()
62 | # print("dirs: ", dirs)
63 | dirs = [x.rstrip("\n") for x in dirs]
64 | dirs = [x.split(",") for x in dirs]
65 | dic = {}
66 | for dir in dirs:
67 | dic[dir[0].strip()] = dir[1].strip()
68 | f.close()
69 |
70 | vol_npz = []
71 | lab_npz = []
72 | npz_count = 0
73 | thick = (cfg.TRAIN.THICKNESS - 1) / 2
74 | pbar = tqdm(range(len(labels)), desc="数据处理中")
75 | for i in range(len(labels)):
76 | pbar.set_postfix(filename=labels[i] + " " + volumes[i])
77 | pbar.update(1)
78 |
79 | print(volumes[i], labels[i])
80 |
81 | volf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, volumes[i]))
82 | labf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[i]))
83 |
84 | util.save_info(volumes[i], volf.header, cfg.DATA.SUMMARY_FILE)
85 |
86 | volume = volf.get_fdata()
87 | label = labf.get_fdata()
88 | label = label.astype(int)
89 | # plt.imshow(volume[:, :, 0])
90 | # plt.show()
91 | if dic[volumes[i]] == "2":
92 | volume = np.rot90(volume, 3)
93 | label = np.rot90(label, 3)
94 | else:
95 | volume = np.rot90(volume, 1)
96 | label = np.rot90(label, 1)
97 |
98 | # plt.imshow(volume[:, :, 0])
99 | # plt.show()
100 |
101 | if cfg.PREP.INTERP:
102 | print("interping")
103 | header = volf.header.structarr
104 | spacing = cfg.PREP.INTERP_PIXDIM
105 | # pixdim 是 ct 三个维度的间距
106 | pixdim = [header["pixdim"][x] for x in range(1, 4)]
107 | for ind in range(3):
108 | if spacing[ind] == -1: # 如果目标spacing为 -1 ,这个维度不进行插值
109 | spacing[ind] = pixdim[ind]
110 | ratio = [x / y for x, y in zip(spacing, pixdim)]
111 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3)
112 | label = scipy.ndimage.interpolation.zoom(label, ratio, order=0)
113 |
114 | if cfg.PREP.WINDOW:
115 | volume = util.windowlize_image(volume, cfg.PREP.WWWC)
116 |
117 | label = util.clip_label(label, cfg.PREP.FRONT)
118 |
119 | if cfg.PREP.CROP: # 裁到只有前景
120 | bb_min, bb_max = get_bbs(label)
121 | label = crop_to_bbs(label, bb_min, bb_max, 0.5)[0]
122 | volume = crop_to_bbs(volume, bb_min, bb_max)[0]
123 |
124 | label = pad_volume(label, [512, 512, 0], 0) # NOTE: 注意标签使用 0
125 | volume = pad_volume(volume, [512, 512, 0], -1024)
126 | print("after padding", volume.shape, label.shape)
127 |
128 | volume = volume.astype(np.float16)
129 | label = label.astype(np.int8)
130 |
131 | crop_size = list(cfg.PREP.SIZE)
132 | for ind in range(3):
133 | if crop_size[ind] == -1:
134 | crop_size[ind] = volume.shape[ind]
135 | volume, label = aug.crop(volume, label, crop_size)
136 |
137 | # 开始切片
138 | if cfg.PREP.PLANE == "xy":
139 | for frame in range(1, volume.shape[2] - 1):
140 | if label[:, :, frame].sum() > cfg.PREP.THRESH:
141 | vol = volume[:, :, frame - thick : frame + thick + 1]
142 | lab = label[:, :, frame]
143 | lab = lab[:, :, np.newaxis]
144 |
145 | vol = np.swapaxes(vol, 0, 2)
146 | lab = np.swapaxes(lab, 0, 2) # [3,512,512],CWH 的顺序
147 |
148 | vol_npz.append(vol.copy())
149 | lab_npz.append(lab.copy())
150 | print("{} 片满足,当前共 {}".format(frame, len(vol_npz)))
151 |
152 | if len(vol_npz) == cfg.PREP.BATCH_SIZE or (
153 | i == (len(labels) - 1) and frame == volume.shape[2] - 1
154 | ):
155 | imgs = np.array(vol_npz)
156 | labs = np.array(lab_npz)
157 | print(imgs.shape)
158 | print(labs.shape)
159 | print("正在存盘")
160 | file_name = "{}_{}_f{}-{}".format(
161 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count
162 | )
163 | file_path = os.path.join(cfg.DATA.PREP_PATH, file_name)
164 | np.savez(file_path, imgs=imgs, labs=labs)
165 | vol_npz = []
166 | lab_npz = []
167 | npz_count += 1
168 | else:
169 | print(volume.shape, label.shape)
170 | for frame in range(1, volume.shape[0] - 1):
171 | if label[frame, :, :].sum() > cfg.PREP.THRESH:
172 | vol = volume[frame - 1 : frame + 2, :, :]
173 | lab = label[frame, :, :]
174 | lab = lab.reshape([1, lab.shape[0], lab.shape[1]])
175 |
176 | vol_npz.append(vol.copy())
177 | lab_npz.append(lab.copy())
178 |
179 | if len(vol_npz) == cfg.PREP.BATCH_SIZE:
180 | vols = np.array(vol_npz)
181 | labs = np.array(lab_npz)
182 | print(vols.shape)
183 | print(labs.shape)
184 | print("正在存盘")
185 | file_name = "{}_{}_f{}-{}".format(
186 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count
187 | )
188 | file_path = os.path.join(cfg.DATA.Z_PREP_PATH, file_name)
189 | np.savez(file_path, vols=vols, labs=labs)
190 | vol_npz = []
191 | lab_npz = []
192 | npz_count += 1
193 |
194 | pbar.close()
195 |
196 |
197 | if __name__ == "__main__":
198 | parse_args()
199 | main()
200 |
--------------------------------------------------------------------------------
/medseg/utils/config.py:
--------------------------------------------------------------------------------
1 | import six
2 | from ast import literal_eval
3 | import codecs
4 | import yaml
5 |
6 | # 使用的时候如果直接赋值出去,默认是不可变的,如果需要再赋值一定注意
7 | class PjConfig(dict):
8 | def __init__(self, *args, **kwargs):
9 | super(PjConfig, self).__init__(*args, **kwargs)
10 | self.immutable = False
11 |
12 | def __setattr__(self, key, value, create_if_not_exist=True):
13 | if key in ["immutable"]:
14 | self.__dict__[key] = value
15 | return
16 |
17 | t = self
18 | keylist = key.split(".")
19 | for k in keylist[:-1]:
20 | t = t.__getattr__(k, create_if_not_exist)
21 |
22 | t.__getattr__(keylist[-1], create_if_not_exist)
23 | t[keylist[-1]] = value
24 |
25 | def __getattr__(self, key, create_if_not_exist=True):
26 | if key in ["immutable"]:
27 | return self.__dict__[key]
28 |
29 | if not key in self:
30 | if not create_if_not_exist:
31 | raise KeyError
32 | self[key] = PjConfig()
33 | return self[key]
34 |
35 | def __setitem__(self, key, value):
36 | if self.immutable:
37 | raise AttributeError(
38 | 'Attempted to set "{}" to "{}", but PjConfig is immutable'.format(
39 | key, value
40 | )
41 | )
42 | if isinstance(value, six.string_types):
43 | try:
44 | value = literal_eval(value)
45 | except ValueError:
46 | pass
47 | except SyntaxError:
48 | pass
49 | super(PjConfig, self).__setitem__(key, value)
50 |
51 | def update_from_Config(self, other):
52 | if isinstance(other, dict):
53 | other = PjConfig(other)
54 | assert isinstance(other, PjConfig)
55 | diclist = [("", other)]
56 | while len(diclist):
57 | prefix, tdic = diclist[0]
58 | diclist = diclist[1:]
59 | for key, value in tdic.items():
60 | key = "{}.{}".format(prefix, key) if prefix else key
61 | if isinstance(value, dict):
62 | diclist.append((key, value))
63 | continue
64 | try:
65 | self.__setattr__(key, value, create_if_not_exist=False)
66 | except KeyError:
67 | raise KeyError("Non-existent config key: {}".format(key))
68 | self.check()
69 |
70 | def update_from_list(self, config_list):
71 | if len(config_list) % 2 != 0:
72 | raise ValueError(
73 | "Command line options config format error! Please check it: {}".format(
74 | config_list
75 | )
76 | )
77 | for key, value in zip(config_list[0::2], config_list[1::2]):
78 | try:
79 | self.__setattr__(key, value, create_if_not_exist=False)
80 | except KeyError:
81 | raise KeyError("Non-existent config key: {}".format(key))
82 | self.check()
83 |
84 | def update_from_file(self, config_file):
85 | with codecs.open(config_file, "r", "utf-8") as file:
86 | dic = yaml.load(file, Loader=yaml.FullLoader)
87 | self.update_from_Config(dic)
88 |
89 | def set_immutable(self, immutable):
90 | self.immutable = immutable
91 | for value in self.values():
92 | if isinstance(value, PjConfig):
93 | value.set_immutable(immutable)
94 |
95 | def is_immutable(self):
96 | return self.immutable
97 |
98 | def check(self):
99 | if cfg.PREP.THICKNESS % 2 != 1:
100 | raise ValueError("2.5D预处理厚度 {} 不是奇数".format(cfg.TRAIN.THICKNESS))
101 |
102 |
103 | cfg = PjConfig()
104 |
105 | """数据集配置"""
106 | # 数据集名称
107 | cfg.DATA.NAME = "lits"
108 | # 输入的2D或3D图像路径
109 | cfg.DATA.INPUTS_PATH = "/home/aistudio/data/scan"
110 | # 标签路径
111 | cfg.DATA.LABELS_PATH = "/home/aistudio/data/label"
112 | # 预处理输出npz路径
113 | cfg.DATA.PREP_PATH = "/home/aistudio/data/preprocess"
114 | # z 方向初始化可以指定一个独立的输出文件路径
115 | cfg.DATA.Z_PREP_PATH = cfg.DATA.PREP_PATH
116 | # 预处理过程中数据信息写到这个文件
117 | cfg.DATA.SUMMARY_FILE = "./{}.csv".format(cfg.DATA.NAME)
118 |
119 | """ 预处理配置 """
120 | # 预处理进行的平面
121 | cfg.PREP.PLANE = "xy"
122 | # 处理过程中所有比这个数字大的标签都设为前景
123 | cfg.PREP.FRONT = 1
124 | # 是否将数据只 crop 到前景
125 | cfg.PREP.CROP = False
126 | # 是否对数据插值改变大小
127 | cfg.PREP.INTERP = False
128 | # 进行插值的话目标片间间隔是多少,单位mm,-1的维度不会进行插值
129 | cfg.PREP.INTERP_PIXDIM = (-1, -1, 1.0)
130 | # 是否进行窗口化,在预处理阶段不建议做,灵活性太低
131 | cfg.PREP.WINDOW = False
132 | # 窗宽窗位
133 | cfg.PREP.WWWC = (400, 0)
134 | # 丢弃前景数量少于thresh的slice
135 | cfg.PREP.THRESH = 256
136 | # 3D的数据在开始切割之前pad到这个大小,-1的维度会放着不动
137 | cfg.PREP.SIZE = (512, 512, -1)
138 | # 2.5D预处理一片的厚度
139 | cfg.PREP.THICKNESS = 3
140 | # 预处理过程中多少组数据组成一个npz文件
141 | # 可以先跑bs=1,看看一对数据多大;尽量至少将训练数据分入10个npz,否则分训练和验证集的时候会很不准
142 | # 这个值不建议给成 2^n,这样更利于随机打乱数据
143 | cfg.PREP.BATCH_SIZE = 128
144 |
145 | """训练配置"""
146 | cfg.TRAIN.DATA_PATH = "/home/aistudio/data/preprocess"
147 | # 训练数据的数量,用来显示训练进度条和时间估计,如果不知道有多少写-1
148 | cfg.TRAIN.DATA_COUNT = -1
149 | # 预训练权重路径,如果没有写空,有的话会尝试加载
150 | cfg.TRAIN.PRETRAINED_WEIGHT = ""
151 | # 预测裁剪模型保存路径
152 | cfg.TRAIN.INF_MODEL_PATH = "./model/lits/inf"
153 | # 可以继续训练的ckpt模型保存路径
154 | cfg.TRAIN.CKPT_MODEL_PATH = "./model/lits/ckpt"
155 | # 效果最好的模型保存路径
156 | cfg.TRAIN.BEST_MODEL_PATH = "./model/lits/best"
157 | # 训练过程中输入图像大小,不加channel
158 | cfg.TRAIN.INPUT_SIZE = (512, 512)
159 | # 训练过程中用的batch_size
160 | cfg.TRAIN.BATCH_SIZE = 32
161 | # 共训练多少个epoch
162 | cfg.TRAIN.EPOCHS = 20
163 | # 使用的模型结构
164 | cfg.TRAIN.ARCHITECTURE = "res_unet"
165 | # 使用的正则化方法,支持L1,L2,其他一切值都是不加正则化
166 | cfg.TRAIN.REG_TYPE = "L1"
167 | # 正则化的权重
168 | cfg.TRAIN.REG_COEFF = 1e-6
169 | # 梯度下降方法
170 | cfg.TRAIN.OPTIMIZER = "adam"
171 | # 学习率
172 | cfg.TRAIN.LR = [0.003, 0.002, 0.001]
173 | # 学习率变化step
174 | cfg.TRAIN.BOUNDARIES = [10000, 20000]
175 | # Loss 支持ce,dice,miou,wce,focal
176 | cfg.TRAIN.LOSS = ["ce", "dice"]
177 | # 是否使用GPU进行训练
178 | cfg.TRAIN.USE_GPU = False
179 | # 进行验证
180 | cfg.TRAIN.DO_EVAL = False
181 | # 每 snapchost_epoch 做一次eval并保存模型
182 | cfg.TRAIN.SNAPSHOT_BATCH = 500
183 | # 每 disp_epoch 打出一次训练过程
184 | cfg.TRAIN.DISP_BATCH = 10
185 | # VDL log路径
186 | cfg.TRAIN.VDL_LOG = "/home/aistudio/log"
187 |
188 | """ HRNET 设置"""
189 | # HRNET STAGE2 设置
190 | cfg.MODEL.HRNET.STAGE2.NUM_MODULES = 1
191 | cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS = [40, 80]
192 | # HRNET STAGE3 设置
193 | cfg.MODEL.HRNET.STAGE3.NUM_MODULES = 4
194 | cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160]
195 | # HRNET STAGE4 设置
196 | cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3
197 | cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320]
198 |
199 | """数据增强"""
200 | # 不单独为增强操作设做不做的config,不想做概率设成 0,注意CWH
201 | # 每个维度进行翻转增强的概率,CWH
202 | # 是否进行窗口化
203 | cfg.AUG.WINDOWLIZE = True
204 | # 窗宽窗位
205 | cfg.AUG.WWWC = cfg.PREP.WWWC
206 | # 进行翻转数据增强的概率
207 | cfg.AUG.FLIP.RATIO = (0, 0, 0)
208 | # 进行旋转增强的概率
209 | cfg.AUG.ROTATE.RATIO = (0, 0, 0)
210 | # 旋转的角度范围,单位度
211 | cfg.AUG.ROTATE.RANGE = (0, (0, 0), 0)
212 | # 进行缩放的概率
213 | cfg.AUG.ZOOM.RATIO = (0, 0, 0)
214 | # 进行缩放的比例
215 | cfg.AUG.ZOOM.RANGE = ((1, 1), (1, 1), (1, 1))
216 | # 进行随机crop的目标大小
217 | cfg.AUG.CROP.SIZE = (3, 512, 512)
218 |
219 | """推理配置"""
220 | # 推理的输入数据路径
221 | cfg.INFER.PATH.INPUT = "/home/aistudio/data/inference"
222 | # 推理的结果输出路径
223 | cfg.INFER.PATH.OUTPUT = "/home/aistudio/data/infer_lab"
224 | # 推理的模型权重路径
225 | cfg.INFER.PATH.PARAM = "/home/aistudio/weight/liver/inf"
226 | # 是否使用GPU进行推理
227 | cfg.INFER.USE_GPU = False
228 | # 推理过程中的 batch_size
229 | cfg.INFER.BATCH_SIZE = 128
230 | # 是否进行窗口化,这个和训练过程中的配置应当相同
231 | cfg.INFER.WINDOWLIZE = True
232 | # 窗宽窗位
233 | cfg.INFER.WWWC = cfg.PREP.WWWC
234 | # 是否进行插值
235 | cfg.INFER.DO_INTERP = False
236 | # 如果进行插值,目标的spacing,-1的维度忽略
237 | cfg.INFER.SPACING = [-1, -1, 1]
238 | # 是否进行最大连通块过滤
239 | cfg.INFER.FILTER_LARGES = True
240 | # 推理过程中区分前景和背景的阈值
241 | cfg.INFER.THRESH = 0.5
242 |
243 | """ 测试配置 """
244 | # 分割结果的路径
245 | cfg.EVAL.PATH.SEG = "/home/aistudio/data/infer_lab"
246 | # 分割GT标签的路径
247 | cfg.EVAL.PATH.GT = "/home/aistudio/data/eval_lab"
248 | # 评估结果存储的文件
249 | cfg.EVAL.PATH.NAME = "eval"
250 | # 测试过程中要计算的指标,包括
251 | # FP,FN,TP,TN(绝对数量)
252 | # Precision,Recall/Sensitivity,Specificity,Accuracy,Kappa
253 | # Dice,IOU/VOE
254 | cfg.EVAL.METRICS = [
255 | "IOU",
256 | "Dice",
257 | "TP",
258 | "TN",
259 | "Precision",
260 | "Recall",
261 | "Sensitivity",
262 | "Specificity",
263 | "Accuracy",
264 | ]
265 |
--------------------------------------------------------------------------------
/medseg/prep_png.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 |
4 | import os
5 | import argparse
6 |
7 | import numpy as np
8 | import nibabel as nib
9 | from tqdm import tqdm
10 | import scipy
11 | import cv2
12 | import matplotlib.pyplot as plt
13 |
14 | import utils.util as util
15 | from utils.config import cfg
16 | import utils.util as util
17 | import aug
18 |
19 | np.set_printoptions(threshold=np.inf)
20 |
21 |
22 | """
23 | 对 3D 体数据进行一些预处理,并保存成npz文件
24 | 每个npz文件包含volume和label两个数组,volume和label各包含n条扫描记录,文件进行压缩
25 | """
26 | # TODO: 支持更多的影像格式
27 | # TODO: 提供预处理npz gzip选项
28 | # https://stackoverflow.com/questions/54238670/what-is-the-advantage-of-saving-npz-files-instead-of-npy-in-python-regard
29 |
30 |
31 | def parse_args():
32 | parser = argparse.ArgumentParser(description="数据预处理")
33 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
34 | parser.add_argument("opts", nargs=argparse.REMAINDER)
35 | args = parser.parse_args()
36 |
37 | if args.cfg_file is not None:
38 | cfg.update_from_file(args.cfg_file)
39 | if args.opts:
40 | cfg.update_from_list(args.opts)
41 |
42 |
43 | def main():
44 | # 1. 创建输出路径,删除非空的summary表格
45 | if cfg.PREP.PLANE == "xy" and not os.path.exists(cfg.DATA.PREP_PATH):
46 | os.makedirs(cfg.DATA.PREP_PATH)
47 | if cfg.PREP.PLANE == "xz" and not os.path.exists(cfg.DATA.Z_PREP_PATH):
48 | os.makedirs(cfg.DATA.Z_PREP_PATH)
49 |
50 | if os.path.exists(cfg.DATA.SUMMARY_FILE) and os.path.getsize(cfg.DATA.SUMMARY_FILE) != 0:
51 | os.remove(cfg.DATA.SUMMARY_FILE)
52 |
53 | volumes = util.listdir(cfg.DATA.INPUTS_PATH)
54 | labels = util.listdir(cfg.DATA.LABELS_PATH)
55 |
56 | # 获取体位信息
57 | f = open("./directions.csv")
58 | dirs = f.readlines()
59 | # print("dirs: ", dirs)
60 | dirs = [x.rstrip("\n") for x in dirs]
61 | dirs = [x.split(",") for x in dirs]
62 | dic = {}
63 | for dir in dirs:
64 | dic[dir[0].strip()] = dir[1].strip()
65 | f.close()
66 |
67 | vol_npz = []
68 | lab_npz = []
69 | npz_count = 0
70 |
71 | pbar = tqdm(range(len(labels)), desc="数据处理中")
72 | for i in range(len(labels)):
73 | pbar.set_postfix(filename=labels[i] + " " + volumes[i])
74 | pbar.update(1)
75 |
76 | print(volumes[i], labels[i])
77 |
78 | volf = nib.load(os.path.join(cfg.DATA.INPUTS_PATH, volumes[i]))
79 | labf = nib.load(os.path.join(cfg.DATA.LABELS_PATH, labels[i]))
80 |
81 | util.save_info(volumes[i], volf.header, cfg.DATA.SUMMARY_FILE)
82 |
83 | volume = volf.get_fdata()
84 | label = labf.get_fdata()
85 | label = label.astype(int)
86 |
87 | # plt.imshow(volume[:, :, 0])
88 | # plt.show()
89 |
90 | if dic[volumes[i]] == "2":
91 | volume = np.rot90(volume, 3)
92 | label = np.rot90(label, 3)
93 | else:
94 | volume = np.rot90(volume, 1)
95 | label = np.rot90(label, 1)
96 |
97 | # plt.imshow(volume[:, :, 0])
98 | # plt.show()
99 |
100 | if cfg.PREP.INTERP:
101 | print("interping")
102 | header = volf.header.structarr
103 | spacing = cfg.PREP.INTERP_PIXDIM
104 | # pixdim 是 ct 三个维度的间距
105 | pixdim = [header["pixdim"][x] for x in range(1, 4)]
106 | for ind in range(3):
107 | if spacing[ind] == -1: # 如果目标spacing为 -1 ,这个维度不进行插值
108 | spacing[ind] = pixdim[ind]
109 | ratio = [x / y for x, y in zip(spacing, pixdim)]
110 | volume = scipy.ndimage.interpolation.zoom(volume, ratio, order=3)
111 | label = scipy.ndimage.interpolation.zoom(label, ratio, order=0)
112 |
113 | if cfg.PREP.WINDOW:
114 | volume = util.windowlize_image(volume, cfg.PREP.WWWC)
115 |
116 | label = util.clip_label(label, cfg.PREP.FRONT)
117 |
118 | if cfg.PREP.CROP: # 裁到只有前景
119 | bb_min, bb_max = get_bbs(label)
120 | label = crop_to_bbs(label, bb_min, bb_max, 0.5)[0]
121 | volume = crop_to_bbs(volume, bb_min, bb_max)[0]
122 |
123 | label = pad_volume(label, [512, 512, 0], 0) # NOTE: 注意标签使用 0
124 | volume = pad_volume(volume, [512, 512, 0], -1024)
125 | print("after padding", volume.shape, label.shape)
126 |
127 | volume = volume.astype(np.float16)
128 | label = label.astype(np.int8)
129 |
130 | crop_size = list(cfg.PREP.SIZE)
131 | for ind in range(3):
132 | if crop_size[ind] == -1:
133 | crop_size[ind] = volume.shape[ind]
134 | volume, label = aug.crop(volume, label, crop_size)
135 |
136 | # 开始切片
137 | volume = volume.clip(-200, 200)
138 | volume = (volume + 200) / 400 * 255
139 | volume = volume.astype(np.uint8)
140 | print(volume.dtype)
141 |
142 | if cfg.PREP.PLANE == "xy":
143 | for frame in range(1, volume.shape[2] - 1):
144 | if label[:, :, frame].sum() > cfg.PREP.THRESH:
145 | # vol = volume[:, :, frame - 1 : frame + 2]
146 | lab = label[:, :, frame]
147 |
148 | vol = volume[:, :, frame]
149 | lab = lab * 255
150 |
151 | # print(vol.shape)
152 | # print(lab.shape)
153 | cv2.imwrite(
154 | os.path.join(
155 | cfg.DATA.PREP_PATH,
156 | "imgs",
157 | "lits-{}-{}.png".format(
158 | volumes[i].lstrip("volume-").rstrip(".nii"), frame
159 | ),
160 | ),
161 | vol,
162 | )
163 |
164 | cv2.imwrite(
165 | os.path.join(
166 | cfg.DATA.PREP_PATH,
167 | "labs",
168 | "lits-{}-{}.png".format(
169 | volumes[i].lstrip("volume-").rstrip(".nii"), frame
170 | ),
171 | ),
172 | lab,
173 | )
174 |
175 | # volimg = Image.fromarray(vol)
176 | # labimg = Image.fromarray(lab, "L")
177 | # volimg.save(
178 | # os.path.join(
179 | # cfg.DATA.PREP_PATH,
180 | # "imgs",
181 | # "lits-{}-{}.png".format(
182 | # volumes[i].lstrip("volume-").rstrip(".nii"), frame
183 | # ),
184 | # )
185 | # )
186 | # labimg.save(
187 | # os.path.join(
188 | # cfg.DATA.PREP_PATH,
189 | # "labs",
190 | # "lits-{}-{}.png".format(
191 | # volumes[i].lstrip("volume-").rstrip(".nii"), frame
192 | # ),
193 | # )
194 | # )
195 | # volimg.close()
196 | # labimg.close()
197 |
198 | # input("here")
199 |
200 | else:
201 | print(volume.shape, label.shape)
202 | for frame in range(1, volume.shape[0] - 1):
203 | if label[frame, :, :].sum() > cfg.PREP.THRESH:
204 | vol = volume[frame - 1 : frame + 2, :, :]
205 | lab = label[frame, :, :]
206 | lab = lab.reshape([1, lab.shape[0], lab.shape[1]])
207 |
208 | vol_npz.append(vol.copy())
209 | lab_npz.append(lab.copy())
210 |
211 | if len(vol_npz) == cfg.PREP.BATCH_SIZE:
212 | vols = np.array(vol_npz)
213 | labs = np.array(lab_npz)
214 | print(vols.shape)
215 | print(labs.shape)
216 | print("正在存盘")
217 | file_name = "{}_{}_f{}-{}".format(
218 | cfg.DATA.NAME, cfg.PREP.PLANE, cfg.PREP.FRONT, npz_count
219 | )
220 | file_path = os.path.join(cfg.DATA.Z_PREP_PATH, file_name)
221 | np.savez(file_path, vols=vols, labs=labs)
222 | vol_npz = []
223 | lab_npz = []
224 | npz_count += 1
225 |
226 | pbar.close()
227 |
228 |
229 | if __name__ == "__main__":
230 | parse_args()
231 | main()
232 |
--------------------------------------------------------------------------------
/medseg/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import sys
3 | import os
4 | import argparse
5 | import random
6 | import math
7 | from datetime import datetime
8 | import multiprocessing
9 |
10 | import numpy as np
11 | from tqdm.auto import tqdm
12 | import paddle
13 | import paddle.fluid as fluid
14 | from paddle.fluid.layers import log
15 | from visualdl import LogWriter
16 |
17 |
18 | import utils.util as util
19 | from utils.config import cfg
20 | import loss
21 | import aug
22 | from models.model import create_model
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description="训练")
27 | parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
28 | parser.add_argument("--use_gpu", action="store_true", default=False, help="是否用GPU")
29 | parser.add_argument("--do_eval", action="store_true", default=False, help="是否进行测试")
30 | parser.add_argument("opts", nargs=argparse.REMAINDER)
31 | args = parser.parse_args()
32 |
33 | if args.cfg_file is not None:
34 | cfg.update_from_file(args.cfg_file)
35 | if args.opts:
36 | cfg.update_from_list(args.opts)
37 | if args.use_gpu: # 命令行参数只能从false改成true,不能声明false
38 | cfg.TRAIN.USE_GPU = True
39 | if args.do_eval:
40 | cfg.TRAIN.DO_EVAL = True
41 |
42 | cfg.set_immutable(True)
43 | # TODO: 打印cfg配置
44 |
45 |
46 | npz_names = util.listdir(cfg.TRAIN.DATA_PATH)
47 | random.shuffle(npz_names)
48 |
49 |
50 | def data_reader(part_start=0, part_end=8):
51 | # NOTE: 这种分法效率高好写,但是npz很少的时候分得不准。npz文件至少分10个
52 | npz_part = npz_names[int(len(npz_names) * part_start / 10) : int(len(npz_names) * part_end / 10)]
53 |
54 | def reader():
55 | # BUG: tqdm每次更新都另起一行,此外要单独测试windows上好不好使
56 | if cfg.TRAIN.DATA_COUNT != -1:
57 | pbar = tqdm(total=cfg.TRAIN.DATA_COUNT, desc="训练进度")
58 | for npz_name in npz_part:
59 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz_name))
60 | imgs = data["imgs"]
61 | labs = data["labs"]
62 | assert len(np.where(labs == 1)[0]) + len(np.where(labs == 0)[0]) == labs.size, "非法的label数值"
63 | if cfg.AUG.WINDOWLIZE:
64 | imgs = util.windowlize_image(imgs, cfg.AUG.WWWC)
65 | else:
66 | imgs = util.windowlize_image(imgs, (4096, 0))
67 | inds = [x for x in range(imgs.shape[0])]
68 | random.shuffle(inds)
69 | for ind in inds:
70 | if cfg.TRAIN.DATA_COUNT != -1:
71 | pbar.update()
72 | vol = imgs[ind].reshape(cfg.TRAIN.THICKNESS, 512, 512).astype("float32")
73 | lab = labs[ind].reshape(1, 512, 512).astype("int32")
74 | yield vol, lab
75 | # TODO: 标签平滑
76 | # https://medium.com/@lessw/label-smoothing-deep-learning-google-brain-explains-why-it-works-and-when-to-use-sota-tips-977733ef020
77 |
78 | return reader
79 |
80 |
81 | def aug_mapper(data):
82 | vol = data[0]
83 | lab = data[1]
84 | ww, wc = cfg.AUG.WWWC
85 | # NOTE: 注意不要增强第0维,那是厚度的方向
86 | vol, lab = aug.flip(vol, lab, cfg.AUG.FLIP.RATIO)
87 | vol, lab = aug.rotate(vol, lab, cfg.AUG.ROTATE.RANGE, cfg.AUG.ROTATE.RATIO, wc - ww / 2)
88 | vol, lab = aug.zoom(vol, lab, cfg.AUG.ZOOM.RANGE, cfg.AUG.ZOOM.RATIO)
89 | vol, lab = aug.crop(vol, lab, cfg.AUG.CROP.SIZE, wc - ww / 2)
90 | return vol, lab
91 |
92 |
93 | def main():
94 | train_program = fluid.Program()
95 | train_init = fluid.Program()
96 |
97 | with fluid.program_guard(train_program, train_init):
98 | image = fluid.layers.data(name="image", shape=[cfg.TRAIN.THICKNESS, 512, 512], dtype="float32")
99 | label = fluid.layers.data(name="label", shape=[1, 512, 512], dtype="int32")
100 | train_loader = fluid.io.DataLoader.from_generator(
101 | feed_list=[image, label],
102 | capacity=cfg.TRAIN.BATCH_SIZE * 2,
103 | iterable=True,
104 | use_double_buffer=True,
105 | )
106 | prediction = create_model(image, 2)
107 | avg_loss = loss.create_loss(prediction, label, 2)
108 | miou = loss.mean_iou(prediction, label, 2)
109 |
110 | # 进行正则化
111 | if cfg.TRAIN.REG_TYPE == "L1":
112 | decay = paddle.fluid.regularizer.L1Decay(cfg.TRAIN.REG_COEFF)
113 | elif cfg.TRAIN.REG_TYPE == "L2":
114 | decay = paddle.fluid.regularizer.L2Decay(cfg.TRAIN.REG_COEFF)
115 | else:
116 | decay = None
117 |
118 | # 选择优化器
119 | lr = fluid.layers.piecewise_decay(boundaries=cfg.TRAIN.BOUNDARIES, values=cfg.TRAIN.LR)
120 | if cfg.TRAIN.OPTIMIZER == "adam":
121 | optimizer = fluid.optimizer.AdamOptimizer(learning_rate=lr, regularization=decay,)
122 | elif cfg.TRAIN.OPTIMIZER == "sgd":
123 | optimizer = fluid.optimizer.SGDOptimizer(learning_rate=lr, regularization=decay)
124 | elif cfg.TRAIN.OPTIMIZE == "momentum":
125 | optimizer = fluid.optimizer.Momentum(momentum=0.9, learning_rate=lr, regularization=decay,)
126 | else:
127 | raise Exception("错误的优化器类型: {}".format(cfg.TRAIN.OPTIMIZER))
128 | optimizer.minimize(avg_loss)
129 |
130 | places = fluid.CUDAPlace(0) if cfg.TRAIN.USE_GPU else fluid.CPUPlace()
131 | exe = fluid.Executor(places)
132 | exe.run(train_init)
133 | exe_test = fluid.Executor(places)
134 |
135 | test_program = train_program.clone(for_test=True)
136 | compiled_train_program = fluid.CompiledProgram(train_program).with_data_parallel(loss_name=avg_loss.name)
137 |
138 | if cfg.TRAIN.PRETRAINED_WEIGHT != "":
139 | print("Loading paramaters")
140 | fluid.io.load_persistables(exe, cfg.TRAIN.PRETRAINED_WEIGHT, train_program)
141 |
142 | # train_reader = fluid.io.xmap_readers(
143 | # aug_mapper, data_reader(0, 8), multiprocessing.cpu_count()/2, 16
144 | # )
145 | train_reader = data_reader(0, 8)
146 | train_loader.set_sample_generator(train_reader, batch_size=cfg.TRAIN.BATCH_SIZE, places=places)
147 | test_reader = paddle.batch(data_reader(8, 10), cfg.INFER.BATCH_SIZE)
148 | test_feeder = fluid.DataFeeder(place=places, feed_list=[image, label])
149 |
150 | writer = LogWriter(logdir="/home/aistudio/log/{}".format(datetime.now()))
151 |
152 | step = 0
153 | best_miou = 0
154 |
155 | for pass_id in range(cfg.TRAIN.EPOCHS):
156 | for train_data in train_loader():
157 | step += 1
158 | avg_loss_value, miou_value = exe.run(
159 | compiled_train_program, feed=train_data, fetch_list=[avg_loss, miou]
160 | )
161 | writer.add_scalar(tag="train_loss", step=step, value=avg_loss_value[0])
162 | writer.add_scalar(tag="train_miou", step=step, value=miou_value[0])
163 | if step % cfg.TRAIN.DISP_BATCH == 0:
164 | print(
165 | "\tTrain pass {}, Step {}, Cost {}, Miou {}".format(
166 | pass_id, step, avg_loss_value[0], miou_value[0]
167 | )
168 | )
169 |
170 | if math.isnan(float(avg_loss_value[0])):
171 | sys.exit("Got NaN loss, training failed.")
172 |
173 | if step % cfg.TRAIN.SNAPSHOT_BATCH == 0 and cfg.TRAIN.DO_EVAL:
174 | test_step = 0
175 | eval_miou = 0
176 | test_losses = []
177 | test_mious = []
178 | for test_data in test_reader():
179 | test_step += 1
180 | preds, test_loss, test_miou = exe_test.run(
181 | test_program,
182 | feed=test_feeder.feed(test_data),
183 | fetch_list=[prediction, avg_loss, miou],
184 | )
185 | test_losses.append(test_loss[0])
186 | test_mious.append(test_miou[0])
187 | if test_step % cfg.TRAIN.DISP_BATCH == 0:
188 | print("\t\tTest Loss: {} , Miou: {}".format(test_loss[0], test_miou[0]))
189 |
190 | eval_miou = np.average(np.array(test_mious))
191 | writer.add_scalar(
192 | tag="test_miou", step=step, value=eval_miou,
193 | )
194 | print("Test loss: {} ,miou: {}".format(np.average(np.array(test_losses)), eval_miou))
195 | ckpt_dir = os.path.join(cfg.TRAIN.CKPT_MODEL_PATH, str(step) + "_" + str(eval_miou))
196 | fluid.io.save_persistables(exe, ckpt_dir, train_program)
197 |
198 | print("此前最高的测试MIOU是: ", best_miou)
199 |
200 | if step % cfg.TRAIN.SNAPSHOT_BATCH == 0 and eval_miou > best_miou:
201 | best_miou = eval_miou
202 | print("正在保存第 {} step的权重".format(step))
203 | fluid.io.save_inference_model(
204 | cfg.TRAIN.INF_MODEL_PATH,
205 | feeded_var_names=["image"],
206 | target_vars=[prediction],
207 | executor=exe,
208 | main_program=train_program,
209 | )
210 |
211 |
212 | if __name__ == "__main__":
213 | args = parse_args()
214 | main()
215 |
--------------------------------------------------------------------------------
/tool/infer/2d_diameter.py:
--------------------------------------------------------------------------------
1 | # 用2d平面内的数据计算管径
2 | import argparse
3 | import os
4 | import sys
5 | import math
6 | from multiprocessing import Pool
7 | import time
8 | import random
9 |
10 | from tqdm import tqdm
11 | import numpy as np
12 | import cv2
13 | import nibabel as nib
14 | import scipy.ndimage
15 | from skimage import filters
16 | from skimage.segmentation import flood, flood_fill
17 | import matplotlib.pyplot as plt
18 |
19 | from util import blood_sort
20 |
21 | np.set_printoptions(threshold=sys.maxsize)
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("-i", "--seg_dir", type=str, required=True)
26 | parser.add_argument("-o", "--dia_dir", type=str, required=True)
27 | parser.add_argument("--filter_arch", type=bool, default=True)
28 | parser.add_argument("-d", "--demo", default=False, action="store_true")
29 | args = parser.parse_args()
30 |
31 |
32 | class Polygon:
33 | def __init__(self, points, height, pixdim, shape=[512, 512]):
34 | self.idx = 0
35 | self.shape = shape
36 | self.points = points
37 | random.shuffle(self.points)
38 | self.center = [0, 0]
39 | self.height = height
40 | self.diameters = []
41 | self.pixdim = pixdim
42 | # 计算边缘位置平均数做圆心
43 | for p in self.points:
44 | self.center[0] += p[0]
45 | self.center[1] += p[1]
46 | self.center[0] /= len(self.points)
47 | self.center[1] /= len(self.points)
48 | self.center[0] = int(self.center[0])
49 | self.center[1] = int(self.center[1])
50 | if not self.is_inside():
51 | self.points = []
52 |
53 | def is_inside(self):
54 | """去掉一些多边形
55 | 1. 点少于30个
56 | 2. 中心不在图形里
57 | 3. 漫水面积过大,不是封闭图形? # TODO: 这个有用吗
58 | Returns
59 | -------
60 | bool
61 | 是否保留这个多边形
62 |
63 | """
64 | if len(self.points) < 30:
65 | return False
66 | label = np.zeros(self.shape, dtype="uint8")
67 | for p in self.points:
68 | label[p[0]][p[1]] = 1
69 | flood_fill(
70 | label,
71 | tuple(self.center),
72 | 1,
73 | selem=[[0, 1, 0], [1, 1, 1], [0, 1, 0]],
74 | in_place=True,
75 | )
76 | c = self.center
77 | if label[c[0]][c[1]] != 1:
78 | print("!!!!!!!!!中心不再图形里")
79 | return False
80 | if label.sum() > label.size / 2:
81 | return False
82 | return True
83 |
84 | def ang_sort(self):
85 | pass
86 |
87 | def plot(self):
88 | img = np.zeros([512, 512])
89 | for ind in range(len(self.points)):
90 | img[self.points[ind][0]][self.points[ind][1]] = 1
91 | img[self.center[0]][self.center[1]] = 1
92 | plt.imshow(img)
93 | plt.show()
94 |
95 | def cal_diameters(self, ang_range=[0, np.pi], split=10, acc=0.1):
96 | """用类似二分的方法,平行线夹计算管径.
97 |
98 | y - y0 + d = k ( x - x0 ):d是这根线在y轴上移动的距离
99 | 在线上取x=0, x=1 的 a,b两点,通过判断a,b,points上的各个点p是不是都向同一个方向转,判断直线是不是已经移出了多边形
100 |
101 | Parameters
102 | ----------
103 | ang_range : list
104 | 直线和x轴角度的范围.
105 | split : int
106 | 在这个范围内,平均测量多少个方向.
107 | pixdim : float
108 | 片子的pixdim,从像素换算到实际的mm.
109 |
110 | Returns
111 | -------
112 | list
113 | ang_range 角度范围内,split 等分个方向上,平行线夹的管径是多少mm.
114 |
115 | """
116 | if len(self.points) == 0:
117 | self.diameters = 0
118 | print("[Error] Polygon at height {} contains no point".format(self.height))
119 | return
120 |
121 | def is_right(a, b, c):
122 | # 向量叉积
123 | x = np.array((b[0] - a[0], b[1] - a[1]))
124 | y = np.array((c[0] - b[0], c[1] - b[1]))
125 | res = np.cross(x, y)
126 | return res >= 0
127 |
128 | # print(self.height)
129 | self.diameters = []
130 | center = self.center
131 | # y - y0 + d = k ( x - x0 )
132 | for alpha in np.arange(
133 | ang_range[0], ang_range[1], (ang_range[1] - ang_range[0]) / split
134 | ):
135 | # TODO: 这个step是纵轴截距,结合k算成斜距
136 | if alpha == np.pi / 2:
137 | continue
138 | k = math.tan(alpha)
139 |
140 | def binary_search(step):
141 | d = 0
142 | prev_out = False
143 | while True:
144 | ya = center[1] - d + k * (0 - center[0]) # (0,ya)
145 | yb = center[1] - d + k * (1 - center[0]) # (1,yb)
146 | dir = is_right((0, ya), (1, yb), self.points[0])
147 | same_dir = True
148 | for p in self.points:
149 | if dir != is_right((0, ya), (1, yb), p): # 在两侧
150 | same_dir = False
151 | break
152 | # print(same_dir)
153 | d_old = d
154 | if same_dir:
155 | if not prev_out:
156 | step /= 2
157 | d -= step
158 | prev_out = True # 进一次可以退多步,连续退多步不降step
159 | else:
160 | d += step
161 | prev_out = False
162 | label = np.zeros(self.shape)
163 | for p in self.points:
164 | label[p[0], p[1]] = 1
165 |
166 | ##### 可视化 ####
167 | if args.demo:
168 | plt.title(
169 | f"Step:{step}, d:{d}, Going {'out' if d_old < d else 'in'}."
170 | )
171 | temp_label = label
172 | temp_label[self.center[0], self.center[1]] = 1
173 | temp_label = (1 - temp_label).reshape([512, 512, 1]) * 255
174 | plt.imshow(
175 | np.tile(
176 | temp_label,
177 | (1, 1, 3),
178 | )
179 | )
180 | x = np.linspace(0, 512, 512)
181 | y = [512 - center[1] - d + k * (t - center[0]) for t in x]
182 | plt.plot(x, y)
183 | plt.show()
184 | #### ####
185 |
186 | if abs(step) < acc / 2:
187 | break
188 | return d
189 |
190 | self.diameters.append(
191 | (binary_search(40) - binary_search(-40))
192 | * np.abs(np.cos(alpha))
193 | * self.pixdim
194 | )
195 | # print(self.diameters)
196 | return self.idx, self.height, self.diameters
197 |
198 |
199 | def dist(a, b):
200 | h = a.height - b.height
201 | ca = a.center
202 | cb = b.center
203 | return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2 + h ** 2) ** 0.5
204 |
205 |
206 | # print(dist(Polygon([[0, 0]], 0, [0, 0]), Polygon([[0, 0]], 1, [1, 2])))
207 |
208 |
209 | def cal(polygon):
210 | return polygon.cal_diameters()
211 |
212 |
213 | def cal_diameter(seg_path, filter_arch, dia_dir, thresh=0.9):
214 | """计算seg_path这个nii分割文件的所有管径,返回.
215 |
216 | 过程:
217 | 1. 按照血流反向,获取所有圆
218 | 1.1 按照高度分片层,层内找连通块,可能一块可能两块,计算层中心
219 | 1.2 从最下面的中心开始,找最近的没入序列的中心,对中心按照血流方向反向排序
220 |
221 | 2. 用平行线夹计算血管管径
222 |
223 |
224 | Parameters
225 | ----------
226 | seg_path : str
227 | 分割标签路径.
228 |
229 | Returns
230 | -------
231 | type
232 | Description of returned object.
233 |
234 | """
235 | start = int(time.time())
236 | segf = nib.load(seg_path)
237 | seg_data = segf.get_fdata()
238 | pixdim = segf.header["pixdim"][1]
239 | seg_data[seg_data > thresh] = 1
240 | seg_data = seg_data.astype("uint8")
241 | print(seg_data.shape)
242 | polygons = []
243 | for height in range(seg_data.shape[2]):
244 | label = seg_data[:, :, height] # 当前片
245 | label = filters.roberts(label) # 只保留边缘
246 | vol, num = scipy.ndimage.label(label, np.ones([3, 3])) # 联通块
247 | for label_idx in range(1, num + 1):
248 | xs, ys = np.where(vol == label_idx)
249 | points = []
250 | for x, y in zip(xs, ys):
251 | points.append([int(x), int(y)])
252 | polygons.append(Polygon(points, height, pixdim, label.shape))
253 | polygons = [p for p in polygons if len(p.points) != 0]
254 | polygons = blood_sort(polygons)
255 |
256 | # pool = Pool(8)
257 | # diameters = []
258 | # for res in tqdm(pool.imap_unordered(cal, polygons), total=len(polygons)):
259 | # diameters.append(res)
260 | # pool.close()
261 | # pool.join()
262 |
263 | # 顺序进行
264 | for p in tqdm(polygons):
265 | p.cal_diameters()
266 |
267 | diameters = sorted(diameters, key=lambda x: x[0])
268 | for d, p in zip(diameters, polygons):
269 | p.diameters = d[2]
270 |
271 | print(os.path.join(dia_dir, seg_path.split("/")[-1]))
272 | f = open(
273 | os.path.join(dia_dir, seg_path.split("/")[-1].rstrip(".gz").rstrip(".nii"))
274 | + ".csv",
275 | "w",
276 | )
277 | print((int(time.time()) - start) / 60)
278 | print((int(time.time()) - start) / 60, end="\n", file=f)
279 | # for d in diameters:
280 | # print(d[1], end=",", file=f)
281 | # for data in d[2]:
282 | # print(data, end=",", file=f)
283 | # print(file=f)
284 | #
285 | for p in polygons:
286 | print(p.height, end=",", file=f)
287 | # print(p.diameters)
288 | if filter_arch:
289 | if np.max(p.diameters) > np.min(p.diameters) * 2:
290 | continue
291 | for d in p.diameters:
292 | print(d, end=",", file=f)
293 | print(end="\n", file=f)
294 | f.close()
295 |
296 |
297 | if __name__ == "__main__":
298 | names = os.listdir(args.seg_dir)
299 | for name in names:
300 | if os.path.exists(os.path.join(args.dia_dir, name.replace(".nii.gz", ".csv"))):
301 | names.remove(name)
302 | print(names)
303 | print("{} patients to measure in total.".format(len(names)))
304 |
305 | start = int(time.time())
306 | tot = len(names)
307 |
308 | for count, name in enumerate(names):
309 | cal_diameter(os.path.join(args.seg_dir, name), args.filter_arch, args.dia_dir)
310 | print(
311 | "\t\tFinished {}/{}, expected to finish in {} minutes".format(
312 | count, tot, int(time.time() - start) / count * (tot - count) / 60
313 | )
314 | )
315 |
--------------------------------------------------------------------------------
/medseg/models/hrnet.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 | import sys
20 |
21 | import paddle
22 | import paddle.fluid as fluid
23 | from paddle.fluid.initializer import MSRA
24 | from paddle.fluid.param_attr import ParamAttr
25 |
26 | from utils.config import cfg
27 |
28 |
29 | def conv_bn_layer(
30 | input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None
31 | ):
32 | conv = fluid.layers.conv2d(
33 | input=input,
34 | num_filters=num_filters,
35 | filter_size=filter_size,
36 | stride=stride,
37 | padding=(filter_size - 1) // 2,
38 | groups=num_groups,
39 | act=None,
40 | param_attr=ParamAttr(initializer=MSRA(), name=name + "_weights"),
41 | bias_attr=False,
42 | )
43 | bn_name = name + "_bn"
44 | bn = fluid.layers.batch_norm(
45 | input=conv,
46 | param_attr=ParamAttr(name=bn_name + "_scale", initializer=fluid.initializer.Constant(1.0)),
47 | bias_attr=ParamAttr(name=bn_name + "_offset", initializer=fluid.initializer.Constant(0.0)),
48 | moving_mean_name=bn_name + "_mean",
49 | moving_variance_name=bn_name + "_variance",
50 | )
51 | if if_act:
52 | bn = fluid.layers.relu(bn)
53 | return bn
54 |
55 |
56 | def basic_block(input, num_filters, stride=1, downsample=False, name=None):
57 | residual = input
58 | conv = conv_bn_layer(
59 | input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + "_conv1"
60 | )
61 | conv = conv_bn_layer(
62 | input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + "_conv2"
63 | )
64 | if downsample:
65 | residual = conv_bn_layer(
66 | input=input,
67 | filter_size=1,
68 | num_filters=num_filters,
69 | if_act=False,
70 | name=name + "_downsample",
71 | )
72 | return fluid.layers.elementwise_add(x=residual, y=conv, act="relu")
73 |
74 |
75 | def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None):
76 | residual = input
77 | conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + "_conv1")
78 | conv = conv_bn_layer(
79 | input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + "_conv2"
80 | )
81 | conv = conv_bn_layer(
82 | input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False, name=name + "_conv3"
83 | )
84 | if downsample:
85 | residual = conv_bn_layer(
86 | input=input,
87 | filter_size=1,
88 | num_filters=num_filters * 4,
89 | if_act=False,
90 | name=name + "_downsample",
91 | )
92 | return fluid.layers.elementwise_add(x=residual, y=conv, act="relu")
93 |
94 |
95 | def fuse_layers(x, channels, multi_scale_output=True, name=None):
96 | out = []
97 | for i in range(len(channels) if multi_scale_output else 1):
98 | residual = x[i]
99 | shape = residual.shape
100 | width = shape[-1]
101 | height = shape[-2]
102 | for j in range(len(channels)):
103 | if j > i:
104 | y = conv_bn_layer(
105 | x[j],
106 | filter_size=1,
107 | num_filters=channels[i],
108 | if_act=False,
109 | name=name + "_layer_" + str(i + 1) + "_" + str(j + 1),
110 | )
111 | y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width])
112 | residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
113 | elif j < i:
114 | y = x[j]
115 | for k in range(i - j):
116 | if k == i - j - 1:
117 | y = conv_bn_layer(
118 | y,
119 | filter_size=3,
120 | num_filters=channels[i],
121 | stride=2,
122 | if_act=False,
123 | name=name
124 | + "_layer_"
125 | + str(i + 1)
126 | + "_"
127 | + str(j + 1)
128 | + "_"
129 | + str(k + 1),
130 | )
131 | else:
132 | y = conv_bn_layer(
133 | y,
134 | filter_size=3,
135 | num_filters=channels[j],
136 | stride=2,
137 | name=name
138 | + "_layer_"
139 | + str(i + 1)
140 | + "_"
141 | + str(j + 1)
142 | + "_"
143 | + str(k + 1),
144 | )
145 | residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
146 |
147 | residual = fluid.layers.relu(residual)
148 | out.append(residual)
149 | return out
150 |
151 |
152 | def branches(x, block_num, channels, name=None):
153 | out = []
154 | for i in range(len(channels)):
155 | residual = x[i]
156 | for j in range(block_num):
157 | residual = basic_block(
158 | residual, channels[i], name=name + "_branch_layer_" + str(i + 1) + "_" + str(j + 1)
159 | )
160 | out.append(residual)
161 | return out
162 |
163 |
164 | def high_resolution_module(x, channels, multi_scale_output=True, name=None):
165 | residual = branches(x, 4, channels, name=name)
166 | out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name)
167 | return out
168 |
169 |
170 | def transition_layer(x, in_channels, out_channels, name=None):
171 | num_in = len(in_channels)
172 | num_out = len(out_channels)
173 | out = []
174 | for i in range(num_out):
175 | if i < num_in:
176 | if in_channels[i] != out_channels[i]:
177 | residual = conv_bn_layer(
178 | x[i],
179 | filter_size=3,
180 | num_filters=out_channels[i],
181 | name=name + "_layer_" + str(i + 1),
182 | )
183 | out.append(residual)
184 | else:
185 | out.append(x[i])
186 | else:
187 | residual = conv_bn_layer(
188 | x[-1],
189 | filter_size=3,
190 | num_filters=out_channels[i],
191 | stride=2,
192 | name=name + "_layer_" + str(i + 1),
193 | )
194 | out.append(residual)
195 | return out
196 |
197 |
198 | def stage(x, num_modules, channels, multi_scale_output=True, name=None):
199 | out = x
200 | for i in range(num_modules):
201 | if i == num_modules - 1 and multi_scale_output == False:
202 | out = high_resolution_module(
203 | out, channels, multi_scale_output=False, name=name + "_" + str(i + 1)
204 | )
205 | else:
206 | out = high_resolution_module(out, channels, name=name + "_" + str(i + 1))
207 |
208 | return out
209 |
210 |
211 | def layer1(input, name=None):
212 | conv = input
213 | for i in range(4):
214 | conv = bottleneck_block(
215 | conv, num_filters=64, downsample=True if i == 0 else False, name=name + "_" + str(i + 1)
216 | )
217 | return conv
218 |
219 |
220 | def high_resolution_net(input, num_classes):
221 |
222 | channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
223 | channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
224 | channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
225 |
226 | num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES
227 | num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES
228 | num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES
229 |
230 | x = conv_bn_layer(
231 | input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name="layer1_1"
232 | )
233 | x = conv_bn_layer(
234 | input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name="layer1_2"
235 | )
236 |
237 | la1 = layer1(x, name="layer2")
238 | tr1 = transition_layer([la1], [256], channels_2, name="tr1")
239 | st2 = stage(tr1, num_modules_2, channels_2, name="st2")
240 | tr2 = transition_layer(st2, channels_2, channels_3, name="tr2")
241 | st3 = stage(tr2, num_modules_3, channels_3, name="st3")
242 | tr3 = transition_layer(st3, channels_3, channels_4, name="tr3")
243 | st4 = stage(tr3, num_modules_4, channels_4, name="st4")
244 |
245 | # upsample
246 | shape = st4[0].shape
247 | height, width = shape[-2], shape[-1]
248 | st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width])
249 | st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width])
250 | st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width])
251 |
252 | out = fluid.layers.concat(st4, axis=1)
253 | last_channels = sum(channels_4)
254 |
255 | out = conv_bn_layer(
256 | input=out, filter_size=1, num_filters=last_channels, stride=1, if_act=True, name="conv-2"
257 | )
258 | out = fluid.layers.conv2d(
259 | input=out,
260 | num_filters=num_classes,
261 | filter_size=1,
262 | stride=1,
263 | padding=0,
264 | act=None,
265 | param_attr=ParamAttr(initializer=MSRA(), name="conv-1_weights"),
266 | bias_attr=False,
267 | )
268 |
269 | out = fluid.layers.resize_bilinear(out, input.shape[2:])
270 |
271 | return out
272 |
273 |
274 | def hrnet(input, num_classes):
275 | logit = high_resolution_net(input, num_classes)
276 | return logit
277 |
278 |
279 | if __name__ == "__main__":
280 | image_shape = [-1, 3, 769, 769]
281 | image = fluid.data(name="image", shape=image_shape, dtype="float32")
282 | logit = hrnet(image, 4)
283 | print("logit:", logit.shape)
284 |
--------------------------------------------------------------------------------
/medseg/models/deeplabv3p.py:
--------------------------------------------------------------------------------
1 | # Deeplabv3p Network is modified from the following link
2 | # https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/deeplabv3%2B/models.py
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 | import paddle.fluid as fluid
8 |
9 | import contextlib
10 | name_scope = ""
11 |
12 | decode_channel = 32
13 | encode_channel = 160
14 |
15 | bn_momentum = 0.99
16 |
17 | op_results = {}
18 |
19 | default_epsilon = 1e-3
20 | default_norm_type = 'bn'
21 | default_group_number = 32
22 | depthwise_use_cudnn = True
23 |
24 | bn_regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)
25 | depthwise_regularizer = fluid.regularizer.L2DecayRegularizer(
26 | regularization_coeff=0.0)
27 |
28 |
29 | @contextlib.contextmanager
30 | def scope(name):
31 | global name_scope
32 | bk = name_scope
33 | name_scope = name_scope + name + '/'
34 | yield
35 | name_scope = bk
36 |
37 |
38 | def check(data, number):
39 | if type(data) == int:
40 | return [data] * number
41 | assert len(data) == number
42 | return data
43 |
44 |
45 | def clean():
46 | global op_results
47 | op_results = {}
48 |
49 |
50 | def append_op_result(result, name):
51 | global op_results
52 | op_index = len(op_results)
53 | name = name_scope + name + str(op_index)
54 | op_results[name] = result
55 | return result
56 |
57 |
58 | def conv(*args, **kargs):
59 | if "xception" in name_scope:
60 | init_std = 0.09
61 | elif "logit" in name_scope:
62 | init_std = 0.01
63 | elif name_scope.endswith('depthwise/'):
64 | init_std = 0.33
65 | else:
66 | init_std = 0.06
67 | if name_scope.endswith('depthwise/'):
68 | regularizer = depthwise_regularizer
69 | else:
70 | regularizer = None
71 |
72 | kargs['param_attr'] = fluid.ParamAttr(
73 | name=name_scope + 'weights',
74 | regularizer=regularizer,
75 | initializer=fluid.initializer.TruncatedNormal(
76 | loc=0.0, scale=init_std))
77 | if 'bias_attr' in kargs and kargs['bias_attr']:
78 | kargs['bias_attr'] = fluid.ParamAttr(
79 | name=name_scope + 'biases',
80 | regularizer=regularizer,
81 | initializer=fluid.initializer.ConstantInitializer(value=0.0))
82 | else:
83 | kargs['bias_attr'] = False
84 | kargs['name'] = name_scope + 'conv'
85 | return append_op_result(fluid.layers.conv2d(*args, **kargs), 'conv')
86 |
87 |
88 | def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None):
89 | N, C, H, W = input.shape
90 | if C % G != 0:
91 | # print "group can not divide channle:", C, G
92 | for d in range(10):
93 | for t in [d, -d]:
94 | if G + t <= 0: continue
95 | if C % (G + t) == 0:
96 | G = G + t
97 | break
98 | if C % G == 0:
99 | # print "use group size:", G
100 | break
101 | assert C % G == 0
102 | x = fluid.layers.group_norm(
103 | input,
104 | groups=G,
105 | param_attr=param_attr,
106 | bias_attr=bias_attr,
107 | name=name_scope + 'group_norm')
108 | return x
109 |
110 |
111 | def bn(*args, **kargs):
112 | if default_norm_type == 'bn':
113 | with scope('BatchNorm'):
114 | return append_op_result(
115 | fluid.layers.batch_norm(
116 | *args,
117 | epsilon=default_epsilon,
118 | momentum=bn_momentum,
119 | param_attr=fluid.ParamAttr(
120 | name=name_scope + 'gamma', regularizer=bn_regularizer),
121 | bias_attr=fluid.ParamAttr(
122 | name=name_scope + 'beta', regularizer=bn_regularizer),
123 | moving_mean_name=name_scope + 'moving_mean',
124 | moving_variance_name=name_scope + 'moving_variance',
125 | **kargs),
126 | 'bn')
127 | elif default_norm_type == 'gn':
128 | with scope('GroupNorm'):
129 | return append_op_result(
130 | group_norm(
131 | args[0],
132 | default_group_number,
133 | eps=default_epsilon,
134 | param_attr=fluid.ParamAttr(
135 | name=name_scope + 'gamma', regularizer=bn_regularizer),
136 | bias_attr=fluid.ParamAttr(
137 | name=name_scope + 'beta', regularizer=bn_regularizer)),
138 | 'gn')
139 | else:
140 | raise "Unsupport norm type:" + default_norm_type
141 |
142 |
143 | def bn_relu(data):
144 | return append_op_result(fluid.layers.relu(bn(data)), 'relu')
145 |
146 |
147 | def relu(data):
148 | return append_op_result(
149 | fluid.layers.relu(
150 | data, name=name_scope + 'relu'), 'relu')
151 |
152 |
153 | def seperate_conv(input, channel, stride, filter, dilation=1, act=None):
154 | with scope('depthwise'):
155 | input = conv(
156 | input,
157 | input.shape[1],
158 | filter,
159 | stride,
160 | groups=input.shape[1],
161 | padding=(filter // 2) * dilation,
162 | dilation=dilation,
163 | use_cudnn=depthwise_use_cudnn)
164 | input = bn(input)
165 | if act: input = act(input)
166 | with scope('pointwise'):
167 | input = conv(input, channel, 1, 1, groups=1, padding=0)
168 | input = bn(input)
169 | if act: input = act(input)
170 | return input
171 |
172 |
173 | def xception_block(input,
174 | channels,
175 | strides=1,
176 | filters=3,
177 | dilation=1,
178 | skip_conv=True,
179 | has_skip=True,
180 | activation_fn_in_separable_conv=False):
181 | repeat_number = 3
182 | channels = check(channels, repeat_number)
183 | filters = check(filters, repeat_number)
184 | strides = check(strides, repeat_number)
185 | data = input
186 | results = []
187 | for i in range(repeat_number):
188 | with scope('separable_conv' + str(i + 1)):
189 | if not activation_fn_in_separable_conv:
190 | data = relu(data)
191 | data = seperate_conv(
192 | data,
193 | channels[i],
194 | strides[i],
195 | filters[i],
196 | dilation=dilation)
197 | else:
198 | data = seperate_conv(
199 | data,
200 | channels[i],
201 | strides[i],
202 | filters[i],
203 | dilation=dilation,
204 | act=relu)
205 | results.append(data)
206 | if not has_skip:
207 | return append_op_result(data, 'xception_block'), results
208 | if skip_conv:
209 | with scope('shortcut'):
210 | skip = bn(
211 | conv(
212 | input, channels[-1], 1, strides[-1], groups=1, padding=0))
213 | else:
214 | skip = input
215 | return append_op_result(data + skip, 'xception_block'), results
216 |
217 |
218 | def entry_flow(data):
219 | with scope("entry_flow"):
220 | with scope("conv1"):
221 | data = conv(data, 32, 3, stride=2, padding=1)
222 | data = bn_relu(data)
223 | with scope("conv2"):
224 | data = conv(data, 32, 3, stride=1, padding=1)
225 | data = bn_relu(data)
226 | with scope("block1"):
227 | data, results1 = xception_block(data, 64, [1, 1, 2])
228 | with scope("block2"):
229 | data, results2 = xception_block(data, 128, [1, 1, 2])
230 | with scope("block3"):
231 | data, _ = xception_block(data, 256, [1, 1, 2])
232 | return data, results1[1], results2[2]
233 |
234 |
235 | def middle_flow(data):
236 | with scope("middle_flow"):
237 | for i in range(8):
238 | with scope("block" + str(i + 1)):
239 | data, _ = xception_block(data, 256, [1, 1, 1], skip_conv=False)
240 | return data
241 |
242 |
243 | def exit_flow(data):
244 | with scope("exit_flow"):
245 | with scope('block1'):
246 | data, _ = xception_block(data, [256, 512, 512], [1, 1, 1])
247 | with scope('block2'):
248 | data, _ = xception_block(
249 | data, [512, 512, 768], [1, 1, 1],
250 | dilation=2,
251 | has_skip=False,
252 | activation_fn_in_separable_conv=True)
253 | return data
254 |
255 |
256 | def encoder(input):
257 | with scope('encoder'):
258 | channel = 192
259 |
260 | with scope("aspp0"):
261 | aspp0 = bn_relu(conv(input, channel, 1, 1, groups=1, padding=0))
262 | with scope("aspp1"):
263 | aspp1 = seperate_conv(input, channel, 1, 3, dilation=3, act=relu)
264 | with scope("aspp2"):
265 | aspp2 = seperate_conv(input, channel, 1, 3, dilation=6, act=relu)
266 | with scope("aspp3"):
267 | aspp3 = seperate_conv(input, channel, 1, 3, dilation=12, act=relu)
268 | with scope("concat"):
269 | data = append_op_result(
270 | fluid.layers.concat(
271 | [aspp0, aspp1, aspp2, aspp3], axis=1),
272 | 'concat')
273 | data = bn_relu(conv(data, channel, 1, 1, groups=1, padding=0))
274 | return data
275 |
276 |
277 | def decoder(encode_data, decode_shortcut):
278 | with scope('decoder'):
279 | with scope('concat'):
280 | decode_shortcut = bn_relu(
281 | conv(
282 | decode_shortcut, decode_channel, 1, 1, groups=1, padding=0))
283 | encode_data = fluid.layers.resize_bilinear(
284 | encode_data, decode_shortcut.shape[2:])
285 | encode_data = fluid.layers.concat(
286 | [encode_data, decode_shortcut], axis=1)
287 | append_op_result(encode_data, 'concat')
288 | with scope("separable_conv1"):
289 | encode_data = seperate_conv(
290 | encode_data, encode_channel, 1, 3, dilation=1, act=relu)
291 | with scope("separable_conv2"):
292 | encode_data = seperate_conv(
293 | encode_data, encode_channel, 1, 3, dilation=1, act=relu)
294 | return encode_data
295 |
296 | def decoder2(encode_data, decode_shortcut):
297 | with scope('decoder2'):
298 | with scope('concat2'):
299 | decode_shortcut = bn_relu(
300 | conv(
301 | decode_shortcut, decode_channel // 2, 1, 1, groups=1, padding=0))
302 | encode_data = fluid.layers.resize_bilinear(
303 | encode_data, decode_shortcut.shape[2:])
304 | encode_data = fluid.layers.concat(
305 | [encode_data, decode_shortcut], axis=1)
306 | append_op_result(encode_data, 'concat2')
307 | with scope("separable_conv12"):
308 | encode_data = seperate_conv(
309 | encode_data, encode_channel // 2, 1, 3, dilation=1, act=relu)
310 | with scope("separable_conv22"):
311 | encode_data = seperate_conv(
312 | encode_data, encode_channel // 2, 1, 3, dilation=1, act=relu)
313 | return encode_data
314 |
315 |
316 | def deeplabv3p(img, label_number):
317 | global default_epsilon
318 | append_op_result(img, 'img')
319 | with scope('xception_65'):
320 | default_epsilon = 1e-3
321 | # Entry flow
322 | data, decode_shortcut1, decode_shortcut2 = entry_flow(img)
323 | print(data.shape)
324 | # Middle flow
325 | data = middle_flow(data)
326 | print(data.shape)
327 | # Exit flow
328 | data = exit_flow(data)
329 | print(data.shape)
330 | default_epsilon = 1e-5
331 | encode_data = encoder(data)
332 | print(encode_data.shape)
333 | encode_data = decoder(encode_data, decode_shortcut2)
334 | print(encode_data.shape)
335 | encode_data = fluid.layers.resize_bilinear(encode_data, (encode_data.shape[2] * 2, encode_data.shape[3] * 2))
336 | encode_data = decoder2(encode_data, decode_shortcut1)
337 | print(encode_data.shape)
338 | with scope('logit'):
339 | logit = conv(
340 | encode_data, label_number, 1, stride=1, padding=0, bias_attr=True)
341 | logit = fluid.layers.resize_bilinear(logit, img.shape[2:])
342 | # logit = fluid.layers.resize_bilinear(logit, (3384, 1020))
343 | print(logit.shape)
344 | return logit
--------------------------------------------------------------------------------
/medseg/utils/util.py:
--------------------------------------------------------------------------------
1 | """
2 | 包含环境变量和常用函数
3 | """
4 | import os
5 | import sys
6 | import numpy as np
7 | import math
8 | from scipy import ndimage
9 |
10 | # from utils.config import cfg
11 |
12 | # TODO: 清理函数,去除不需要的,对需要的加上清晰注释
13 |
14 |
15 | def listdir(path):
16 | """展示一个路径下所有文件,排序并去除常见辅助文件.
17 |
18 | Parameters
19 | ----------
20 | path : type
21 | Description of parameter `path`.
22 |
23 | Returns
24 | -------
25 | type
26 | Description of returned object.
27 |
28 | """
29 | dirs = os.listdir(path)
30 | if ".DS_Store" in dirs:
31 | dirs.remove(".DS_Store")
32 | dirs.sort() # 通过一样的sort保持vol和seg的对应
33 | return dirs
34 |
35 |
36 | def save_info(name, header, file_name):
37 | """将扫描header中的一些信息保存入csv.
38 |
39 | Parameters
40 | ----------
41 | name : str
42 | 扫描文件名.
43 | header : dict
44 | 扫描文件文件头.
45 | file_name : str
46 | csv文件的文件名.
47 |
48 | Returns
49 | -------
50 | type
51 | Description of returned object.
52 |
53 | """
54 |
55 | """
56 | sizeof_hdr : 348
57 | data_type : b''
58 | db_name : b''
59 | extents : 0
60 | session_error : 0
61 | regular : b'r'
62 | dim_info : 0
63 | dim : [ 3 512 512 75 1 1 1 1]
64 | intent_p1 : 0.0
65 | intent_p2 : 0.0
66 | intent_p3 : 0.0
67 | intent_code : none
68 | datatype : int16
69 | bitpix : 16
70 | slice_start : 0
71 | pixdim : [-1.00000e+00 7.03125e-01 7.03125e-01 5.00000e+00 0.00000e+00
72 | 1.00000e+00 1.00000e+00 5.22410e+04]
73 | vox_offset : 0.0
74 | scl_slope : nan
75 | scl_inter : nan
76 | slice_end : 0
77 | slice_code : unknown
78 | xyzt_units : 10
79 | cal_max : 0.0
80 | cal_min : 0.0
81 | slice_duration : 0.0
82 | toffset : 0.0
83 | glmax : 255
84 | glmin : 0
85 | descrip : b'TE=0;sec=52241.0000;name='
86 | aux_file : b'!62ABDOMENNATIVUNDVENS'
87 | qform_code : scanner
88 | sform_code : scanner
89 | quatern_b : 0.0
90 | quatern_c : 1.0
91 | quatern_d : 0.0
92 | qoffset_x : 172.9
93 | qoffset_y : -179.29688
94 | qoffset_z : -368.0
95 | srow_x : [ -0.703125 0. 0. 172.9 ]
96 | srow_y : [ 0. 0.703125 0. -179.29688 ]
97 | srow_z : [ 0. 0. 5. -368.]
98 | intent_name : b''
99 | magic : b'n+1'``
100 | """
101 | file = open(file_name, "a+")
102 | print(name, end=", ", file=file)
103 | print(
104 | header["dim"][1], ",", header["dim"][2], ",", header["dim"][3], end=", ", file=file,
105 | )
106 | print(
107 | header["pixdim"][1], ",", header["pixdim"][2], ",", header["pixdim"][3], end=", ", file=file,
108 | )
109 | print(header["bitpix"], " , ", header["datatype"], file=file)
110 | file.close()
111 |
112 |
113 | """
114 | import nibabel as nib
115 | volf = nib.load('/home/aistudio/data/volume/volume-0.nii')
116 | save_info('v1', volf.header.structarr, 'vol_info.csv')
117 | """
118 |
119 |
120 | """ 体数据处理 """
121 |
122 |
123 | def windowlize_image(vol, wwwc):
124 | """对扫描按照wwwc进行硬crop.
125 |
126 | Parameters
127 | ----------
128 | vol : ndarray
129 | 需要进行窗口化的扫描
130 | ww : int
131 | 窗宽
132 | wc : int
133 | 窗位
134 |
135 | Returns
136 | -------
137 | ndarray
138 | 经过窗口化的扫描
139 | """
140 | ww = wwwc[0]
141 | wc = wwwc[1]
142 | wl = wc - ww / 2
143 | wh = wc + ww / 2
144 | vol = vol.clip(wl, wh)
145 | return vol
146 |
147 |
148 | def clip_label(label, category):
149 | # 有时候标签会包含多种标注,一般是0背景,从1开始随着数变大标记的东西变小
150 | # label是ndarray,category是最后成为1的类别号,max是最大的类别号
151 | label[label < category] = 0
152 | label[label >= category] = 1
153 | return label
154 |
155 |
156 | def get_bbs(label):
157 | """求一个标签中所有前景的bb.
158 |
159 | Parameters
160 | ----------
161 | label : ndarray
162 | 标签.
163 |
164 | Returns
165 | -------
166 | list
167 | list中每一个前景区域一个[bb_min, bb_max],分别是这个前景块bb低和高两个角的坐标.
168 |
169 | """
170 | # TODO: 目前实现了一个病灶,需要实现多个
171 | one_indexes = np.array(np.where(label == 1))
172 | if one_indexes.ndim == 0:
173 | raise Exception("label中没有任何前景")
174 |
175 | bb_min = one_indexes.min(axis=1)
176 | bb_max = one_indexes.max(axis=1)
177 | bb_max = bb_max + 1
178 | return bb_min.reshape(-1, 3), bb_max.reshape(-1, 3)
179 |
180 |
181 | def crop_to_bbs(volume, bbs, padding=0.3):
182 | """将一个扫描的背景mute掉,只留下前景及其周围的区域,支持多个前景块.
183 | 具体做法是创建一个mask,对于bbs中的每个前景块,计算中心位置,按照padding计算保留的块范围(不会超出volume),在mask中设成1。所有块都计算完之后mute掉mask中还是0的所有位置
184 | Parameters
185 | ----------
186 | volume : ndarray
187 | 扫描.
188 | bbs : list
189 | [[bb_min, bb_max], [bb_min, bb_max], ...].
190 | padding :
191 | Description of parameter `padding`.
192 |
193 | Returns
194 | -------
195 | type
196 | Description of returned object.
197 |
198 | """
199 |
200 | # 将一个体切成一个或者多个包含1的区域的bb
201 | # padding 值是在各个维度上向大和小分别拓展多大的视野,一个数就是都一样,列表可以让不同维度不一样
202 | pd = padding
203 | if isinstance(padding, float):
204 | padding = []
205 | for i in range(volume.ndim):
206 | padding.append(pd)
207 |
208 | volumes = []
209 | bb_size = bb_max - bb_min
210 | bb_min = np.maximum(np.floor(bb_min - bb_size * padding), 0).astype("int32")
211 | bb_max = np.minimum(np.ceil(bb_max + bb_size * padding), volume.shape).astype("int32")
212 |
213 | for i in range(bb_min.shape[0]):
214 | volumes.append(
215 | volume[bb_min[i][0] : bb_max[i][0], bb_min[i][1] : bb_max[i][1], bb_min[i][2] : bb_max[i][2],]
216 | )
217 | return volumes
218 |
219 |
220 | def get_pad_len(volume_shape, pad_size, strict=True):
221 | # 1. 计算每个维度当前长度和目标差多少
222 | margin = []
223 | for x, y in zip(volume_shape, pad_size):
224 | # 1.1 如果目标 -1 ,那这个维度过
225 | if y == -1:
226 | margin.append(0)
227 | continue
228 | # 1.2 如果当前长度比目标长度还大,报错或者过
229 | if x > y:
230 | if strict:
231 | raise Exception(
232 | "Invalid Crop Size", "数据的大小 {} 应小于 pad_size {}".format(volume_shape, pad_size),
233 | )
234 | else:
235 | margin.append(0)
236 | continue
237 | # 1.3 如果正常,目标大于当前维度,做差
238 | margin.append(y - x)
239 | # 2. 计算每个维度应该补多少
240 | res = []
241 | for m, p, v in zip(margin, pad_size, volume_shape):
242 | if m == 0:
243 | # 2.1 margin = 0的略过
244 | res.append([0, 0])
245 | else:
246 | # 2.2 margin分成两份
247 | half = math.floor(m / 2)
248 | res.append([half, p - v - half])
249 |
250 | return res
251 |
252 |
253 | # print(get_pad_len([3, 512, 300], [3, -1, 512]))
254 | # print(get_pad_len([512, 512, 3], [512, 512, 3]))
255 | # print(get_pad_len([8, 512, 300], [4, -1, 512]))
256 | # print(get_pad_len([8, 512, 300], [4, -1, 512], False))
257 |
258 |
259 | def pad_volume(volume, pad_size, pad_value=0, strice=True):
260 | """将volume放在中间,用 pad_value 填充到 pad_size 大小
261 | 每个维度一共包含三种情况:
262 | 1. 正常: pad_size大于实际大小,那就计算差多少,在这个维度的两侧均匀的补上
263 | 2. 忽略: 不希望改变这个维度的大小,pad_size这个维度填 -1
264 | 3. 错误: volume的大小比 pad_size 还大,在 strice=true 模式下这个报错,终止执行;strice=false模式下这个维度忽略
265 |
266 | Parameters
267 | ----------
268 | volume : type
269 | Description of parameter `volume`.
270 | pad_size : int/list/tuple
271 | 如果是一个int,那么做成一个和volume.ndim维的list,三个维度的大小一样,按照这个pad;如果是list,tuple直接按照这个pad
272 | pad_value : type
273 | Description of parameter `pad_value`.
274 | strice : type
275 | Description of parameter `strice`.
276 |
277 | Returns
278 | -------
279 | type
280 | 经过pad的数据
281 |
282 | """
283 | if isinstance(pad_size, int):
284 | pad_size = [pad_size for i in range(volume.ndim)]
285 | margin = get_pad_len(volume.shape, pad_size, strice)
286 | # print(margin)
287 | volume = np.pad(volume, margin, "constant", constant_values=(pad_value))
288 | # print(volume.shape)
289 | return volume
290 |
291 |
292 | def filter_largest_bb(label, ratio=1.2):
293 | """求最大的连通块bb范围,去掉范围外的所有fp。比只保留最大的连通块更保守.
294 |
295 | Parameters
296 | ----------
297 | label : ndarray
298 | 分割标签,前景为1.
299 | ratio: float
300 | 最终bb范围是最大联通块bb范围的多少倍,比如1.2相当于周围有0.1的拓展
301 |
302 | Returns
303 | -------
304 | type
305 | 经过处理的标签.
306 |
307 | """
308 | # 求最大连通块
309 | vol, num = ndimage.label(label, np.ones([3 for _ in range(label.ndim)]))
310 | maxi = 0
311 | maxnum = 0
312 | for i in range(1, num + 1):
313 | count = vol[vol == i].size
314 | if count > maxnum:
315 | maxi = i
316 | maxnum = count
317 | maxind = np.where(vol == maxi)
318 | # 求最大连通块的bb范围
319 | ind_range = [[np.min(maxind[axis]), np.max(maxind[axis]) + 1] for axis in range(label.ndim)]
320 | ind_len = [r[1] - r[0] for r in ind_range]
321 | ext_ratio = (ratio - 1) / 2
322 | # 求加上拓展的边缘
323 | clip_range = [[r[0] - int(l * ext_ratio), r[1] + int(l * ext_ratio)] for r, l in zip(ind_range, ind_len)]
324 | for ind in range(len(clip_range)):
325 | if clip_range[ind][0] < 0:
326 | clip_range[ind][0] = 0
327 | if clip_range[ind][1] > label.shape[ind]:
328 | clip_range[ind][1] = label.shape[ind]
329 | r = clip_range
330 | # print(r)
331 | # 去掉拓展外的fp
332 | new_lab = np.zeros(label.shape)
333 | # 内部所有前景都保留
334 | new_lab[r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]] = label[
335 | r[0][0] : r[0][1], r[1][0] : r[1][1], r[2][0] : r[2][1]
336 | ]
337 | return new_lab
338 |
339 |
340 | # filter_largest_bb(
341 | # np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]), 3
342 | # )
343 |
344 |
345 | def filter_largest_volume(label, ratio=1.2, mode="soft"):
346 | """对输入的一个3D标签进行处理,只保留其中最大的连通块
347 |
348 | Parameters
349 | ----------
350 | label : ndarray
351 | 3D array:一个分割标签
352 | ratio : float
353 | 分割保留的范围
354 | mode : str
355 | "soft" / "hard"
356 | hard是只保留最大的联通块,soft是保留最大连通块bb内的
357 |
358 | Returns
359 | -------
360 | type
361 | 只保留最大连通块的标签.
362 |
363 | """
364 | if mode == "soft":
365 | return filter_largest_bb(label, ratio)
366 | vol, num = ndimage.label(label, np.ones([3, 3, 3]))
367 | maxi = 0
368 | maxnum = 0
369 | for i in range(1, num + 1):
370 | count = vol[vol == i].size
371 | if count > maxnum:
372 | maxi = i
373 | maxnum = count
374 |
375 | vol[vol != maxi] = 0
376 | vol[vol == maxi] = 1
377 | label = vol
378 | return label
379 |
380 |
381 | # vol = np.array([[[0, 0, 1, 0], [0, 0, 0, 0], [0, 1, 1, 0]]])
382 | # vol = np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]])
383 | # print(filter_largest_volume(vol, 3, mode="soft"))
384 |
385 |
386 | def save_nii(vol, lab, name="test"):
387 | import nibabel as nib
388 |
389 | vol = vol.astype("int16")
390 | volf = nib.Nifti1Image(vol, np.eye(4))
391 | labf = nib.Nifti1Image(lab, np.eye(4))
392 | nib.save(volf, "/home/aistudio/data/temp/{}-vol.nii".format(name))
393 | nib.save(labf, "/home/aistudio/data/temp/{}-lab.nii".format(name))
394 |
395 |
396 | def slice_count():
397 | """数所有的npz总共包含多少slice.
398 |
399 | Returns
400 | -------
401 | int
402 | 所有npz中slice总数.
403 |
404 | """
405 | tot = 0
406 | npz_names = listdir(cfg.TRAIN.DATA_PATH)
407 | for npz_name in npz_names:
408 | data = np.load(os.path.join(cfg.TRAIN.DATA_PATH, npz_name))
409 | lab = data["labs"]
410 | tot += lab.shape[0]
411 | return tot
412 |
413 |
414 | # print(slice_count())
415 |
416 |
417 | def cal_direction(fname, scan, label):
418 | """根据预存信息矫正患者体位.
419 |
420 | Parameters
421 | ----------
422 | fname : str
423 | 患者文件名.
424 | scan : ndarray
425 | 3D扫描.
426 | label : type
427 | 3D标签.
428 |
429 | Returns
430 | -------
431 | ndarray, ndarray
432 | 校准后的3D数组.
433 |
434 | """
435 | f = open("./config/directions.csv")
436 | dirs = f.readlines()
437 | f.close()
438 | # print("dirs: ", dirs)
439 | dirs = [x.rstrip("\n") for x in dirs]
440 | dirs = [x.split(",") for x in dirs]
441 | dic = {}
442 | for dir in dirs:
443 | dic[dir[0].strip()] = dir[1].strip()
444 | dirs = dic
445 | try:
446 | if dirs[fname] == "2":
447 | scan = np.rot90(scan, 3)
448 | label = np.rot90(label, 3)
449 | else:
450 | scan = np.rot90(scan, 1)
451 | label = np.rot90(label, 1)
452 | except KeyError:
453 | pass
454 | return scan, label
455 |
--------------------------------------------------------------------------------
/tool/infer/util.bk:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import math
4 | import concurrent.futures
5 | import multiprocessing
6 |
7 | import nibabel as nib
8 | import numpy as np
9 | import cv2
10 | import scipy.ndimage
11 | from tqdm import tqdm
12 | import matplotlib.pyplot as plt
13 | import trimesh
14 | from pypinyin import pinyin, Style
15 |
16 |
17 | logging.basicConfig(level=logging.NOTSET)
18 |
19 |
20 | def filter_largest_volume(label, ratio=1.2, mode="soft"):
21 | """对输入的一个3D标签进行处理,只保留其中最大的连通块
22 |
23 | Parameters
24 | ----------
25 | label : ndarray
26 | 3D array:一个分割标签
27 | ratio : float
28 | 分割保留的范围
29 | mode : str
30 | "soft" / "hard"
31 | hard是只保留最大的联通块,soft是保留最大连通块bb内的
32 |
33 | Returns
34 | -------
35 | type
36 | 只保留最大连通块的标签.
37 |
38 | """
39 | if mode == "soft":
40 | return filter_largest_bb(label, ratio)
41 | vol, num = scipy.ndimage.label(label, np.ones([3, 3, 3]))
42 | maxi = 0
43 | maxnum = 0
44 | for i in range(1, num + 1):
45 | count = vol[vol == i].size
46 | if count > maxnum:
47 | maxi = i
48 | maxnum = count
49 |
50 | vol[vol != maxi] = 0
51 | vol[vol == maxi] = 1
52 | label = vol
53 | return label
54 |
55 |
56 | labels = []
57 |
58 | # TODO: 添加clip到一个前景类型的功能
59 | def nii2png(scan_path, scan_img_dir, label_path=None, label_img_dir=None, rot=0, wwwc=(400, 0), thresh=None):
60 | """将nii格式的扫描转成png.
61 | 扫描和标签一起处理,支持窗口化,旋转,略过没有前景的片
62 |
63 | Parameters
64 | ----------
65 | scan_path : str
66 | 扫描nii路径.
67 | scan_img_dir : str
68 | 扫描生成png放到这.
69 | label_path : str
70 | 标签nii路径.
71 | label_img_dir : str
72 | 标签生成png放到这.
73 | rot : int
74 | 进行几次旋转,如果有标签会一起.
75 | wwwc : list/tuple
76 | 进行窗口化的窗宽窗位.
77 | thresh : int
78 | 标签中前景数量达到这个数才生成png,否则略过.
79 |
80 | Returns
81 | -------
82 | type
83 | Description of returned object.
84 |
85 | """
86 | scanf = nib.load(scan_path)
87 | scan_data = scanf.get_fdata()
88 | name = os.path.basename(scan_path)
89 | # print(name)
90 | if scan_data.shape[0] == 1024:
91 | print("[WARNNING]", name, "is 1024")
92 | # vol = scipy.ndimage.interpolation.zoom(vol, [0.5, 0.5, 1], order=1 if islabel else 3)
93 |
94 | if label_path:
95 | labelf = nib.load(label_path)
96 | label_data = labelf.get_fdata()
97 | if label_data.shape != scan_data.shape:
98 | print("[ERROR] Scan and image dimension mismatch", name, scan_data.shape, label_data.shape)
99 |
100 | for _ in range(rot):
101 | scan_data = np.rot90(scan_data)
102 | if label_path:
103 | label_data = np.rot90(label_data)
104 |
105 | if not os.path.exists(scan_img_dir):
106 | os.makedirs(scan_img_dir)
107 | if label_path and not os.path.exists(label_img_dir):
108 | os.makedirs(label_img_dir)
109 |
110 | wl, wh = (wwwc[1] - wwwc[0] / 2, wwwc[1] + wwwc[0] / 2)
111 | scan_data = scan_data.astype("float32").clip(wl, wh)
112 | scan_data = (scan_data - wl) / (wh - wl) * 256
113 | scan_data = scan_data.astype("uint8")
114 |
115 | with concurrent.futures.ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
116 | for ind in range(1, scan_data.shape[2] - 1):
117 | if label_path:
118 | label_slice = label_data[:, :, ind]
119 | # 如果进行thresh限制而且这一片不达到这个数,那么就直接跳过,label和scan都不存
120 | if thresh and label_slice.sum() < thresh:
121 | continue
122 | file_path = os.path.join(label_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind))
123 | executor.submit(save_png, label_slice, file_path)
124 |
125 | scan_slice = scan_data[:, :, ind - 1 : ind + 2]
126 | file_path = os.path.join(scan_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind))
127 | executor.submit(save_png, scan_slice, file_path)
128 | if label_path:
129 | label_slice = label_data[:, :, ind]
130 | file_path = os.path.join(label_img_dir, "{}-{}.png".format(name.rstrip(".gz").rstrip(".nii"), ind))
131 | executor.submit(save_png, label_slice, file_path)
132 |
133 | # input("here")
134 |
135 |
136 | def save_png(slice, file_path):
137 | cv2.imwrite(file_path, slice)
138 |
139 |
140 | def nii2png_single(nii_path, png_folder, rot=1, wwwl=(256, 0), islabel=False, thresh=0):
141 | """将一个nii扫描转换成一系列图片,并进行简单的检查.
142 | # TODO: 检查是否只有一个连通块
143 | # TODO: 检查是否只有一种前景
144 |
145 | Parameters
146 | ----------
147 | nii_path : str
148 | nii扫描文件的路径.
149 | png_path : type
150 | 图片存在哪个文件夹.
151 | rot : int
152 | 扫描进行几次旋转,摆正体位.
153 | wwwl : (int, int)
154 | 窗宽窗位.
155 | """
156 | volf = nib.load(nii_path)
157 | nii_name = os.path.basename(nii_path)
158 | vol = volf.get_fdata()
159 | if vol.shape[0] == 1024:
160 | vol = scipy.ndimage.interpolation.zoom(vol, [0.5, 0.5, 1], order=1 if islabel else 3)
161 | for _ in range(rot):
162 | vol = np.rot90(vol)
163 | if not islabel:
164 | # vol = vol + np.random.randn() * 50 - 10
165 | wl, wh = (wwwl[1] - wwwl[0] / 2, wwwl[1] + wwwl[0] / 2)
166 | vol = vol.astype("float32").clip(wl, wh)
167 | vol = (vol - wl) / (wh - wl) * 256
168 | vol = vol.astype("uint8")
169 | if not os.path.exists(png_folder):
170 | os.makedirs(png_folder)
171 |
172 | for ind in range(1, vol.shape[2] - 1):
173 | if islabel:
174 | slice = vol[:, :, ind]
175 | else:
176 | slice = vol[:, :, ind - 1 : ind + 2]
177 | if islabel:
178 | sum = np.sum(slice)
179 | print(sum, thresh)
180 | if sum <= thresh:
181 | continue
182 | slice[slice == 2] = 1
183 |
184 | file_path = os.path.join(png_folder, "{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind))
185 | # if not islabel:
186 | # if "{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind) not in labels:
187 | # print("{}-{}.png".format(nii_name.rstrip(".gz").rstrip(".nii"), ind))
188 | # continue
189 |
190 | cv2.imwrite(file_path, slice)
191 |
192 |
193 | def nii2png_folder(nii_folder, png_folder, rot=1, wwwl=(400, 0), subfolder=False, islabel=False, thresh=0):
194 | """将一个文件夹里所有的nii转换成png.
195 |
196 | Parameters
197 | ----------
198 | nii_folder : str
199 | 放所有nii的文件夹.
200 | png_folder : str
201 | 放所有png的文件夹.
202 | rot : int
203 | 旋转几次.
204 | wwwl : (int, int)
205 | 窗宽窗位.
206 | subfolder : bool
207 | 是否给每一个nii创建单独的文件夹.
208 |
209 | """
210 | nii_names = os.listdir(nii_folder)
211 | for nii_name in tqdm(nii_names):
212 | if subfolder:
213 | png_folder = os.path.join(png_folder, nii_name)
214 | if len(nii_name) > 12:
215 | nii2png_single(os.path.join(nii_folder, nii_name), png_folder, 1, wwwl, islabel, thresh=thresh)
216 | else:
217 | nii2png_single(os.path.join(nii_folder, nii_name), png_folder, 3, wwwl, islabel, thresh=thresh)
218 | # print(len(nii_name))
219 | # print(nii_name)
220 | # input("here")
221 | # os.system("rm /home/lin/Desktop/data/aorta/dataset/scan/*")
222 |
223 |
224 | def check_nii_match(scan_dir, label_dir, remove=False):
225 | # TODO: 用片间间隔和大小计算片内方向的边长,太大或者太小报错
226 | """检查两个目录下的扫描和标签是不是对的上.
227 |
228 | Parameters
229 | ----------
230 | scan_dir : str
231 | 扫描所在路径.
232 | label_dir : str
233 | 标签所在路径.
234 |
235 | Returns
236 | -------
237 | bool
238 | 比较的结果,对上了返回True,否则False,具体的细节直接打到stdio.
239 |
240 | """
241 | pass_check = True
242 | scans = os.listdir(scan_dir)
243 | labels = os.listdir(label_dir)
244 | scans = [n for n in scans if n.endswith("nii") or n.endswith("gz")]
245 | labels = [n for n in labels if n.endswith("nii") or n.endswith("gz")]
246 |
247 | scan_names = [n.rstrip(".gz").rstrip(".nii") for n in scans]
248 | label_names = [n.rstrip(".gz").rstrip(".nii") for n in labels]
249 |
250 | if len(scans) != len(labels):
251 | logging.error("Number of scnas({}) and labels ({}) don't match".format(len(scans), len(labels)))
252 | pass_check = False
253 | else:
254 | logging.info("Pass file number check")
255 |
256 | names_match = True
257 | for ind, s in enumerate(scan_names):
258 | if s not in label_names:
259 | logging.error("Scan {} dont have corresponding label".format(s))
260 | names_match = False
261 | print("removing {}".format(s))
262 | if remove:
263 | os.remove(os.path.join(scan_dir, scans[ind]))
264 |
265 | for l in label_names:
266 | if l not in scan_names:
267 | logging.error("Label {} dont have corresponding scan".format(l))
268 | names_match = False
269 |
270 | if names_match:
271 | logging.info("Pass file names check")
272 | else:
273 | pass_check = False
274 |
275 | scans = os.listdir(scan_dir)
276 | labels = os.listdir(label_dir)
277 | scans = [n for n in scans if n.endswith("nii") or n.endswith("gz")]
278 | labels = [n for n in labels if n.endswith("nii") or n.endswith("gz")]
279 |
280 | scan_names = [n.rstrip(".gz").rstrip(".nii") for n in scans]
281 | label_names = [n.rstrip(".gz").rstrip(".nii") for n in labels]
282 |
283 | for scan_name in scans:
284 | scanf = nib.load(os.path.join(scan_dir, scan_name))
285 | labelf = nib.load(os.path.join(label_dir, scan_name))
286 | if (scanf.affine == np.eye(4)).all():
287 | logging.warn("Scan {} have np.eye(4) affine, check the header".format(scan_name))
288 | if (labelf.affine == np.eye(4)).all():
289 | logging.warn("Label {} have np.eye(4) affine, check the header".format(scan_name))
290 | if not (labelf.header["dim"] == scanf.header["dim"]).all():
291 | logging.error(
292 | "Label and scan dimension mismatch for {}, scan is {}, label is {}".format(
293 | scan_name, scanf.header["dim"][1:4], labelf.header["dim"][1:4]
294 | )
295 | )
296 | pass_check = False
297 | return pass_check
298 |
299 |
300 | def inspect_pair(scan_path, label_path):
301 |
302 | # 如果是nii格式
303 | # TODO: 完善
304 | if scan_path.endswith("nii") or scan_path.endswith("gz"):
305 | pass
306 | # TODO: 一对图片放到一个frame
307 | if scan_path.endswith("png"):
308 | scan = cv2.imread(scan_path)
309 | label = cv2.imread(label_path)
310 | plt.imshow(scan)
311 | plt.show()
312 | label = label * 255
313 | plt.imshow(label)
314 | plt.show()
315 |
316 |
317 | def is_right(a, b, c):
318 | a = np.array((b[0] - a[0], b[1] - a[1]))
319 | b = np.array((c[0] - b[0], c[1] - b[1]))
320 | res = np.cross(a, b)
321 | if res >= 0:
322 | return True
323 | return False
324 |
325 |
326 | class Polygon:
327 | points = []
328 | center = []
329 | height = 0
330 | base = []
331 | epsilon = 1e-6
332 |
333 | def __init__(self, p):
334 | if len(p) == 0:
335 | raise RuntimeError("Nan points in polygon")
336 | self.points = p
337 | if len(self.points[0]) == 2:
338 | points = []
339 | for p in self.points:
340 | points.append(list(p[0]))
341 | unique = []
342 | for p in points:
343 | if p not in unique:
344 | unique.append(p)
345 | self.points = unique
346 | self.cal_rep()
347 | self.ang_sort()
348 | return
349 |
350 | for ind in range(len(self.points)):
351 | self.points[ind] = list(self.points[ind])
352 |
353 | self.cal_rep()
354 | self.ang_sort()
355 | try:
356 | self.height = p[0][2]
357 | except:
358 | print(p)
359 |
360 | def cal_rep(self):
361 | """计算中心点和基点.
362 | 两个操作有顺序
363 | Returns
364 | -------
365 | type
366 | Description of returned object.
367 |
368 | """
369 | # print("___", np.min(self.points, axis=0))
370 | # print("---", np.max(self.points, axis=0))
371 | self.center = list((np.min(self.points, axis=0) + np.max(self.points, axis=0)) / 2)
372 | # print(self.center)
373 | # input("here")
374 | self.points.sort()
375 | self.base = self.points[0]
376 | del self.points[0]
377 |
378 | def ang_sort(self):
379 | """对所有的点进行极角排序.
380 |
381 | Returns
382 | -------
383 | type
384 | Description of returned object.
385 |
386 | """
387 |
388 | def cmp(a):
389 | return math.atan((a[1] - self.base[1]) / (a[0] - self.base[0] + self.epsilon))
390 |
391 | self.points.sort(key=cmp, reverse=True)
392 |
393 | def cal_size(self):
394 | """给一个数组的点,求它构成的多边形的面积.
395 | 注意这里没有做极角排序,点本身需要满足顺时针或者逆时针顺序
396 |
397 | Parameters
398 | ----------
399 | points : type
400 | Description of parameter `points`.
401 |
402 | Returns
403 | -------
404 | type
405 | Description of returned object.
406 |
407 | """
408 | points = self.points
409 | tot = 0
410 | p = self.base
411 | for ind in range(0, len(points) - 1):
412 | a = [t2 - t1 for t1, t2 in zip(p, points[ind])]
413 | b = [t2 - t1 for t1, t2 in zip(p, points[ind + 1])]
414 | a = np.array(a)
415 | b = np.array(b)
416 | # b = b.reshape(b.size, 1)
417 | # print(a, b)
418 | res = np.cross(a, b)
419 | tot += (res[0] ** 2 + res[1] ** 2 + res[2] ** 2) ** (1 / 2) / 2
420 | return tot
421 |
422 | def to_2d(self):
423 | def rot_to_horizontal(p):
424 | """将一个点旋转到水平面上.
425 |
426 | Parameters
427 | ----------
428 | p : type
429 | Description of parameter `p`.
430 |
431 | Returns
432 | -------
433 | type
434 | Description of returned object.
435 |
436 | """
437 | # TODO: 在基本和y轴平行的时候会有div0错误
438 | epsilon = 1e-6
439 | if p[0] == 0 and p[1] == 0 and p[2] == 0:
440 | return [0, 0]
441 | x = p[0]
442 | y = p[1]
443 | z = p[2]
444 | angle = math.atan(z / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon))
445 | unit = (
446 | y / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon),
447 | -x / ((x ** 2 + y ** 2) ** (1 / 2) + epsilon),
448 | 0,
449 | )
450 | matrix = trimesh.transformations.rotation_matrix(-angle, unit, (0, 0, 0))
451 | p.append(1)
452 | p = np.array(p).reshape([1, 4])
453 | p = p.transpose()
454 | res = np.dot(matrix, p)
455 | return [float(res[0]), float(res[1])]
456 |
457 | self.points = [[p[0] - self.base[0], p[1] - self.base[1], p[2] - self.base[2]] for p in self.points]
458 | # print("+_+", self.center)
459 | self.center = [b - a for a, b in zip(self.base, self.center)]
460 | self.center = rot_to_horizontal(self.center)
461 | # print("_+_", self.center)
462 | self.points = [rot_to_horizontal(p) for p in self.points]
463 | self.base = [0, 0]
464 |
465 | def cal_diameter(self, ang_range=[0, np.pi], split=30, pixdim=1, step=1):
466 | """计算这个多边形的直径
467 | 1. 将所有点旋转到一个平面内
468 | 2. 将半个圆周分成split份,在center做两条线,分别往向上和线下的方向运动
469 | 3. 这个线第一次让所有多边形端点都在线一侧停止运动
470 | 4. 计算两根线距离,作为直径
471 |
472 | Parameters
473 | ----------
474 | ang_range : type
475 | Description of parameter `ang_range`.
476 | split : type
477 | Description of parameter `split`.
478 | pixdim : type
479 | Description of parameter `pixdim`.
480 | step : type
481 | Description of parameter `step`.
482 |
483 | Returns
484 | -------
485 | type
486 | Description of returned object.
487 |
488 | """
489 | self.to_2d()
490 | # self.plot_2d()
491 | center = self.center
492 | diameters = [self.height]
493 | for alpha in np.arange(ang_range[0], ang_range[1], (ang_range[1] - ang_range[0]) / split):
494 | # TODO: 如果这个线是垂直的
495 | if alpha == np.pi / 2:
496 | continue
497 | k = math.tan(alpha)
498 | # print(alpha)
499 |
500 | d = 0
501 | d1 = 0
502 | d2 = 0
503 | while True:
504 | # print("+:", d)
505 | y0 = center[1] - d + k * (0 - center[0])
506 | y1 = center[1] - d + k * (1 - center[0])
507 | dir = is_right((0, y0), (1, y1), self.points[0])
508 | same_dir = True
509 | for p in self.points:
510 | tmp = is_right((0, y0), (1, y1), p)
511 | if tmp != dir:
512 | same_dir = False
513 | break
514 | if same_dir:
515 | d1 = d
516 | break
517 | d += step
518 |
519 | d = 0
520 | while True:
521 | # print("-:", d)
522 | y0 = center[1] - d + k * (0 - center[0])
523 | y1 = center[1] - d + k * (1 - center[0])
524 | dir = is_right((0, y0), (1, y1), self.points[0])
525 | same_dir = True
526 | for p in self.points:
527 | tmp = is_right((0, y0), (1, y1), p)
528 | if tmp != dir:
529 | same_dir = False
530 | break
531 | if same_dir:
532 | d2 = d
533 | break
534 | d += step
535 | diameters.append((d1 + d2) * np.abs(np.cos(alpha)) * pixdim)
536 |
537 | return diameters
538 |
539 | def plot_2d(self):
540 | plt.scatter([p[0] for p in self.points], [p[1] for p in self.points])
541 | # plt.plot([self.center[0]], [self.center[1]])
542 | plt.show()
543 |
544 |
545 | """
546 | y-y0=k(x-x0)+b
547 | x=0, y=y0+k(x-x0)+b
548 | """
549 | # po = Polygon([[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]])
550 | # print(po.cal_diameter(split=2, step=0.1))
551 |
552 |
553 | def dist(a, b):
554 | res = 0
555 | for p, q in zip(a, b):
556 | res += (p - q) ** 2
557 | res = res ** (1 / 2)
558 | return res
559 |
560 |
561 | # print(dist([1, 0, 1], [0, 0, 3]))
562 |
563 |
564 | def filter_polygon(points, center, thresh=1):
565 | """给一堆点,不一定是一个多边形,返回中心离center比较近的那个多边形.
566 |
567 | Parameters
568 | ----------
569 | points : type
570 | Description of parameter `points`.
571 |
572 | Returns
573 | -------
574 | type
575 | Description of returned object.
576 |
577 | """
578 | # TODO: 点需要极角排序
579 |
580 | curr_point = points[-1]
581 | polygons = [[]]
582 | p_ind = 0
583 | while len(points) != 0:
584 | found = False
585 | for ind in range(len(points) - 1, -1, -1):
586 | if dist(curr_point, points[ind]) < thresh:
587 | curr_point = points[ind]
588 | del points[ind]
589 | polygons[p_ind].append(curr_point)
590 | found = True
591 | break
592 | if not found:
593 | polygons.append([])
594 | p_ind += 1
595 | curr_point = points[-1]
596 | if center == "all":
597 | return polygons
598 | centers = []
599 | for polygon in polygons:
600 | ps = np.array(polygon)
601 | centers.append(np.mean(ps, axis=0))
602 | dists = [dist(p, center) for p in centers]
603 | min_ind = np.argmin(dists)
604 | print(min_ind)
605 | return list(polygons[min_ind])
606 |
607 |
608 | # print(filter_polygon([[0, 0, 0], [1, 1, 1], [4, 4, 4], [5, 5, 5]], [0, 0, 0]))
609 |
610 |
611 | def sort_line(polygons):
612 | """给一个list的多边形,找到一个开头,之后从这个开头开始dfs的顺序给序列排序.
613 |
614 | Parameters
615 | ----------
616 | points : type
617 | Description of parameter `points`.
618 | dist : type
619 | Description of parameter `dist`.
620 |
621 | Returns
622 | -------
623 | type
624 | Description of returned object.
625 |
626 | """
627 | # 从最低点开始
628 | polygons.sort(key=lambda a: [a.center[2], a.center[1], a.center[0]], reverse=True)
629 | curr_point = polygons[-1].center
630 | ordered = []
631 | while len(polygons) != 0:
632 | # 找最近的点
633 | min_dist = dist(curr_point, polygons[-1].center)
634 | min_ind = len(polygons) - 1
635 | for ind in range(len([polygons])):
636 | curr_dist = dist(polygons[ind].center, curr_point)
637 | if curr_dist < min_dist:
638 | min_dist = curr_dist
639 | min_ind = ind
640 | curr_point = polygons[min_ind].center
641 | ordered.append(polygons[min_ind])
642 | polygons.pop(min_ind)
643 | return ordered
644 |
645 |
646 | # print(sort_line([[0, 0, 0], [0, 0, 4], [0, 0, 3], [0, 0, 2], [0, 0, 6], [0, 1, 5.1], [0, 1, 6], [0, 0, 5]]))
647 |
648 |
649 | def to_pinyin(name, nonum=False):
650 | new_name = ""
651 | for ch in name:
652 | if u"\u4e00" <= ch <= u"\u9fff":
653 | new_name += pinyin(ch, style=Style.NORMAL)[0][0]
654 | else:
655 | # if nonum and ("0" <= ch <= "9" or ch == "_"):
656 | # continue
657 | new_name += ch
658 | return new_name
659 |
--------------------------------------------------------------------------------