├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── preprocess.iml └── vcs.xml ├── CBAM.py ├── FreeSurfer+FSL.py ├── FreeSurfer.py ├── README.md ├── ants.py ├── antspy.py ├── crop_resize.py ├── data_augmentation.py ├── dcm2nii.py ├── image_combine_example.py ├── img2nii.py ├── jacobian.py ├── phi ├── .idea │ ├── NII.iml │ ├── deployment.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── workspace.xml ├── f_img.jpg ├── img.jpg ├── img1.jpg ├── img2.jpg ├── m_img.jpg ├── my_show_grid.py └── show_grid.py ├── resize_img.py └── torch_datagenerators.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../:\Data\1.Work\02.Code\preprocess\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/preprocess.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CBAM_Module(nn.Module): 6 | def __init__(self, dim, in_channels, ratio, kernel_size): 7 | super(CBAM_Module, self).__init__() 8 | self.avg_pool = getattr(nn, "AdaptiveAvgPool{0}d".format(dim))(1) 9 | self.max_pool = getattr(nn, "AdaptiveMaxPool{0}d".format(dim))(1) 10 | conv_fn = getattr(nn, "Conv{0}d".format(dim)) 11 | self.fc1 = conv_fn(in_channels, in_channels // ratio, kernel_size=1, padding=0) 12 | self.relu = nn.ReLU() 13 | self.fc2 = conv_fn(in_channels // ratio, in_channels, kernel_size=1, padding=0) 14 | self.sigmoid = nn.Sigmoid() 15 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 16 | padding = 3 if kernel_size == 7 else 1 17 | self.conv = conv_fn(2, 1, kernel_size=kernel_size, stride=1, padding=padding) 18 | 19 | def forward(self, x): 20 | print("CBAM") 21 | # Channel attention module:(Mc(f) = σ(MLP(AvgPool(f)) + MLP(MaxPool(f)))) 22 | module_input = x 23 | avg = self.fc2(self.relu(self.fc1(self.avg_pool(x)))) 24 | mx = self.fc2(self.relu(self.fc1(self.max_pool(x)))) 25 | x = self.sigmoid(avg + mx) 26 | x = module_input * x 27 | # Spatial attention module:Ms (f) = σ( f7×7( AvgPool(f) ; MaxPool(F)] ))) 28 | module_input = x 29 | avg = torch.mean(x, dim=1, keepdim=True) 30 | mx, _ = torch.max(x, dim=1, keepdim=True) 31 | x = torch.cat((avg, mx), dim=1) 32 | x = self.sigmoid(self.conv(x)) 33 | x = module_input * x 34 | return x -------------------------------------------------------------------------------- /FreeSurfer+FSL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | # 使用FreeSurfer对图像进行颅骨剥离 6 | print("FreeSurfer start......\n") 7 | # 图像坐在的目录 8 | #------------------------图像路径需更改------------------------# 9 | path ="/home/syzhou/zuzhiang/Dataset/MGH10/Heads" 10 | # 读取目录下的.img文件列表,*.img表示该目录下所有以.img结尾的文件 11 | #-----------------------图像后缀名需更改--- -------------------# 12 | files = glob.glob(os.path.join(path,"*.img")) 13 | # 输出路径 14 | #------------------------输出路径需更改------------------------# 15 | out_path="/home/syzhou/zuzhiang/MIP/FSL_img/MGH10_results" 16 | print("number: ",len(files)) 17 | # 下面为freesurfer的环境配置命令 18 | a = "export FREESURFER_HOME=/home/syzhou/zuzhiang/freesurfer;" 19 | b = "source $FREESURFER_HOME/SetUpFreeSurfer.sh;" 20 | # 数据所在的目录 21 | c = "export SUBJECTS_DIR="+path+";" 22 | 23 | for file in files: 24 | # 将文件路径和文件名分离 25 | filename = os.path.split(file)[1] # 将路径名和文件名分开 26 | filename = filename.split(".")[0] # 去除所有扩展名 27 | #recon-all是颅骨去除的命令 28 | # mri_convert是进行格式转换,从mgz转到nii.gz,只是为了方便查看 29 | filename=filename[:] #根据扩展名的不同,这里需要做更改,只保留文件名即可 30 | # 当前输出文件路径,以.nii.gz格式保存 31 | cur_out_path=os.path.join(out_path,filename+".nii.gz") 32 | print("file name: ",file) 33 | cmd = a + b + c + "mri_watershed "+file+" "+ cur_out_path 34 | #print(cmd,"\n") 35 | os.system(cmd) 36 | 37 | 38 | # 使用FSL对图像和对应的label进行仿射对齐 39 | print("FSL start......\n") 40 | # fixed图像的路径 41 | #---------------去除头骨后的fixed图像名需更改-------------------# 42 | f_path= os.path.join(out_path,"g1.nii.gz") 43 | # moving图像的路径 44 | m_path=out_path 45 | # label的路径 46 | #-----------------------label路径需更改-----------------------# 47 | label_path="/home/syzhou/zuzhiang/Dataset/MGH10/Atlases" 48 | files=glob.glob(os.path.join(m_path,"*.nii.gz")) 49 | print("number: ",len(files)) 50 | for file in files: 51 | print("file: ",file) 52 | # 根据图像名找到对应的label名 53 | filename=os.path.split(file)[1] 54 | filename = filename.split(".")[0] # 去除所有扩展名 55 | #---------------------label后缀名需更改--------------------# 56 | label=os.path.join(label_path,filename+".img") 57 | # 下面分别是输出图像名/转换矩阵名/label名, 58 | out_img=os.path.join(out_path,filename+"_img.nii.gz") 59 | out_mat=os.path.join(out_path,filename+".mat") 60 | out_label=os.path.join(out_path,filename+"_label.nii.gz") 61 | # 如果当前文件和fixed图像一样则只将对应label格式转换为.nii.gz 62 | if f_path==file: 63 | convert="mri_convert " + label +" " + out_label 64 | os.system(convert) 65 | print("continue.........") 66 | continue 67 | # 将moving图像对齐到fixed图像 68 | flirt_img="flirt -in "+file+ " -ref "+f_path+" -out "+out_img+" -omat "+out_mat+ " -dof 12" 69 | # 将上一步的仿射变换矩阵作用在图像对应的label上 70 | flirt_label="flirt -in "+label+" -ref "+f_path+" -out "+out_label+" -init "+out_mat+" -applyxfm -interp nearestneighbour" 71 | #print(flirt_img,"\n") 72 | #print(flirt_label,"\n") 73 | os.system(flirt_img) 74 | os.system(flirt_label) 75 | 76 | print("\n\nEnd") -------------------------------------------------------------------------------- /FreeSurfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | path = r"/home/syzhou/zuzhiang/Dataset/MGH10/Heads" 5 | # 读取目录下的nii.gz文件 6 | images = glob.glob(os.path.join(path,"*.img")) 7 | # 下面为freesurfer的环境配置命令 8 | a = "export FREESURFER_HOME=/home/syzhou/zuzhiang/freesurfer;" 9 | b = "source $FREESURFER_HOME/SetUpFreeSurfer.sh;" 10 | # 数据所在的目录 11 | c = "export SUBJECTS_DIR="+path+";" 12 | 13 | #images=['/home/syzhou/zuzhiang/Dataset/MGH10/Heads/1127.img'] 14 | for image in images: 15 | # 将文件路径和文件名分离 16 | filename = os.path.split(image)[1] # 将路径名和文件名分开 17 | filename = os.path.splitext(filename)[0] #将文件名和扩展名分开,如果为.nii.gz,则认为扩展名是.gz 18 | # freesurfer环境配置、颅骨去除、未仿射对齐mpz转nii、仿射对齐、仿射对齐mpz转nii.gz格式 19 | #recon-all是颅骨去除的命令 20 | # mri_convert是进行格式转换,从mgz转到nii.gz,只是为了方便查看 21 | # --apply_transform:仿射对齐操作 22 | # 转格式 23 | filename=filename[:] #根据扩展名的不同,这里需要做更改,只保留文件名即可 24 | cur_path=os.path.join(path,filename) 25 | print("file name: ",cur_path) 26 | cmd = a + b + c \ 27 | + "recon-all -parallel -i " + image + " -autorecon1 -subjid " + cur_path + "&&" \ 28 | + "mri_convert " + cur_path + "/mri/brainmask.mgz " +cur_path + "/mri/"+filename+".nii.gz;"\ 29 | + "mri_convert " + cur_path + "/mri/brainmask.mgz --apply_transform " + cur_path + "/mri/transforms/talairach.xfm -o " + cur_path + "/mri/brainmask_affine.mgz&&" \ 30 | + "mri_convert " + cur_path + "/mri/brainmask_affine.mgz " + cur_path + "/mri/"+filename+"_affine.nii.gz;" 31 | #print("cmd:\n",cmd) 32 | os.system(cmd) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 数据很杂,就懒得逐一整理了,很多都是实验性的,可能存在问题,不过好久没弄也忘得差不多,有问题的话还请大家自行阅读代码解决吧,下面只把每个文件是做什么用的大体说一下: 2 | 3 | - `phi`: 4 | - `my_show_grid.py`:先产生规则的网格图片,然后使用STN根据读入的形变场对其变形 5 | - `show_grid.py`:读入形变场的.nii文件,然后生成可视化的形变场图片 6 | - `show_nii`: 7 | - `show_nii.py`:展示三维的.nii图像 8 | - `show_nii3.py`:同时展示三维的.nii格式的器官、器官分割mask、病灶分割mask图像 9 | - `ants.py`:使用ANTs包内的SyN对图像进行配准 10 | - `antspy.py`:使用基于python版的ANTs——antspy对图像和其对应的标签同时进行配准 11 | - `CBAM`:CBAM注意力模块的实现 12 | - `crop_resize.py`:先找到包含脑部区域的最小矩形框,然后手动计算矩形框的大小(每一维应该是16的倍数),然后用已经注释掉的部分进行resize 13 | - `data_augmentation.py`:数据增强的代码,包括B样条采样、反转、平移、缩放、旋转、灰度值均衡化等,但B样条采样貌似有问题 14 | - `dcm2nii.py`:将dcm格式的三维图像转换为.nii格式的 15 | - `FreeSurfer+FSL.py`:先用FreeSurfer对脑部图像进行颅骨剥离,然后用FSL对图像和标签同时进行仿射对齐 16 | - `FreeSurfer.py`:用FreeSurfer对脑部图像进行颅骨剥离 17 | - `image_combine_example.py`:别人给的,没用过 18 | - `img2nii.py`:将.img格式的三维图像转化为.nii格式的 19 | - `jacobian.py`:计算形变场的雅克比行列式 20 | - `resize_img.py`:对图像进行resize 21 | - `tor_datagenerators.py`:基于pytorch的数据生成器 22 | -------------------------------------------------------------------------------- /ants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | path="/home/syzhou/zuzhiang/Dataset/LPBA40" 5 | out_path="/home/syzhou/zuzhiang/MIP/ANTs" 6 | f_name="/home/syzhou/zuzhiang/Dataset/LPBA40/1.nii.gz" 7 | for i in range(2,41): 8 | m_name=os.path.join(path,str(i)+".nii.gz") 9 | out_name=str(i)+"m" 10 | cmd= "antsRegistrationSyN.sh -d 3 -f " + f_name + " -m " + m_name + " -o " + out_name 11 | print("cmd: ",cmd) 12 | os.system(cmd) 13 | print("End") 14 | -------------------------------------------------------------------------------- /antspy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import ants 4 | import numpy as np 5 | import SimpleITK as sitk 6 | 7 | # ants图片的读取 8 | f_img = ants.image_read("./data/f_img.nii.gz") 9 | m_img = ants.image_read("./data/m_img.nii.gz") 10 | f_label = ants.image_read("./data/f_label.nii.gz") 11 | m_label = ants.image_read("./data/m_label.nii.gz") 12 | 13 | ''' 14 | ants.registration()函数的返回值是一个字典: 15 | warpedmovout: 配准到fixed图像后的moving图像 16 | warpedfixout: 配准到moving图像后的fixed图像 17 | fwdtransforms: 从moving到fixed的形变场 18 | invtransforms: 从fixed到moving的形变场 19 | 20 | type_of_transform参数的取值可以为: 21 | Rigid:刚体 22 | Affine:仿射配准,即刚体+缩放 23 | ElasticSyN:仿射配准+可变形配准,以MI为优化准则,以elastic为正则项 24 | SyN:仿射配准+可变形配准,以MI为优化准则 25 | SyNCC:仿射配准+可变形配准,以CC为优化准则 26 | ''' 27 | # 图像配准 28 | mytx = ants.registration(fixed=f_img, moving=m_img, type_of_transform='SyN') 29 | # 将形变场作用于moving图像,得到配准后的图像,interpolator也可以选择"nearestNeighbor"等 30 | warped_img = ants.apply_transforms(fixed=f_img, moving=m_img, transformlist=mytx['fwdtransforms'], 31 | interpolator="linear") 32 | # 对moving图像对应的label图进行配准 33 | warped_label = ants.apply_transforms(fixed=f_img, moving=m_label, transformlist=mytx['fwdtransforms'], 34 | interpolator="linear") 35 | # 将配准后图像的direction/origin/spacing和原图保持一致 36 | warped_img.set_direction(f_img.direction) 37 | warped_img.set_origin(f_img.origin) 38 | warped_img.set_spacing(f_img.spacing) 39 | warped_label.set_direction(f_img.direction) 40 | warped_label.set_origin(f_img.origin) 41 | warped_label.set_spacing(f_img.spacing) 42 | img_name = "./result/warped_img.nii.gz" 43 | label_name = "./result/warped_label.nii.gz" 44 | # 图像的保存 45 | ants.image_write(warped_img, img_name) 46 | ants.image_write(warped_label, label_name) 47 | 48 | # 将antsimage转化为numpy数组 49 | warped_img_arr = warped_img.numpy(single_components=False) 50 | # 从numpy数组得到antsimage 51 | img = ants.from_numpy(warped_img_arr, origin=None, spacing=None, direction=None, has_components=False, is_rgb=False) 52 | # 生成图像的雅克比行列式 53 | jac = ants.create_jacobian_determinant_image(domain_image=f_img, tx=mytx["fwdtransforms"][0], do_log=False, geom=False) 54 | ants.image_write(jac, "./result/jac.nii.gz") 55 | # 生成带网格的moving图像,实测效果不好 56 | m_grid = ants.create_warped_grid(m_img) 57 | m_grid = ants.create_warped_grid(m_grid, grid_directions=(False, False), transform=mytx['fwdtransforms'], 58 | fixed_reference_image=f_img) 59 | ants.image_write(m_grid, "./result/m_grid.nii.gz") 60 | 61 | ''' 62 | 以下为其他不常用的函数: 63 | 64 | ANTsTransform.apply_to_image(image, reference=None, interpolation='linear') 65 | ants.read_transform(filename, dimension=2, precision='float') 66 | # transform的格式是".mat" 67 | ants.write_transform(transform, filename) 68 | # field是ANTsImage类型 69 | ants.transform_from_displacement_field(field) 70 | ''' 71 | 72 | print("End") 73 | -------------------------------------------------------------------------------- /crop_resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import SimpleITK as sitk 5 | 6 | 7 | # 获取可以包裹mask的最小bounding box 8 | def get_bbox_from_mask(mask, outside_value=0): 9 | mask_voxel_coords = np.where(mask != outside_value) 10 | minzidx = int(np.min(mask_voxel_coords[0])) 11 | maxzidx = int(np.max(mask_voxel_coords[0])) + 1 12 | minxidx = int(np.min(mask_voxel_coords[1])) 13 | maxxidx = int(np.max(mask_voxel_coords[1])) + 1 14 | minyidx = int(np.min(mask_voxel_coords[2])) 15 | maxyidx = int(np.max(mask_voxel_coords[2])) + 1 16 | return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]] 17 | 18 | 19 | # 根据bbox截取图片 20 | def crop_to_bbox(image, bbox): 21 | assert len(image.shape) == 3, "only supports 3d images" 22 | # slice是切片函数,参数为:起始值,终止值,[步长] 23 | resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) 24 | return image[resizer] 25 | 26 | 27 | out_path = r"C:\Users\zuzhiang\Desktop" 28 | files = glob.glob(os.path.join(r"C:\Data\1.Work\03.DataSet\LPBA40\train\256-256-256", "*.nii.gz")) 29 | print("Number of images: ", len(files)) 30 | 31 | minx, miny, minz, maxx, maxy, maxz = 300, 300, 300, 0, 0, 0 32 | for file in files: 33 | _, name = os.path.split(file) 34 | print("Name: ", name) 35 | old_img = sitk.ReadImage(file) 36 | img_arr = sitk.GetArrayFromImage(old_img) 37 | bbox = get_bbox_from_mask(img_arr) 38 | if bbox[0][0] < minx: 39 | minx = bbox[0][0] 40 | if bbox[0][1] > maxx: 41 | maxx = bbox[0][1] 42 | if bbox[1][0] < miny: 43 | miny = bbox[1][0] 44 | if bbox[1][1] > maxy: 45 | maxy = bbox[1][1] 46 | if bbox[2][0] < minz: 47 | minz = bbox[2][0] 48 | if bbox[2][1] > maxz: 49 | maxz = bbox[2][1] 50 | print("bbox: ", bbox) 51 | print("\nminx: %d maxx: %d miny: %d maxy: %d minz: %d maxz: %d" % (minx, maxx, miny, maxy, minz, maxz)) 52 | 53 | # # 根据以上6个值手工得到bbox 54 | # bbox = [[0, 240], [21, 229], [40, 216]] 55 | # print("bbox: ", bbox) 56 | # for file in files: 57 | # _, name = os.path.split(file) 58 | # print("Name: ", name) 59 | # old_img = sitk.ReadImage(file) 60 | # img_arr = sitk.GetArrayFromImage(old_img) 61 | # img_arr = crop_to_bbox(img_arr, bbox) 62 | # print("new shape: ", img_arr.shape) 63 | # img = sitk.GetImageFromArray(img_arr) 64 | # img.SetOrigin(old_img.GetOrigin()) 65 | # img.SetDirection(old_img.GetDirection()) 66 | # img.SetSpacing(old_img.GetSpacing()) 67 | # sitk.WriteImage(img, os.path.join(out_path, name)) 68 | -------------------------------------------------------------------------------- /data_augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import SimpleITK as sitk 4 | import matplotlib.pyplot as plt 5 | 6 | np.set_printoptions(threshold=np.inf) 7 | 8 | 9 | def resample(image, transform): 10 | # Output image Origin, Spacing, Size, Direction are taken from the reference 11 | # image in this call to Resample 12 | reference_image = image 13 | interpolator = sitk.sitkCosineWindowedSinc 14 | default_value = -100.0 15 | return sitk.Resample(image, reference_image, transform, 16 | interpolator, default_value) 17 | 18 | 19 | # 数据增强函数,可以对2D/3D图像随机进行直方图均衡化、翻转、平移、放缩、旋转、灰度值归一化等操作 20 | def data_augmentation(path=None, image=None, do_bspline=True, do_flip=False, do_trslt=True, do_scale=True, do_rota=True, 21 | show_help_info=False): 22 | ''' 23 | :param path: 图像路径 24 | :param image: 输入图像,当path为None使用 25 | :param do_bspline: 是否做B样条转换,只在为3D图像时使用,因2D时有问题(没效果) 26 | :param do_flip: 是否做翻转操作 27 | :param do_trslt: 是否做平移操作 28 | :param do_scale: 是否做缩放操作 29 | :param do_rota: 是否做旋转操作 30 | :param show_help_info: 是否显示帮助信息 31 | :return: 32 | ''' 33 | if path is not None: 34 | image = sitk.ReadImage(path) 35 | image_arr = sitk.GetArrayFromImage(image) 36 | 37 | # 去除灰度直方图中最亮的0.2%的灰度值 38 | value, _, count = np.unique(image_arr, return_inverse=True, return_counts=True) 39 | # axis=0,按照行累加。axis=1,按照列累加。axis不给定具体值,就把numpy数组当成一个一维数组。 40 | # 累加即当前元素值,等于之前所有元素值之和+当前元素值 41 | quantiles = np.cumsum(count).astype(np.float32) 42 | quantiles /= quantiles[-1] 43 | # 找到灰度直方图中小于0.998的最大位置,并将大于其对应值的值设为该值 44 | pos = np.where(quantiles < 0.998)[0] 45 | val_max = value[np.max(pos)] 46 | image_arr[image_arr > val_max] = val_max 47 | 48 | # 计算窗宽和窗位 49 | min_val, max_val = np.min(image_arr), np.max(image_arr) 50 | window_width, window_center = max_val - min_val, (min_val + max_val) / 2 51 | 52 | dim = image.GetDimension() 53 | if show_help_info: 54 | print("dim: ", dim) 55 | transform = sitk.AffineTransform(dim) 56 | matrix = np.array(transform.GetMatrix()).reshape((dim, dim)) 57 | if dim == 2: 58 | if do_flip: 59 | # 翻转(flip) 60 | if np.random.rand() > 0.5: 61 | image = image[::-1] # 沿x轴反转 62 | if show_help_info: 63 | print("x axis flip.") 64 | if np.random.rand() > 0.5: 65 | image = image[:, ::-1] # 沿y轴反转 66 | if show_help_info: 67 | print("y axis flip.") 68 | if do_trslt: 69 | # 平移(translation) 70 | x_trslt, y_trslt = np.random.randint(-20, 20, 2) # 每个轴的平移范围在[-20,20]像素之间 71 | transform.SetTranslation((float(x_trslt), float(y_trslt))) 72 | if show_help_info: 73 | print("x_trslt: ", x_trslt, " y_trslt: ", y_trslt) 74 | if do_scale: 75 | # 缩放(scale) 76 | x_scale = 1.0 + random.uniform(-0.1, 0.1) # 缩放范围为原来的[0.9,1.1] 77 | y_scale = 1.0 + random.uniform(-0.1, 0.1) 78 | # x_scale, y_scale表示原图与结果图的倍数关系,如scale为2时缩小为原来的0.5 79 | matrix[0, 0] = x_scale 80 | matrix[1, 1] = y_scale 81 | if show_help_info: 82 | print("x_scale: ", x_scale, " y_scale: ", y_scale) 83 | if do_rota: 84 | # 旋转(rotation) 85 | degree = np.random.randint(-15, 15) # 旋转角度范围为[-15°,15°] 86 | radians = -np.pi * degree / 180. 87 | rotation = np.array([[np.cos(radians), -np.sin(radians)], [np.sin(radians), np.cos(radians)]]) 88 | matrix = np.dot(rotation, matrix) 89 | if show_help_info: 90 | print("degree: ", degree) 91 | 92 | elif dim == 3: 93 | if do_bspline: 94 | # B样条变换(B Spline) 95 | m=5 96 | spline_order = 3 97 | bspline = sitk.BSplineTransform(dim, spline_order) 98 | bspline.SetTransformDomainPhysicalDimensions(image.GetSize()) 99 | mesh_size = [m, m, m] 100 | bspline.SetTransformDomainMeshSize(mesh_size) 101 | # Random displacement of the control points. 102 | # [13,18]为变形的强度,值越大变形越大 103 | originalControlPointDisplacements = np.random.random(len(bspline.GetParameters())) * np.random.randint(13,18) 104 | bspline.SetParameters(originalControlPointDisplacements) 105 | image = resample(image, bspline) 106 | if do_flip: 107 | # 翻转(flip) 108 | if np.random.rand() > 0.5: 109 | image = image[::-1] # 沿x轴反转 110 | if show_help_info: 111 | print("x axis flip.") 112 | if np.random.rand() > 0.5: 113 | image = image[:, ::-1] # 沿y轴反转 114 | if show_help_info: 115 | print("y axis flip.") 116 | if np.random.rand() > 0.5: 117 | image = image[:, :, ::-1] # 沿z轴反转 118 | if show_help_info: 119 | print("z axis flip.") 120 | if do_trslt: 121 | # 平移(translation) 122 | x_trslt, y_trslt, z_trslt = np.random.randint(-20, 20, 3) # 每个轴的平移范围在[-20,20]像素之间 123 | transform.SetTranslation((float(x_trslt), float(y_trslt), float(z_trslt))) 124 | if show_help_info: 125 | print("x_trslt: ", x_trslt, " y_trslt: ", y_trslt, " z_trslt: ", z_trslt) 126 | if do_scale: 127 | # 缩放(scale) 128 | x_scale = 1.0 + random.uniform(-0.1, 0.1) # 缩放范围为原来的[0.9,1.1] 129 | y_scale = 1.0 + random.uniform(-0.1, 0.1) 130 | z_scale = 1.0 + random.uniform(-0.1, 0.1) 131 | # x_scale, y_scale, z_scale表示原图与结果图的倍数关系,如scale为2时缩小为原来的0.5 132 | matrix[0, 0] = x_scale 133 | matrix[1, 1] = y_scale 134 | matrix[2, 2] = z_scale 135 | if show_help_info: 136 | print("x_scale: ", x_scale, " y_scale: ", y_scale, " z_scale: ", z_scale) 137 | if do_rota: 138 | # 旋转(rotation) 139 | x_dgr, y_dgr, z_dgr = np.random.randint(-15, 15, 3) # 旋转角度范围为[-15°,15°] 140 | x_rad = -np.pi * x_dgr / 180. 141 | y_rad = -np.pi * y_dgr / 180. 142 | z_rad = -np.pi * z_dgr / 180. 143 | rotation = np.array([[np.cos(y_rad) * np.cos(z_rad), np.cos(y_rad) * np.sin(z_rad), -np.sin(y_rad)], 144 | [-np.cos(x_rad) * np.sin(z_rad) + np.sin(x_rad) * np.sin(y_rad) * np.cos(z_rad), 145 | np.cos(x_rad) * np.cos(z_rad) + np.sin(x_rad) * np.sin(y_rad) * np.sin(z_rad), 146 | np.sin(x_rad) * np.cos(y_rad)], 147 | [np.sin(x_rad) * np.sin(z_rad) + np.cos(x_rad) * np.sin(y_rad) * np.cos(z_rad), 148 | -np.sin(x_rad) * np.cos(z_rad) + np.cos(x_rad) * np.sin(y_rad) * np.sin(z_rad), 149 | np.cos(x_rad) * np.cos(y_rad)]]) 150 | matrix = np.dot(rotation, matrix) 151 | if show_help_info: 152 | print("x_dgr: ", x_dgr, " y_dgr: ", y_dgr, " z_dgr: ", z_dgr) 153 | 154 | transform.SetMatrix(matrix.ravel()) 155 | # 以下两行是为了让图像的中心点为物体的中心点 156 | center = image.GetOrigin() + np.array(image.GetSize()) * np.array(image.GetSpacing()) * 0.5 157 | transform.SetCenter(center) 158 | out = resample(image, transform) 159 | 160 | # 保持和原图像相同的窗宽和窗位,同时做了归一化 161 | # 不加此步,则图像会在缩放和旋转的时候导致灰度值改变 162 | out_arr = sitk.GetArrayFromImage(out) 163 | min_window = float(window_center) - 0.5 * float(window_width) 164 | out_arr = (out_arr - min_window) / float(window_width) 165 | out_arr[out_arr < 0] = 0. 166 | out_arr[out_arr > 1] = 1. 167 | if dim == 2: 168 | return out_arr, dim 169 | elif dim == 3: 170 | return sitk.GetImageFromArray(out_arr), dim 171 | 172 | 173 | if __name__ == "__main__": 174 | path_2D = r"C:\Users\zuzhiang\Desktop\1.jpg" 175 | path_3D = r"C:\Data\1.Work\02.Code\DDR\results\f_img.nii" 176 | out_path_3D = r"C:\Users\zuzhiang\Desktop\out.nii" 177 | # out, dim = data_augmentation(path_3D) 178 | out, dim = data_augmentation(image=sitk.ReadImage(path_3D)) 179 | if dim == 2: 180 | plt.imshow(out) 181 | plt.title("result image") 182 | plt.show() 183 | elif dim == 3: 184 | sitk.WriteImage(out, out_path_3D) 185 | print("end") 186 | -------------------------------------------------------------------------------- /dcm2nii.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | 4 | 5 | def dcm2nii(dcm, nii): 6 | # GetGDCMSeriesIDs读取序列号相同的dcm文件 7 | series_id = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dcm) 8 | # GetGDCMSeriesFileNames读取序列号相同dcm文件的路径,series[0]代表第一个序列号对应的文件 9 | series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dcm, series_id[0]) 10 | print(len(series_file_names)) 11 | series_reader = sitk.ImageSeriesReader() 12 | series_reader.SetFileNames(series_file_names) 13 | image3d = series_reader.Execute() 14 | print("type: ",type(image3d)) 15 | sitk.WriteImage(image3d, nii) 16 | 17 | 18 | if __name__=="__main__": 19 | ''' 20 | dcm对应的文件夹下有很多子文件夹,每个子文件夹是一套dicom图像,将所有的dicom 21 | 图像转换为.nii格式的图像 22 | ''' 23 | dcm=input("dcm_dir:\n") 24 | nii = input("nii_dir:\n") 25 | for root,dirs,files in os.walk(dcm): 26 | print(dirs) 27 | break 28 | for dir in dirs: 29 | dcm_dir=dcm+"\\"+dir 30 | if not os.listdir(dcm_dir): #若文件夹为空则不处理 31 | continue 32 | nii_file=nii+"\\"+dir+".nii" 33 | dcm2nii(dcm_dir,nii_file) #将dicom转为nii -------------------------------------------------------------------------------- /image_combine_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.util as su 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def paste(img, canvas, i, j, method='replace', export_dtype='float'): 7 | """ paste `img` on `canvas` with its left-top corner at (i, j) """ 8 | # check dtypes 9 | img = su.img_as_float(img) 10 | canvas = su.img_as_float(canvas) 11 | # check shapes 12 | if len(img.shape) != 2 or len(img.shape) != 3: 13 | if len(canvas.shape) != 2 or len(canvas.shape) != 3: 14 | assert AttributeError('dimensions of input images not all equal to 2 or 3!') 15 | # check channels 16 | # all grayscale image 17 | if len(img.shape) == 2 and len(canvas.shape) == 2: 18 | pass 19 | # `img` color image, possible with alpha channel; `canvas` grayscale image 20 | elif len(img.shape) == 3 and len(canvas.shape) == 2: 21 | c = img.shape[-1] 22 | if c == 3: 23 | canvas = np.stack([canvas]*c, axis=-1) 24 | if c == 4: 25 | canvas = np.stack([canvas]*(c-1)+[np.ones((canvas.shape[0], canvas.shape[1]))], axis=-1) 26 | # `canvas` color image, possible with alpha channel; `img` grayscale image 27 | elif len(img.shape) == 2 and len(canvas.shape) == 3: 28 | c = canvas.shape[-1] 29 | if c == 3: 30 | img = np.stack([img]*c, axis=-1) 31 | if c == 4: 32 | img = np.stack([img]*(c-1)+[np.ones((img.shape[0], img.shape[1]))], axis=-1) 33 | # all color image 34 | elif len(img.shape) == 3 and len(canvas.shape) == 3: 35 | if img.shape[-1] == 3 and canvas.shape[-1] == 4: 36 | img = np.concatenate([img, np.ones((img.shape[0], img.shape[1], 1))], -1) 37 | elif img.shape[-1] == 4 and canvas.shape[-1] == 3: 38 | canvas = np.concatenate([canvas, np.ones((canvas.shape[0], canvas.shape[1], 1))], -1) 39 | elif img.shape[-1] == canvas.shape[-1]: 40 | pass 41 | else: 42 | assert ValueError('channel number should equal to 3 or 4!') 43 | # get shapes 44 | h_i, w_i = img.shape[:2] 45 | h_c, w_c = canvas.shape[:2] 46 | # find extent of `img` on `canvas` 47 | i_min = np.max([0, i]) 48 | i_max = np.min([h_c, i+h_i]) 49 | j_min = np.max([0, j]) 50 | j_max = np.min([w_c, j+w_i]) 51 | # paste `img` on `canvas` 52 | if method == 'replace': 53 | canvas[i_min:i_max, j_min:j_max] = img[i_min-i:i_max-i, j_min-j:j_max-j] 54 | elif method == 'add': 55 | canvas[i_min:i_max, j_min:j_max] += img[i_min-i:i_max-i, j_min-j:j_max-j] 56 | else: 57 | raise ValueError('no such method!') 58 | # return `canvas` 59 | if export_dtype == 'float': 60 | return canvas 61 | elif export_dtype == 'ubyte': 62 | return su.img_as_ubyte(canvas) 63 | else: 64 | raise ValueError('no such data type for exporting!') 65 | 66 | 67 | def combine_avg(imgs, num_w=10, strides=(10, 10), padding=5, bg_level_1=1.0, bg_level_2=1.0, export_dtype='float'): 68 | """ paste contents of `imgs` on a single image with `strides` """ 69 | # dtypes check 70 | imgs = [su.img_as_float(img) for img in imgs] 71 | # shapes check 72 | shapes = [img.shape for img in imgs] 73 | if not all([len(s) == 2 or len(s) == 3 for s in shapes]): 74 | assert AttributeError('dimensions of imgs not all 2 or 3!') 75 | # find the shape of canvas 76 | n = len(imgs) 77 | num_h = (n - 1) // num_w + 1 78 | h = strides[0]*(num_h-1) + shapes[-1][0] 79 | w = strides[1]*(num_w-1) + shapes[-1][1] 80 | lt_poses = [(strides[0]*i, strides[1]*j) for i in range(num_h) for j in range(num_w) if i*num_w+j 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /phi/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /phi/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /phi/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /phi/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /phi/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 37 | 38 | 39 | 40 | 41 | 61 | 62 | 63 | 83 | 84 | 85 | 105 | 106 | 107 | 127 | 128 | 129 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 1592118601422 166 | 184 | 185 | 186 | 187 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /phi/f_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/f_img.jpg -------------------------------------------------------------------------------- /phi/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img.jpg -------------------------------------------------------------------------------- /phi/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img1.jpg -------------------------------------------------------------------------------- /phi/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/img2.jpg -------------------------------------------------------------------------------- /phi/m_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuzhiang/MedicalImageProcess/4361482e453ff8f949f1df0383541b8f72f33a72/phi/m_img.jpg -------------------------------------------------------------------------------- /phi/my_show_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | # 空间转换网络 11 | class SpatialTransformer(nn.Module): 12 | # 1.生成网格grid;2.new_grid=grid+flow,即旧网格加上一个位移;3.将网格规范化到[-1,1];4.根据新网格对原图进行采样 13 | def __init__(self, size, mode='bilinear'): 14 | """ 15 | Instiatiate the block 16 | :param size: size of input to the spatial transformer block 17 | :param mode: method of interpolation for grid_sampler 18 | """ 19 | super(SpatialTransformer, self).__init__() 20 | 21 | # Create sampling grid 22 | vectors = [torch.arange(0, s) for s in size] 23 | grids = torch.meshgrid(vectors) 24 | grid = torch.stack(grids) # y, x, z 25 | grid = torch.unsqueeze(grid, 0) # add batch 26 | grid = grid.type(torch.FloatTensor) 27 | self.register_buffer('grid', grid) 28 | 29 | self.mode = mode 30 | 31 | def forward(self, src, flow): 32 | """ 33 | Push the src and flow through the spatial transform block 34 | :param src: the original moving image 35 | :param flow: the output from the U-Net 36 | """ 37 | new_locs = self.grid + flow 38 | 39 | shape = flow.shape[2:] 40 | 41 | # Need to normalize grid values to [-1, 1] for resampler 42 | for i in range(len(shape)): 43 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 44 | 45 | if len(shape) == 2: 46 | new_locs = new_locs.permute(0, 2, 3, 1) # 维度置换,变为0,2,3,1 47 | new_locs = new_locs[..., [1, 0]] 48 | elif len(shape) == 3: 49 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 50 | new_locs = new_locs[..., [2, 1, 0]] 51 | 52 | return F.grid_sample(src, new_locs, mode=self.mode) 53 | 54 | 55 | # 生成网格图片 56 | def create_grid(size, path): 57 | num1, num2 = (size[0] + 10) // 10, (size[1] + 10) // 10 # 改变除数(10),即可改变网格的密度 58 | x, y = np.meshgrid(np.linspace(-2, 2, num1), np.linspace(-2, 2, num2)) 59 | 60 | plt.figure(figsize=((size[0] + 10) / 100.0, (size[1] + 10) / 100.0)) # 指定图像大小 61 | plt.plot(x, y, color="black") 62 | plt.plot(x.transpose(), y.transpose(), color="black") 63 | plt.axis('off') # 不显示坐标轴 64 | # 去除白色边框 65 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 66 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 67 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 68 | plt.margins(0, 0) 69 | plt.savefig(path) # 保存图像 70 | # plt.show() 71 | 72 | 73 | if __name__ == "__main__": 74 | out_path = r"C:\Users\zuzhiang\Desktop\new_img.nii" # 图片保存路径 75 | # 读入形变场 76 | phi = sitk.ReadImage("./2D1.nii") # [324,303,2] 77 | phi_arr = torch.from_numpy(sitk.GetArrayFromImage(phi)).float() 78 | phi_shape = phi_arr.shape 79 | # 产生网格图片 80 | create_grid(phi_shape, out_path) 81 | img = sitk.GetArrayFromImage(sitk.ReadImage(out_path))[..., 0] 82 | img = np.squeeze(img)[np.newaxis, np.newaxis, :phi_shape[0], :phi_shape[1]] 83 | # 用STN根据形变场对网格图片进行变形 84 | STN = SpatialTransformer(phi_shape[:2]) 85 | phi_arr = phi_arr.permute(2, 0, 1)[np.newaxis, ...] 86 | warp = STN(torch.from_numpy(img).float(), phi_arr) 87 | # 保存图片 88 | warp_img = sitk.GetImageFromArray(warp[0, 0, ...].numpy().astype(np.uint8)) 89 | sitk.WriteImage(warp_img, out_path) 90 | print("end") 91 | -------------------------------------------------------------------------------- /phi/show_grid.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import SimpleITK as sitk 3 | import numpy as np 4 | 5 | 6 | def grid2contour(grid, title): 7 | ''' 8 | grid--image_grid used to show deform field 9 | type: numpy ndarray, shape: (h, w, 2), value range:(-1, 1) 10 | ''' 11 | assert grid.ndim == 3 12 | x = np.arange(-1, 1, 2.0 / grid.shape[1]) 13 | y = np.arange(-1, 1, 2.0 / grid.shape[0]) 14 | X, Y = np.meshgrid(x, y) 15 | Z1 = grid[:, :, 0] + 2 # remove the dashed line 16 | Z1 = Z1[::-1] # vertical flip 17 | Z2 = grid[:, :, 1] + 2 18 | 19 | plt.figure() 20 | plt.contour(X, Y, Z1, 15, levels=50, colors='k') #改变levels的值,可以改变形变场的外貌 21 | plt.contour(X, Y, Z2, 15, levels=50, colors='k') 22 | plt.xticks(()), plt.yticks(()) # remove x, y ticks 23 | plt.title(title) 24 | plt.show() 25 | 26 | 27 | def show_grid(): 28 | img = sitk.ReadImage(r"C:\Users\zuzhiang\Desktop\7_flow.nii.gz") 29 | img_arr = sitk.GetArrayFromImage(img)[:,:,0,:2] 30 | img_shape = img_arr.shape 31 | print("shape: ", img_shape) 32 | 33 | # 起点、终点、步长(可为小数) 34 | x = np.arange(-1, 1, 2 / img_shape[1]) 35 | y = np.arange(-1, 1, 2 / img_shape[0]) 36 | X, Y = np.meshgrid(x, y) 37 | regular_grid = np.stack((X, Y), axis=2) 38 | grid2contour(regular_grid, "regular_grid") 39 | 40 | rand_field = np.random.rand(*img_shape[:2], 2) # 参数前加*是以元组形式导入 41 | rand_field_norm = rand_field.copy() 42 | rand_field_norm[:, :, 0] = rand_field_norm[:, :, 0] * 2 / img_shape[1] 43 | rand_field_norm[:, :, 1] = rand_field_norm[:, :, 1] * 2 / img_shape[0] 44 | sampling_grid = regular_grid + rand_field_norm 45 | grid2contour(sampling_grid, "sampling_grid") 46 | 47 | img_arr[..., 0] = img_arr[..., 0] * 2 / img_shape[1] 48 | img_arr[..., 1] = img_arr[..., 1] * 2 / img_shape[0] 49 | img_grid = regular_grid + img_arr 50 | grid2contour(img_grid, "img_grid") 51 | 52 | 53 | if __name__ == "__main__": 54 | show_grid() 55 | print("end") 56 | -------------------------------------------------------------------------------- /resize_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import SimpleITK as sitk 5 | from skimage.transform import resize 6 | import models.data_augmentation as da 7 | 8 | path = r"C:\Data\1.Work\03.DataSet\LPBA40" 9 | out_path = r"C:\Data\1.Work\03.DataSet\LPBA40_DA" 10 | #file_lst = glob.glob(os.path.join(path, '*/*bfc.nii')) 11 | file_lst=[r"C:\Data\1.Work\02.Code\DDR\fixed_LPBA40.nii"] 12 | num = len(file_lst) 13 | # for file in file_lst: 14 | # img, dim = da.data_augmentation(file, do_bspline=False, do_trslt=True, do_scale=True, do_rota=True, 15 | # show_help_info=True) 16 | # _, name = os.path.split(file) 17 | # print("path: ", os.path.join(out_path, name)) 18 | # sitk.WriteImage(img, os.path.join(out_path, name)) 19 | 20 | for file in file_lst: 21 | # img, dim = da.data_augmentation(file, do_bspline=False, do_trslt=False, do_scale=False, do_rota=False, 22 | # show_help_info=True) 23 | img = sitk.ReadImage(file) 24 | img_arr = sitk.GetArrayFromImage(img) 25 | new_shape = [256, 128, 256] 26 | new_img = resize(img_arr, new_shape, order=3, mode='constant', cval=0, preserve_range=bool) 27 | # order几次样条插值,cval外部补零,保留原数据(否则被默认标准化、归一化之类的乱七八糟) 28 | ''' 29 | The order of interpolation. The order has to be in the range 0-5: 30 | 0: Nearest-neighbor 31 | 1: Bi-linear (default) 32 | 2: Bi-quadratic 33 | 3: Bi-cubic 34 | 4: Bi-quartic 35 | 5: Bi-quintic 36 | ''' 37 | new_img = sitk.GetImageFromArray(new_img) 38 | new_img.SetOrigin(img.GetOrigin()) 39 | new_img.SetSpacing(img.GetSpacing()) 40 | new_img.SetDirection(img.GetDirection()) 41 | print("file: ", file) 42 | #temp_path = r"C:\Users\zuzhiang\Desktop\out.nii" 43 | sitk.WriteImage(new_img, file) 44 | -------------------------------------------------------------------------------- /torch_datagenerators.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import SimpleITK as sitk 4 | import torch.utils.data as Data 5 | 6 | ''' 7 | 通过继承Data.Dataset,实现将一组Tensor数据对封装成Tensor数据集 8 | 至少要重载__init__,__len__和__getitem__方法 9 | ''' 10 | class Dataset(Data.Dataset): 11 | def __init__(self, path): 12 | # 初始化 13 | self.files = glob.glob(path) 14 | 15 | def __len__(self): 16 | # 返回数据集的大小 17 | return len(self.files) 18 | 19 | def __getitem__(self, index): 20 | # 索引数据集中的某个数据,还可以对数据进行预处理 21 | # 下标index参数是必须有的,名字任意 22 | img_arr = sitk.GetArrayFromImage(sitk.ReadImage(self.files[index]))[..., np.newaxis] 23 | return img_arr --------------------------------------------------------------------------------