├── .idea ├── .gitignore ├── Unet.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── data ├── test.txt ├── test_images │ └── test.txt ├── test_mask │ └── test.txt ├── train_images │ └── test.txt └── train_mask │ └── test.txt ├── dataload.py ├── debug.py ├── declare ├── demo_contents.png ├── kaggle-carvana.png ├── lr_loss_curve.png └── unet.png ├── model.py ├── readme.md ├── saved_images ├── img_0.png ├── img_1.png ├── img_10.png ├── img_11.png ├── img_12.png ├── img_13.png ├── img_14.png ├── img_15.png ├── img_16.png ├── img_17.png ├── img_18.png ├── img_19.png ├── img_2.png ├── img_20.png ├── img_21.png ├── img_22.png ├── img_23.png ├── img_24.png ├── img_25.png ├── img_26.png ├── img_27.png ├── img_28.png ├── img_29.png ├── img_3.png ├── img_30.png ├── img_31.png ├── img_32.png ├── img_33.png ├── img_34.png ├── img_35.png ├── img_36.png ├── img_37.png ├── img_4.png ├── img_5.png ├── img_6.png ├── img_7.png ├── img_8.png ├── img_9.png ├── img_input_0_0.png ├── img_input_0_1.png ├── img_input_0_2.png └── img_input_0_3.png ├── test.png ├── train.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | /parameters/ 10 | *.tar 11 | ../parameters/ 12 | ../declare/ 13 | /declare/ 14 | ../saved_images/ 15 | /saved_images/ 16 | ../__pycache__/ 17 | /__pycache__/ 18 | -------------------------------------------------------------------------------- /.idea/Unet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /data/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/test_images/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/test_mask/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train_images/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train_mask/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataload.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : Chaser 3 | @Time : 2023/3/3 14:50 4 | @File : dataload.py 5 | @software : PyCharm 6 | """ 7 | """ 8 | @update: 9 | 2023/3/5 10 | 2023/3/10 11 | """ 12 | 13 | import os 14 | import torch 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | import numpy as np 18 | 19 | #车的数据集加载器 20 | class CarDataset(Dataset): 21 | # 传入图像地址文件夹名和标签文件夹名 22 | def __init__(self, image_dir, mask_dir, transform=None): 23 | self.image_dir = image_dir 24 | self.mask_dir = mask_dir 25 | self.transform = transform 26 | # 包含在image_dir文件夹下的所有文件 27 | self.images = os.listdir(image_dir) 28 | 29 | def __len__(self): 30 | return len(self.images) 31 | 32 | def __getitem__(self, index): 33 | # 遍历读取图片的相对位置 34 | image_path = os.path.join(self.image_dir, self.images[index]) 35 | mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif")) 36 | 37 | # 将图片转换为RGB图,将标签图片转换为灰度图 38 | image = np.array(Image.open(image_path).convert("RGB"),dtype=float) 39 | mask = np.array(Image.open(mask_path).convert("L"),dtype=float) 40 | #令像素为255.0的为1,使所以像素为0或1 41 | mask[mask == 255.0 ] = 1.0 42 | 43 | # 对image,mask进行空间转换等操作 44 | if self.transform is not None: 45 | augumentations = self.transform(image=image, mask=mask) 46 | image = augumentations["image"] 47 | mask = augumentations["mask"] 48 | 49 | # print(image,mask) 50 | # print(mask.max()) 51 | return image, mask -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : Chaser 3 | @Time : 2023/3/3 14:59 4 | @File : debug.py 5 | @software : PyCharm 6 | """ 7 | 8 | """ 9 | @update: 10 | 2023/3/3 11 | 2023/3/5 12 | 2023/3/10 13 | """ 14 | 15 | # import numpy as np 16 | # import PIL.Image as Image 17 | # 18 | # mask_path = "./data/train_mask/0cdf5b5d0ce1_01_mask.gif" 19 | # 20 | # 21 | # mask = np.array(Image.open(mask_path).convert("L"),dtype=float) 22 | # 23 | # print(mask) 24 | # import os 25 | # folder = "./saved_images/" 26 | # idx = "0" 27 | # 28 | # os.remove("%spred_%s.png" % (folder, idx)) 29 | # import torchvision 30 | # import torch 31 | # y = torch.randn(1,1,224,224) 32 | # y = (y>0.5).float()*255 33 | # x = torch.ones(1,1,224,224) 34 | # z = torch.zeros(1,1,224,224) 35 | # print(y) 36 | # torchvision.utils.save_image(y, "test.png") 37 | 38 | 39 | # import torch 40 | # a = torch.ones(3,3,dtype=torch.int32) 41 | # b = torch.zeros(3,3,dtype=torch.int32) 42 | # a = a.int() 43 | # print(a|b) 44 | # print(a&b) 45 | 46 | 47 | # import numpy as np 48 | # import matplotlib.pyplot as plt 49 | # 50 | # a = np.random.normal(10,5,500) 51 | # b = np.random.normal(5,5,500) 52 | # cluster1 = np.array([[a,b,-1] for a,b in zip(a,b)]) 53 | # 54 | # a = np.random.normal(30,5,500) 55 | # b = np.random.normal(35,5,500) 56 | # cluster2 = np.array([[a,b,1] for a,b in zip(a,b)]) 57 | # 58 | # #axis 表示轴线,为1表示垂直,为0表示水平 59 | # dataset = np.append(cluster1,cluster2,axis=0) 60 | # 61 | # # print(a) 62 | # # print(b) 63 | # # print(dataset) 64 | # 65 | # 66 | # for i in dataset: 67 | # if i[2] == 1: 68 | # plt.scatter(i[0],i[1],c='r',s=8) 69 | # else: 70 | # plt.scatter(i[0],i[1],c='g',s=8) 71 | # plt.show() 72 | 73 | 74 | '''图像分割评价指标''' 75 | 76 | 77 | # print('a\nb') 78 | 79 | from train import evaluation_curve 80 | 81 | a = [95.651878,95.933640,95.941444] 82 | b = [90.1212,90.41232,90.141234] 83 | c = [83.12312,82.123124,84.21312] 84 | 85 | evaluation_curve(a,b,c) -------------------------------------------------------------------------------- /declare/demo_contents.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/declare/demo_contents.png -------------------------------------------------------------------------------- /declare/kaggle-carvana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/declare/kaggle-carvana.png -------------------------------------------------------------------------------- /declare/lr_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/declare/lr_loss_curve.png -------------------------------------------------------------------------------- /declare/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/declare/unet.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : Chaser 3 | @Time : 2023/3/3 14:50 4 | @File : dataload.py 5 | @software : PyCharm 6 | """ 7 | """ 8 | @update: 9 | 2023/3/5 10 | 2023/3/10 11 | 2023/5/10 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.nn.init as init 18 | 19 | # print(torch.__version__) 20 | 21 | 22 | # 连续两次卷积过程 23 | class DoubleConv(nn.Module): 24 | '''Convolution -> BN -> ReLU''' 25 | 26 | def __init__(self, in_channels, out_channels, mid_channels=None): 27 | super(DoubleConv, self).__init__() 28 | 29 | if not mid_channels: 30 | mid_channels = out_channels 31 | self.double_conv = nn.Sequential( 32 | # 经过卷积后,输出层h/w等于输入层 33 | nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1, padding=1, 34 | bias=False), 35 | nn.BatchNorm2d(mid_channels), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, 38 | bias=False), 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU(inplace=True) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.double_conv(x) 45 | 46 | 47 | # 下采样过程 48 | class Down(nn.Module): 49 | """Maxpool -> DoubleConv""" 50 | 51 | def __init__(self, in_channels, out_channels): 52 | super(Down, self).__init__() 53 | 54 | self.maxpool_conv = nn.Sequential( 55 | nn.MaxPool2d(2, 2), 56 | DoubleConv(in_channels, out_channels) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.maxpool_conv(x) 61 | 62 | 63 | # 上采样过程 64 | class Up(nn.Module): 65 | """ConvolutionTranspose -> DoubleConv""" 66 | 67 | def __init__(self, in_channels, out_channels, transpose=False): 68 | super(Up, self).__init__() 69 | 70 | '''conv: out_shape = (in_shape + 2*padding - kernel_size)/stride + 1 71 | conv_transpose: out_shape = (in_shape - 1)*strde + kernel_size - 2*padding 72 | ''' 73 | if transpose: 74 | self.up = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=2, 75 | stride=2) 76 | else: 77 | self.up = nn.Sequential( 78 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), 79 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, stride=2, padding=0), 80 | nn.ReLU(inplace=True) 81 | ) 82 | self.conv = DoubleConv(in_channels, out_channels) 83 | # 传递参数 84 | self.up.apply(self.init_weights) 85 | 86 | def forward(self, x1, x2): 87 | """ 88 | x1为上采样单元,x2为与x1同层的下采样单元 89 | """ 90 | 91 | x1 = self.up(x1) 92 | 93 | diffX = x2.size()[2] - x1.size()[2] 94 | diffY = x2.size()[3] - x1.size()[3] 95 | 96 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 97 | diffY // 2, diffY - diffY // 2) # 分别表示上下左右 98 | ) 99 | 100 | x = torch.cat([x2, x1], dim=1) 101 | return self.conv(x) 102 | 103 | # 设置参数权重 104 | @staticmethod 105 | def init_weights(m): 106 | if type(m) == nn.Conv2d: 107 | init.xavier_normal_(m.weight) 108 | init.constant_(m.bias, 0) 109 | 110 | # 最后一次卷积得到输出 111 | class OutConv(nn.Module): 112 | def __init__(self, in_channels, out_channels): 113 | super(OutConv, self).__init__() 114 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 115 | 116 | def forward(self, x): 117 | return self.conv(x) 118 | 119 | 120 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 121 | 122 | 123 | # Unet网络结构 124 | class Unet(nn.Module): 125 | def __init__(self, in_channels, out_channels): 126 | super(Unet, self).__init__() 127 | # 输入卷积 128 | self.inc = DoubleConv(in_channels, 64) 129 | 130 | # 下采样 131 | self.down1 = Down(64, 128) 132 | self.down2 = Down(128, 256) 133 | 134 | # 最后两次下采样进行dropout操作,防止过拟合 135 | self.down3 = Down(256, 512) 136 | self.drop3 = nn.Dropout2d(0.3) 137 | self.down4 = Down(512, 1024) 138 | self.drop4 = nn.Dropout2d(0.4) 139 | 140 | # 上采样 141 | self.up1 = Up(1024, 512, False) 142 | self.up2 = Up(512, 256, False) 143 | self.up3 = Up(256, 128, False) 144 | self.up4 = Up(128, 64, False) 145 | 146 | # 得到卷积输出 147 | self.outc = OutConv(64, out_channels) 148 | 149 | def forward(self, x): 150 | x1 = self.inc(x) 151 | 152 | x2 = self.down1(x1) 153 | x3 = self.down2(x2) 154 | x4 = self.down3(x3) 155 | x4 = self.drop3(x4) 156 | x5 = self.down4(x4) 157 | x5 = self.drop4(x5) 158 | 159 | x = self.up1(x5, x4) 160 | x = self.up2(x, x3) 161 | x = self.up3(x, x2) 162 | x = self.up4(x, x1) 163 | 164 | return self.outc(x) 165 | 166 | if __name__ == "__main__": 167 | model = Unet(in_channels=3,out_channels=1) 168 | 169 | inp = torch.randn(1,3,224,224) 170 | outp = model(inp) 171 | print(outp) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## 文件目录说明 2 | ![目录结构](/declare/demo_contents.png) 3 | 4 | 5 | * __data包__ 6 | test_images为测试集input 7 | test_mask为测试集target 8 | train_images为训练集input 9 | train_mask为训练集target 10 | 11 | 12 | * __parameters包__ 13 | 用于保存学习率为lr训练的模型参数 14 | 15 | 16 | * __saved_images包__ 17 | 用于保存模型训练完成后的predicted图像和target图像 18 | 19 | 20 | * __dataload.py文件__ 21 | 定义了CarDataset类,Car数据集加载器 22 | 23 | 24 | * __model.py文件__ 25 | 分别定义了DoubleConv类、Down类、Up类、OutConv类和Unet类 26 | **DoubleConv类**:进行连续的两次卷积,每次卷积的过程为Convolution -> BN -> ReLU 27 | **Down类**:下采样过程(downsampling),由最大池化(Maxpool)和连续两次卷积(DoubleConv)组成 28 | **Up类**:上采样过程(upsampling),由反卷积(转置卷积或双线性插值)和连续两次卷积(DoubleConv)组成 29 | **OutConv类**:最后一次卷积输出结果 30 | **Unet类**:Unet网络的主体结构,具体过程是输入卷积->四次下采样->四次上采样->得到卷积输出 31 | 32 | 33 | * __utils.py文件__ 34 | 一共定义了save_checkpoint、load_checkpoint、get_loaders、check_accuracy和save_predictions5个方法 35 | **save_checkpoint方法**: 用于保存模型的参数 36 | **load_checkpoint**:用于加载模型的参数 37 | **get_loaders**:用于加载数据集 38 | **check_accuracy**:用于检查训练模型的精度 39 | **save_predictions**:用于保存预测影像的结果 40 | 41 | 42 | * __train.py文件__ 43 | 一共定义了train_fn、find_lr、loss_lr_curve和main四个方法 44 | **train_fn方法**:用于模型的训练 45 | **find_lr方法**:找到最佳的学习率 46 | **loss_lr_curve方法**:用户绘制损失函数与学习率的关联曲线 47 | * **loss_curve方法**:绘制损失函数与数据训练批次的关系 48 | * **evaluation_curve方法**:绘制评估函数随着训练次数的变化 49 | **main方法**:主程序的入口,调用各个方法和api,完成模型的训练和预测以及评估 50 | 51 | 52 | ## 数据集准备 53 | ![数据集网站](/declare/kaggle-carvana.png) 54 | 网站链接:[Kaggle-carvana数据集下载网站](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data) 55 | 56 | 57 | ## 数据处理 58 | **先对.jpg和.tif图片转换为RGB图片何灰度图片,再使用albumentations库对数据进行增强** 59 | ### 数据加载 60 | 1.遍历读取图片的相对位置 61 | 2.将训练数据转换为RGB图,测试数据转换为灰度图 62 | 3.对数据进行增强操作 63 | 64 | ### 数据增强 65 | 导入albumentations数据增强库,该库是负责处理图像的一个库,可用于所有数据类型,支持各类 66 | 图像处理的方法,图像翻转、裁剪、填充、组合序列化等等操作,与此同时,该库也是处理图像最快的库 67 | * **Compose方法:** 将所有图形转换操作组合起来 68 | * **Resize方法:** 重置图形的尺寸 69 | * **Rotate方法:** 对图形进行旋转 70 | * **HorizontalFlip方法:** 图形围绕Y轴水平翻转 71 | * **VerticalFlip方法:** 图形围绕X轴垂直翻转 72 | * **Normalize方法:** 图形进行归一化处理 73 | * **ToTensorV2方法:** 将图形数据转换为tensor数据 74 | 75 | 76 | ## 模型构建 77 | 该模型使用的是语义分割的经典网络Unet网络。UNet是一种用于图像分割的卷积神经网络,它结合了编码器和解码器, 78 | 可以有效地将输入图像分割成多个部分。UNet最初是用于生物医学图像分割,但现在已经广泛应用于其他领域, 79 | 如自然图像分割、语义分割等。 80 | UNet的编码器部分由卷积层和池化层组成,用于提取图像的特征。解码器部分由反卷积层和跳跃连接组成, 81 | 用于将编码器提取的特征映射还原为原始图像大小。跳跃连接是指将编码器的特征图与解码器的特征图连接起来, 82 | 以便解码器可以使用更多的上下文信息进行分割。 83 | 84 | **Unet网络** 85 | ![Unet网络结构](/declare/unet.png) 86 | 87 | 该网络结构分为Encoder和Decoder两部分: 88 | ### Encoder 89 | Encoder由卷积操作和下采样操作组成,每次卷积的卷积结构为 3x3 的卷积核,padding 为 0 ,striding 为 1; 90 | 下采样操作为最大池化(max pooling),stride为2,输出大小为1/2 *(H, W)。下采样操作进行四次,同时每层进行 91 | 两次卷积操作,将得到的特征图输入Decoder。 92 | 93 | ### Decoder 94 | Decoder由卷积操作和上采样操作组成,上采样的方式为双线性插值方式实现;上采样操作完成后,将得到的特征图与Encoder同层的 95 | 特征图进行跳跃连接,即进行拼接,再进行两次卷积操作。重复四次该过程,最后再进行一次卷积,得到最终的特征图。 96 | 97 | ## 模型训练 98 | 1.定义超参数(lr、device、batch_size、epochs等等) 99 | 2.实例化模型、实例化损失函数和优化器 100 | 3.训练集和数据集的加载 101 | 4.是否使用预训练参数 102 | 5.模型训练 103 | 6.模型评估 104 | 7.参数保存 105 | 8.可视化结果 106 | 107 | 108 | ## 模型评估 109 | * **像素准确率:** 在图像分割中,分类正确的像素数量占总像素数量的比例 110 | 计算公式:**像素准确率 = 分类正确的像素数 / 总像素数** 111 | * **Dice分数:** 图像分割评估指标,用于评估分割结果的准确性,它计算了预测分割结果和真实分割结果之间的相似度 112 | 计算公式:**DICE = \frac{2TP}{2TP+FP+FN}** 113 | * **IOU分数:** 一种用于评估图像分割结果的指标。它是通过计算分割结果与真实标注之间的交集与并集之比来衡量分割的准确性 114 | 计算公式:**IOU = (分割结果与真实标注的交集面积) / (分割结果与真实标注的并集面积)** 115 | 116 | 117 | ## 项目运行 118 | python train.py 119 | 120 | 121 | ## update 122 | 如果对该项目有疑问,可以创建issue 123 | 目前数据集还没上传 124 | 125 | 126 | -------------------------------------------------------------------------------- /saved_images/img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_0.png -------------------------------------------------------------------------------- /saved_images/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_1.png -------------------------------------------------------------------------------- /saved_images/img_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_10.png -------------------------------------------------------------------------------- /saved_images/img_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_11.png -------------------------------------------------------------------------------- /saved_images/img_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_12.png -------------------------------------------------------------------------------- /saved_images/img_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_13.png -------------------------------------------------------------------------------- /saved_images/img_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_14.png -------------------------------------------------------------------------------- /saved_images/img_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_15.png -------------------------------------------------------------------------------- /saved_images/img_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_16.png -------------------------------------------------------------------------------- /saved_images/img_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_17.png -------------------------------------------------------------------------------- /saved_images/img_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_18.png -------------------------------------------------------------------------------- /saved_images/img_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_19.png -------------------------------------------------------------------------------- /saved_images/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_2.png -------------------------------------------------------------------------------- /saved_images/img_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_20.png -------------------------------------------------------------------------------- /saved_images/img_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_21.png -------------------------------------------------------------------------------- /saved_images/img_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_22.png -------------------------------------------------------------------------------- /saved_images/img_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_23.png -------------------------------------------------------------------------------- /saved_images/img_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_24.png -------------------------------------------------------------------------------- /saved_images/img_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_25.png -------------------------------------------------------------------------------- /saved_images/img_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_26.png -------------------------------------------------------------------------------- /saved_images/img_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_27.png -------------------------------------------------------------------------------- /saved_images/img_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_28.png -------------------------------------------------------------------------------- /saved_images/img_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_29.png -------------------------------------------------------------------------------- /saved_images/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_3.png -------------------------------------------------------------------------------- /saved_images/img_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_30.png -------------------------------------------------------------------------------- /saved_images/img_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_31.png -------------------------------------------------------------------------------- /saved_images/img_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_32.png -------------------------------------------------------------------------------- /saved_images/img_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_33.png -------------------------------------------------------------------------------- /saved_images/img_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_34.png -------------------------------------------------------------------------------- /saved_images/img_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_35.png -------------------------------------------------------------------------------- /saved_images/img_36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_36.png -------------------------------------------------------------------------------- /saved_images/img_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_37.png -------------------------------------------------------------------------------- /saved_images/img_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_4.png -------------------------------------------------------------------------------- /saved_images/img_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_5.png -------------------------------------------------------------------------------- /saved_images/img_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_6.png -------------------------------------------------------------------------------- /saved_images/img_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_7.png -------------------------------------------------------------------------------- /saved_images/img_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_8.png -------------------------------------------------------------------------------- /saved_images/img_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_9.png -------------------------------------------------------------------------------- /saved_images/img_input_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_input_0_0.png -------------------------------------------------------------------------------- /saved_images/img_input_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_input_0_1.png -------------------------------------------------------------------------------- /saved_images/img_input_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_input_0_2.png -------------------------------------------------------------------------------- /saved_images/img_input_0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/saved_images/img_input_0_3.png -------------------------------------------------------------------------------- /test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaser682/Unet-caravan-demo/2cb107e2771fb699239e81243b690291ff18880e/test.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : Chaser 3 | @Time : 2023/3/3 14:51 4 | @File : train.py 5 | @software : PyCharm 6 | """ 7 | """ 8 | @update: 9 | 2023/3/5 10 | 2023/3/10 11 | 2023/4/21 12 | 2023/5/10 13 | 2023/5/20 14 | 2023/5/25 15 | """ 16 | import math 17 | 18 | import torch 19 | import albumentations as A # 数据增强库,对目标进行多种空间转换 20 | from albumentations.pytorch import ToTensorV2 # 只会[h, w, c] -> [c, h, w],不会将数据归一化到[0, 1] 21 | from tqdm import tqdm # python进度条库 22 | import torch.nn as nn 23 | import torch.optim as optim # 导入优化器 24 | import matplotlib.pyplot as plt 25 | import import_ipynb 26 | from model import Unet 27 | from utils import (load_checkpoint, 28 | save_checkpoint, 29 | get_loaders, 30 | check_accuracy, 31 | save_predictions) 32 | # 防止内核挂掉 33 | import os 34 | 35 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 36 | 37 | # 训练的参数设置 38 | """ 39 | 根据学习率lr与损失值loss之间的关系曲线 40 | lr取值适合1e-3,1e-4,1e-5,1e-6 41 | """ 42 | lr = 0.000001 # 学习率 43 | device = "cuda" if torch.cuda.is_available() else "cpu" # 在gpu上训练 44 | batch_size = 4 # 批处理大小 45 | epochs = 3 # 训练的次数 46 | num_workers = 1 # 工作线程数 47 | # 图片尺寸 48 | image_height = 224 49 | image_width = 224 50 | pin_menory = True 51 | load_model = True 52 | train_img_dir = "./data/train_images/" 53 | train_mask_dir = "./data/train_mask/" 54 | test_img_dir = "./data/test_images/" 55 | test_mask_idr = "./data/test_mask/" 56 | 57 | 58 | # 训练函数 59 | def train_fn(loader, model, optimizer, loss_fn, scaler, losses): # 数据读取,网络模型,优化器,损失函数,缩放标量,损失值列表 60 | # 读取数据 61 | loop = tqdm(loader) 62 | 63 | for idx, (data, targets) in enumerate(loop): 64 | # 将数据展开到GPU中 65 | data = data.to(device) 66 | targets = targets.float().unsqueeze(1).to(device) 67 | 68 | # optimizer.zero_grad() # 梯度清零 69 | # # forward 70 | # with torch.cuda.amp.autocast(): # 混合精度加速训练 71 | # preds = model(data) 72 | # # print(data.size()) 73 | # # print(preds) 74 | # loss = loss_fn(preds, targets) # 计算损失 75 | # # print(loss) 76 | # # backward 77 | # scaler.scale(loss).backward() # 反向传播 78 | # scaler.step(optimizer) # 优化器参数更新 79 | # scaler.update() # 更新缩放标量以使其适应训练的梯度 80 | 81 | preds = model(data) 82 | loss = loss_fn(preds, targets) 83 | # print(preds) 84 | # print(targets) 85 | # print(targets.max()) 86 | # 梯度清零 87 | optimizer.zero_grad() 88 | # 反向传播,梯度更新 89 | loss.backward() 90 | optimizer.step() 91 | 92 | losses.append(loss.item()) 93 | loop.set_postfix(loss=loss.item()) # 设置进度条右边显示的信息 94 | # loss值为nan,则说明梯度爆炸或者学习率过高 95 | # print(data) 96 | # print(targets) 97 | 98 | # break#跑一批次数据,debug 99 | 100 | 101 | # 找到最佳的学习率 102 | def find_lr(loader, model, optimizer, loss_fn, device=device, init_value=1e-8, final_value=10., beta=0.98): 103 | num = len(loader) - 1 104 | mult = (final_value / init_value) ** (1 / num) 105 | lr = init_value 106 | # 动态调整学习率 107 | # 长度为6的字典,分别为['amsgrad', 'params', 'lr', 'betas', 'weight_decay', 'eps'] 108 | optimizer.param_groups[0]['lr'] = lr 109 | avg_loss = 0 110 | best_loss = 0 111 | batch_num = 0 112 | losses = [] 113 | log_lrs = [] 114 | for x, y in loader: 115 | batch_num += 1 116 | x = x.to(device) 117 | # 给y在dim=1加上一维 118 | y = y.to(device).unsqueeze(1) 119 | optimizer.zero_grad() 120 | preds = model(x) 121 | loss = loss_fn(preds, y) 122 | # 得到平滑的损失函数值 123 | avg_loss = beta * avg_loss + (1 - beta) * loss.item() 124 | smooth_loss = avg_loss / (1 - beta ** batch_num) 125 | # 如果损失值爆炸,则结束 126 | if batch_num > 1 and smooth_loss > 4 * best_loss: 127 | return log_lrs, losses 128 | if smooth_loss < best_loss or batch_num == 1: 129 | best_loss = smooth_loss 130 | # 将损失值和学习率指数保存 131 | losses.append(smooth_loss) 132 | log_lrs.append(math.log10(lr)) 133 | # 梯度下降 134 | loss.backward() 135 | optimizer.step() 136 | lr *= mult 137 | optimizer.param_groups[0]['lr'] = lr 138 | return log_lrs, losses 139 | 140 | 141 | # 损失函数与学习率的关联曲线 142 | def loss_lr_curve(loader, model, optimizer, loss_fn): 143 | logs, losses = find_lr(loader=loader, model=model, optimizer=optimizer, loss_fn=loss_fn) 144 | plt.xlabel("lr") 145 | plt.ylabel("loss") 146 | plt.plot(logs[10:], losses[10:]) 147 | plt.show() 148 | 149 | #损失函数随着训练批次的变化 150 | def loss_curve(losses): 151 | plt.xlabel("epochs") 152 | plt.ylabel("loss") 153 | plt.plot(losses) 154 | plt.show() 155 | 156 | #准确率、dice分数、iou分数随着训练次数的变化 157 | def evaluation_curve(accuracy,dice,iou): 158 | 159 | fig, ax = plt.subplots() # 创建图实例 160 | ax.plot(accuracy, label='accuracy') 161 | ax.plot(dice, label='dice') 162 | ax.plot(iou, label='iou') 163 | ax.set_xlabel('epochs') 164 | ax.set_ylabel('score') 165 | ax.set_title("evaluation curve") 166 | ax.legend() 167 | plt.show() 168 | 169 | 170 | # 主程序入口 171 | def main(): 172 | # 对训练的数据进行处理,数据增强处理 173 | train_transform = A.Compose( 174 | [ 175 | A.Resize(height=image_height, width=image_width), # 重置图片尺寸 176 | A.Rotate(limit=35, p=1.0), # 旋转,limit表示旋转范围,p表示概率 177 | A.HorizontalFlip(p=0.5), # 围绕Y轴水平翻转 178 | A.VerticalFlip(p=0.5), # 围绕X轴垂直翻转 179 | A.Normalize( 180 | mean=[0.0, 0.0, 0.0], # 归一化处理 181 | std=[1.0, 1.0, 1.0], 182 | max_pixel_value=255.0), 183 | ToTensorV2(), 184 | ] 185 | ) 186 | 187 | # 对测试的数据进行处理,数据增强处理 188 | test_transform = A.Compose( 189 | [ 190 | A.Resize(height=image_height, width=image_width), 191 | A.Rotate(limit=35, p=1.0), 192 | A.HorizontalFlip(p=0.5), 193 | A.VerticalFlip(p=0.5), 194 | A.Normalize( 195 | mean=[0.0, 0.0, 0.0], 196 | std=[1.0, 1.0, 1.0], 197 | max_pixel_value=255.0, 198 | ), 199 | ToTensorV2(), 200 | ] 201 | ) 202 | 203 | # 定义模型的一些参数 204 | model = Unet(in_channels=3, out_channels=1).to(device) # 实例化模型 205 | # 相比于BCE(),会进行sigmoid操作 206 | loss_fn = nn.BCEWithLogitsLoss() # 二分类交叉熵损失函数实例化 207 | optimizer = optim.Adam(model.parameters(), lr=lr) # 自适应动量算法优化器 208 | 209 | # 训练集和测试集的加载 210 | train_loader, test_loader = get_loaders( 211 | train_dir=train_img_dir, 212 | train_maskdir=train_mask_dir, 213 | test_dir=test_img_dir, 214 | test_maskdir=test_mask_idr, 215 | batch_size=batch_size, 216 | train_transform=train_transform, 217 | test_transform=test_transform, 218 | num_workers=num_workers, 219 | pin_memory=pin_menory 220 | ) 221 | 222 | '''由于开始没有训练好的参数pth文件,故需要注释掉''' 223 | # 使用之前训练的参数 224 | if load_model: 225 | load_checkpoint(torch.load("./parameters/%fmy_checkpoint.pth.tar" % (lr,)), model) 226 | 227 | scaler = torch.cuda.amp.GradScaler() # torch.cuda.amp提供了可以使用混合精度的方便方法,以加速训练 228 | 229 | #损失值列表 230 | losses = [] 231 | #准确率、dice分数、iou分数 232 | accuracy = [] 233 | dice = [] 234 | iou = [] 235 | 236 | # 训练模型,训练次数为epochs 237 | for epoch in range(epochs): 238 | train_fn(train_loader, model, optimizer, loss_fn, scaler, losses) 239 | 240 | # 保存模型、优化器参数 241 | checkpoint = { 242 | "state_dict": model.state_dict(), 243 | "optimizer": optimizer.state_dict(), 244 | } 245 | save_checkpoint(state=checkpoint, filename="./parameters/%fmy_checkpoint.pth.tar" % (lr,)) 246 | 247 | # 检验模型的精度 248 | x,y,z = check_accuracy(test_loader, model, device) 249 | 250 | accuracy.append(x.cpu()) 251 | dice.append(y.cpu()) 252 | iou.append(z.cpu()) 253 | 254 | # 可持久化存储,保存预测影像 255 | save_predictions(test_loader, model, folder="./saved_images/", device=device) 256 | 257 | # 画出学习率与损失函数之间的曲线,以便发现最佳学习率 258 | # loss_lr_curve(loader=train_loader,model=model,optimizer=optimizer,loss_fn=loss_fn) 259 | 260 | # 画出损失函数与训练次数之间的曲线 261 | loss_curve(losses) 262 | 263 | #绘制评估曲线 264 | evaluation_curve(accuracy,dice,iou) 265 | 266 | 267 | 268 | if __name__ == "__main__": 269 | main() 270 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : Chaser 3 | @Time : 2023/3/3 14:53 4 | @File : utils.py 5 | @software : PyCharm 6 | """ 7 | """ 8 | @update: 9 | 2023/3/5 10 | 2023/3/10 11 | 2023/5/20 12 | 2023/5/25 13 | """ 14 | 15 | import import_ipynb 16 | import torch 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | from PIL import Image 20 | from dataload import CarDataset 21 | from torch.utils.data import DataLoader 22 | import os 23 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | # print(device) 26 | 27 | # 保存训练模型参数 28 | def save_checkpoint(state, filename): 29 | print("-> Saving checkpoint") 30 | torch.save(state, filename) 31 | 32 | 33 | # 加载模型参数 34 | def load_checkpoint(checkpoint, model): 35 | print("-> Loading checkpoint") 36 | model.load_state_dict(checkpoint["state_dict"]) 37 | 38 | 39 | # 加载数据集 40 | def get_loaders( 41 | train_dir, 42 | train_maskdir, 43 | test_dir, 44 | test_maskdir, 45 | batch_size, # 每个批次的大小 46 | train_transform, 47 | test_transform, 48 | num_workers=4, # 线程数 49 | pin_memory=True # 拷贝数据到 CUDA Pinned Memory 50 | ): 51 | # 训练集 52 | train_dataset = CarDataset(image_dir=train_dir, 53 | mask_dir=train_maskdir, 54 | transform=train_transform) 55 | # 读取训练数据 56 | train_loader = DataLoader(train_dataset, 57 | batch_size=batch_size, 58 | num_workers=num_workers, 59 | pin_memory=pin_memory, 60 | shuffle=True) # 打乱数据集顺序 61 | 62 | # 测试集 63 | test_dataset = CarDataset(image_dir=test_dir, 64 | mask_dir=test_maskdir, 65 | transform=test_transform) 66 | # 读取测试数据 67 | test_loader = DataLoader(test_dataset, 68 | batch_size=batch_size, 69 | num_workers=num_workers, 70 | pin_memory=pin_memory, 71 | shuffle=True) 72 | 73 | # idata = iter(train_loader) 74 | # print(next(idata)) 75 | return train_loader, test_loader 76 | 77 | 78 | # 检测训练模型的精度 79 | def check_accuracy(loader, model, device="cuda"): 80 | # 初始化正确率 81 | num_correct = 0 82 | # 初始化总像素 83 | num_pixels = 0 84 | # 每次训练的得分 85 | dice_score = 0 86 | # 位置偏差得分 87 | iou_score = 0 88 | 89 | # 不启用 Batch Normalization 和 Dropout(预测之前一定要进行这一步) 90 | model.eval() 91 | 92 | # 评估模型时不需要记录梯度数据 93 | with torch.no_grad(): 94 | for x, y in loader: 95 | # 将tensor转换为cuda张量 96 | x = x.to(device) 97 | y = y.to(device).unsqueeze(1) # 在维度为1的位置插入一个维度 98 | preds = torch.sigmoid(model(x)) 99 | # print(preds.max()) 100 | preds = (preds > 0.5).float() # 将大于0.5的设置为1,否则为0 101 | 102 | num_correct += (preds == y).sum() # 统计相同结果的像素量 103 | num_pixels += torch.numel(preds) # 总像素 104 | 105 | # 防止0除 106 | smooth = 1e-8 107 | # Dice相似系数计算公式:dice = (2*tp)/(fp+2*tp+fn) 108 | dice_score += (2.0 * (preds * y).sum() + smooth) / ((preds + y).sum() + smooth) 109 | # print(dice_score) 110 | 111 | # iou计算公式:iou = tp/(tp+fp+fn) 112 | preds = preds.int() 113 | y = y.int() 114 | iou_score += (1.0 * (preds & y).sum() + smooth) / ((preds | y).sum() + smooth) 115 | 116 | print("The accuracy of the Unet is %.6f" % (num_correct / num_pixels * 100)) 117 | print("Dice score : %.6f" % (dice_score / len(loader)*100)) 118 | print("IOU score : %.6f" % (iou_score / len(loader)*100)) 119 | 120 | # 启用 Batch Normalization 和 Dropout 121 | model.train() 122 | 123 | return num_correct / num_pixels * 100,dice_score / len(loader)*100,iou_score / len(loader)*100 124 | 125 | 126 | # 保存预测影像结果 127 | def save_predictions(loader, model, folder="./saved_images/", device="cuda"): 128 | model.eval() 129 | 130 | for idx, (x, y) in enumerate(loader): 131 | x = x.to(device) 132 | y = y.unsqueeze(1).float().to(device) 133 | with torch.no_grad(): 134 | x = model(x) 135 | preds = torch.sigmoid(x) 136 | preds = (preds > 0.5).float() 137 | # # 保存结果影像 138 | # if os.path.exists("%spred_%s.png" % (folder, idx)): 139 | # os.remove("%spred_%s.png" % (folder, idx)) 140 | # torchvision.utils.save_image(preds, "%spred_%s.png" % (folder, idx)) 141 | # # 标签影像 142 | # if os.path.exists("%s%s.png" % (folder, idx)): 143 | # os.remove("%s%s.png" % (folder, idx)) 144 | # # print(x.size(),x.max()) 145 | # # print(x) 146 | # torchvision.utils.save_image(y, "%s%s.png" % (folder, idx)) 147 | 148 | # 保存结果影像和标签影像 149 | # print(preds.size(),y.size()) 150 | img = torch.cat([preds, y], 0) 151 | if os.path.exists(os.path.join(folder, f"img_{idx}.png")): 152 | os.remove(os.path.join(folder, f"img_{idx}.png")) 153 | torchvision.utils.save_image(img.cpu(), os.path.join(folder, f"img_{idx}.png")) 154 | 155 | # 查看标签图像 156 | # for i in range(x.size(0)): 157 | # transforms.ToPILImage()(x[i]).save(os.path.join(folder, f"img_input_{idx}_{i}.png")) 158 | # break 159 | --------------------------------------------------------------------------------