├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── preprocess.iml
└── vcs.xml
├── CBAM.py
├── FreeSurfer+FSL.py
├── FreeSurfer.py
├── README.md
├── ants.py
├── antspy.py
├── crop_resize.py
├── data_augmentation.py
├── dcm2nii.py
├── image_combine_example.py
├── img2nii.py
├── jacobian.py
├── phi
├── .idea
│ ├── NII.iml
│ ├── deployment.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ └── workspace.xml
├── f_img.jpg
├── img.jpg
├── img1.jpg
├── img2.jpg
├── m_img.jpg
├── my_show_grid.py
└── show_grid.py
├── resize_img.py
└── torch_datagenerators.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../:\Data\1.Work\02.Code\preprocess\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/preprocess.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/CBAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CBAM_Module(nn.Module):
6 | def __init__(self, dim, in_channels, ratio, kernel_size):
7 | super(CBAM_Module, self).__init__()
8 | self.avg_pool = getattr(nn, "AdaptiveAvgPool{0}d".format(dim))(1)
9 | self.max_pool = getattr(nn, "AdaptiveMaxPool{0}d".format(dim))(1)
10 | conv_fn = getattr(nn, "Conv{0}d".format(dim))
11 | self.fc1 = conv_fn(in_channels, in_channels // ratio, kernel_size=1, padding=0)
12 | self.relu = nn.ReLU()
13 | self.fc2 = conv_fn(in_channels // ratio, in_channels, kernel_size=1, padding=0)
14 | self.sigmoid = nn.Sigmoid()
15 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
16 | padding = 3 if kernel_size == 7 else 1
17 | self.conv = conv_fn(2, 1, kernel_size=kernel_size, stride=1, padding=padding)
18 |
19 | def forward(self, x):
20 | print("CBAM")
21 | # Channel attention module:(Mc(f) = σ(MLP(AvgPool(f)) + MLP(MaxPool(f))))
22 | module_input = x
23 | avg = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
24 | mx = self.fc2(self.relu(self.fc1(self.max_pool(x))))
25 | x = self.sigmoid(avg + mx)
26 | x = module_input * x
27 | # Spatial attention module:Ms (f) = σ( f7×7( AvgPool(f) ; MaxPool(F)] )))
28 | module_input = x
29 | avg = torch.mean(x, dim=1, keepdim=True)
30 | mx, _ = torch.max(x, dim=1, keepdim=True)
31 | x = torch.cat((avg, mx), dim=1)
32 | x = self.sigmoid(self.conv(x))
33 | x = module_input * x
34 | return x
--------------------------------------------------------------------------------
/FreeSurfer+FSL.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 |
5 | # 使用FreeSurfer对图像进行颅骨剥离
6 | print("FreeSurfer start......\n")
7 | # 图像坐在的目录
8 | #------------------------图像路径需更改------------------------#
9 | path ="/home/syzhou/zuzhiang/Dataset/MGH10/Heads"
10 | # 读取目录下的.img文件列表,*.img表示该目录下所有以.img结尾的文件
11 | #-----------------------图像后缀名需更改--- -------------------#
12 | files = glob.glob(os.path.join(path,"*.img"))
13 | # 输出路径
14 | #------------------------输出路径需更改------------------------#
15 | out_path="/home/syzhou/zuzhiang/MIP/FSL_img/MGH10_results"
16 | print("number: ",len(files))
17 | # 下面为freesurfer的环境配置命令
18 | a = "export FREESURFER_HOME=/home/syzhou/zuzhiang/freesurfer;"
19 | b = "source $FREESURFER_HOME/SetUpFreeSurfer.sh;"
20 | # 数据所在的目录
21 | c = "export SUBJECTS_DIR="+path+";"
22 |
23 | for file in files:
24 | # 将文件路径和文件名分离
25 | filename = os.path.split(file)[1] # 将路径名和文件名分开
26 | filename = filename.split(".")[0] # 去除所有扩展名
27 | #recon-all是颅骨去除的命令
28 | # mri_convert是进行格式转换,从mgz转到nii.gz,只是为了方便查看
29 | filename=filename[:] #根据扩展名的不同,这里需要做更改,只保留文件名即可
30 | # 当前输出文件路径,以.nii.gz格式保存
31 | cur_out_path=os.path.join(out_path,filename+".nii.gz")
32 | print("file name: ",file)
33 | cmd = a + b + c + "mri_watershed "+file+" "+ cur_out_path
34 | #print(cmd,"\n")
35 | os.system(cmd)
36 |
37 |
38 | # 使用FSL对图像和对应的label进行仿射对齐
39 | print("FSL start......\n")
40 | # fixed图像的路径
41 | #---------------去除头骨后的fixed图像名需更改-------------------#
42 | f_path= os.path.join(out_path,"g1.nii.gz")
43 | # moving图像的路径
44 | m_path=out_path
45 | # label的路径
46 | #-----------------------label路径需更改-----------------------#
47 | label_path="/home/syzhou/zuzhiang/Dataset/MGH10/Atlases"
48 | files=glob.glob(os.path.join(m_path,"*.nii.gz"))
49 | print("number: ",len(files))
50 | for file in files:
51 | print("file: ",file)
52 | # 根据图像名找到对应的label名
53 | filename=os.path.split(file)[1]
54 | filename = filename.split(".")[0] # 去除所有扩展名
55 | #---------------------label后缀名需更改--------------------#
56 | label=os.path.join(label_path,filename+".img")
57 | # 下面分别是输出图像名/转换矩阵名/label名,
58 | out_img=os.path.join(out_path,filename+"_img.nii.gz")
59 | out_mat=os.path.join(out_path,filename+".mat")
60 | out_label=os.path.join(out_path,filename+"_label.nii.gz")
61 | # 如果当前文件和fixed图像一样则只将对应label格式转换为.nii.gz
62 | if f_path==file:
63 | convert="mri_convert " + label +" " + out_label
64 | os.system(convert)
65 | print("continue.........")
66 | continue
67 | # 将moving图像对齐到fixed图像
68 | flirt_img="flirt -in "+file+ " -ref "+f_path+" -out "+out_img+" -omat "+out_mat+ " -dof 12"
69 | # 将上一步的仿射变换矩阵作用在图像对应的label上
70 | flirt_label="flirt -in "+label+" -ref "+f_path+" -out "+out_label+" -init "+out_mat+" -applyxfm -interp nearestneighbour"
71 | #print(flirt_img,"\n")
72 | #print(flirt_label,"\n")
73 | os.system(flirt_img)
74 | os.system(flirt_label)
75 |
76 | print("\n\nEnd")
--------------------------------------------------------------------------------
/FreeSurfer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | path = r"/home/syzhou/zuzhiang/Dataset/MGH10/Heads"
5 | # 读取目录下的nii.gz文件
6 | images = glob.glob(os.path.join(path,"*.img"))
7 | # 下面为freesurfer的环境配置命令
8 | a = "export FREESURFER_HOME=/home/syzhou/zuzhiang/freesurfer;"
9 | b = "source $FREESURFER_HOME/SetUpFreeSurfer.sh;"
10 | # 数据所在的目录
11 | c = "export SUBJECTS_DIR="+path+";"
12 |
13 | #images=['/home/syzhou/zuzhiang/Dataset/MGH10/Heads/1127.img']
14 | for image in images:
15 | # 将文件路径和文件名分离
16 | filename = os.path.split(image)[1] # 将路径名和文件名分开
17 | filename = os.path.splitext(filename)[0] #将文件名和扩展名分开,如果为.nii.gz,则认为扩展名是.gz
18 | # freesurfer环境配置、颅骨去除、未仿射对齐mpz转nii、仿射对齐、仿射对齐mpz转nii.gz格式
19 | #recon-all是颅骨去除的命令
20 | # mri_convert是进行格式转换,从mgz转到nii.gz,只是为了方便查看
21 | # --apply_transform:仿射对齐操作
22 | # 转格式
23 | filename=filename[:] #根据扩展名的不同,这里需要做更改,只保留文件名即可
24 | cur_path=os.path.join(path,filename)
25 | print("file name: ",cur_path)
26 | cmd = a + b + c \
27 | + "recon-all -parallel -i " + image + " -autorecon1 -subjid " + cur_path + "&&" \
28 | + "mri_convert " + cur_path + "/mri/brainmask.mgz " +cur_path + "/mri/"+filename+".nii.gz;"\
29 | + "mri_convert " + cur_path + "/mri/brainmask.mgz --apply_transform " + cur_path + "/mri/transforms/talairach.xfm -o " + cur_path + "/mri/brainmask_affine.mgz&&" \
30 | + "mri_convert " + cur_path + "/mri/brainmask_affine.mgz " + cur_path + "/mri/"+filename+"_affine.nii.gz;"
31 | #print("cmd:\n",cmd)
32 | os.system(cmd)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 数据很杂,就懒得逐一整理了,很多都是实验性的,可能存在问题,不过好久没弄也忘得差不多,有问题的话还请大家自行阅读代码解决吧,下面只把每个文件是做什么用的大体说一下:
2 |
3 | - `phi`:
4 | - `my_show_grid.py`:先产生规则的网格图片,然后使用STN根据读入的形变场对其变形
5 | - `show_grid.py`:读入形变场的.nii文件,然后生成可视化的形变场图片
6 | - `show_nii`:
7 | - `show_nii.py`:展示三维的.nii图像
8 | - `show_nii3.py`:同时展示三维的.nii格式的器官、器官分割mask、病灶分割mask图像
9 | - `ants.py`:使用ANTs包内的SyN对图像进行配准
10 | - `antspy.py`:使用基于python版的ANTs——antspy对图像和其对应的标签同时进行配准
11 | - `CBAM`:CBAM注意力模块的实现
12 | - `crop_resize.py`:先找到包含脑部区域的最小矩形框,然后手动计算矩形框的大小(每一维应该是16的倍数),然后用已经注释掉的部分进行resize
13 | - `data_augmentation.py`:数据增强的代码,包括B样条采样、反转、平移、缩放、旋转、灰度值均衡化等,但B样条采样貌似有问题
14 | - `dcm2nii.py`:将dcm格式的三维图像转换为.nii格式的
15 | - `FreeSurfer+FSL.py`:先用FreeSurfer对脑部图像进行颅骨剥离,然后用FSL对图像和标签同时进行仿射对齐
16 | - `FreeSurfer.py`:用FreeSurfer对脑部图像进行颅骨剥离
17 | - `image_combine_example.py`:别人给的,没用过
18 | - `img2nii.py`:将.img格式的三维图像转化为.nii格式的
19 | - `jacobian.py`:计算形变场的雅克比行列式
20 | - `resize_img.py`:对图像进行resize
21 | - `tor_datagenerators.py`:基于pytorch的数据生成器
22 |
--------------------------------------------------------------------------------
/ants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | path="/home/syzhou/zuzhiang/Dataset/LPBA40"
5 | out_path="/home/syzhou/zuzhiang/MIP/ANTs"
6 | f_name="/home/syzhou/zuzhiang/Dataset/LPBA40/1.nii.gz"
7 | for i in range(2,41):
8 | m_name=os.path.join(path,str(i)+".nii.gz")
9 | out_name=str(i)+"m"
10 | cmd= "antsRegistrationSyN.sh -d 3 -f " + f_name + " -m " + m_name + " -o " + out_name
11 | print("cmd: ",cmd)
12 | os.system(cmd)
13 | print("End")
14 |
--------------------------------------------------------------------------------
/antspy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import ants
4 | import numpy as np
5 | import SimpleITK as sitk
6 |
7 | # ants图片的读取
8 | f_img = ants.image_read("./data/f_img.nii.gz")
9 | m_img = ants.image_read("./data/m_img.nii.gz")
10 | f_label = ants.image_read("./data/f_label.nii.gz")
11 | m_label = ants.image_read("./data/m_label.nii.gz")
12 |
13 | '''
14 | ants.registration()函数的返回值是一个字典:
15 | warpedmovout: 配准到fixed图像后的moving图像
16 | warpedfixout: 配准到moving图像后的fixed图像
17 | fwdtransforms: 从moving到fixed的形变场
18 | invtransforms: 从fixed到moving的形变场
19 |
20 | type_of_transform参数的取值可以为:
21 | Rigid:刚体
22 | Affine:仿射配准,即刚体+缩放
23 | ElasticSyN:仿射配准+可变形配准,以MI为优化准则,以elastic为正则项
24 | SyN:仿射配准+可变形配准,以MI为优化准则
25 | SyNCC:仿射配准+可变形配准,以CC为优化准则
26 | '''
27 | # 图像配准
28 | mytx = ants.registration(fixed=f_img, moving=m_img, type_of_transform='SyN')
29 | # 将形变场作用于moving图像,得到配准后的图像,interpolator也可以选择"nearestNeighbor"等
30 | warped_img = ants.apply_transforms(fixed=f_img, moving=m_img, transformlist=mytx['fwdtransforms'],
31 | interpolator="linear")
32 | # 对moving图像对应的label图进行配准
33 | warped_label = ants.apply_transforms(fixed=f_img, moving=m_label, transformlist=mytx['fwdtransforms'],
34 | interpolator="linear")
35 | # 将配准后图像的direction/origin/spacing和原图保持一致
36 | warped_img.set_direction(f_img.direction)
37 | warped_img.set_origin(f_img.origin)
38 | warped_img.set_spacing(f_img.spacing)
39 | warped_label.set_direction(f_img.direction)
40 | warped_label.set_origin(f_img.origin)
41 | warped_label.set_spacing(f_img.spacing)
42 | img_name = "./result/warped_img.nii.gz"
43 | label_name = "./result/warped_label.nii.gz"
44 | # 图像的保存
45 | ants.image_write(warped_img, img_name)
46 | ants.image_write(warped_label, label_name)
47 |
48 | # 将antsimage转化为numpy数组
49 | warped_img_arr = warped_img.numpy(single_components=False)
50 | # 从numpy数组得到antsimage
51 | img = ants.from_numpy(warped_img_arr, origin=None, spacing=None, direction=None, has_components=False, is_rgb=False)
52 | # 生成图像的雅克比行列式
53 | jac = ants.create_jacobian_determinant_image(domain_image=f_img, tx=mytx["fwdtransforms"][0], do_log=False, geom=False)
54 | ants.image_write(jac, "./result/jac.nii.gz")
55 | # 生成带网格的moving图像,实测效果不好
56 | m_grid = ants.create_warped_grid(m_img)
57 | m_grid = ants.create_warped_grid(m_grid, grid_directions=(False, False), transform=mytx['fwdtransforms'],
58 | fixed_reference_image=f_img)
59 | ants.image_write(m_grid, "./result/m_grid.nii.gz")
60 |
61 | '''
62 | 以下为其他不常用的函数:
63 |
64 | ANTsTransform.apply_to_image(image, reference=None, interpolation='linear')
65 | ants.read_transform(filename, dimension=2, precision='float')
66 | # transform的格式是".mat"
67 | ants.write_transform(transform, filename)
68 | # field是ANTsImage类型
69 | ants.transform_from_displacement_field(field)
70 | '''
71 |
72 | print("End")
73 |
--------------------------------------------------------------------------------
/crop_resize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import SimpleITK as sitk
5 |
6 |
7 | # 获取可以包裹mask的最小bounding box
8 | def get_bbox_from_mask(mask, outside_value=0):
9 | mask_voxel_coords = np.where(mask != outside_value)
10 | minzidx = int(np.min(mask_voxel_coords[0]))
11 | maxzidx = int(np.max(mask_voxel_coords[0])) + 1
12 | minxidx = int(np.min(mask_voxel_coords[1]))
13 | maxxidx = int(np.max(mask_voxel_coords[1])) + 1
14 | minyidx = int(np.min(mask_voxel_coords[2]))
15 | maxyidx = int(np.max(mask_voxel_coords[2])) + 1
16 | return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
17 |
18 |
19 | # 根据bbox截取图片
20 | def crop_to_bbox(image, bbox):
21 | assert len(image.shape) == 3, "only supports 3d images"
22 | # slice是切片函数,参数为:起始值,终止值,[步长]
23 | resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
24 | return image[resizer]
25 |
26 |
27 | out_path = r"C:\Users\zuzhiang\Desktop"
28 | files = glob.glob(os.path.join(r"C:\Data\1.Work\03.DataSet\LPBA40\train\256-256-256", "*.nii.gz"))
29 | print("Number of images: ", len(files))
30 |
31 | minx, miny, minz, maxx, maxy, maxz = 300, 300, 300, 0, 0, 0
32 | for file in files:
33 | _, name = os.path.split(file)
34 | print("Name: ", name)
35 | old_img = sitk.ReadImage(file)
36 | img_arr = sitk.GetArrayFromImage(old_img)
37 | bbox = get_bbox_from_mask(img_arr)
38 | if bbox[0][0] < minx:
39 | minx = bbox[0][0]
40 | if bbox[0][1] > maxx:
41 | maxx = bbox[0][1]
42 | if bbox[1][0] < miny:
43 | miny = bbox[1][0]
44 | if bbox[1][1] > maxy:
45 | maxy = bbox[1][1]
46 | if bbox[2][0] < minz:
47 | minz = bbox[2][0]
48 | if bbox[2][1] > maxz:
49 | maxz = bbox[2][1]
50 | print("bbox: ", bbox)
51 | print("\nminx: %d maxx: %d miny: %d maxy: %d minz: %d maxz: %d" % (minx, maxx, miny, maxy, minz, maxz))
52 |
53 | # # 根据以上6个值手工得到bbox
54 | # bbox = [[0, 240], [21, 229], [40, 216]]
55 | # print("bbox: ", bbox)
56 | # for file in files:
57 | # _, name = os.path.split(file)
58 | # print("Name: ", name)
59 | # old_img = sitk.ReadImage(file)
60 | # img_arr = sitk.GetArrayFromImage(old_img)
61 | # img_arr = crop_to_bbox(img_arr, bbox)
62 | # print("new shape: ", img_arr.shape)
63 | # img = sitk.GetImageFromArray(img_arr)
64 | # img.SetOrigin(old_img.GetOrigin())
65 | # img.SetDirection(old_img.GetDirection())
66 | # img.SetSpacing(old_img.GetSpacing())
67 | # sitk.WriteImage(img, os.path.join(out_path, name))
68 |
--------------------------------------------------------------------------------
/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import SimpleITK as sitk
4 | import matplotlib.pyplot as plt
5 |
6 | np.set_printoptions(threshold=np.inf)
7 |
8 |
9 | def resample(image, transform):
10 | # Output image Origin, Spacing, Size, Direction are taken from the reference
11 | # image in this call to Resample
12 | reference_image = image
13 | interpolator = sitk.sitkCosineWindowedSinc
14 | default_value = -100.0
15 | return sitk.Resample(image, reference_image, transform,
16 | interpolator, default_value)
17 |
18 |
19 | # 数据增强函数,可以对2D/3D图像随机进行直方图均衡化、翻转、平移、放缩、旋转、灰度值归一化等操作
20 | def data_augmentation(path=None, image=None, do_bspline=True, do_flip=False, do_trslt=True, do_scale=True, do_rota=True,
21 | show_help_info=False):
22 | '''
23 | :param path: 图像路径
24 | :param image: 输入图像,当path为None使用
25 | :param do_bspline: 是否做B样条转换,只在为3D图像时使用,因2D时有问题(没效果)
26 | :param do_flip: 是否做翻转操作
27 | :param do_trslt: 是否做平移操作
28 | :param do_scale: 是否做缩放操作
29 | :param do_rota: 是否做旋转操作
30 | :param show_help_info: 是否显示帮助信息
31 | :return:
32 | '''
33 | if path is not None:
34 | image = sitk.ReadImage(path)
35 | image_arr = sitk.GetArrayFromImage(image)
36 |
37 | # 去除灰度直方图中最亮的0.2%的灰度值
38 | value, _, count = np.unique(image_arr, return_inverse=True, return_counts=True)
39 | # axis=0,按照行累加。axis=1,按照列累加。axis不给定具体值,就把numpy数组当成一个一维数组。
40 | # 累加即当前元素值,等于之前所有元素值之和+当前元素值
41 | quantiles = np.cumsum(count).astype(np.float32)
42 | quantiles /= quantiles[-1]
43 | # 找到灰度直方图中小于0.998的最大位置,并将大于其对应值的值设为该值
44 | pos = np.where(quantiles < 0.998)[0]
45 | val_max = value[np.max(pos)]
46 | image_arr[image_arr > val_max] = val_max
47 |
48 | # 计算窗宽和窗位
49 | min_val, max_val = np.min(image_arr), np.max(image_arr)
50 | window_width, window_center = max_val - min_val, (min_val + max_val) / 2
51 |
52 | dim = image.GetDimension()
53 | if show_help_info:
54 | print("dim: ", dim)
55 | transform = sitk.AffineTransform(dim)
56 | matrix = np.array(transform.GetMatrix()).reshape((dim, dim))
57 | if dim == 2:
58 | if do_flip:
59 | # 翻转(flip)
60 | if np.random.rand() > 0.5:
61 | image = image[::-1] # 沿x轴反转
62 | if show_help_info:
63 | print("x axis flip.")
64 | if np.random.rand() > 0.5:
65 | image = image[:, ::-1] # 沿y轴反转
66 | if show_help_info:
67 | print("y axis flip.")
68 | if do_trslt:
69 | # 平移(translation)
70 | x_trslt, y_trslt = np.random.randint(-20, 20, 2) # 每个轴的平移范围在[-20,20]像素之间
71 | transform.SetTranslation((float(x_trslt), float(y_trslt)))
72 | if show_help_info:
73 | print("x_trslt: ", x_trslt, " y_trslt: ", y_trslt)
74 | if do_scale:
75 | # 缩放(scale)
76 | x_scale = 1.0 + random.uniform(-0.1, 0.1) # 缩放范围为原来的[0.9,1.1]
77 | y_scale = 1.0 + random.uniform(-0.1, 0.1)
78 | # x_scale, y_scale表示原图与结果图的倍数关系,如scale为2时缩小为原来的0.5
79 | matrix[0, 0] = x_scale
80 | matrix[1, 1] = y_scale
81 | if show_help_info:
82 | print("x_scale: ", x_scale, " y_scale: ", y_scale)
83 | if do_rota:
84 | # 旋转(rotation)
85 | degree = np.random.randint(-15, 15) # 旋转角度范围为[-15°,15°]
86 | radians = -np.pi * degree / 180.
87 | rotation = np.array([[np.cos(radians), -np.sin(radians)], [np.sin(radians), np.cos(radians)]])
88 | matrix = np.dot(rotation, matrix)
89 | if show_help_info:
90 | print("degree: ", degree)
91 |
92 | elif dim == 3:
93 | if do_bspline:
94 | # B样条变换(B Spline)
95 | m=5
96 | spline_order = 3
97 | bspline = sitk.BSplineTransform(dim, spline_order)
98 | bspline.SetTransformDomainPhysicalDimensions(image.GetSize())
99 | mesh_size = [m, m, m]
100 | bspline.SetTransformDomainMeshSize(mesh_size)
101 | # Random displacement of the control points.
102 | # [13,18]为变形的强度,值越大变形越大
103 | originalControlPointDisplacements = np.random.random(len(bspline.GetParameters())) * np.random.randint(13,18)
104 | bspline.SetParameters(originalControlPointDisplacements)
105 | image = resample(image, bspline)
106 | if do_flip:
107 | # 翻转(flip)
108 | if np.random.rand() > 0.5:
109 | image = image[::-1] # 沿x轴反转
110 | if show_help_info:
111 | print("x axis flip.")
112 | if np.random.rand() > 0.5:
113 | image = image[:, ::-1] # 沿y轴反转
114 | if show_help_info:
115 | print("y axis flip.")
116 | if np.random.rand() > 0.5:
117 | image = image[:, :, ::-1] # 沿z轴反转
118 | if show_help_info:
119 | print("z axis flip.")
120 | if do_trslt:
121 | # 平移(translation)
122 | x_trslt, y_trslt, z_trslt = np.random.randint(-20, 20, 3) # 每个轴的平移范围在[-20,20]像素之间
123 | transform.SetTranslation((float(x_trslt), float(y_trslt), float(z_trslt)))
124 | if show_help_info:
125 | print("x_trslt: ", x_trslt, " y_trslt: ", y_trslt, " z_trslt: ", z_trslt)
126 | if do_scale:
127 | # 缩放(scale)
128 | x_scale = 1.0 + random.uniform(-0.1, 0.1) # 缩放范围为原来的[0.9,1.1]
129 | y_scale = 1.0 + random.uniform(-0.1, 0.1)
130 | z_scale = 1.0 + random.uniform(-0.1, 0.1)
131 | # x_scale, y_scale, z_scale表示原图与结果图的倍数关系,如scale为2时缩小为原来的0.5
132 | matrix[0, 0] = x_scale
133 | matrix[1, 1] = y_scale
134 | matrix[2, 2] = z_scale
135 | if show_help_info:
136 | print("x_scale: ", x_scale, " y_scale: ", y_scale, " z_scale: ", z_scale)
137 | if do_rota:
138 | # 旋转(rotation)
139 | x_dgr, y_dgr, z_dgr = np.random.randint(-15, 15, 3) # 旋转角度范围为[-15°,15°]
140 | x_rad = -np.pi * x_dgr / 180.
141 | y_rad = -np.pi * y_dgr / 180.
142 | z_rad = -np.pi * z_dgr / 180.
143 | rotation = np.array([[np.cos(y_rad) * np.cos(z_rad), np.cos(y_rad) * np.sin(z_rad), -np.sin(y_rad)],
144 | [-np.cos(x_rad) * np.sin(z_rad) + np.sin(x_rad) * np.sin(y_rad) * np.cos(z_rad),
145 | np.cos(x_rad) * np.cos(z_rad) + np.sin(x_rad) * np.sin(y_rad) * np.sin(z_rad),
146 | np.sin(x_rad) * np.cos(y_rad)],
147 | [np.sin(x_rad) * np.sin(z_rad) + np.cos(x_rad) * np.sin(y_rad) * np.cos(z_rad),
148 | -np.sin(x_rad) * np.cos(z_rad) + np.cos(x_rad) * np.sin(y_rad) * np.sin(z_rad),
149 | np.cos(x_rad) * np.cos(y_rad)]])
150 | matrix = np.dot(rotation, matrix)
151 | if show_help_info:
152 | print("x_dgr: ", x_dgr, " y_dgr: ", y_dgr, " z_dgr: ", z_dgr)
153 |
154 | transform.SetMatrix(matrix.ravel())
155 | # 以下两行是为了让图像的中心点为物体的中心点
156 | center = image.GetOrigin() + np.array(image.GetSize()) * np.array(image.GetSpacing()) * 0.5
157 | transform.SetCenter(center)
158 | out = resample(image, transform)
159 |
160 | # 保持和原图像相同的窗宽和窗位,同时做了归一化
161 | # 不加此步,则图像会在缩放和旋转的时候导致灰度值改变
162 | out_arr = sitk.GetArrayFromImage(out)
163 | min_window = float(window_center) - 0.5 * float(window_width)
164 | out_arr = (out_arr - min_window) / float(window_width)
165 | out_arr[out_arr < 0] = 0.
166 | out_arr[out_arr > 1] = 1.
167 | if dim == 2:
168 | return out_arr, dim
169 | elif dim == 3:
170 | return sitk.GetImageFromArray(out_arr), dim
171 |
172 |
173 | if __name__ == "__main__":
174 | path_2D = r"C:\Users\zuzhiang\Desktop\1.jpg"
175 | path_3D = r"C:\Data\1.Work\02.Code\DDR\results\f_img.nii"
176 | out_path_3D = r"C:\Users\zuzhiang\Desktop\out.nii"
177 | # out, dim = data_augmentation(path_3D)
178 | out, dim = data_augmentation(image=sitk.ReadImage(path_3D))
179 | if dim == 2:
180 | plt.imshow(out)
181 | plt.title("result image")
182 | plt.show()
183 | elif dim == 3:
184 | sitk.WriteImage(out, out_path_3D)
185 | print("end")
186 |
--------------------------------------------------------------------------------
/dcm2nii.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SimpleITK as sitk
3 |
4 |
5 | def dcm2nii(dcm, nii):
6 | # GetGDCMSeriesIDs读取序列号相同的dcm文件
7 | series_id = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dcm)
8 | # GetGDCMSeriesFileNames读取序列号相同dcm文件的路径,series[0]代表第一个序列号对应的文件
9 | series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dcm, series_id[0])
10 | print(len(series_file_names))
11 | series_reader = sitk.ImageSeriesReader()
12 | series_reader.SetFileNames(series_file_names)
13 | image3d = series_reader.Execute()
14 | print("type: ",type(image3d))
15 | sitk.WriteImage(image3d, nii)
16 |
17 |
18 | if __name__=="__main__":
19 | '''
20 | dcm对应的文件夹下有很多子文件夹,每个子文件夹是一套dicom图像,将所有的dicom
21 | 图像转换为.nii格式的图像
22 | '''
23 | dcm=input("dcm_dir:\n")
24 | nii = input("nii_dir:\n")
25 | for root,dirs,files in os.walk(dcm):
26 | print(dirs)
27 | break
28 | for dir in dirs:
29 | dcm_dir=dcm+"\\"+dir
30 | if not os.listdir(dcm_dir): #若文件夹为空则不处理
31 | continue
32 | nii_file=nii+"\\"+dir+".nii"
33 | dcm2nii(dcm_dir,nii_file) #将dicom转为nii
--------------------------------------------------------------------------------
/image_combine_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import skimage.util as su
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def paste(img, canvas, i, j, method='replace', export_dtype='float'):
7 | """ paste `img` on `canvas` with its left-top corner at (i, j) """
8 | # check dtypes
9 | img = su.img_as_float(img)
10 | canvas = su.img_as_float(canvas)
11 | # check shapes
12 | if len(img.shape) != 2 or len(img.shape) != 3:
13 | if len(canvas.shape) != 2 or len(canvas.shape) != 3:
14 | assert AttributeError('dimensions of input images not all equal to 2 or 3!')
15 | # check channels
16 | # all grayscale image
17 | if len(img.shape) == 2 and len(canvas.shape) == 2:
18 | pass
19 | # `img` color image, possible with alpha channel; `canvas` grayscale image
20 | elif len(img.shape) == 3 and len(canvas.shape) == 2:
21 | c = img.shape[-1]
22 | if c == 3:
23 | canvas = np.stack([canvas]*c, axis=-1)
24 | if c == 4:
25 | canvas = np.stack([canvas]*(c-1)+[np.ones((canvas.shape[0], canvas.shape[1]))], axis=-1)
26 | # `canvas` color image, possible with alpha channel; `img` grayscale image
27 | elif len(img.shape) == 2 and len(canvas.shape) == 3:
28 | c = canvas.shape[-1]
29 | if c == 3:
30 | img = np.stack([img]*c, axis=-1)
31 | if c == 4:
32 | img = np.stack([img]*(c-1)+[np.ones((img.shape[0], img.shape[1]))], axis=-1)
33 | # all color image
34 | elif len(img.shape) == 3 and len(canvas.shape) == 3:
35 | if img.shape[-1] == 3 and canvas.shape[-1] == 4:
36 | img = np.concatenate([img, np.ones((img.shape[0], img.shape[1], 1))], -1)
37 | elif img.shape[-1] == 4 and canvas.shape[-1] == 3:
38 | canvas = np.concatenate([canvas, np.ones((canvas.shape[0], canvas.shape[1], 1))], -1)
39 | elif img.shape[-1] == canvas.shape[-1]:
40 | pass
41 | else:
42 | assert ValueError('channel number should equal to 3 or 4!')
43 | # get shapes
44 | h_i, w_i = img.shape[:2]
45 | h_c, w_c = canvas.shape[:2]
46 | # find extent of `img` on `canvas`
47 | i_min = np.max([0, i])
48 | i_max = np.min([h_c, i+h_i])
49 | j_min = np.max([0, j])
50 | j_max = np.min([w_c, j+w_i])
51 | # paste `img` on `canvas`
52 | if method == 'replace':
53 | canvas[i_min:i_max, j_min:j_max] = img[i_min-i:i_max-i, j_min-j:j_max-j]
54 | elif method == 'add':
55 | canvas[i_min:i_max, j_min:j_max] += img[i_min-i:i_max-i, j_min-j:j_max-j]
56 | else:
57 | raise ValueError('no such method!')
58 | # return `canvas`
59 | if export_dtype == 'float':
60 | return canvas
61 | elif export_dtype == 'ubyte':
62 | return su.img_as_ubyte(canvas)
63 | else:
64 | raise ValueError('no such data type for exporting!')
65 |
66 |
67 | def combine_avg(imgs, num_w=10, strides=(10, 10), padding=5, bg_level_1=1.0, bg_level_2=1.0, export_dtype='float'):
68 | """ paste contents of `imgs` on a single image with `strides` """
69 | # dtypes check
70 | imgs = [su.img_as_float(img) for img in imgs]
71 | # shapes check
72 | shapes = [img.shape for img in imgs]
73 | if not all([len(s) == 2 or len(s) == 3 for s in shapes]):
74 | assert AttributeError('dimensions of imgs not all 2 or 3!')
75 | # find the shape of canvas
76 | n = len(imgs)
77 | num_h = (n - 1) // num_w + 1
78 | h = strides[0]*(num_h-1) + shapes[-1][0]
79 | w = strides[1]*(num_w-1) + shapes[-1][1]
80 | lt_poses = [(strides[0]*i, strides[1]*j) for i in range(num_h) for j in range(num_w) if i*num_w+j
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/phi/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/phi/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/phi/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/phi/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/phi/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 | 1592118601422
166 |
167 |
168 | 1592118601422
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/phi/f_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/f_img.jpg
--------------------------------------------------------------------------------
/phi/img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img.jpg
--------------------------------------------------------------------------------
/phi/img1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img1.jpg
--------------------------------------------------------------------------------
/phi/img2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img2.jpg
--------------------------------------------------------------------------------
/phi/m_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/m_img.jpg
--------------------------------------------------------------------------------
/phi/my_show_grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import SimpleITK as sitk
3 | import matplotlib.pyplot as plt
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | # 空间转换网络
11 | class SpatialTransformer(nn.Module):
12 | # 1.生成网格grid;2.new_grid=grid+flow,即旧网格加上一个位移;3.将网格规范化到[-1,1];4.根据新网格对原图进行采样
13 | def __init__(self, size, mode='bilinear'):
14 | """
15 | Instiatiate the block
16 | :param size: size of input to the spatial transformer block
17 | :param mode: method of interpolation for grid_sampler
18 | """
19 | super(SpatialTransformer, self).__init__()
20 |
21 | # Create sampling grid
22 | vectors = [torch.arange(0, s) for s in size]
23 | grids = torch.meshgrid(vectors)
24 | grid = torch.stack(grids) # y, x, z
25 | grid = torch.unsqueeze(grid, 0) # add batch
26 | grid = grid.type(torch.FloatTensor)
27 | self.register_buffer('grid', grid)
28 |
29 | self.mode = mode
30 |
31 | def forward(self, src, flow):
32 | """
33 | Push the src and flow through the spatial transform block
34 | :param src: the original moving image
35 | :param flow: the output from the U-Net
36 | """
37 | new_locs = self.grid + flow
38 |
39 | shape = flow.shape[2:]
40 |
41 | # Need to normalize grid values to [-1, 1] for resampler
42 | for i in range(len(shape)):
43 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
44 |
45 | if len(shape) == 2:
46 | new_locs = new_locs.permute(0, 2, 3, 1) # 维度置换,变为0,2,3,1
47 | new_locs = new_locs[..., [1, 0]]
48 | elif len(shape) == 3:
49 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
50 | new_locs = new_locs[..., [2, 1, 0]]
51 |
52 | return F.grid_sample(src, new_locs, mode=self.mode)
53 |
54 |
55 | # 生成网格图片
56 | def create_grid(size, path):
57 | num1, num2 = (size[0] + 10) // 10, (size[1] + 10) // 10 # 改变除数(10),即可改变网格的密度
58 | x, y = np.meshgrid(np.linspace(-2, 2, num1), np.linspace(-2, 2, num2))
59 |
60 | plt.figure(figsize=((size[0] + 10) / 100.0, (size[1] + 10) / 100.0)) # 指定图像大小
61 | plt.plot(x, y, color="black")
62 | plt.plot(x.transpose(), y.transpose(), color="black")
63 | plt.axis('off') # 不显示坐标轴
64 | # 去除白色边框
65 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
66 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
67 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
68 | plt.margins(0, 0)
69 | plt.savefig(path) # 保存图像
70 | # plt.show()
71 |
72 |
73 | if __name__ == "__main__":
74 | out_path = r"C:\Users\zuzhiang\Desktop\new_img.nii" # 图片保存路径
75 | # 读入形变场
76 | phi = sitk.ReadImage("./2D1.nii") # [324,303,2]
77 | phi_arr = torch.from_numpy(sitk.GetArrayFromImage(phi)).float()
78 | phi_shape = phi_arr.shape
79 | # 产生网格图片
80 | create_grid(phi_shape, out_path)
81 | img = sitk.GetArrayFromImage(sitk.ReadImage(out_path))[..., 0]
82 | img = np.squeeze(img)[np.newaxis, np.newaxis, :phi_shape[0], :phi_shape[1]]
83 | # 用STN根据形变场对网格图片进行变形
84 | STN = SpatialTransformer(phi_shape[:2])
85 | phi_arr = phi_arr.permute(2, 0, 1)[np.newaxis, ...]
86 | warp = STN(torch.from_numpy(img).float(), phi_arr)
87 | # 保存图片
88 | warp_img = sitk.GetImageFromArray(warp[0, 0, ...].numpy().astype(np.uint8))
89 | sitk.WriteImage(warp_img, out_path)
90 | print("end")
91 |
--------------------------------------------------------------------------------
/phi/show_grid.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import SimpleITK as sitk
3 | import numpy as np
4 |
5 |
6 | def grid2contour(grid, title):
7 | '''
8 | grid--image_grid used to show deform field
9 | type: numpy ndarray, shape: (h, w, 2), value range:(-1, 1)
10 | '''
11 | assert grid.ndim == 3
12 | x = np.arange(-1, 1, 2.0 / grid.shape[1])
13 | y = np.arange(-1, 1, 2.0 / grid.shape[0])
14 | X, Y = np.meshgrid(x, y)
15 | Z1 = grid[:, :, 0] + 2 # remove the dashed line
16 | Z1 = Z1[::-1] # vertical flip
17 | Z2 = grid[:, :, 1] + 2
18 |
19 | plt.figure()
20 | plt.contour(X, Y, Z1, 15, levels=50, colors='k') #改变levels的值,可以改变形变场的外貌
21 | plt.contour(X, Y, Z2, 15, levels=50, colors='k')
22 | plt.xticks(()), plt.yticks(()) # remove x, y ticks
23 | plt.title(title)
24 | plt.show()
25 |
26 |
27 | def show_grid():
28 | img = sitk.ReadImage(r"C:\Users\zuzhiang\Desktop\7_flow.nii.gz")
29 | img_arr = sitk.GetArrayFromImage(img)[:,:,0,:2]
30 | img_shape = img_arr.shape
31 | print("shape: ", img_shape)
32 |
33 | # 起点、终点、步长(可为小数)
34 | x = np.arange(-1, 1, 2 / img_shape[1])
35 | y = np.arange(-1, 1, 2 / img_shape[0])
36 | X, Y = np.meshgrid(x, y)
37 | regular_grid = np.stack((X, Y), axis=2)
38 | grid2contour(regular_grid, "regular_grid")
39 |
40 | rand_field = np.random.rand(*img_shape[:2], 2) # 参数前加*是以元组形式导入
41 | rand_field_norm = rand_field.copy()
42 | rand_field_norm[:, :, 0] = rand_field_norm[:, :, 0] * 2 / img_shape[1]
43 | rand_field_norm[:, :, 1] = rand_field_norm[:, :, 1] * 2 / img_shape[0]
44 | sampling_grid = regular_grid + rand_field_norm
45 | grid2contour(sampling_grid, "sampling_grid")
46 |
47 | img_arr[..., 0] = img_arr[..., 0] * 2 / img_shape[1]
48 | img_arr[..., 1] = img_arr[..., 1] * 2 / img_shape[0]
49 | img_grid = regular_grid + img_arr
50 | grid2contour(img_grid, "img_grid")
51 |
52 |
53 | if __name__ == "__main__":
54 | show_grid()
55 | print("end")
56 |
--------------------------------------------------------------------------------
/resize_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import SimpleITK as sitk
5 | from skimage.transform import resize
6 | import models.data_augmentation as da
7 |
8 | path = r"C:\Data\1.Work\03.DataSet\LPBA40"
9 | out_path = r"C:\Data\1.Work\03.DataSet\LPBA40_DA"
10 | #file_lst = glob.glob(os.path.join(path, '*/*bfc.nii'))
11 | file_lst=[r"C:\Data\1.Work\02.Code\DDR\fixed_LPBA40.nii"]
12 | num = len(file_lst)
13 | # for file in file_lst:
14 | # img, dim = da.data_augmentation(file, do_bspline=False, do_trslt=True, do_scale=True, do_rota=True,
15 | # show_help_info=True)
16 | # _, name = os.path.split(file)
17 | # print("path: ", os.path.join(out_path, name))
18 | # sitk.WriteImage(img, os.path.join(out_path, name))
19 |
20 | for file in file_lst:
21 | # img, dim = da.data_augmentation(file, do_bspline=False, do_trslt=False, do_scale=False, do_rota=False,
22 | # show_help_info=True)
23 | img = sitk.ReadImage(file)
24 | img_arr = sitk.GetArrayFromImage(img)
25 | new_shape = [256, 128, 256]
26 | new_img = resize(img_arr, new_shape, order=3, mode='constant', cval=0, preserve_range=bool)
27 | # order几次样条插值,cval外部补零,保留原数据(否则被默认标准化、归一化之类的乱七八糟)
28 | '''
29 | The order of interpolation. The order has to be in the range 0-5:
30 | 0: Nearest-neighbor
31 | 1: Bi-linear (default)
32 | 2: Bi-quadratic
33 | 3: Bi-cubic
34 | 4: Bi-quartic
35 | 5: Bi-quintic
36 | '''
37 | new_img = sitk.GetImageFromArray(new_img)
38 | new_img.SetOrigin(img.GetOrigin())
39 | new_img.SetSpacing(img.GetSpacing())
40 | new_img.SetDirection(img.GetDirection())
41 | print("file: ", file)
42 | #temp_path = r"C:\Users\zuzhiang\Desktop\out.nii"
43 | sitk.WriteImage(new_img, file)
44 |
--------------------------------------------------------------------------------
/torch_datagenerators.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import SimpleITK as sitk
4 | import torch.utils.data as Data
5 |
6 | '''
7 | 通过继承Data.Dataset,实现将一组Tensor数据对封装成Tensor数据集
8 | 至少要重载__init__,__len__和__getitem__方法
9 | '''
10 | class Dataset(Data.Dataset):
11 | def __init__(self, path):
12 | # 初始化
13 | self.files = glob.glob(path)
14 |
15 | def __len__(self):
16 | # 返回数据集的大小
17 | return len(self.files)
18 |
19 | def __getitem__(self, index):
20 | # 索引数据集中的某个数据,还可以对数据进行预处理
21 | # 下标index参数是必须有的,名字任意
22 | img_arr = sitk.GetArrayFromImage(sitk.ReadImage(self.files[index]))[..., np.newaxis]
23 | return img_arr
--------------------------------------------------------------------------------