├── README.md ├── get_small_npy.py ├── global_.py ├── global_annos.py ├── mask_xml.py ├── mask_xml_def.py ├── train.py ├── train_def.py ├── vnet.py └── vnet_def.py /README.md: -------------------------------------------------------------------------------- 1 | https://www.yuque.com/u41648611/qg2m6n/kao3f71fry3iukbs?singleDoc 《运行vnet操作记录》包含可下载数据 2 | -------------------------------------------------------------------------------- /get_small_npy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 屏蔽通知和警告信息 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用gpu0 6 | 7 | from train_def import * 8 | import torch.utils.data 9 | 10 | BATCH_SIZE = 1 11 | EPOCH = 1 12 | 13 | torch.cuda.empty_cache() # 时不时清下内存 14 | 15 | data_path = [] # 装图所在subset的绝对地址,如 [D:\datasets\sk_output\bbox_image\subset0,D:\datasets\sk_output\bbox_image\subset1,..] 16 | label_path = [] # 装标签所在subset的绝对地址,与上一行一致,为对应关系 17 | for i in range(0,10): # 0,1,2,3,4,5,6,7 训练集 18 | data_path.append(str(Path(bbox_img_path)/f'subset{i}')) # 放入对应的训练集subset的绝对地址 19 | label_path.append(str(Path(bbox_msk_path)/f'subset{i}')) 20 | dataset_train = cutDataset(data_path, label_path) # 送入dataset 21 | print(len(dataset_train)) 22 | train_loader = torch.utils.data.DataLoader(dataset_train, # 生成dataloader 23 | batch_size=BATCH_SIZE, shuffle=False, 24 | num_workers=0)#16) # 警告页面文件太小时可改为0 25 | print("train_dataloader_ok") 26 | 27 | all_msg_list = [] 28 | for epoch in range(1, EPOCH + 1): # 每一个epoch 训练一轮 检测一轮 29 | tqdr = tqdm(enumerate(train_loader)) # 用一下tqdm函数,也就是进度条工具(枚举) 30 | 31 | for batch_index, one_list in tqdr: 32 | all_msg_list.append([i[0] for i in one_list]) 33 | df = pd.DataFrame(all_msg_list, columns=['img_path', 'lbl_path','msg']) # msg是结节的中心 z,y,x 34 | df.to_excel(msg_path) 35 | -------------------------------------------------------------------------------- /global_.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | # 修改区域 5 | luna_path = r"G:\datasets\LUNA16" 6 | xml_file_path = r'G:\datasets\LIDC-IDRI\LIDC-XML-only\tcia-lidc-xml' 7 | annos_csv = r'G:\datasets\LUNA16\CSVFILES\annotations.csv' 8 | new_bbox_annos_path = r"G:\datasets\sk_output\bbox_annos\bbox_annos.xlsx" 9 | mask_path = r'G:\datasets\LUNA16\seg-lungs-LUNA16' 10 | output_path = r"G:\datasets\sk_output" 11 | bbox_img_path = r"G:\datasets\sk_output\bbox_image" 12 | bbox_msk_path = r"G:\datasets\sk_output\bbox_mask" 13 | wrong_img_path = r"G:\datasets\wrong_img.xlsx" 14 | zhibiao_path = r'G:\datasets\sk_output\zhibiao' 15 | model_path = r'G:\datasets\sk_output\model' 16 | msg_path = r'G:\datasets\sk_output\msgs.xlsx' 17 | 18 | # 训练设置 19 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 没gpu就用cpu 20 | valid_epoch_each = 5 # 每几轮验证一次 21 | 22 | # 建立文件夹结构 23 | Path(new_bbox_annos_path).parent.mkdir(exist_ok=True, parents=True) 24 | Path(bbox_img_path).mkdir(exist_ok=True, parents=True) 25 | Path(bbox_msk_path).mkdir(exist_ok=True, parents=True) 26 | Path(model_path).mkdir(exist_ok=True, parents=True) 27 | Path(zhibiao_path).mkdir(exist_ok=True, parents=True) 28 | -------------------------------------------------------------------------------- /global_annos.py: -------------------------------------------------------------------------------- 1 | ## 公共区域函数 2 | import pandas as pd 3 | import numpy as np 4 | from global_ import new_bbox_annos_path 5 | 6 | 7 | def annos(): # 收集有结节图的名字 8 | annos = pd.read_excel(new_bbox_annos_path) # 读取bbox_annos.xls 9 | a = [] 10 | for ind,val in annos.iterrows(): 11 | if val['annos'] != '[]': 12 | a.append(list(val)[1:]) #(val['name']) 13 | return a # 返回所有 有结节的图名 14 | annos = annos() 15 | -------------------------------------------------------------------------------- /mask_xml.py: -------------------------------------------------------------------------------- 1 | from global_ import * 2 | from mask_xml_def import for_one_,read_data,resample,bbox_annos_ 3 | from global_annos import * 4 | import numpy as np 5 | import pandas 6 | import os 7 | """ 8 | 功能:数据预处理主函数 9 | 包含生成信息表格和具体预处理操作 10 | bbox_annos_() 主函数 11 | """ 12 | 13 | 14 | 15 | anno_name_list = annos # 有结节图的名字 16 | print(len(anno_name_list)) 17 | wrony = [] # 染色失败的标签 18 | 19 | for name in anno_name_list: # 遍历有结节图的名字 20 | mask,ct_image_path,wrony = for_one_(name,wrony) # 输入:单图,空列表wrong, 输出:单图染色mask,mhd文件的绝对地址,错图名字列表wrong 21 | path = ct_image_path.split("LUNA16")[1].split(".m")[0] # 取LUNA16后,.mhd前的字符串 22 | # 如 D:\datasets\LUNA16\subset1\1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886.mhd 23 | # 则 path = \subset1\1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886 一会要用 24 | ct_image, origin, spacing, isflip = read_data(ct_image_path) # 读取mhd文件得到图 25 | ct_image1, origin1, spacing1, isflip1 = read_data(str(Path(mask_path)/(name+".mhd"))) # 读取肺部掩膜 26 | ct_image1[ct_image1>1]=1 # LUNA16提供的肺部掩膜分左右肺,左肺为3右肺为4,我们需要统一为1 27 | ct_image = ct_image * ct_image1 # 图与肺mask相乘,肺外区域归0 28 | image = resample(ct_image, spacing) # 图 重采样 29 | msk = resample(mask, spacing) # 标签 重采样,标签就弄好了。 30 | print(image.shape) # 这俩一样大 31 | print(msk.shape) # 这俩一样大 32 | # LUNA16竞赛中常用来做归一化处理的阈值集是-1000和400 33 | max_num = 400 # 阈值最高 34 | min_num = -1000 # 阈值最低 35 | image = (image - min_num) / (max_num - min_num) # 归一化公式 36 | image[image > 1] = 1. # 高于1的归1,float格式 37 | image[image < 0] = 0. # 低于0的归0,float格式 38 | ## LUNA16竞赛中的均值大约是0.25 39 | img = image - 0.25 # 去均值,图也弄好了 40 | 41 | Path(output_path).mkdir(exist_ok=True,parents=True) 42 | Path(str(Path(output_path)/"bbox_image")).mkdir(exist_ok=True,parents=True) 43 | Path(str(Path(output_path)/"bbox_mask")).mkdir(exist_ok=True,parents=True) 44 | sub_path = Path(ct_image_path).parent.name # 取LUNA16后,.mhd前的字符串 45 | # 如 D:\datasets\LUNA16\subset1\1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886.mhd 46 | # 则 path = \subset1\ 47 | Path(str(Path(output_path)/("bbox_image"+sub_path))).mkdir(exist_ok=True,parents=True) 48 | Path(str(Path(output_path)/("bbox_mask"+sub_path))).mkdir(exist_ok=True,parents=True) 49 | 50 | 51 | np.save(str(Path(output_path)/("bbox_image"+path)),img) # 图存到 如 D:\datasets\sk_output\bbox_image\subset1\1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886.npy 52 | np.save(str(Path(output_path)/("bbox_mask"+path)),msk) # 标签存到 如 D:\datasets\sk_output\bbox_mask\subset1\1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886.npy 53 | del ct_image,ct_image1,image,msk 54 | 55 | wrong_img = pandas.DataFrame(wrony) # 保存 未染色图 56 | wrong_img.to_excel(wrong_img_path) # 保存到 wrong_img_path 57 | print("wrony",wrony) # 打印看一下是否为空 58 | -------------------------------------------------------------------------------- /mask_xml_def.py: -------------------------------------------------------------------------------- 1 | from global_ import * 2 | import xml 3 | from xml.dom.minidom import parse # 用来读xml文件的 4 | import numpy as np 5 | import cv2 as cv # 用来做形态填充的 6 | import scipy 7 | from scipy import ndimage # 用来补洞的 8 | import SimpleITK as sitk 9 | import matplotlib.pyplot as plt 10 | import os 11 | import pandas as pd 12 | from scipy.ndimage.interpolation import zoom 13 | import xlrd 14 | """ 15 | 功能:数据预处理主函数 16 | 包含生成信息表格和具体预处理操作 17 | bbox_annos_() 主函数 18 | """ 19 | 20 | # 在开始之前,由于我们需要只保留有结节的图像,因此从annotations中取得有结节的图像名字。①进行结节中心的坐标变换,并保存到 bbox_annos 21 | # ②提取结节不为空的图的名字 annos() 这个在主函数 22 | 23 | def get_raw_label(img_name,img,annos,origin,zheng): # 进行单图的坐标转换,把结节坐标 从世界坐标 准换到 图像坐标 ([1,1,1]像距) 输入为:图的名字,图,LUNA6中的annotations,像素间隔,是否翻转 24 | annos_these_list = [] # 准备装[[名字,结节1],[名字,结节2],...] 装所有有结节的图的名字和坐标点(原始) 25 | for i in range(len(annos)): # 遍历annotations所有数据 26 | if annos[i][0] == img_name: ### 如果名字相符 27 | annos_these_list.append(list(annos[i])) ### 装进去 28 | print(annos_these_list) # [["名字",x,y,z,diam],["名字",x,y,z,diam]] ,其中,所有的名字都与img_name相符,因为一个图可能有好几个结节坐标。diam是直径 29 | return_list = [] # 准备装 ["名字", [[结节1],[结节2],...] ] ,装所有有结节的图的名字和坐标点(原始→坐标变换后) 30 | for one_annos_list in annos_these_list: # 打开第一个结节数据["名字",x,y,z,diam] 31 | print("one_annos_list:",one_annos_list) # 打印出第一个结节数据["名字",x,y,z,diam],此时是世界坐标 32 | w_center = [one_annos_list[1],one_annos_list[2],one_annos_list[3]] # w_center为世界坐标的xyz 33 | print("世界坐标的 结节中心(x,y,z) ",w_center) 34 | v_center = list(abs(w_center - np.array(origin))) # /np.array(spacing) 像素间隔为[1,1,1],因此不用再除 abs是绝对值,因为有的为负 35 | print("图像坐标的 结节中心(x,y,z) ",v_center) # v_center为图像坐标的xyz 36 | if zheng is False: # 如果是反的,由于图反过来了,结节坐标也要反过来 37 | v_center = [img.shape[2] - 1 - v_center[0],img.shape[1] - 1 - v_center[1],v_center[2]] # img.shape[2]就是x的总长,-1是因为从0开始,-v_center[0]是减去x的坐标,也即翻转, z轴不需翻转 38 | diam = one_annos_list[4] # 直径 39 | print("结节直径",diam) 40 | one_annos = [] 41 | one_annos.append(v_center[0]) # 图像坐标x 42 | one_annos.append(v_center[1]) # 图像坐标y 43 | one_annos.append(v_center[2]) # 图像坐标z 44 | one_annos.append(diam/2) # 半径 45 | return_list.append(one_annos) # 收集这个结节到return_list 46 | print("one_annos:",one_annos,"[坐标(x,y,z)]") 47 | return return_list # 返回该 img_name 的所有结节 [[结节1],[结节2],...] 48 | 49 | def bbox_annos_(): # 产生bbox_annos文件,处理所有图的坐标转换 50 | c = np.array(pd.read_csv(annos_csv)) # c为将annotations读取为数组 51 | d = [] # 准备放坐标转换后的 ["名字",[[结节1],[结节2],...]] 52 | for i in range(10): # 默认你10个subset都下完了 53 | file_list = os.listdir(str(Path(luna_path)/f"subset{i}")) # 打开D:\datasets\LUNA16\subset0 遍历10个subset 54 | for ii in file_list: # 遍历如 subset0 内所有文件 55 | if len(ii.split(".m")) == 2: # 如果文件名是mhd文件的话 56 | name = ii.split(".m")[0] # 取出文件名,去掉后缀,得到图名 57 | ct_image_path = find_mhd_path(name) # 把文件名拿去找对应的mhd文件的绝对地址 58 | numpyImage, origin, spacing, fanzhuan = read_data(ct_image_path) # 读取这个mhd文件 59 | one_annos = get_raw_label(name, numpyImage, c, origin, fanzhuan) # 进行坐标变换 60 | d.append([name,one_annos]) # 把变换后的 ["名字",[[结节1],[结节2],...]] 添加到d里 61 | bbox_annos = pd.DataFrame(d, columns=['name', 'annos']) # 把 d 转换成excel文件 62 | bbox_annos.to_excel(new_bbox_annos_path) # 保存到new_bbox_annos_path 63 | 64 | 65 | 66 | 67 | 68 | def name(xml_path): # 从xml文件中取得name 69 | child = "ResponseHeader" # 响应 头 70 | child_child = "SeriesInstanceUid" # 案例名 71 | child_child1 = "CTSeriesInstanceUid" # 案例名,之所以有两个是因为dataset标注不规范 72 | dom = xml.dom.minidom.parse(xml_path) # 读取xml文件 73 | root = dom.documentElement # 取得树根 74 | a = root.getElementsByTagName(child) # 取得树根下的child(ResponseHeader)点 75 | child_node = a[0].getElementsByTagName(child_child) # 取得树根下的child下的child_child(SeriesInstanceUid)点 76 | if child_node==[]: # 如果值为0 77 | child_node = a[0].getElementsByTagName(child_child1) # 取得树根下的child下的child_child1(CTSeriesInstanceUid)点 78 | child_value = child_node[0].childNodes[0].nodeValue # 取得该点的值,也就是name 79 | return child_value # name 80 | 81 | def find_xml_path(name1): 82 | list1 = [] 83 | for file_list in os.listdir(xml_file_path): # 遍历xml_file_path文件夹下所有文件 84 | print(file_list) # 打印进度 85 | for ii in os.listdir(str(Path(xml_file_path)/file_list)): # 取得xml_file_path文件夹下文件的列表,如157 185 ... 86 | aim_path = str(Path(xml_file_path)/file_list/ii) # 取得xml_file_path文件夹下的 157文件夹下的 文件,如158 159 .. 87 | with open(aim_path) as f: # 打开这个文件 88 | if name(f) == name1: # 取得这个文件的文件名,如果与输入文件名相符: 89 | path = str(Path(xml_file_path)/file_list/ii) # 保留这个文件的绝对地址为path(爷找到了) 90 | list1.append(path) # 把这个绝对地址装到 list1 里去 91 | print(path) # 打印绝对地址 92 | if list1 !=[]: 93 | return list1 # 得到一个装着绝对地址的列表 if的位置设置为:如果在这个文件夹下找到了,下个文件夹就不找了 94 | 95 | def find_mhd_path(name1): 96 | for file_list in os.listdir(luna_path): # 遍历luna16文件夹下所有文件 97 | if file_list.find("subset") != -1: # 在有subset的文件夹下查找 这一句是为了避免找到seg-lungs-LUNA16文件夹里边去 98 | for ii in os.listdir(str(Path(luna_path)/file_list)): # 打开luna16文件夹下的文件夹 如subset0,遍历文件 99 | if len(ii.split(".m")) >1: # 如果文件中有".m"字符,len就会为2,也即 len > 1 100 | if ii.split((".m"))[0] == name1: # 如果文件名去掉".mhd"后与输入的案例名name一致 101 | path = str(Path(luna_path)/file_list/ii) # 取得该文件的绝对地址 102 | print(path) 103 | return path 104 | 105 | # one_name = "1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886" 106 | # find_xml_path(one_name) 107 | # find_mhd_path(one_name) 108 | 109 | def point(xml_path,origin2): # 需要xml_path 和 图像的原点坐标origin的z轴坐标 110 | a = [] # 该案例图的 所有z轴和该z轴上点位的列表 [ [z1,[[x1,y1],[x2,y2],...]], [z2,[[x1,y1],[x2,y2],...]], ...] 111 | dom = xml.dom.minidom.parse(xml_path) # 读取xml文件 112 | root = dom.documentElement # 取得树根 113 | nodeid = root.getElementsByTagName("readingSession") # 取得树根下的(readingSession)点 114 | for u in nodeid: # 遍历所有readingSession点 115 | child = u.getElementsByTagName("unblindedReadNodule") # 取得该(readingSession)下的 (unblindedReadNodule)点 116 | for i in child: # 遍历该(unblindedReadNodule)下的所有点 117 | id = i.getElementsByTagName("noduleID") # 取得该(unblindedReadNodule)下的 (noduleID)点 118 | id1 = id[0].childNodes[0].nodeValue # 取得该(unblindedReadNodule)下的 (noduleID)点的值 119 | if id1: # 如果(noduleID)的值不为空 120 | one_all_iou = i.getElementsByTagName("roi") # 取得该(unblindedReadNodule)下的 (noduleID)下的(roi)点 121 | for r in one_all_iou: # 遍历这些(roi)点 122 | z = r.getElementsByTagName("imageZposition") # 取得该(unblindedReadNodule)下的 (noduleID)下的(roi)下的(imageZposition) 123 | z1 = float(z[0].childNodes[0].nodeValue)-origin2 # 取得图像坐标的z轴 , 即 z轴坐标-原点的z轴坐标。 其中 图像在x、y轴不需要变化,在标注时就是按照图像坐标标注的 124 | ioux = r.getElementsByTagName("xCoord") # 取得该z轴切片上的("xCoord") 125 | iouy = r.getElementsByTagName("yCoord") # 取得该z轴切片上的("yCoord") 126 | 127 | ioux1 = np.array([int(k.childNodes[0].nodeValue) for k in ioux]) # 取得该z轴切片上的x点位(所有x点位),并排列成数组 128 | iouy1 = np.array([int(l.childNodes[0].nodeValue) for l in iouy]) # 取得该z轴切片上的y点位(所有y点位),并排列成数组 129 | iou = np.array([ioux1,iouy1]) # 数组合并,得到[[x1,x2,...],[y1,y2,...]] 130 | point1 = np.transpose(iou) # 数组转置,得到[[x1,y1],[x2,y2],...] 131 | a.append([z1,point1]) # [z轴,z轴对应的点位数组[[x1,y1],[x2,y2],...]] 添加到a列表中 132 | return a # 返回该案例图的 所有z轴和z轴上点位 的列表 [ [z1,[[x1,y1],[x2,y2],...]], [z2,[[x1,y1],[x2,y2],...]], ...] 133 | 134 | 135 | def read_data(mhd_file): # 读取图像数据(包括图,坐标原点,像素间隔,是否需要翻转) 136 | with open(mhd_file) as f: 137 | mhd_data = f.readlines() 138 | for i in mhd_data: # 判断是否反转,其中 TransformMatrix = 1 0 0 0 1 0 0 0 1\n 代表反转为正True 139 | if i.startswith('TransformMatrix'): # 取得以'TransformMatrix'开头的这一行 140 | zheng = i.split(' = ')[1] # 取得' = '后边的字符串 141 | if zheng == '1 0 0 0 1 0 0 0 1\n': # 如果与'1 0 0 0 1 0 0 0 1\n'相符,其中 100代表x,010代表y,001代表z 142 | zheng = True # 代表是正的,不需要反转 143 | itkimage = sitk.ReadImage(mhd_file) # 读取mhd文件 144 | numpyImage = sitk.GetArrayFromImage(itkimage) # 从mhd读取到raw,也就是图 145 | print("读取数据,读取的图片大小(zyx):",numpyImage.shape) # 深 depth * 宽 width * 高 height 146 | origin = itkimage.GetOrigin() # 从mhd读取到origin,也就是原点坐标 147 | print("读取数据,读取的坐标原点(xyz):",origin) # 坐标原点 x,y,z 148 | spacing = itkimage.GetSpacing() # 从mhd读取到spacing,也就是像素间隔 149 | print("读取数据,读取的像素间隔(xyz):",spacing) # 像素间隔 x,y,z 150 | return numpyImage,origin,spacing,zheng 151 | 152 | 153 | def for_one_(name,wrong): # 一个处理每张图的函数,输入名字+一个空列表, 输出 上色好的mask + mhd的绝对地址用于后续存放mask用 + 出错的图的名字 154 | 155 | xml_path_list = find_xml_path(name) # 根据名字,得到了对应的 xml文件的绝对地址 156 | ct_image_path = find_mhd_path(name) # 根据名字,得到了对应的 mhd文件的绝对地址 157 | 158 | ct_image,origin,spacing,fanzhuan = read_data(ct_image_path) # 根据 mhd文件的绝对地址 ,得到 图,原点信息,像素间隔,是否需要翻转 159 | s = ct_image.shape # 拿到 图的尺寸,用来画 全0的mask 160 | mm = np.zeros((s[0],s[1],s[2]), dtype=np.int32) # mm为 全0的mask , 注意 图.shape 是zyx的,所以顺序不用变 161 | #取截面 描点 162 | for i in xml_path_list: # 取得xml文件的绝对地址, 163 | list1 = point(i,origin[2]) # 在这个绝对地址内获取所有[ [ z层 , 点 ],[ z层 ,点 ] , ... ] 164 | print(len(list1)) # 共多少层 165 | for ii in list1: # 遍历所有层 166 | ceng = ii[0] # ceng为z轴坐标 167 | print("ceng",ceng) # 打印层 168 | pts = ii[1] # 该层的所有点位 [[x1,y1],[x2,y2],...] 169 | color = 1 # (0, 255, 0) 170 | # 解释一下, int(ceng/spacing[2]-1) 是因为ceng代表图像坐标的y轴位置,比如4,是代表4mm,而不是第4层.spacing[2]是z轴的像素间隔,也即每spacing[2]的距离有一层。层数是从0算起,所以-1。这样做的好处是处理后与原图保持一致。 171 | mm[int(ceng/spacing[2]-1),:,:] = cv.drawContours(mm[int(ceng/spacing[2]-1),:,:], [pts], -1, color=color, thickness=-1) # 取出这一层,开始染色填充 172 | mm[int(ceng/spacing[2]-1),:,:] = scipy.ndimage.binary_fill_holes(mm[int(ceng/spacing[2]-1),:,:], structure=None, output=None, origin=0) # 补洞 173 | if (mm==np.zeros((s[0],s[1],s[2]), dtype=np.int32)).all(): # 如果没染上色,即仍是全0数组: 174 | wrong.append(name) # 认为有错,把名字添加到wrong里 175 | return mm,ct_image_path,wrong # 返回染色好的mask,mhd的绝对地址,错误列表 176 | 177 | # one_name = "1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886" 178 | # a,b,c = for_one_(one_name,wrong=[]) 179 | # print("a",a.shape,"b",b,"c",c) 180 | 181 | # 可视化验证 182 | def plot_2d(image,z = 132): 183 | # z,y,x#查看第100张图像 184 | plt.figure() 185 | plt.imshow(image[z, :, :]) 186 | plt.show() 187 | # z = 240 188 | # plot_2d(a,z=int(z/2.5)) 189 | # ct_image_path = find_mhd_path(one_name) 190 | # ct_image,origin,spacing,fanzhuan = read_data(ct_image_path) 191 | # plot_2d(ct_image,z = int(z/2.5)) 192 | 193 | # 此时,我们拥有处理单个图片mask的能力,处理好的mask与原图保持一致。为了做数据预处理,我们需要把原图和标签 均经过重采样,仅对原图做归一化和去均值 194 | 195 | def resample(imgs, spacing, new_spacing=[1,1,1]): # 重采样 ,即把原图的像素间隔统一 196 | ### 重采样,坐标原点位置为0 197 | if len(imgs.shape)==3: # 如果是3维的话: 198 | new_shape = [] # 新图大小会变,如 原有4个切片,像素间隔为2.5,重采样后有10个切片,像素间隔为1 199 | for i in range(3): # 对每个维度 0,1,2 → z,y,x 200 | print("(zyx)像素间隔",i,":",spacing[-i-1]) # spacing原顺序为(xyz),spacing[-i-1]顺序为(zyx) 201 | new_zyx = np.round(imgs.shape[i]*spacing[-i-1]/new_spacing[-i-1]) # round为四舍五入(原图尺寸 * 原像素间隔/新像素间隔) 202 | new_shape.append(new_zyx) # new_shape集齐新zyx尺寸 203 | print("(zyx)新图大小:",new_shape) 204 | resize_factor = [] # 新图尺寸/原图尺寸 即缩放比例,如 原像素间隔为2.5,新像素间隔为1,放缩比例为1/2.5 205 | for i in range(3): # 依次为 0 1 2 → z y x 206 | resize_zyx = new_shape[i]/imgs.shape[i] # 放缩比例 207 | resize_factor.append(resize_zyx) # 放缩比例 存入 resize_factor ,zoom函数要用 208 | imgs = zoom(imgs, resize_factor, mode = 'nearest') # 放缩,边缘使用最近邻,插值默认为三线性插值 209 | return imgs 210 | else: 211 | raise ValueError('wrong shape') # 本代码只能处理3维数据 212 | 213 | bbox_annos_() 214 | 215 | 216 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 屏蔽通知和警告信息 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用gpu0 5 | 6 | 7 | from global_ import * 8 | from global_annos import * 9 | from train_def import * 10 | from vnet import VNet 11 | import time 12 | import torch.utils.data 13 | import torch.optim as optim 14 | 15 | 16 | BATCH_SIZE = 4 # 2 17 | EPOCH = 200 # 共跑200轮 18 | 19 | 20 | print(DEVICE) 21 | 22 | model = VNet(2) # 模型 23 | model = model.to(DEVICE) # 模型部署到gpu或cpu里 24 | 25 | torch.cuda.empty_cache() # 时不时清下内存 26 | 27 | 28 | ###### 数据准备 29 | 30 | data_path = [] # 装图所在subset的绝对地址,如 [D:\datasets\sk_output\bbox_image\subset0,D:\datasets\sk_output\bbox_image\subset1,..] 31 | label_path = [] # 装标签所在subset的绝对地址,与上一行一致,为对应关系 32 | for i in range(0,8): # 0,1,2,3,4,5,6,7 训练集 33 | data_path.append(str(Path(bbox_img_path)/f'subset{i}')) # 放入对应的训练集subset的绝对地址 34 | label_path.append(str(Path(bbox_msk_path)/f'subset{i}')) 35 | dataset_train = myDataset(data_path, label_path) # 送入dataset 36 | print(len(dataset_train)) 37 | train_loader = torch.utils.data.DataLoader(dataset_train, # 生成dataloader 38 | batch_size=BATCH_SIZE, shuffle=False, 39 | num_workers=0)#16) # 警告页面文件太小时可改为0 40 | print("train_dataloader_ok") 41 | 42 | 43 | 44 | data_valid_path = [] # 装图所在subset的绝对地址 45 | label_valid_path = [] # 装标签所在subset的绝对地址 46 | for j in range(8,9): # 8 验证集 47 | data_valid_path.append(str(Path(bbox_img_path)/f'subset{j}')) # 放入对应的验证集subset的绝对地址 48 | label_valid_path.append(str(Path(bbox_msk_path)/f'subset{j}')) 49 | dataset_valid = myDataset(data_valid_path, label_valid_path) # 送入dataset 50 | valid_loader = torch.utils.data.DataLoader(dataset_valid, # 生成dataloader 51 | batch_size=BATCH_SIZE, shuffle=False, 52 | num_workers=0)#16) # 警告页面文件太小时可改为0 53 | print("valid_dataloader_ok") 54 | 55 | data_test_path = [] # 装图所在subset的绝对地址 56 | label_test_path = [] # 装标签所在subset的绝对地址 57 | for ii in range(9,10): # 9 测试集 58 | data_test_path.append(str(Path(bbox_img_path)/f'subset{ii}')) # 放入对应的测试集subset的绝对地址 59 | label_test_path.append(str(Path(bbox_msk_path)/f'subset{ii}')) 60 | dataset_test = myDataset(data_test_path, label_test_path) # 送入dataset 61 | test_loader = torch.utils.data.DataLoader(dataset_test, # 生成dataloader 62 | batch_size=BATCH_SIZE, shuffle=False, 63 | num_workers=0)#16) # 警告页面文件太小时可改为0 64 | print("Test_dataloader_ok") 65 | 66 | 67 | ###### 数据准备完成,开始训练 68 | 69 | start = time.perf_counter() # 记录训练开始时间 70 | 71 | train_loss_list = [] # 用来记录训练损失 72 | valid_loss_list = [] # 用来记录验证损失 73 | 74 | 75 | minnum = 0 # 寻找最小损失,损失最小意味着模型最佳 76 | mome = 0.99 # 动量,可以认为是前冲的速度 77 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=mome, weight_decay=1e-8) # weight_decay质量,认为是前冲的惯性 78 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.9, 79 | last_epoch=-1) # 设置优化器在训练时改变,每3轮lr变为原来的0.1倍,如果中途停止则从头开始 80 | 81 | train_loss1 = 0.0 82 | lr = 1e-1 83 | for epoch in range(1, EPOCH + 1): # 每一个epoch 训练一轮 检测一轮 84 | if epoch ==180: # 180轮时动量变为0.9,即更容易落入低点,也更难以回避局部最优点 85 | mome = 0.9 86 | train_loss = train_model(model, DEVICE, train_loader, optimizer,scheduler, epoch) # 训练 87 | train_loss1 = train_loss # 训练损失 88 | train_loss_list.append(train_loss) # 记录每个epoch训练损失 89 | train_loss_pd = pd.DataFrame(train_loss_list) # 存成excel格式 90 | train_loss_pd.to_excel(zhibiao_path + "/第%d个epoch的训练损失.xlsx" %(epoch)) 91 | 92 | torch.save(model, str(Path(model_path)/'train_model.pth')) # 保存训练模型 93 | torch.cuda.empty_cache() # 清理内存 94 | 95 | 96 | if epoch%valid_epoch_each == 0: # 如:每5轮验证一次 97 | 98 | valid_loss, valid_zhibiao = test_model(model, DEVICE, valid_loader,epoch,test=False) # 验证 99 | dice1 = valid_zhibiao[2] # 记录dice值 100 | valid_loss_list.append(valid_loss) # 验证损失 101 | valid_loss_pd = pd.DataFrame(valid_loss_list) # 存成excel格式 102 | valid_loss_pd.to_excel(zhibiao_path + "/第%d个epoch的验证损失.xlsx" % (epoch)) 103 | 104 | if epoch == valid_epoch_each: # 第一此验证,如:epoch==5 105 | torch.save(model, str(Path(model_path)/'best_model.pth')) # 保存为最好模型 106 | minnum = valid_loss # 刚开始,令min为该loss 107 | print("minnum",minnum) # 打印最小验证损失 108 | 109 | elif valid_loss < minnum: # 如果验证损失 比 记录中最小的验证损失 更小 110 | 111 | print("valid_loss < minnum",valid_loss, "<", minnum) # 打印 这一轮验证损失更小,所以准备更新了 112 | minnum = valid_loss # 最小验证损失 更新为 这一轮验证损失 113 | torch.save(model, str(Path(model_path)/'best_model.pth')) # 保存为最好模型,这里是直接覆盖了之前的best_model 114 | zhibiao = valid_zhibiao # 把指标也记录一下 115 | zhibiao_pd = pd.DataFrame(zhibiao) # 存成excel格式 116 | zhibiao_pd.to_excel(zhibiao_path + "/目前为止最合适的model指标:第%d个epoch的验证指标[PA, IOU, DICE, P, R, F1].xlsx" % epoch) 117 | else: 118 | pass # 验证损失没有变小则不做处理 119 | 120 | torch.cuda.empty_cache() # 清理内存 121 | # optimizer.step() 122 | # scheduler.step() 123 | end = time.perf_counter() # 记录训练结束时间 124 | train_time = end-start # 记录总耗时 125 | print('Running time: %s Seconds' % train_time) # 打印总耗时 126 | time_list = list([train_time]) # 总耗时转化为列表 127 | train_time_pd = pd.DataFrame(time_list) # 存成excel格式 128 | train_time_pd.to_excel(zhibiao_path + "/总epoch的训练时间(不包含测试).xlsx") 129 | 130 | 131 | # 训练和验证 结束,保存的最好模型在 model_path +fengefu +'best_model.pth',用它进行测试 132 | 133 | test_start = time.perf_counter() # 记录测试开始时间 134 | torch.cuda.empty_cache() # 清一下内存 135 | 136 | test_loss_list = [] # 准备放测试损失 137 | test_zhibiao_list = [] # 准备放测试指标 138 | 139 | 140 | model = torch.load(str(Path(model_path)/'best_model.pth')) # 载入最好模型 141 | model = model.to(DEVICE) # 部署到gpu或cpu上 142 | 143 | test_loss, test_zhibiao = test_model(model, DEVICE, test_loader,EPOCH,test=True) # 测试 144 | test_loss_list.append(test_loss) # 测试损失 145 | test_zhibiao_list.append(test_zhibiao) # 测试指标 146 | 147 | test_loss_pd = pd.DataFrame(test_loss_list) # 存成excel格式 148 | test_loss_pd.to_excel(zhibiao_path + "/测试损失.xlsx") 149 | test_zhibiao_pd = pd.DataFrame(test_zhibiao_list) # 存成excel格式 150 | test_zhibiao_pd.to_excel(zhibiao_path + "/测试验证指标[PA, IOU, DICE, P, R, F1].xlsx") 151 | 152 | 153 | test_end = time.perf_counter() # 记录测试结束时间 154 | test_time =test_end-test_start # 记录测试耗时 155 | print('Running time: %s Seconds' % test_time) # 打印总耗时 156 | test_time_list = list([test_time]) # 测试时间转化为列表 157 | test_time_pd = pd.DataFrame(test_time_list) # 存成excel格式 158 | test_time_pd.to_excel(zhibiao_path + "/测试时间.xlsx") 159 | -------------------------------------------------------------------------------- /train_def.py: -------------------------------------------------------------------------------- 1 | from global_annos import * 2 | from global_ import * 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | from sklearn.metrics import confusion_matrix 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | class dice_loss(nn.Module): # dice损失,做反向传播 15 | def __init__(self,c_num=2): # 格式需要 16 | super(dice_loss, self).__init__() 17 | def forward(self,data,label): # 格式需要 18 | n = data.size(0) # data.size(0)指 batch_size 的值,也就是一个批次几个 19 | dice_list = [] # 用来放本批次中的每一个图的dice 20 | all_dice = 0. # 一会 算本批次的平均dice 用 21 | for i in range(n): # 本批次内,拿一个图出来 22 | 23 | my_label11 = label[i] # my_label11为取得的对应label,也可以说是前景为结节的分割图 24 | my_label1 = torch.abs(1 - my_label11) # my_label1为 前景为非结节的分割图 1-1=0,1-0=1,这样就互换了 25 | 26 | my_data1 = data[i][0] # my_data1为我的模型预测出的 前景为非结节的分割图 27 | my_data11 = data[i][1] # my_data11为我的模型预测出的 前景为结节的分割图 28 | 29 | m1 = my_data1.view(-1) # 把my_data1拉成一维 ps:前景为非结节的分割图 30 | m2 = my_label1.view(-1) # 把my_label1拉成一维 ps:前景为非结节的分割图 31 | 32 | m11 = my_data11.view(-1) # 把my_data1拉成一维 ps:前景为结节的分割图 33 | m22 = my_label11.view(-1) # 把my_label1拉成一维 ps:前景为结节的分割图 34 | 35 | dice = 0 # dice初始化为0 36 | dice += (1-(( 2. * (m1 * m2).sum() +1 ) / (m1.sum() + m2.sum() +1))) # dice loss = 1-DSC的公式,比较的是 前景为非结节的分割图 37 | dice += (1-(( 2. * (m11 * m22).sum() + 1) / ( m11.sum()+m22.sum()+ 1))) # dice loss = 1-DSC的公式,比较的是 前景为结节的分割图 38 | dice_list.append(dice) # 里面放本批次中的所有图的dice,每张图的dice为 前景结节 和 前景非结节 两图的dice loss 求和 39 | 40 | 41 | for i in range(n): # 遍历本批次所有图 42 | all_dice += dice_list[i] # 求和 43 | dice_loss = all_dice/n/2 44 | 45 | return dice_loss # 返回本批次所有图的平均dice loss 46 | 47 | 48 | Loss = dice_loss().to(DEVICE) # 损失函数布置到gpu或cpu上 49 | 50 | 51 | def train_model(model, device, train_loader, optimizer,scheduler, epoch): # 训练模型 52 | # 模型训练-----调取方法 53 | model.train() # 用来训练的 54 | loss_need = [] # 记录loss 55 | tqdr = tqdm(enumerate(train_loader)) # 用一下tqdm函数,也就是进度条工具(枚举) 56 | for batch_index, (data, target) in tqdr: # 取batch索引,(data,target),也就是图和标签 57 | data, target = data.to(device), target.to(device) # 放到gpu或cpu上 58 | output = model(data) # 图 进模型 得到预测输出 59 | loss = Loss(output, target) # 计算损失 60 | optimizer.zero_grad() # 梯度归零 61 | loss.backward() # 反向传播 62 | optimizer.step() # 优化器走一步 63 | 64 | train_loss = loss.item() # 取得损失值 65 | loss_need.append(train_loss) # 放到loss_need列表里 66 | tqdr.set_description("Train Epoch : {} \t train Loss : {:.6f} ".format(epoch, loss.item())) # 实时显示损失 67 | #scheduler.step() 68 | print(optimizer.state_dict()['param_groups'][0]['lr'], scheduler.state_dict()['_last_lr'][0]) 69 | train_loss = np.mean(loss_need) # 求平均 70 | print("train_loss", train_loss) # 打印平均损失 71 | return train_loss,loss_need # 返回平均损失,损失列表 72 | 73 | def test_model(model, device, test_loader, epoch,test): # 加了个test 1是想打印时好看(区分valid和test) 2是test要打印图,需要特别设计 74 | # 模型训练-----调取方法 75 | model.eval() # 用来验证或测试的 76 | test_loss = 0.0 # 测试损失 77 | PA = IOU = DICE = P =R =F1 = 0 # 指标归0 78 | tqrr = tqdm(enumerate(test_loader)) # 进度条 79 | with torch.no_grad(): # 不进行 梯度计算(反向传播) 80 | for batch_index,(data, target) in tqrr: # 枚举batch索引,(图,标签) 81 | if test: # 如果是测试,做可视化;如果是验证,则不做 82 | data_cpu = data.clone().cpu() # 取出图到cpu 83 | my_label_cpu = target.clone().cpu() # 取出预测的二值分割到cpu 84 | for i in range(len(data_cpu)): # 取出改batch中的单张图 85 | true_img_tensor = data_cpu[i][0] # 取图得到张量tensor,注意这里的[0]是因为我们在dataset部分给图增加了一个维度 86 | true_label_tensor = my_label_cpu[i] # 取得预测的二值分割张量tensor 87 | use_plot_2d(true_img_tensor,true_label_tensor,z=8,batch_index=batch_index,i=i,true_label=True) # 存图,这里存标签图到pic 88 | 89 | data, target = data.to(device), target.to(device) 90 | torch.cuda.empty_cache() 91 | output = model(data) #(output.shape) torch.Size([4, 2, 96, 96, 96]) 92 | loss = Loss(output, target)#*nllLoss(out1, target) 93 | test_loss += loss.item() 94 | 95 | PA0, IOU0, DICE0, P0, R0, F10,tn, fp, fn, tp = zhibiao(output, target) 96 | PA += PA0 97 | IOU += IOU0 98 | DICE += DICE0 99 | P += P0 100 | R += R0 101 | F1 += F10 102 | if test: 103 | name = 'Test' 104 | else: 105 | name = 'Valid' 106 | tqrr.set_description("{} Epoch : {} \t {} Loss : {:.6f} \t tn, fp, fn, tp: {:.0f} {:.0f} {:.0f} {:.0f} ".format(name,epoch,name, loss.item(),tn, fp, fn, tp)) 107 | if test: 108 | data_cpu = data.clone().cpu() 109 | my_output_cpu = output.clone().cpu() 110 | for i in range(len(data_cpu)): 111 | img_tensor = data_cpu[i][0] # 96 * 96 * 96 112 | label_tensor = torch.gt(my_output_cpu[i][1], my_output_cpu[i][0]) # 96 * 96 * 96 113 | use_plot_2d(img_tensor,label_tensor,z=8,batch_index=batch_index,i=i) 114 | 115 | test_loss /= len(test_loader) 116 | PA /= len(test_loader) 117 | IOU /= len(test_loader) 118 | DICE /= len(test_loader) 119 | P /= len(test_loader) 120 | R /= len(test_loader) 121 | F1 /= len(test_loader) 122 | 123 | print(" Epoch : {} \t {} Loss : {:.6f} \t DICE :{:.6f} PA: {:.6f} ".format(epoch, name,test_loss,DICE,PA)) 124 | 125 | return test_loss, [PA, IOU, DICE, P, R, F1] 126 | 127 | 128 | class myDataset(Dataset): 129 | def __init__(self, data_path, label_path): ### transform 我没写 130 | self.annos_img, self.annos_label = self.get_img_label(data_path, label_path) 131 | 132 | def __getitem__(self, index): 133 | img_all = self.annos_img[index] 134 | label_all = self.annos_label[index] 135 | img = np.load(img_all) # 载入的是图片地址 136 | label = np.load(label_all) # 载入的是label地址 137 | 138 | img = np.expand_dims(img,0) ##(1, 96, 96, 96) 139 | img = torch.tensor(img) 140 | img = img.type(torch.FloatTensor) 141 | label = torch.Tensor(label).long() ##(96, 96, 96) label不用升通道维度 142 | torch.cuda.empty_cache() 143 | return img,label ### 从这里出去还是96*96*96 144 | 145 | def __len__(self): 146 | return len(self.annos_img) 147 | 148 | @staticmethod 149 | def get_img_label(data_path, label_path): ### list 地址下所有图片的绝对地址 150 | data_path = [Path(i).name for i in data_path] 151 | msgs = pd.read_excel(msg_path) 152 | img_paths = [] 153 | lbl_paths = [] 154 | for i,v in msgs.iterrows(): 155 | img_path = v['img_path'] 156 | if Path(img_path).parent.name in data_path: 157 | lbl_path = v['lbl_path'] 158 | img_paths.append(img_path) 159 | lbl_paths.append(lbl_path) 160 | return img_paths,lbl_paths # 返回的也就是图像路径列表 和 标签路径列表 161 | 162 | 163 | class cutDataset(Dataset): 164 | 165 | def __init__(self, data_path, label_path): ### transform 我没写 166 | self.data = self.get_img_label(data_path) ## 图的位置列表 167 | self.label = self.get_img_label(label_path) ## 标签的位置列表 168 | 169 | self.annos_img = self.get_annos_label(self.data) # 图的位置列表 输入进去 吐出 结节附近的图的【【图片位置,结节中心,半径】列表】 170 | self.annos_label = self.get_annos_label(self.label) #112 171 | 172 | 173 | def __getitem__(self, index): 174 | img_all = self.annos_img[index] 175 | label_all = self.annos_label[index] 176 | img = np.load(img_all[0]) # 载入的是图片地址 177 | label = np.load(label_all[0]) # 载入的是label地址 178 | cut_list = [] ## 切割需要用的数 179 | 180 | for i in range(len(img.shape)): ### 0,1,2 → z,y,x 181 | if i == 0: 182 | a = img_all[1][-i - 1] - 8 ### z 183 | b = img_all[1][-i - 1] + 8 184 | else: 185 | a = img_all[1][-i-1]-48 ### z 186 | b = img_all[1][-i-1]+48 ### 187 | if a<0: 188 | if i == 0: 189 | a = 0 190 | b = 96 191 | else: 192 | a = 0 193 | b = 96 194 | elif b>img.shape[i]: 195 | if i == 0 : 196 | a = img.shape[i] - 16 197 | b = img.shape[i] 198 | else: 199 | a = img.shape[i]-96 200 | b = img.shape[i] 201 | else: 202 | pass 203 | 204 | cut_list.append(a) 205 | cut_list.append(b) 206 | 207 | 208 | cut_list = [round(i) for i in cut_list] 209 | img = img[cut_list[0]:cut_list[1],cut_list[2]:cut_list[3],cut_list[4]:cut_list[5]] ### z,y,x 210 | label = label[cut_list[0]:cut_list[1],cut_list[2]:cut_list[3],cut_list[4]:cut_list[5]] ### z,y,x 211 | one_path_img = str(Path(output_path) / "bbox_image_npy" / Path(img_all[0]).parent.name / ( 212 | Path(img_all[0]).stem + f'_{img_all[-1]}.npy')) 213 | Path(one_path_img).parent.mkdir(exist_ok=True, parents=True) 214 | np.save(one_path_img, img) 215 | one_path_label = str(Path(output_path) / "bbox_mask_npy" / Path(img_all[0]).parent.name / ( 216 | Path(img_all[0]).stem + f'_{img_all[-1]}.npy')) 217 | Path(one_path_label).parent.mkdir(exist_ok=True, parents=True) 218 | np.save(one_path_label, label) 219 | one_list = [str(one_path_img),str(one_path_label),str(img_all[1])] 220 | 221 | # img = np.expand_dims(img,0) ##(1, 96, 96, 96) 222 | # img = torch.tensor(img) 223 | # img = img.type(torch.FloatTensor) 224 | # label = torch.Tensor(label).long() ##(96, 96, 96) label不用升通道维度 225 | # torch.cuda.empty_cache() 226 | return one_list ### 从这里出去还是96*96*96 227 | 228 | 229 | def __len__(self): 230 | return len(self.annos_img) 231 | 232 | 233 | @staticmethod 234 | def get_img_label(data_path): ### list 地址下所有图片的绝对地址 235 | 236 | img_path = [] 237 | for t in data_path: ### 打开subset0,打开subset1 238 | data_img_list = os.listdir(t) ## 列出图 239 | img_path += [os.path.join(t, j) for j in data_img_list] ##'/public/home/menjingru/dataset/sk_output/bbox_image/subset1/1.3.6.1.4.1.14519.5.2.1.6279.6001.104562737760173137525888934217.npy' 240 | img_path.sort() 241 | return img_path ##返回的也就是图像路径 或 标签路径 242 | 243 | @staticmethod 244 | def get_annos_label(img_path): 245 | annos_path = [] # 这里边要装图的地址,结节的中心,结节的半径 要小于96/4 # ###半径最大才12 246 | 247 | ### ok , anoos 是处理好的列表了,我只需要把他们对比一下是否在列表里,然后根据列表里的坐标输出一个列表 就可以了 在__getitem__里边把它切下来就行 248 | 249 | for u in img_path: # 图的路径 250 | name = Path(u).stem 251 | for one in annos_list: # 遍历有结节的图 252 | if one[0] == name: # 如果有结节的图的名字 == 输入的图的名字 253 | one1_list = eval(one[1]) # gai2 254 | for l in range(len(one1_list)): # 数一数有几个结节 255 | annos_path.append( 256 | [u, [one1_list[l][0], one1_list[l][1], one1_list[l][2]], one1_list[l][3],l]) # 图的地址,结节的中心 257 | return annos_path # ###半径最大才12 258 | 259 | 260 | def zhibiao(data,label): # data n,2,96,96,96 label n,96,96,96 261 | 262 | ### 这里需要把data变换成label形式,方法是取大为1 263 | 264 | n = data.size(0) 265 | PA, IOU, DICE, P, R, F1 ,TN, FP, FN, TP= 0,0,0,0,0,0,0,0,0,0 266 | 267 | 268 | for i in range(n): 269 | 270 | empty_data = torch.gt(data[i][1], data[i][0]) 271 | empty_data = empty_data.long() #pred label 272 | 273 | my_data = empty_data ## 得到处理好的 pred label(96*96*96) 274 | my_label = label[i] ## 标准答案 label 275 | 276 | 277 | my_data = my_data.cpu().numpy() 278 | my_data = numpy_list(my_data) 279 | # print(my_data) 280 | 281 | my_label = my_label.cpu().numpy() 282 | my_label = numpy_list(my_label) 283 | 284 | 285 | confuse = confusion_matrix(my_label,my_data,labels=[0,1]) ### 混淆矩阵 286 | tn, fp, fn, tp = confusion_matrix(my_label,my_data, labels=[0,1]).ravel() 287 | all = tn + fp + fn + tp 288 | # print("tn, fp, fn, tp",tn, fp, fn, tp) 289 | diag = torch.diag(torch.from_numpy(confuse)) 290 | b = 0 291 | for ii in diag: 292 | b += ii 293 | diag = b 294 | 295 | PA += float(torch.true_divide(diag , all )) ## 混淆矩阵 对角线/总数 296 | # IOU += float(torch.true_divide(diag,(2 * all - diag))) ## 交并比 297 | # DICE += float(2 * torch.true_divide(diag,2 * all)) 298 | IOU += float(torch.true_divide(tp,tp+fp+fn)) ## 交并比 299 | # DICE += float(2 * torch.true_divide(diag,2 * all)) 300 | DICE += float(torch.true_divide(2*tp,fp+fn+2*tp)) 301 | if tp + fp ==0: 302 | P += tp/(tp + fp+1) ## 精准率 (注意不是精度) 303 | else: 304 | P += tp/(tp + fp) ## 精准率 (注意不是精度) 305 | 306 | if tp + fn == 0: 307 | R += tp/(tp + fn+1) ## 召回率 308 | else: 309 | R += tp/(tp + fn) ## 召回率 310 | 311 | # if P + R == 0: 312 | # F1 += 2 * P * R / (P + R+1) 313 | # else: 314 | # F1 += 2 * P * R / (P + R) 315 | 316 | TN += tn 317 | FP += fp 318 | FN += fn 319 | TP += tp 320 | TN /= n 321 | FP /= n 322 | FN /= n 323 | TP /= n 324 | 325 | PA = PA/n 326 | IOU = IOU/n 327 | DICE = DICE/n 328 | P = P/n 329 | R = R/n 330 | if P + R == 0: 331 | F1 += 2 * P * R / (P + R + 1) 332 | else: 333 | F1 += 2 * P * R / (P + R) 334 | return PA,IOU,DICE,P,R,F1,TN, FP, FN, TP 335 | 336 | 337 | 338 | def numpy_list(numpy): 339 | x = [] 340 | numpy_to_list(x,numpy) 341 | return x 342 | 343 | 344 | def numpy_to_list(x,numpy): 345 | for i in range(len(numpy)): 346 | if type(numpy[i]) is np.ndarray: 347 | numpy_to_list(x,numpy[i]) 348 | else: 349 | x.append(numpy[i]) 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | def show_loss(loss_list,STR,path): ### 损失列表,损失名称,保存位置 359 | EPOCH = len(loss_list) ## 训练集中是 总epoch 验证集中是 总epoch/每多少epoch进行验证集的epoch数 测试集中就一个数不用画 360 | x1 = range(0, EPOCH) 361 | y1 = loss_list 362 | 363 | plt.plot(x1, y1, "-" ,label=STR) 364 | plt.legend() 365 | 366 | plt.savefig(path +'/%s.jpg'%STR) 367 | plt.close() 368 | 369 | 370 | def use_plot_2d(image,output,z = 132,batch_index=0,i=0,true_label=False): 371 | # z,y,x#查看第100张图像 372 | plt.figure() 373 | p = image[z, :, :] +0.25 ## 96*96 这是归一化后的 374 | p = torch.unsqueeze(p,dim=2) 375 | q = output[z, :, :] ##96*96 376 | q = (q * 0.2).float() 377 | q = torch.unsqueeze(q,dim=2) 378 | q = p + q 379 | q[q >1] = 1 380 | r = p 381 | cat_pic = torch.cat([r,q,p],dim=2) # 红色为空,my_label为绿色,原图为蓝色 382 | plt.imshow(cat_pic) 383 | 384 | path = zhibiao_path # 我真的懒得引入参数了,这个path 就是 zhibiao_path 385 | if true_label: 386 | if not os.path.exists(str(Path(path)/'true_pic')): # groud truth 387 | os.mkdir(str(Path(path)/'true_pic')) 388 | plt.savefig(str(Path(path)/'true_pic'/f'{batch_index}_{i}.jpg')) 389 | else: 390 | if not os.path.exists(str(Path(path)/'pic')): # predict 391 | os.mkdir(str(Path(path)/'pic')) 392 | plt.savefig(str(path)/'pic'/f'{batch_index}_{i}.jpg')) 393 | plt.close() 394 | -------------------------------------------------------------------------------- /vnet.py: -------------------------------------------------------------------------------- 1 | #### 构建vnet 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import vnet_def 7 | import torch.nn.functional as f 8 | 9 | 10 | 11 | 12 | 13 | class VNet(nn.Module): 14 | def __init__(self,num_classes=2): 15 | super(VNet, self).__init__() 16 | 17 | self.layer0 = vnet_def.res_block(1, 16,"forward0") 18 | 19 | self.layer11 = vnet_def.res_block(16,32,"deconv") 20 | self.layer2 = vnet_def.res_block(32,32,"forward2") 21 | 22 | self.layer22 = vnet_def.res_block(32,64,"deconv") 23 | self.layer3 = vnet_def.res_block(64,64,"forward3") 24 | 25 | self.layer33 = vnet_def.res_block(64,128,"deconv",dropout=True) 26 | self.layer4 = vnet_def.res_block(128,128,"forward3") 27 | 28 | self.layer44 = vnet_def.res_block(128,256,"deconv",dropout=True) 29 | self.layer5 = vnet_def.res_block(256,256,"forward3") 30 | 31 | self.layer55 = vnet_def.res_block(256,128,"upconv") 32 | self.layer6 = vnet_def.res_block(256,256,"forward3") 33 | 34 | self.layer66 = vnet_def.res_block(256,64,"upconv") 35 | self.layer7 = vnet_def.res_block(128, 128,"forward3") 36 | 37 | self.layer77 = vnet_def.res_block(128, 32,"upconv") 38 | self.layer8 = vnet_def.res_block(64, 64,"forward2") 39 | 40 | self.layer88 = vnet_def.res_block(64, 16,"upconv") 41 | self.layer9 = vnet_def.res_block(32, 32,"forward1") 42 | 43 | self.layer10 = vnet_def.res_block(32,num_classes,"forward10") 44 | self.softmax = nn.Softmax(dim=1) #log_softmax 45 | 46 | self.dropv = nn.Dropout3d() 47 | 48 | 49 | #### 提取特征 50 | 51 | def forward(self,x): 52 | out = self.layer0(x) 53 | link1 = out # 16 54 | out = self.layer11(out)#.deconv(out) 55 | out = self.layer2(out)#.forward2(out) 56 | link2 = out # 32 57 | out = self.layer22(out)#.deconv(out) 58 | out = self.layer3(out)#.forward3(out) 59 | link3 = out #64 60 | out = self.layer33(out)#,dropout=True)#.deconv(out) 61 | out = self.layer4(out)#.forward3(out) 62 | link4 = out # 128 63 | out = self.layer44(out)#,dropout=True)#.deconv(out) 64 | out = self.layer5(out)#.forward3(out) 65 | 66 | out = self.layer55(out)#.upconv(out) 67 | out = torch.cat((self.dropv(link4),out),1) 68 | out = self.layer6(out)#.forward3(out) 69 | 70 | out = self.layer66(out)#.upconv(out) 71 | out = torch.cat((self.dropv(link3),out),1) 72 | out = self.layer7(out)#.forward3(out) 73 | 74 | out = self.layer77(out)#.upconv(out) 75 | out = torch.cat((self.dropv(link2), out), 1) 76 | out = self.layer8(out)#.forward2(out) 77 | 78 | out = self.layer88(out)#.upconv(out) 79 | out = torch.cat((self.dropv(link1), out), 1) 80 | out = self.layer9(out)#.forward1(out) 81 | 82 | out = self.layer10(out)#.pointconv(out) 83 | out = self.softmax(out) 84 | return out 85 | 86 | -------------------------------------------------------------------------------- /vnet_def.py: -------------------------------------------------------------------------------- 1 | 2 | # 3 | # import torch.nn as nn 4 | # import numpy as np 5 | # import torch 6 | # import math 7 | # import torch.nn.functional as f 8 | # 9 | # class sk_block(nn.Module): ### 明天照着这个继续写 https://blog.csdn.net/zahidzqj/article/details/105982058 10 | # def __init__(self, in_channel, out_channel, M=2, r=16, L=32): ### M是分支数,r是降维比率,L是维度下界 11 | # super(sk_block, self).__init__() 12 | # self.in_channel = in_channel #### 我们需要的 输入 要等与 输出 13 | # self.out_channel = out_channel 14 | # self.M = M 15 | # self.r = r 16 | # self.L = L 17 | # g = min(in_channel, 16, out_channel) 18 | # self.k_3_conv = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, 19 | # padding=1, groups=g) # .cuda() 20 | # self.dilated_conv = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, 21 | # padding=2, dilation=2, groups=g) # .cuda() # 膨胀卷积 22 | # self.ave_pooling = nn.AdaptiveAvgPool3d(1) # .cuda() # 全局平均池化 23 | # d = max(out_channel // r, L) 24 | # self.fc1 = nn.Linear(out_channel, d) # .cuda() 25 | # self.fc2 = nn.Linear(d, out_channel) # .cuda() 26 | # self.softmax = f.softmax 27 | # self.prelu = nn.PReLU() 28 | # self.bn = nn.BatchNorm3d(out_channel) 29 | # self.bn1 = nn.BatchNorm1d(d) 30 | # self.point = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1) 31 | # 32 | # def forward(self, x): 33 | # res = self.point(x) 34 | # out1 = self.k_3_conv(x) ## 这里 通道数 变了 (BS,C,SHAPE) 35 | # out1 = self.bn(out1) 36 | # out1 = self.prelu(out1) 37 | # out2 = self.dilated_conv(x) 38 | # out2 = self.bn(out2) 39 | # out2 = self.prelu(out2) 40 | # out = out1.add(out2) 41 | # out1d = self.ave_pooling(out) ## (BD,C,1*1*1) 42 | # out1d = torch.flatten(out1d, start_dim=1) 43 | # out = self.fc1(out1d) 44 | # # out = self.bn1(out) 45 | # out = self.prelu(out) 46 | # outfc1 = self.fc2(out) 47 | # # 48 | # outfc1 = self.prelu(outfc1) 49 | # outfc2 = self.fc2(out) 50 | # # 51 | # outfc2 = self.prelu(outfc2) 52 | # outfc = torch.cat((outfc1, outfc2), 0) 53 | # 54 | # out = self.softmax(outfc, 1) # 55 | # k_3_out = out[0, :].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) 56 | # dil_out = out[1, :].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) 57 | # se1 = torch.mul(k_3_out, out1) ### 这里两个不同大小的张量要相乘了 先把一个张量扩张一下 再点乘 58 | # se2 = torch.mul(dil_out, out2) 59 | # out = se1.add(se2) 60 | # out = res + out 61 | # return out # 有正有负,在0附近 62 | # # 由于数据量少,小网效果更好 63 | # 64 | # class mv2_block(nn.Module): 65 | # def __init__(self,in_channel,out_channel): 66 | # super(mv2_block,self).__init__() 67 | # self.in_channel = in_channel 68 | # self.out_channel = out_channel 69 | # 70 | # # self.d_conv = nn.Conv3d(in_channels=2 * in_channel,out_channels=2 * in_channel,kernel_size=5,stride=1,padding=2,groups=2*in_channel) # 深度卷积 71 | # self.d_conv = nn.Conv3d(in_channels=2 * in_channel,out_channels=2 * in_channel,kernel_size=3,stride=1,padding=1,groups=2*in_channel) # 深度卷积 72 | # # self.d_conv = se_block(2*in_channel, 2 * in_channel) # 深度卷积 73 | # 74 | # self.p_conv1 = nn.Conv3d(in_channels=in_channel, out_channels=2 * in_channel, kernel_size=1, stride=1,groups=1) # 点卷积1 75 | # # self.p_conv1 = se_block(in_channel, 3 * in_channel) # 点卷积1 76 | # 77 | # self.p_conv2 = nn.Conv3d(in_channels=2 * in_channel, out_channels=out_channel, kernel_size=1, stride=1,groups=1) # 点卷积2 78 | # # self.p_conv2 = se_block(3 * in_channel, out_channel) 79 | # 80 | # self.prelu = nn.PReLU() 81 | # 82 | # 83 | # def forward(self,x): 84 | # resres = x 85 | # mv2res = res_block(self.in_channel, self.out_channel, "pointconv") 86 | # resres = mv2res(resres) 87 | # 88 | # outupc = self.p_conv1(x) 89 | # outupc = self.prelu(outupc) 90 | # 91 | # out = self.d_conv(outupc) # 分了16组 每组 4,1,96,96,96 92 | # out = self.prelu(out) 93 | # out = self.p_conv2(out) 94 | # ### 线性激活函数我没找到 用 y=x代替咯 95 | # out = resres.add(out) 96 | # return out 97 | # 98 | # class se_block(nn.Module): 99 | # def __init__(self,in_channel,out_channel,r=16,L=4): 100 | # super(se_block, self).__init__() 101 | # self.in_channel = in_channel 102 | # self.out_channel = out_channel 103 | # self.point = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1) 104 | # self.k_3_conv = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1,padding=1) 105 | # self.bn = nn.BatchNorm3d(out_channel) 106 | # self.prelu = nn.PReLU() 107 | # self.ave_pooling = nn.AdaptiveAvgPool3d(1) 108 | # d = max(out_channel // r,L) 109 | # self.fc1 = nn.Linear(out_channel, d) # .cuda() 110 | # self.fc2 = nn.Linear(d, out_channel) # .cuda() 111 | # 112 | # def forward(self,x): 113 | # res = self.point(x) 114 | # # res = self.bn(res) 115 | # res = self.prelu(res) 116 | # 117 | # out = self.k_3_conv(x) 118 | # out = self.bn(out) 119 | # out = self.prelu(out) 120 | # # print(out.shape) 121 | # out1d = self.ave_pooling(out) ## (BD,C,1*1*1) 122 | # out1d = torch.flatten(out1d, start_dim=1) 123 | # # print(out1d.shape) 124 | # out_mid = self.fc1(out1d) 125 | # out_mid = self.prelu(out_mid) 126 | # out_out = self.fc2(out_mid) 127 | # out_out = self.prelu(out_out) 128 | # out_out = out_out.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 129 | # # print(out_out.shape) 130 | # out = torch.mul(out, out_out) 131 | # out = out + res 132 | # return out 133 | # 134 | # class res_block(nn.Module): ##nn.Module 135 | # def __init__(self, i_channel, o_channel,lei): 136 | # super(res_block, self).__init__() 137 | # self.in_c = i_channel 138 | # self.out_c = o_channel 139 | # 140 | # self.conv1 = nn.Conv3d(in_channels=self.in_c, out_channels=self.out_c, kernel_size=5, stride=1,padding=2).cuda()#.to(device) ### 从 输入channel 到 输出channel 141 | # self.conv2 = nn.Conv3d(in_channels=self.out_c, out_channels=self.out_c, kernel_size=5, stride=1,padding=2).cuda()#.to(device) ### 从 输出channel 到 输出channel (叠加层) 142 | # # self.conv1 = sk_block(in_channel=i_channel, out_channel=o_channel).cuda() 143 | # # self.conv2 = sk_block(in_channel=o_channel, out_channel=o_channel).cuda() 144 | # 145 | # self.conv3 = nn.Conv3d(in_channels=self.out_c, out_channels=self.out_c, kernel_size=2, stride=2).cuda()#.to(device) ### 卷积下采样 146 | # 147 | # self.conv4 = nn.ConvTranspose3d(in_channels=self.out_c, out_channels=self.out_c, kernel_size=2, stride=2).cuda()#.to(device) ### 反卷积上采样 148 | # 149 | # self.conv5 = nn.Conv3d(in_channels=self.in_c, out_channels=self.out_c, kernel_size=1, stride=1).cuda()#.to(device) ### 点卷积 150 | # 151 | # self.bn = nn.BatchNorm3d(o_channel).cuda()#.to(device) 152 | # self.prelu = nn.PReLU().cuda()#.to(device) 153 | # self.lei = lei 154 | # 155 | # 156 | # def forward(self,x): 157 | # if self.lei == "forward1": 158 | # # out = self.forward1(x) 159 | # x = x.to(torch.float32) 160 | # res = x ### 记录下输入时的 x 161 | # res1 = res_block(self.in_c, self.out_c, "pointconv") 162 | # res = res1(res) 163 | # out = self.conv1(x) 164 | # out = self.bn(out) 165 | # out = res.add(out) 166 | # out = self.prelu(out) 167 | # elif self.lei == "forward2": 168 | # # out = self.forward2(x) 169 | # res = x ### 记录下输入时的 x 170 | # res1 = res_block(self.in_c, self.out_c, "pointconv") 171 | # res = res1(res) 172 | # out = self.conv1(x) 173 | # out = self.bn(out) 174 | # out = self.prelu(out) 175 | # out = self.conv2(out) 176 | # out = self.bn(out) 177 | # 178 | # out = res.add(out) 179 | # out = self.prelu(out) 180 | # elif self.lei == "forward3": 181 | # # out = self.forward3(x) 182 | # res = x ### 记录下输入时的 x 183 | # res1 = res_block(self.in_c, self.out_c, "pointconv") 184 | # res = res1(res) 185 | # out = self.conv1(x) 186 | # out = self.bn(out) 187 | # out = self.prelu(out) 188 | # out = self.conv2(out) 189 | # out = self.bn(out) 190 | # out = self.prelu(out) 191 | # out = self.conv2(out) 192 | # out = self.bn(out) 193 | # out = res.add(out) 194 | # out = self.prelu(out) 195 | # elif self.lei == "deconv": 196 | # # out = self.deconv(x) 197 | # out = self.conv3(x) 198 | # out = self.bn(out) 199 | # out = self.prelu(out) 200 | # elif self.lei == "upconv": 201 | # # out = self.upconv(x) 202 | # out = self.conv4(x) 203 | # out = self.bn(out) 204 | # out = self.prelu(out) 205 | # elif self.lei == "pointconv": 206 | # # out = self.pointconv(x) 207 | # out = self.conv5(x) 208 | # out = self.bn(out) 209 | # out = self.prelu(out) 210 | # else: 211 | # print("有问题") 212 | # out = x 213 | # return out 214 | 215 | 216 | import torch.nn as nn 217 | import torch 218 | import torch.nn.functional as f 219 | 220 | 221 | class res_block(nn.Module): ##nn.Module 222 | def __init__(self, i_channel, o_channel, lei, dropout=False): 223 | super(res_block, self).__init__() 224 | self.in_c = i_channel 225 | self.out_c = o_channel 226 | 227 | self.conv1 = nn.Conv3d(in_channels=i_channel, out_channels=i_channel, kernel_size=5, stride=1, padding=2) 228 | self.conv2 = nn.Conv3d(in_channels=i_channel, out_channels=o_channel, kernel_size=5, stride=1, padding=2) 229 | self.conv3 = nn.Conv3d(in_channels=i_channel, out_channels=o_channel, kernel_size=2, stride=2).cuda() # 卷积下采样 230 | self.conv4 = nn.ConvTranspose3d(in_channels=i_channel, out_channels=o_channel, kernel_size=2, stride=2).cuda() # 反卷积上采样 231 | self.conv5 = nn.Conv3d(in_channels=i_channel, out_channels=o_channel, kernel_size=1, stride=1).cuda() # 点卷积 232 | 233 | self.bn = nn.BatchNorm3d(i_channel).cuda() 234 | self.bn1 = nn.BatchNorm3d(o_channel).cuda() 235 | self.prelu = nn.ELU().cuda() 236 | self.lei = lei 237 | self.dropout = dropout 238 | self.drop = nn.Dropout3d() 239 | 240 | 241 | def forward(self,x): 242 | if self.lei == "forward1": 243 | out = self.forward1(x) 244 | elif self.lei == "forward2": 245 | out = self.forward2(x) 246 | elif self.lei == "forward3": 247 | out = self.forward3(x) 248 | elif self.lei == "deconv": 249 | out = self.deconv(x) 250 | elif self.lei == "upconv": 251 | out = self.upconv(x) 252 | elif self.lei == "forward10": 253 | out = self.forward10(x) 254 | elif self.lei == "forward0": 255 | out = self.forward0(x) 256 | else: 257 | out = self.pointconv(x) 258 | return out 259 | 260 | 261 | def forward0(self, x): 262 | x = x.to(torch.float32) 263 | res = torch.cat((x,x,x,x,x,x,x,x,x,x,x,x,x,x,x,x),1) 264 | 265 | out = self.conv2(x) 266 | out = self.bn1(out) 267 | out = self.prelu(out) 268 | out = res.add(out) 269 | out = self.prelu(out) 270 | return out 271 | 272 | def forward1(self, x): 273 | 274 | x = x.to(torch.float32) 275 | res = x 276 | res1 = res_block(self.in_c,self.out_c,"pointconv") 277 | res = res1(res) 278 | 279 | out = self.conv1(x) 280 | out = self.bn1(out) 281 | out = self.prelu(out) 282 | 283 | out = res.add(out) 284 | out = self.prelu(out) 285 | return out 286 | 287 | def forward2(self, x): 288 | res = x ### 记录下输入时的 x 289 | res1 = res_block(self.in_c, self.out_c, "pointconv") 290 | res = res1(res) 291 | 292 | out = self.conv1(x) 293 | out = self.bn(out) 294 | out = self.prelu(out) 295 | out = self.conv1(out) 296 | out = self.bn1(out) 297 | out = self.prelu(out) 298 | 299 | out = res.add(out) 300 | out = self.prelu(out) 301 | 302 | return out 303 | 304 | def forward3(self, x): 305 | res = x ### 记录下输入时的 x 306 | res1 = res_block(self.in_c, self.out_c, "pointconv") 307 | res = res1(res) 308 | 309 | out = self.conv1(x) 310 | out = self.bn(out) 311 | out = self.prelu(out) 312 | out = self.conv1(out) 313 | out = self.bn1(out) 314 | out = self.prelu(out) 315 | out = self.conv1(out) 316 | out = self.bn1(out) 317 | out = self.prelu(out) 318 | 319 | out = res.add(out) 320 | out = self.prelu(out) 321 | 322 | return out 323 | 324 | def forward10(self, x): 325 | out = self.conv1(x) 326 | out = self.bn(out) 327 | out = self.prelu(out) 328 | out = self.conv5(out) 329 | return out 330 | 331 | def deconv(self,x): 332 | 333 | out = self.conv3(x) 334 | out = self.bn1(out) 335 | out = self.prelu(out) 336 | if self.dropout: 337 | out = self.drop(out) 338 | return out 339 | 340 | def upconv(self,out): 341 | if self.dropout: 342 | out = self.drop(out) 343 | 344 | out = self.conv4(out) 345 | out = self.bn1(out) 346 | out = self.prelu(out) 347 | 348 | 349 | return out 350 | 351 | 352 | def pointconv(self,x): 353 | out = self.prelu(x) 354 | out = self.conv5(out) 355 | out = self.bn1(out) 356 | 357 | return out 358 | 359 | --------------------------------------------------------------------------------