├── WSLDatasets ├── configs │ ├── class10_train5valid5 │ │ └── testing.txt │ ├── class1_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class2_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class3_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class4_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class5_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class6_train5valid5 │ │ ├── testing.txt │ │ ├── training.txt │ │ └── validation.txt │ ├── class7_train5valid5 │ │ └── testing.txt │ ├── class8_train5valid5 │ │ └── testing.txt │ ├── class9_train5valid5 │ │ └── testing.txt │ └── config200424_train5valid5 │ │ └── testing.txt ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── wsl_dataset.cpython-36.pyc │ └── wsl_dataset.cpython-37.pyc ├── __init__.py ├── gen_data_config.py └── wsl_dataset.py ├── utils ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── gen_defect.cpython-36.pyc │ │ ├── gen_defect.cpython-37.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── base_dataset.cpython-37.pyc │ │ ├── data_manager.cpython-36.pyc │ │ ├── data_manager.cpython-37.pyc │ │ ├── data_managers.cpython-36.pyc │ │ ├── data_managers.cpython-37.pyc │ │ ├── my_transforms.cpython-36.pyc │ │ └── my_transforms.cpython-37.pyc │ ├── gen_defect.py │ ├── data_managers.py │ ├── base_dataset.py │ ├── my_transforms.py │ └── data_manager.py ├── __init__.py ├── __pycache__ │ ├── crf.cpython-36.pyc │ ├── crf.cpython-37.pyc │ ├── param.cpython-36.pyc │ ├── param.cpython-37.pyc │ ├── timer.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── cv_utiles.cpython-36.pyc │ ├── cv_utiles.cpython-37.pyc │ ├── my_metrics.cpython-36.pyc │ ├── my_metrics.cpython-37.pyc │ ├── np_utils.cpython-36.pyc │ ├── np_utils.cpython-37.pyc │ ├── pil_func.cpython-36.pyc │ ├── pil_func.cpython-37.pyc │ ├── plt_utils.cpython-36.pyc │ ├── plt_utils.cpython-37.pyc │ ├── segment_metrics.cpython-36.pyc │ ├── segment_metrics.cpython-37.pyc │ └── scikit_image_tools.cpython-36.pyc ├── .idea │ ├── misc.xml │ ├── vcs.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── utils.iml │ └── workspace.xml ├── np_utils.py ├── plt_utils.py ├── pil_func.py ├── segment_metrics.py ├── param.py └── cv_utiles.py ├── unets ├── __init__.py ├── __pycache__ │ ├── resunet.cpython-36.pyc │ ├── resunet.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── unet_blocks.cpython-36.pyc │ ├── unet_blocks.cpython-37.pyc │ ├── resnet_blocks.cpython-36.pyc │ └── resnet_blocks.cpython-37.pyc ├── resunet.py ├── unet_blocks.py └── resnet_blocks.py ├── photo ├── Train_0576.PNG ├── Train_0588.PNG ├── Train_0609.PNG ├── iou_train.png ├── iou_valid.png └── step_loss.png ├── README.md ├── params.py ├── train.py ├── funcs.py └── visual.py /WSLDatasets/configs/class10_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class1_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class2_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class3_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class4_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class5_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class6_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class7_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class8_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class9_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /WSLDatasets/configs/config200424_train5valid5/testing.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import * 2 | 3 | from . import data_managers -------------------------------------------------------------------------------- /unets/__init__.py: -------------------------------------------------------------------------------- 1 | from .resunet import UNet,Res18_UNet,Res50_UNet 2 | from . import unet_blocks -------------------------------------------------------------------------------- /photo/Train_0576.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/Train_0576.PNG -------------------------------------------------------------------------------- /photo/Train_0588.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/Train_0588.PNG -------------------------------------------------------------------------------- /photo/Train_0609.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/Train_0609.PNG -------------------------------------------------------------------------------- /photo/iou_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/iou_train.png -------------------------------------------------------------------------------- /photo/iou_valid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/iou_valid.png -------------------------------------------------------------------------------- /photo/step_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/photo/step_loss.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .pil_func import * 3 | from .param import Param 4 | from . import np_utils,cv_utiles,plt_utils,segment_metrics -------------------------------------------------------------------------------- /utils/__pycache__/crf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/crf.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/crf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/crf.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/param.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/param.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/param.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/param.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/timer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/timer.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /unets/__pycache__/resunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/resunet.cpython-36.pyc -------------------------------------------------------------------------------- /unets/__pycache__/resunet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/resunet.cpython-37.pyc -------------------------------------------------------------------------------- /unets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /unets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cv_utiles.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/cv_utiles.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cv_utiles.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/cv_utiles.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/my_metrics.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/my_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/np_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/np_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/np_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/np_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pil_func.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/pil_func.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pil_func.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/pil_func.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plt_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/plt_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plt_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/plt_utils.cpython-37.pyc -------------------------------------------------------------------------------- /unets/__pycache__/unet_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/unet_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /unets/__pycache__/unet_blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/unet_blocks.cpython-37.pyc -------------------------------------------------------------------------------- /WSLDatasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/WSLDatasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /WSLDatasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/WSLDatasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /unets/__pycache__/resnet_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/resnet_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /unets/__pycache__/resnet_blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/unets/__pycache__/resnet_blocks.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/segment_metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/segment_metrics.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/segment_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/segment_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/gen_defect.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/gen_defect.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/gen_defect.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/gen_defect.cpython-37.pyc -------------------------------------------------------------------------------- /WSLDatasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | def get_cur_path(): 3 | return os.path.abspath(os.path.dirname(__file__)) 4 | from .wsl_dataset import WSLDataset_train,WSLDataset_split -------------------------------------------------------------------------------- /WSLDatasets/__pycache__/wsl_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/WSLDatasets/__pycache__/wsl_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /WSLDatasets/__pycache__/wsl_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/WSLDatasets/__pycache__/wsl_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/data_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/data_manager.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/data_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/data_manager.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scikit_image_tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/__pycache__/scikit_image_tools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/data_managers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/data_managers.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/data_managers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/data_managers.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/my_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/my_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/my_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuaiLYU/baseline_segment_Resunet_pytorch/HEAD/utils/data/__pycache__/my_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /utils/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /utils/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /utils/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/.idea/utils.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /WSLDatasets/gen_data_config.py: -------------------------------------------------------------------------------- 1 | from datasets.WSLDatasets import * 2 | from utils.data.base_dataset import * 3 | if __name__=="__main__": 4 | cls=9 5 | # root = "/home/gdut/disk/datasets/wsl_datasets/configs/Class{}".format(cls) 6 | root=r"G:\数据集\Weakly Supervised Learning for Industrial Optical Inspection\Class{}".format(cls) 7 | data=WSLDataset_train(root) 8 | 9 | data_split=divide_dataset(data.cls_dict,0.5,0.5) 10 | 11 | config_name="class{}_train5valid5".format(cls) 12 | config_path=os.path.join(get_cur_path(),config_name) 13 | write_txt(config_path,data_split) 14 | -------------------------------------------------------------------------------- /utils/np_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import logging 4 | import os 5 | import time 6 | 7 | 8 | def gray2RGB(img): 9 | img = np.array(img).squeeze() 10 | assert img.ndim==2 11 | return np.tile(img[:, :, np.newaxis], (1, 1, 3)).astype(np.uint8) 12 | 13 | def one_hot(input,len): 14 | one_hot = np.eye(len)[input] 15 | 16 | def normalize(x): 17 | axis = (x.ndim - 2, x.ndim - 1) 18 | min_value = np.min(x, axis=axis, keepdims=True) 19 | x = x - min_value 20 | max_value = np.max(x, axis=axis, keepdims=True) 21 | x = x / max_value 22 | return x 23 | 24 | def sigmoid(x): 25 | s = 1 / (1 + np.exp(-x)) 26 | return s 27 | 28 | def relu(x): 29 | # relu函数 30 | return np.maximum(0, x) 31 | 32 | 33 | def tanh(x): 34 | s1 = np.exp(x) - np.exp(-x) 35 | s2 = np.exp(x) + np.exp(-x) 36 | s = s1 / s2 37 | return s 38 | 39 | def transform(image, reverse=False): 40 | if not reverse: 41 | image=np.array(image).astype(np.float) 42 | image=image / 255.0 43 | return image 44 | else: 45 | image = np.array(image) 46 | image=(image)*255 47 | image=image.astype("uint8") 48 | return image 49 | 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 一个缺陷分割的baseline(缺陷检测, 语义分割) 2 | 3 | ## 数据集 4 | 1. Weakly Supervised Learning for Industrial Optical Inspection(https://hci.iwr.uni-heidelberg.de/node/3616) 5 | 这个链接中有10个不同的数据集,我们对2,4,8,10进行了像素级别标注,数据集和像素标注放在下面的百度网盘链接中: 6 | 链接:https://pan.baidu.com/s/1fkmUTPH0Di8p2C7A8P0fbg 提取码:tpnb (使用请注明数据来源:本仓库地址) 7 | 8 | 2. 参考本仓库配置自己的数据集。 9 | 10 | 11 | 12 | ## 实现功能 13 | 1. 数据读取 14 | 2. loss 可视化(tensorboard) 15 | 3. metrics(缺陷的iou 和 pa) 16 | 4. 训练过程中保存预测结果 17 | 5. 参数结构化管理 18 | 6. 通过将残差网络作为编码器,改进UNet (improving the unet by using the resnet as the encoder) 19 | 20 | 21 | ## visualization: 22 | ### 分割结果: 23 | 1.输入图像 2.像素标注 3.分割结果 24 | ![1](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/Train_0576.PNG) 25 | 26 | ![2](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/Train_0588.PNG) 27 | 28 | ![3](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/Train_0609.PNG) 29 | 30 | ### 损失曲线: 31 | 32 | ![step_loss](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/step_loss.png) 33 | 34 | ### 训练集IOU: 35 | 36 | ![iou_train](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/iou_train.png) 37 | 38 | ### 验证集IOU: 39 | 40 | ![iou_valid](https://github.com/Wslsdx/baseline_segment_Resunet_pytorch/blob/master/photo/iou_valid.png) 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | from utils.param import Param 2 | 3 | 4 | PARAM=Param() 5 | 6 | PARAM.dataset_train = Param( 7 | root="./dataset/Class2", 8 | data_config="class2_train5valid5", 9 | phase="training", 10 | return_numpy=True) 11 | PARAM.dataloader_train = Param(batch_size=32, 12 | shuffle=True, 13 | num_workers=8, 14 | drop_last=False) 15 | 16 | PARAM.dataset_valid = Param( 17 | root="./dataset/Class2", 18 | data_config="class2_train5valid5", 19 | phase="validation", 20 | return_numpy=True) 21 | 22 | PARAM.dataloader_valid = Param(batch_size=32, 23 | shuffle=False, 24 | num_workers=8, 25 | drop_last=False) 26 | 27 | PARAM.model = Param(n_classes=1, #类别数,二分类所以为1 28 | level=4, 29 | b_RGB=False, 30 | base_channels=32 31 | ) # 0-1值 32 | 33 | PARAM.Adam = Param( 34 | lr=0.001, 35 | weight_decay=0.001, 36 | betas=(0.9, 0.999)) 37 | 38 | PARAM.train=Param( 39 | epoch=100, 40 | valid_frequency=1, #几个epoch 验证一次 41 | save_frequency=10, #几个epoch 保存一次模型 42 | model_dir="./save/pth/", 43 | log_dir="./save/log/", 44 | ) 45 | PARAM.visualer=Param( 46 | save_dir="./save/visualizaiton/", #保存路径 47 | visual_frequency=3, #几个epoch保存一次 48 | visual_batchs = 5, #保存的记录 49 | ) 50 | -------------------------------------------------------------------------------- /unets/resunet.py: -------------------------------------------------------------------------------- 1 | from unets.unet_blocks import * 2 | from unets.resnet_blocks import _resnet,BasicBlock,Bottleneck 3 | """ 4 | 1. resnet_net 采用了5个不同尺度的特征图图 level:5 5 | 2. 用三个3*3卷积代替 7*7卷积,并且步长全部为1,得到与原始图片尺寸相同的特征 6 | 3. base_channels控制着网络的宽度 7 | 4. stride:1 网络输出与输入尺寸相同 8 | """ 9 | class Res18_UNet(UNet): 10 | def __init__(self,n_classes,norm_layer=None,bilinear=True,**kwargs): 11 | self.base_channels = kwargs.get("base_channels",32) # resnet18 和resnet34 这里为 32 , 64 12 | level=kwargs.get("level",5) 13 | self.b_RGB = kwargs.get("level", True) 14 | 15 | padding = 1 16 | super(Res18_UNet,self).__init__(n_classes, self.base_channels,level,padding,norm_layer,bilinear) 17 | 18 | def build_encoder(self): 19 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2],base_planes= self.base_channels,b_RGB=self.b_RGB ) 20 | 21 | 22 | 23 | 24 | class Res50_UNet(UNet): 25 | def __init__(self,n_classes,norm_layer=None,bilinear=True): 26 | self.base_channels = 64 # resnet50 ,resnet101和resnet152 这里为 64, 128,256 27 | level = 5 28 | padding = 1 29 | super(Res50_UNet,self).__init__(n_classes, self.base_channels,level,padding,norm_layer,bilinear) 30 | def build_encoder(self): 31 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3],base_planes=self.base_channels,) 32 | 33 | 34 | if __name__=="__main__": 35 | 36 | ipt=torch.rand(1,3,512,512) 37 | res18net=Res18_UNet(n_classes=10,level=4) 38 | opt=res18net(ipt) 39 | print(opt.shape) 40 | 41 | # res50net=Res50_UNet(n_classes=10) 42 | # opt=res50net(ipt) 43 | # print(opt.shape) -------------------------------------------------------------------------------- /utils/plt_utils.py: -------------------------------------------------------------------------------- 1 | # import matplotlib 2 | # matplotlib.use('TKAgg') 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plt_show_img(image,title=None): 7 | assert image is not None 8 | plt.figure() 9 | if title is not None: 10 | plt.title(title) 11 | plt.imshow(image,cmap="gray") 12 | plt.show() 13 | 14 | def plt_show_imgs(imgs,title=None,): 15 | assert isinstance(imgs,(list,tuple)) 16 | plt.figure() 17 | length=len(imgs) 18 | for i in range(length): 19 | plt.subplot(1,length,i+1) 20 | plt.imshow(imgs[i], cmap="gray") 21 | # plt.imshow(imgs[i],cmap="gray") 22 | plt.show() 23 | 24 | def show_rects_on_img(img,rects,tittle=""): 25 | """ 26 | :param img_path: 27 | :param rects: (x,y,w,h) 28 | :param tittle: 29 | :return: 30 | """ 31 | 32 | plt.figure(figsize=img.shape[:2]) 33 | plt.imshow(img) 34 | plt.title(tittle) 35 | # 解决中文显示问题 36 | plt.rcParams['font.sans-serif'] = ['KaiTi'] # 指定默认字体 37 | plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题 38 | for rect in rects: 39 | plt.gca().add_patch(plt.Rectangle(xy=(rect[0],rect[1]),width=rect[2],height=rect[3],)) 40 | # fill=False, edgecolor="red",linewidth=1)) 41 | # for bbox,category_id,category in zip(bboxs,category_ids,categorys): 42 | # """ 43 | # 当前的图表和子图可以使用plt.gcf()和plt.gca()获得,分别表示GetCurrentFigure和GetCurrentAxes。 44 | # 在pyplot模块中,许多函数都是对当前的Figure或Axes对象进行处理,比如说:plt.plot()实际上会通过plt.gca() 45 | # 获得当前的Axes对象ax,然后再调用ax.plot()方法实现真正的绘图。 46 | # """ 47 | # 48 | # plt.gca().add_patch(plt.Rectangle(xy=(bbox[0],bbox[1]),width=bbox[2],height=bbox[3], 49 | # fill=False, edgecolor="red",linewidth=1)) 50 | # plt.text(x=bbox[0],y=bbox[1],s=category_id,ha='center',va='bottom',fontsize=10,color='red') 51 | plt.show() -------------------------------------------------------------------------------- /utils/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 13 | 14 | 16 | 17 | 18 | 19 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 1591498666548 34 | 39 | 40 | 41 | 42 | 44 | -------------------------------------------------------------------------------- /utils/pil_func.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | 5 | def concatImage(images,mode="Adapt",scale=0.5,offset=None): 6 | """ 7 | :param images: 图片列表 8 | :param mode: 图片排列方式["Row" ,"Col","Adapt"] 9 | :param scale: 10 | :param offset: 图片间距 11 | :return: 12 | """ 13 | if not isinstance(images, list): 14 | raise Exception('images must be a list ') 15 | if mode not in ["Row" ,"Col","Adapt"]: 16 | raise Exception('mode must be "Row" ,"Adapt",or "Col"') 17 | images=[np.uint8(img) for img in images] # if Gray [H,W] else if RGB [H,W,3] 18 | images = [img.squeeze(2) if len(img.shape)>2 and img.shape[2]==1 else img for img in images] 19 | count = len(images) 20 | img_ex = Image.fromarray(images[0]) 21 | size=img_ex.size #[W,H] 22 | if mode=="Adapt": 23 | mode= "Row" if size[0]<=size[1] else "Col" 24 | if offset is None:offset = int(np.floor(size[0] * 0.02)) 25 | if mode=="Row": 26 | target = Image.new(img_ex.mode, (size[0] * count+offset*(count-1), size[1] * 1),100) 27 | for i in range(count): 28 | image = Image.fromarray(images[i]).resize(size, Image.BILINEAR).convert(img_ex.mode) 29 | target.paste(image, (i*(size[0]+offset), 0)) 30 | #target.paste(image, (i * (size[0] + offset), 0, i * (size[0] + offset) + size[0], size[1])) 31 | return target 32 | if mode=="Col": 33 | target = Image.new(img_ex.mode, (size[0] , size[1]* count+offset*(count-1)),100) 34 | for i in range(count): 35 | image = Image.fromarray(images[i]).resize(size, Image.BILINEAR).convert(img_ex.mode) 36 | target.paste(image, (0,i*(size[1]+offset))) 37 | #target.paste(image, (0, i * (size[1] + offset), size[0], i * (size[1] + offset) + size[1])) 38 | return target 39 | 40 | def visualization(list_batchs,filenames,save_dirs): 41 | """ 42 | :param list_batchs: list[ array[b,h,w,c] ,array[b,h,w,c] ] or [ [imags],[images] ] 43 | :param filenames: list [filename] 44 | :param save_dir: 45 | :return: 46 | """ 47 | #batch_num= filenames.shape(0) 48 | for i, filename in enumerate(filenames): 49 | if not isinstance(save_dirs, list): 50 | #save_dirs = [save_dirs for i in len(batch_num)] 51 | save_dir = save_dirs 52 | else: 53 | save_dir=save_dirs[i] 54 | if not os.path.exists(save_dir): 55 | os.makedirs(save_dir) 56 | if not isinstance(filename,str): filename=filename.decode("utf-8") 57 | filename =filename.replace("/","_") 58 | list_images=[] 59 | for batchs in list_batchs: 60 | image=np.array(batchs[i]) 61 | if len(image.shape)>2 and image.shape[2]==1: image=image.squeeze(2) 62 | list_images.append(image) 63 | img_visual=concatImage(list_images,offset=10) 64 | visualization_path = os.path.join(save_dir,filename) 65 | try: 66 | img_visual.save(visualization_path) 67 | except: 68 | print("图片保存失败【[]】".format(visualization_path)) 69 | 70 | 71 | 72 | 73 | def save_image(image,save_dir,filename): 74 | image = Image.fromarray(np.uint8(image)) if not isinstance(image, Image.Image) else image 75 | if not os.path.exists(save_dir): 76 | os.makedirs(save_dir) 77 | visualization_path = os.path.join(save_dir, filename) 78 | image.save(visualization_path) -------------------------------------------------------------------------------- /utils/data/gen_defect.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import numpy as np 3 | import os 4 | import random 5 | 6 | 7 | class DefectiveGenerator(object): 8 | def __init__(self,dir_Database,shape_Img,Limit_ROI,withDatabase=True): 9 | """ 10 | 11 | :param dir_Database: 缺陷ROI路径 12 | :param shape_Img: 图片大小[height,width] 13 | :param Limit_ROI: ROI外接矩形大小[lower ,upper] 14 | :param withDatabase: true:从硬盘读入ROI false:算法生成ROI 15 | """ 16 | self.dir_Database=dir_Database 17 | self.height_Img = shape_Img[0] 18 | self.width_Img=shape_Img[1] 19 | self.lowerLimit_ROI=Limit_ROI[0] 20 | self.upperLimit_ROI=Limit_ROI[1] 21 | #从数据库读入ROI 22 | self.names_ROIs,self.num_ROIs=self.loadROIs(self.dir_Database) 23 | if self.num_ROIs<1: 24 | print("the dataset is empty!") 25 | def loadROIs(self,dir): 26 | ROIs=os.listdir(dir) 27 | num_ROI=len(ROIs) 28 | return ROIs,num_ROI 29 | 30 | def genDefect(self,img): 31 | ROI=self.randReadROI() 32 | ROI_new=self.randMoveROI(ROI) 33 | #Rows,Cols = np.nonzero(ROI_new) 34 | #随机设置灰度值 35 | rand=random.randint(0,200) 36 | img_rand=self.genRandImg(rand,20,[self.height_Img, self.width_Img]) 37 | img_new=img.copy() 38 | img_new=img*(1-ROI_new)+img_rand*ROI_new 39 | return img_new,ROI_new 40 | def randReadROI(self): 41 | while(True): 42 | rand=random.randint(0,self.num_ROIs-1) 43 | name_Img=self.names_ROIs[rand] 44 | img_Label=cv.imread(self.dir_Database+"/"+name_Img,0) 45 | _,ROI=cv.threshold(img_Label,100,255,cv.THRESH_BINARY) 46 | if(np.sum(ROI)>5): 47 | return ROI 48 | def randMoveROI(self,ROI): 49 | #求图像的域的大小 50 | Height_Domain = self.height_Img 51 | Width_Domain= self.width_Img 52 | #求ROI区域的坐标 53 | Rows,Cols = np.nonzero(ROI) 54 | #求ROI区域的外接矩形大小 55 | Width_ROI=np.max(Cols)-np.min(Cols) 56 | Height_ROI=np.max(Rows)-np.min(Rows) 57 | #随机设置ROI的起始坐标 58 | Row_Upleft=random.randint(0,Height_Domain-Height_ROI-1) 59 | Col_Upleft = random.randint(0, Width_Domain - Width_ROI-1) 60 | Rows=Rows-np.min(Rows)+Row_Upleft 61 | Cols=Cols-np.min(Cols)+Col_Upleft 62 | ROI_new=np.zeros([Height_Domain,Width_Domain]) 63 | ROI_new[Rows,Cols]=1 64 | return ROI_new 65 | 66 | def genRandImg(self,mean,fluct,size): 67 | 68 | low=mean-fluct+(mean-fluct<0)*abs(mean-fluct) 69 | height=mean+fluct-(mean+fluct>255)*abs(255-(mean+fluct)) 70 | img=np.random.randint(low,height,size) 71 | img=img.astype("uint8") 72 | return img 73 | 74 | # if __name__=="__main__": 75 | # fig = plt.figure() 76 | # 77 | # img=cv.imread("Part4.jpg",0) 78 | # shape_Img=img.shape 79 | # print(shape_Img) 80 | # gen=DefectiveGenerator("./label",shape_Img,[0,10000]) 81 | # new_img=gen.genDefect(img) 82 | # plt.subplot(121),plt.imshow(img,'gray'),plt.title('img') 83 | # plt.subplot(122), plt.imshow(new_img, 'gray'), plt.title('new img') 84 | # plt.show() 85 | 86 | 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from unets.resunet import Res18_UNet 5 | from WSLDatasets.wsl_dataset import WSLDataset_split 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torch.utils.data import WeightedRandomSampler 8 | from params import PARAM 9 | from funcs import iter_on_a_epoch 10 | from utils.data import my_transforms 11 | import utils 12 | from visual import Visualer 13 | import os 14 | ###定义数据 15 | transform={"train": my_transforms.ComposeJoint([ 16 | my_transforms.ToPIL(), 17 | my_transforms.GroupRandomHorizontalFlip(), 18 | my_transforms.GroupRandomVerticalFlip(), 19 | my_transforms.GroupResize(size=(256,256)), 20 | ]), 21 | "valid": my_transforms.ComposeJoint([ 22 | my_transforms.ToPIL(), 23 | my_transforms.GroupResize(size=(256,256)), 24 | ]) 25 | } 26 | 27 | train_data = WSLDataset_split(transform_PIL=transform["train"],**(PARAM.dataset_train)) 28 | #使用WeightedRandomSampler解决训练样本不平衡的问题 29 | weights=[ 1 if data[2]==0 else 6 for data in train_data ] #正负样本采样6:1 30 | sampler=WeightedRandomSampler(weights=weights,num_samples=len(train_data),replacement=True) 31 | PARAM.dataloader_train.shuffle=False #自定义sampler和DataLoader的shuffle参数互斥 32 | train_loader = DataLoader(train_data,sampler=sampler, **(PARAM.dataloader_train)) 33 | valid_data = WSLDataset_split(transform_PIL=transform["valid"],**(PARAM.dataset_valid)) 34 | valid_loader = DataLoader(train_data, **(PARAM.dataloader_valid)) 35 | DATA_LOADERS={ "train":train_loader, "valid":valid_loader } 36 | 37 | #设备 38 | DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" 39 | print(DEVICE) 40 | #定义模型 41 | MODEL = Res18_UNet(**(PARAM.model)) 42 | MODEL.to(DEVICE) 43 | #定义二分类交叉熵损失函数: 44 | LOSSES={"supervise":nn.BCELoss()} 45 | #使用Adam优化器 46 | parameters = filter(lambda p: p.requires_grad, MODEL.parameters()) 47 | OPTIM=torch.optim.Adam(params=parameters,**(PARAM.Adam)) 48 | 49 | #定义一个metrics评价网络的性能 50 | METRICS=utils.segment_metrics.SegmentationMetric(numClass=2) 51 | #定义一个writer保存输出结果 52 | WRITER=SummaryWriter(log_dir=PARAM.train.log_dir) 53 | 54 | #定义一个对象保存图像 55 | VISUALER=Visualer(**(PARAM.visualer)) 56 | TRAIN_PARAM=PARAM.train 57 | 58 | def train(train_param,data_loaders,model,losses,optim,metrics,writer,visualer,device): 59 | for epo in range(1,train_param.epoch+1): 60 | # print("epoch:{}......".format(epo)) 61 | iter_on_a_epoch(epo,"train",data_loaders["train"],model,losses,optim,metrics,writer,visualer,device) 62 | #验证 63 | with_valid=True if epo%train_param.valid_frequency==0 else False 64 | if with_valid: 65 | iter_on_a_epoch(epo,"valid",data_loaders["valid"],model,losses,optim,metrics,writer,visualer,device) 66 | #保存模型 67 | if epo%train_param.save_frequency==0: 68 | if not os.path.exists(train_param.model_dir): 69 | os.makedirs(train_param.model_dir) 70 | model_path=os.path.join(train_param.model_dir,"epoch-{}.pth".format(epo)) 71 | torch.save(model.state_dict(),model_path) 72 | 73 | if __name__=="__main__": 74 | train(TRAIN_PARAM,DATA_LOADERS,MODEL,LOSSES,OPTIM,METRICS,WRITER,VISUALER,DEVICE) 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /funcs.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | def iter_on_a_batch(batch, model, losses, optim, phase, device): 5 | assert isinstance(losses, dict) 6 | assert phase in ["train", "valid", "test", "infer"] 7 | img_batch, label_pixel_batch, label_batch, file_name_batch = batch 8 | img_rensor = torch.tensor(img_batch).float().to(device) 9 | label_tensor = torch.tensor(label_batch).float().to(device) 10 | # forward 11 | 12 | mask_tensor = model(img_rensor) # 在devie 13 | mask_batch = mask_tensor.detach().cpu().numpy() # cpu 14 | 15 | ###### cul loss 16 | if phase in ["train", "valid", "test"]: 17 | 18 | label_pixel_tensor = torch.tensor(label_pixel_batch).float().to(device) # gpu 19 | loss_segment = losses["supervise"](mask_tensor.squeeze(1), label_pixel_tensor.squeeze(1)) 20 | loss_dict = {"segment": loss_segment.mean()} 21 | ##### backward 22 | if phase in ["train"]: 23 | assert isinstance(loss_dict, dict) 24 | model.zero_grad() 25 | if len(loss_dict)==1: 26 | loss_sum=list(loss_dict.values())[0] 27 | else: 28 | loss_sum = sum(*list(loss_dict.values())) 29 | loss_sum.backward() 30 | optim.step() 31 | #### return 32 | result = {"mask_batch": mask_batch, } 33 | if phase in ["train", "valid", "test"]: 34 | for key, loss in loss_dict.items(): 35 | loss_dict[key] = float(loss) 36 | result["loss"] = loss_dict 37 | return result 38 | 39 | 40 | def iter_on_a_epoch(epo,phase,data_loader,model,losses,optim,metrics,writer,visualer,device): 41 | # train 42 | metrics.reset() 43 | epo_loss = {} 44 | for cnt_batch, batch in enumerate(data_loader): 45 | 46 | result = iter_on_a_batch(batch, model, 47 | losses=losses, 48 | optim=optim, 49 | phase=phase, device=device) 50 | 51 | img_batch, label_pixel_batch, _, file_name_batch = batch 52 | #可视化图像 53 | img_batch = img_batch.detach().cpu().numpy() 54 | label_pixel_batch = label_pixel_batch.detach().cpu().numpy() 55 | visual_list = [img_batch.transpose(0, 2, 3, 1) * 255, 56 | label_pixel_batch.transpose(0, 2, 3, 1) * 255, 57 | result["mask_batch"].transpose(0, 2, 3, 1) * 255] 58 | visualer.write(epo,cnt_batch,visual_list, 59 | child_dir="epo-{}_{}".format(epo,phase), 60 | file_name_batch=file_name_batch) 61 | 62 | # 每次迭代打印损失函数 63 | loss_dict = result["loss"] 64 | # s = "epoch:{},batch:{},lr:{:.4f}".format(epo, cnt_batch, float(optim.state_dict()['param_groups'][0]['lr'])) 65 | # for key, val in loss_dict.items(): 66 | # s += ",{}_loss:{:.4f}".format(key, float(val)) 67 | # print(s) 68 | # 添加到writer 69 | writer.add_scalars("step_loss", loss_dict,(epo-1)*len(data_loader)+cnt_batch) 70 | # 添加到epo_loss 71 | for key, val in loss_dict.items(): 72 | key = key + "_loss" 73 | if key not in epo_loss.keys(): 74 | epo_loss[key] = list() 75 | else: 76 | epo_loss[key].append(val) 77 | # 添加结果到metrcis 78 | metrics.addBatch(np.where(result["mask_batch"]>0.3,1,0), label_pixel_batch.astype(np.int64)) 79 | 80 | iou_defect = metrics.clsIntersectionOverUnion(1) 81 | #对epoch_loss求平均 82 | for key, val in epo_loss.items(): 83 | epo_loss[key] = np.array(val).sum()/ len(val) 84 | 85 | s="----epoch:{},{},iou:{:.4f}".format(epo,phase, iou_defect) 86 | for key, val in epo_loss.items(): 87 | s+=",{}:{:.4f}".format(key,val) 88 | print(s) 89 | writer.add_scalar("iou_defect_{}".format(phase), iou_defect, global_step=epo) 90 | writer.add_scalars("loss_{}".format(phase), epo_loss, global_step=epo) 91 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import random 6 | def concatImage(images,mode="Adapt",scale=0.5,offset=None): 7 | """ 8 | :param images: 图片列表 9 | :param mode: 图片排列方式["Row" ,"Col","Adapt"] 10 | :param scale: 11 | :param offset: 图片间距 12 | :return: 13 | """ 14 | if not isinstance(images, list): 15 | raise Exception('images must be a list ') 16 | if mode not in ["Row" ,"Col","Adapt"]: 17 | raise Exception('mode must be "Row" ,"Adapt",or "Col"') 18 | images=[np.uint8(img) for img in images] # if Gray [H,W] else if RGB [H,W,3] 19 | images = [img.squeeze(2) if len(img.shape)>2 and img.shape[2]==1 else img for img in images] 20 | count = len(images) 21 | img_ex = Image.fromarray(images[0]) 22 | size=img_ex.size #[W,H] 23 | if mode=="Adapt": 24 | mode= "Row" if size[0]<=size[1] else "Col" 25 | if offset is None:offset = int(np.floor(size[0] * 0.02)) 26 | if mode=="Row": 27 | target = Image.new(img_ex.mode, (size[0] * count+offset*(count-1), size[1] * 1),100) 28 | for i in range(count): 29 | image = Image.fromarray(images[i]).resize(size, Image.BILINEAR).convert(img_ex.mode) 30 | target.paste(image, (i*(size[0]+offset), 0)) 31 | #target.paste(image, (i * (size[0] + offset), 0, i * (size[0] + offset) + size[0], size[1])) 32 | return target 33 | if mode=="Col": 34 | target = Image.new(img_ex.mode, (size[0] , size[1]* count+offset*(count-1)),100) 35 | for i in range(count): 36 | image = Image.fromarray(images[i]).resize(size, Image.BILINEAR).convert(img_ex.mode) 37 | target.paste(image, (0,i*(size[1]+offset))) 38 | #target.paste(image, (0, i * (size[1] + offset), size[0], i * (size[1] + offset) + size[1])) 39 | return target 40 | 41 | def visualization(list_batchs,filenames,save_dirs): 42 | """ 43 | :param list_batchs: list[ array[b,h,w,c] ,array[b,h,w,c] ] or [ [imags],[images] ] 44 | :param filenames: list [filename] 45 | :param save_dir: 46 | :return: 47 | """ 48 | #batch_num= filenames.shape(0) 49 | for i, filename in enumerate(filenames): 50 | if not isinstance(save_dirs, list): 51 | #save_dirs = [save_dirs for i in len(batch_num)] 52 | save_dir = save_dirs 53 | else: 54 | save_dir=save_dirs[i] 55 | if not os.path.exists(save_dir): 56 | os.makedirs(save_dir) 57 | if not isinstance(filename,str): filename=filename.decode("utf-8") 58 | filename =filename.replace("/","_") 59 | list_images=[] 60 | for batchs in list_batchs: 61 | image=np.array(batchs[i]) 62 | if len(image.shape)>2 and image.shape[2]==1: image=image.squeeze(2) 63 | list_images.append(image) 64 | img_visual=concatImage(list_images,offset=10) 65 | visualization_path = os.path.join(save_dir,filename) 66 | try: 67 | img_visual.save(visualization_path) 68 | except: 69 | print("图片保存失败【[]】".format(visualization_path)) 70 | 71 | 72 | class Visualer(object): 73 | 74 | def __init__(self,save_dir,visual_frequency,visual_batchs): 75 | self.save_dir=save_dir 76 | self.visual_batchs=visual_batchs 77 | self.visual_frequency=visual_frequency 78 | if not os.path.exists(save_dir): 79 | print("create directory:{} ".format(save_dir)) 80 | os.makedirs(save_dir) 81 | 82 | 83 | def write(self,epoch,cnt_batch,batch_list,child_dir,file_name_batch): 84 | 85 | 86 | if epoch%self.visual_frequency!=0: 87 | return 88 | if cnt_batch>self.visual_batchs: 89 | return 90 | file_names=[ file_name_batch[i] for i in range(len(file_name_batch))] 91 | save_dir=os.path.join(self.save_dir,child_dir) 92 | visualization(batch_list,file_names,save_dir) 93 | -------------------------------------------------------------------------------- /utils/segment_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | refer 3 | to 4 | https: // github.com / jfzhang95 / pytorch - deeplab - xception / blob / master / utils / metrics.py 5 | """ 6 | import numpy as np 7 | __all__ = ['SegmentationMetric'] 8 | 9 | """ 10 | confusionMetric 11 | P\L 12 | P 13 | N 14 | 15 | P 16 | TP 17 | FP 18 | 19 | N 20 | FN 21 | TN 22 | 23 | """ 24 | class SegmentationMetric(object): 25 | def __init__(self, numClass): 26 | self.numClass = numClass 27 | self.confusionMatrix = np.zeros((self.numClass,)*2) 28 | 29 | def pixelAccuracy(self): 30 | # return all class overall pixel accuracy 31 | # acc = (TP + TN) / (TP + TN + FP + TN) 32 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() 33 | return acc 34 | 35 | def classPixelAccuracy(self): 36 | # return each category pixel accuracy(A more accurate way to call it precision) 37 | # acc = (TP) / TP + FP 38 | classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1) 39 | return classAcc 40 | 41 | def meanPixelAccuracy(self): 42 | classAcc = self.classPixelAccuracy() 43 | meanAcc = np.nanmean(classAcc) 44 | return meanAcc 45 | 46 | def meanIntersectionOverUnion(self): 47 | # Intersection = TP Union = TP + FP + FN 48 | # IoU = TP / (TP + FP + FN) 49 | intersection = np.diag(self.confusionMatrix) 50 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(self.confusionMatrix) 51 | IoU = intersection / union 52 | mIoU = np.nanmean(IoU) 53 | return mIoU 54 | 55 | def clsIntersectionOverUnion(self,cls): 56 | assert cls= 0) & (imgLabel < self.numClass) 68 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 69 | count = np.bincount(label, minlength=self.numClass**2) 70 | confusionMatrix = count.reshape(self.numClass, self.numClass) 71 | return confusionMatrix 72 | 73 | def Frequency_Weighted_Intersection_over_Union(self): 74 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 75 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 76 | iu = np.diag(self.confusion_matrix) / ( 77 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 78 | np.diag(self.confusion_matrix)) 79 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 80 | return FWIoU 81 | 82 | 83 | def addBatch(self, imgPredict, imgLabel): 84 | #print(imgPredict.shape, imgLabel.shape) 85 | assert imgPredict.shape == imgLabel.shape,print("imgPredict shape:{}" "imgLabel shape:{}" 86 | .format(imgPredict.shape,imgLabel.shape)) 87 | 88 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) 89 | 90 | def reset(self): 91 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) 92 | 93 | 94 | if __name__ == '__main__': 95 | imgPredict = np.array([0, 0, 1, 1, 2, 2]) 96 | imgLabel = np.array([0, 0, 1, 1, 2, 2]) 97 | metric = SegmentationMetric(3) 98 | metric.addBatch(imgPredict, imgLabel) 99 | acc = metric.pixelAccuracy() 100 | mIoU = metric.meanIntersectionOverUnion() 101 | IoU = metric.clsIntersectionOverUnion(1) 102 | print(acc, mIoU,IoU) -------------------------------------------------------------------------------- /utils/param.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/12 8:31 4 | # @Author : Wslsdx 5 | # @FileName: param.py 6 | # @Software: PyCharm 7 | # @Github :https://github.com/Wslsdx 8 | class Param(object): 9 | def __init__(self,**kargs): 10 | self._name="param" 11 | self.regist_from_dict(kargs) 12 | 13 | def regist_from_parser(self,parser): 14 | for key,val in parser.__dict__.items(): 15 | self.__setitem__(key, val) 16 | 17 | def regist_from_dict(self,_dict): 18 | assert isinstance(_dict,dict) 19 | for key,val in _dict.items(): 20 | self.__setitem__(key, val) 21 | 22 | def regist(self, key, val): 23 | self.__setitem__(key, val) 24 | 25 | def update_name(self,last_name,key): 26 | self._name=last_name+"."+key 27 | for key,val in self.__dict__.items(): 28 | if isinstance(val,Param): 29 | val.update_name(self._name,key) 30 | # 功能 A["a"] 31 | def __setitem__(self, key, value): 32 | super(Param,self).__setattr__( key, value) 33 | if isinstance(value,Param): 34 | value.update_name(self._name,key) 35 | #self.__dict__[key] = value 36 | def __getitem__(self, attr): 37 | return super(Param, self).__getattribute__(attr) 38 | def __delitem__(self, key): 39 | try: 40 | del self.__dict__[key] 41 | except KeyError as k: 42 | return None 43 | 44 | # 功能 A.a 45 | def __setattr__(self, key, value): 46 | super(Param,self).__setattr__( key, value) 47 | if isinstance(value,Param): 48 | value.update_name(self._name,key) 49 | #self.__dict__[key] = value 50 | def __getattribute__(self, attr): 51 | return super(Param, self).__getattribute__(attr) 52 | def __getattr__(self, attr): 53 | """ 54 | 重载此函数防止属性不存在时__getattribute__报错,而是返回None 55 | 那“_ getattribute_”与“_ getattr_”的最大差异在于: 56 | 1. 无论调用对象的什么属性,包括不存在的属性,都会首先调用“_ getattribute_”方法; 57 | 2. 只有找不到对象的属性时,才会调用“_ getattr_”方法; 58 | :param attr: 59 | :return: 60 | """ 61 | return None 62 | def __delattr__(self, key): 63 | try: 64 | del self.__dict__[key] 65 | except KeyError as k: 66 | return None 67 | # def __str__(self): 68 | # string="" 69 | # for key,val in self.__dict__.items(): 70 | # if key is "_name": continue 71 | # if isinstance(val,Param): 72 | # string += self._name + "{}=Param()\n".format(key) 73 | # string +="{}".format(val) 74 | # else: 75 | # string +=self._name+"{}={}\n".format(key,val) 76 | # return string 77 | def __str__(self): 78 | string=self._name + "=Param()\n" 79 | for key,val in self.__dict__.items(): 80 | if key is "_name": continue 81 | if isinstance(val,Param): 82 | string +=str(val) 83 | else: 84 | string +=self._name+".{}={}\n".format(key,val) 85 | return string 86 | def __len__(self): 87 | return len(self.__dict__) 88 | 89 | 90 | def keys(self): 91 | keys=[ key for key in self.__dict__.keys() if key !="_name"] 92 | return keys 93 | 94 | def values(self): 95 | return [ self[key] for key in self.keys() ] 96 | def items(self): 97 | return [ item for item in self if item[0] in self.keys()] 98 | 99 | def get(self,key,defaut): 100 | if key in self.keys(): 101 | return self[key] 102 | else: 103 | return defaut 104 | if __name__=="__main__": 105 | 106 | # a=dict() 107 | # a["b"]=1 108 | # print(a.__dict__) 109 | # pass 110 | c = Param() 111 | c.crf=Param( 112 | PGauss_sxy=80,#15 113 | PGauss_compat=15,#3,15,50 114 | PBila_sxy=50, #80 115 | PBila_srgb=5,#30 116 | PBila_compat=50,) 117 | print(c.crf.PGauss_sxy) 118 | print (c.__dict__) 119 | # c.regist("z", 3) 120 | # c.regist("x", 4) 121 | # c.regist("y", 4) 122 | # c.regist("func", lambda x: "".join(["=>", str(x), "<="])) 123 | # c["x"]=1 124 | # print (c.__dict__) 125 | # print (c.x, c.y, c.z,) 126 | # print (c["x"], ) 127 | # c["d"]=100 128 | # c.d=100 129 | # print(c.d) 130 | # print(c["d"]) 131 | # print(c.adc) 132 | # print("---") 133 | # print(c.d) 134 | # del c.d 135 | # print(c.d) 136 | # print(list(c.items())) 137 | # for key,val in c.items(): 138 | # print(key) 139 | # print(val) -------------------------------------------------------------------------------- /utils/data/data_managers.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | # import tensorflow as tf 4 | from random import shuffle 5 | import math 6 | import numpy as np 7 | class DataManager(object): 8 | def __init__(self, dataset,param,shuffle=True): 9 | """ 10 | """ 11 | self.shuffle=shuffle 12 | self.dataset=dataset 13 | self.data_size=len(dataset) 14 | self.epochs_num=10 15 | self.batch_size = param["batch_size"] 16 | self.next_batch=self.get_next() 17 | self._session= tf.Session() 18 | self.num_batch=param["num_batch_train"] 19 | self.num_batch=math.ceil(self.data_size / self.batch_size) if self.num_batch==None else self.num_batch 20 | def get_next(self): 21 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32,tf.int32, tf.string)) 22 | dataset = dataset.repeat(self.epochs_num) 23 | if self.shuffle: 24 | dataset = dataset.shuffle(self.batch_size*3) 25 | dataset = dataset.batch(self.batch_size) 26 | iterator = dataset.make_one_shot_iterator() 27 | out_batch = iterator.get_next() 28 | return out_batch 29 | 30 | def generator(self): 31 | while True: 32 | for index in range(self.data_size): 33 | yield self.dataset[index] 34 | 35 | def __iter__(self): 36 | self.cnt_batch=0 37 | return self 38 | def __next__(self): 39 | if self.cnt_batch < len(self): 40 | self.cnt_batch+=1 41 | next_batch = self._session.run(self.next_batch) 42 | return next_batch 43 | else: 44 | raise StopIteration 45 | 46 | def __len__(self): 47 | return self.num_batch 48 | 49 | 50 | class DataManager_balance(DataManager): 51 | def __init__(self, dataset, param): 52 | shuffle = param.get("shuffle", True) 53 | super(DataManager_balance,self).__init__(dataset, param, shuffle) 54 | self.num_batch = param.get("num_batch_train", -1) 55 | if self.num_batch ==-1: 56 | self.num_batch =len(self.dataset.cls_dict[1])*len(self.dataset.cls_dict)//param["batch_size"] 57 | if self.num_batch ==0: 58 | self.num_batch =len(self.dataset.cls_dict[0])*len(self.dataset.cls_dict)//param["batch_size"] 59 | 60 | def get_next(self): 61 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32,tf.int32, tf.string)) 62 | dataset = dataset.repeat(self.epochs_num) 63 | # if self.shuffle: 64 | # dataset = dataset.shuffle(self.batch_size*3+200) 65 | dataset = dataset.batch(self.batch_size) 66 | iterator = dataset.make_one_shot_iterator() 67 | out_batch = iterator.get_next() 68 | return out_batch 69 | def generator(self): 70 | step=0 71 | cls_idxs_map={ key:list(range(len(val))) for key,val in self.dataset.cls_dict.items()} 72 | while(True): 73 | for cls, idxs, in cls_idxs_map.items(): 74 | iidx=step%len(idxs) 75 | if iidx==0 and self.shuffle: 76 | shuffle(cls_idxs_map[cls]) 77 | idx = cls_idxs_map[cls][iidx] 78 | yield self.dataset.getitem_cls(idx,cls) 79 | step+=1 80 | def __len__(self): 81 | return self.num_batch 82 | 83 | 84 | class DataManager_valid(object): 85 | def __init__(self,dataset,batch_size): 86 | """ 87 | """ 88 | self.dataset=dataset 89 | self.data_size=len(dataset) 90 | self.batch_size =batch_size 91 | self.num_batch=math.ceil(self.data_size/self.batch_size) 92 | 93 | def get_a_batch(self,cnt_batch): 94 | assert cnt_batch 1: 61 | raise Exception("wrong params...") 62 | print(u"划分数据集......") 63 | ratio_dict = {"training": train_ratio, "validation": val_ratio, "testing": 1 - train_ratio - val_ratio} 64 | dataset_dict = {key: [] for key, val in ratio_dict.items()} 65 | for cls, data_list in cls_dict.items(): 66 | sample_num = len(data_list) 67 | train_offset = int(np.floor(sample_num * ratio_dict["training"])) 68 | val_offset = int(np.floor(sample_num * (ratio_dict["training"] + ratio_dict["validation"]))) 69 | # print (u" 类别[{}]中,训练集:{},验证集:{},测试集:{}" \ 70 | # .format(cls, train_offset, val_offset - train_offset, len(data_list) - val_offset)) 71 | Keys = ["training"] * train_offset \ 72 | + ["validation"] * (val_offset - train_offset) \ 73 | + ["testing"] * (len(data_list) - val_offset) 74 | if shuffle: 75 | random.shuffle(data_list) 76 | for key, item in zip(Keys, data_list): 77 | dataset_dict[key].append(item) 78 | return dataset_dict 79 | 80 | def write_txt(dir, dataset_dict, prefix='%s %s %d', shuffle=False, clear=True): 81 | """ 82 | :param dir: 写入目标文件夹 83 | :param dataset_dict: {"training" : list1[]...} 84 | :param prefix: 写入格式 85 | :param shuffle: 是否打乱列表 86 | :param clear: 是否清空之间的写入对象 87 | :return: 88 | """ 89 | if not os.path.exists(dir): os.makedirs(dir) 90 | mode = 'w' if clear else 'a' 91 | for key, sample_list in dataset_dict.items(): # 每个类别 92 | if shuffle: random.shuffle(sample_list) 93 | write_path = os.path.join(dir, key) + '.txt' 94 | with open(write_path, 'a', encoding='utf-8') as file: 95 | lines = [(prefix % (image[0], image[1], image[2])) for image in sample_list] 96 | file.write('\n'.join(lines)) 97 | print(u"数据集配置保存到txt:【{}】".format(dir)) 98 | 99 | def read_txt(data_dir,phase_re=None): 100 | """Read the content of the text file and store it into lists.""" 101 | phases = ["training", "validation", "testing",None] 102 | assert phase_re in phases 103 | assert os.path.exists(data_dir),"{}".format(data_dir) 104 | data_dict = {phase: [] for phase in phases[:-1]} 105 | for phase, data_list in data_dict.items(): 106 | txt_file = data_dir + "/" + phase + ".txt" 107 | if not os.path.exists(txt_file): 108 | continue 109 | with open(txt_file, 'r', encoding='utf-8') as f: 110 | lines = f.readlines() 111 | for line in lines: 112 | items = list(line.strip().split(' ')) 113 | items=[item if item !="None" else None for item in items ] 114 | data_list.append(items) 115 | data_dict[phase] = data_list 116 | if phase_re ==None: 117 | return data_dict 118 | else: 119 | return data_dict[phase_re] 120 | 121 | 122 | class BDataset(object): 123 | def __init__(self,root): 124 | self.root=root 125 | self.samples=[] 126 | self.cls_dict={} 127 | def list_dir(self,root,use_absPath, func): 128 | return list_folder(root,use_absPath,func) 129 | 130 | def make_dataset(self): 131 | raise Exception("") 132 | 133 | def split_dataset_by_cls(self,samples,loc): 134 | """ 135 | :param samples: samples= [ sample1, sample2, ... ] 136 | :param loc: the location index of class in sample 137 | :return: 138 | """ 139 | # 根据类别生成字典 140 | cls_dict = {} 141 | for sample in samples: 142 | cls = int(sample[loc]) 143 | if cls not in cls_dict.keys(): 144 | cls_dict[cls] = [] 145 | cls_dict[cls].append(sample) 146 | # for cls, sample_list in cls_dict.items(): 147 | # print(" 类别{},样本数为:{}".format(cls, len(sample_list))) 148 | return cls_dict 149 | 150 | def gen_a_sample(self,sample): 151 | raise Exception("") 152 | 153 | def __len__(self): 154 | return len(self.samples) 155 | 156 | def __getitem__(self, idx): 157 | sample = self.samples[idx] 158 | return self.gen_a_sample(sample) 159 | 160 | def getitem_cls(self, idx, cls): 161 | sample = self.cls_dict[cls][idx] 162 | return self.gen_a_sample(sample) 163 | 164 | class BaseDataset(object): 165 | def __init__(self,Datadir,ConfigDir): 166 | self.Datadir=os.path.abspath(Datadir) 167 | self.ConfigDir=os.path.abspath(ConfigDir) 168 | self.list_folder(self.Datadir) 169 | self.sample_list,self.cls_dict=self.make_dataset() 170 | self.Config=os.path.exists(ConfigDir) 171 | 172 | def make_dataset(self, useAbsDir=False): 173 | print("生成数据集......") 174 | root=self.Datadir 175 | samples=[] 176 | #添加样本列表 177 | for img in sorted(self.imgs): 178 | label = img.split(".")[-2] + "_label.bmp" 179 | # 过滤没有语义标签的图片 180 | if label not in self.imgs_pixel: 181 | continue 182 | # 过滤不符合类别的图片 183 | target = 1 if np.sum(cv2.imread(label)) > 10 else 0 184 | if not useAbsDir: 185 | img=os.path.relpath(img,root).replace('\\','/') 186 | label=os.path.relpath(label,root).replace('\\','/') 187 | item = (img, label, target) 188 | samples.append(item) 189 | print(" 总样本数:{}".format(len(samples))) 190 | #根据类别生成字典 191 | cls_dict={} 192 | for sample in samples: 193 | cls=sample[2] 194 | if cls not in cls_dict.keys(): 195 | cls_dict[cls]=[] 196 | cls_dict[cls].append(sample) 197 | for cls,sample_list in cls_dict.items(): 198 | print(" 类别{},样本数为:{}".format(cls,len(sample_list))) 199 | return samples,cls_dict 200 | 201 | @staticmethod 202 | def list_folder(root,extensions=EXTENSIONS): 203 | return list_folder(root,has_file_allowed_extension) 204 | @staticmethod 205 | def divide_dataset(cls_dict, train_ratio, val_ratio, shuffle=True): 206 | return divide_dataset(cls_dict, train_ratio, val_ratio, shuffle) 207 | @staticmethod 208 | def write_txt(dir, dataset_dict, prefix='%s %d %d', shuffle=False, clear=True): 209 | return write_txt(dir, dataset_dict, prefix, shuffle, clear) 210 | @staticmethod 211 | def read_txt(data_dir): 212 | return read_txt(data_dir) 213 | @staticmethod 214 | def has_file_allowed_extension(filename, extensions=EXTENSIONS): 215 | return has_file_allowed_extension(filename, extensions) 216 | 217 | 218 | -------------------------------------------------------------------------------- /WSLDatasets/configs/class1_train5valid5/validation.txt: -------------------------------------------------------------------------------- 1 | Train/0759.PNG None 0 2 | Train/0923.PNG None 0 3 | Train/1096.PNG None 0 4 | Train/1077.PNG None 0 5 | Train/0960.PNG None 0 6 | Train/0808.PNG None 0 7 | Train/0853.PNG None 0 8 | Train/0647.PNG None 0 9 | Train/0855.PNG None 0 10 | Train/1055.PNG None 0 11 | Train/0890.PNG None 0 12 | Train/1059.PNG None 0 13 | Train/1053.PNG None 0 14 | Train/0859.PNG None 0 15 | Train/0731.PNG None 0 16 | Train/0831.PNG None 0 17 | Train/0989.PNG None 0 18 | Train/0925.PNG None 0 19 | Train/0636.PNG None 0 20 | Train/0576.PNG None 0 21 | Train/0866.PNG None 0 22 | Train/0834.PNG None 0 23 | Train/0671.PNG None 0 24 | Train/0963.PNG None 0 25 | Train/0847.PNG None 0 26 | Train/0795.PNG None 0 27 | Train/0594.PNG None 0 28 | Train/0848.PNG None 0 29 | Train/0727.PNG None 0 30 | Train/0734.PNG None 0 31 | Train/1075.PNG None 0 32 | Train/1029.PNG None 0 33 | Train/1071.PNG None 0 34 | Train/0630.PNG None 0 35 | Train/0969.PNG None 0 36 | Train/1085.PNG None 0 37 | Train/1024.PNG None 0 38 | Train/0880.PNG None 0 39 | Train/0926.PNG None 0 40 | Train/0651.PNG None 0 41 | Train/0809.PNG None 0 42 | Train/0596.PNG None 0 43 | Train/0946.PNG None 0 44 | Train/0802.PNG None 0 45 | Train/0661.PNG None 0 46 | Train/0718.PNG None 0 47 | Train/0683.PNG None 0 48 | Train/1091.PNG None 0 49 | Train/1043.PNG None 0 50 | Train/1113.PNG None 0 51 | Train/0845.PNG None 0 52 | Train/0793.PNG None 0 53 | Train/0826.PNG None 0 54 | Train/0814.PNG None 0 55 | Train/0732.PNG None 0 56 | Train/0904.PNG None 0 57 | Train/0961.PNG None 0 58 | Train/0833.PNG None 0 59 | Train/1012.PNG None 0 60 | Train/0820.PNG None 0 61 | Train/1039.PNG None 0 62 | Train/0875.PNG None 0 63 | Train/0895.PNG None 0 64 | Train/0878.PNG None 0 65 | Train/0799.PNG None 0 66 | Train/0656.PNG None 0 67 | Train/0990.PNG None 0 68 | Train/0804.PNG None 0 69 | Train/0935.PNG None 0 70 | Train/0659.PNG None 0 71 | Train/0633.PNG None 0 72 | Train/0939.PNG None 0 73 | Train/0921.PNG None 0 74 | Train/0998.PNG None 0 75 | Train/1099.PNG None 0 76 | Train/0884.PNG None 0 77 | Train/0867.PNG None 0 78 | Train/1050.PNG None 0 79 | Train/1035.PNG None 0 80 | Train/0896.PNG None 0 81 | Train/1021.PNG None 0 82 | Train/1070.PNG None 0 83 | Train/0888.PNG None 0 84 | Train/0771.PNG None 0 85 | Train/0619.PNG None 0 86 | Train/1120.PNG None 0 87 | Train/0668.PNG None 0 88 | Train/0754.PNG None 0 89 | Train/0577.PNG None 0 90 | Train/0694.PNG None 0 91 | Train/0621.PNG None 0 92 | Train/0874.PNG None 0 93 | Train/0877.PNG None 0 94 | Train/0712.PNG None 0 95 | Train/0724.PNG None 0 96 | Train/0908.PNG None 0 97 | Train/0993.PNG None 0 98 | Train/1124.PNG None 0 99 | Train/1032.PNG None 0 100 | Train/1092.PNG None 0 101 | Train/0927.PNG None 0 102 | Train/1067.PNG None 0 103 | Train/0821.PNG None 0 104 | Train/0723.PNG None 0 105 | Train/0928.PNG None 0 106 | Train/0832.PNG None 0 107 | Train/0632.PNG None 0 108 | Train/0872.PNG None 0 109 | Train/0972.PNG None 0 110 | Train/1082.PNG None 0 111 | Train/0685.PNG None 0 112 | Train/1118.PNG None 0 113 | Train/0952.PNG None 0 114 | Train/0991.PNG None 0 115 | Train/0782.PNG None 0 116 | Train/0700.PNG None 0 117 | Train/0650.PNG None 0 118 | Train/0686.PNG None 0 119 | Train/1052.PNG None 0 120 | Train/0684.PNG None 0 121 | Train/0806.PNG None 0 122 | Train/0980.PNG None 0 123 | Train/0584.PNG None 0 124 | Train/0966.PNG None 0 125 | Train/1020.PNG None 0 126 | Train/0968.PNG None 0 127 | Train/0889.PNG None 0 128 | Train/0745.PNG None 0 129 | Train/0714.PNG None 0 130 | Train/0974.PNG None 0 131 | Train/0979.PNG None 0 132 | Train/0752.PNG None 0 133 | Train/1116.PNG None 0 134 | Train/0830.PNG None 0 135 | Train/0767.PNG None 0 136 | Train/0588.PNG None 0 137 | Train/0728.PNG None 0 138 | Train/0800.PNG None 0 139 | Train/0856.PNG None 0 140 | Train/0777.PNG None 0 141 | Train/0838.PNG None 0 142 | Train/0623.PNG None 0 143 | Train/0788.PNG None 0 144 | Train/0852.PNG None 0 145 | Train/0648.PNG None 0 146 | Train/1049.PNG None 0 147 | Train/0794.PNG None 0 148 | Train/1044.PNG None 0 149 | Train/1110.PNG None 0 150 | Train/1009.PNG None 0 151 | Train/0883.PNG None 0 152 | Train/0891.PNG None 0 153 | Train/0825.PNG None 0 154 | Train/0643.PNG None 0 155 | Train/1125.PNG None 0 156 | Train/0864.PNG None 0 157 | Train/0947.PNG None 0 158 | Train/0592.PNG None 0 159 | Train/0857.PNG None 0 160 | Train/1038.PNG None 0 161 | Train/0716.PNG None 0 162 | Train/0854.PNG None 0 163 | Train/1037.PNG None 0 164 | Train/0604.PNG None 0 165 | Train/0711.PNG None 0 166 | Train/0681.PNG None 0 167 | Train/0900.PNG None 0 168 | Train/0612.PNG None 0 169 | Train/0959.PNG None 0 170 | Train/1093.PNG None 0 171 | Train/1026.PNG None 0 172 | Train/0978.PNG None 0 173 | Train/0944.PNG None 0 174 | Train/1013.PNG None 0 175 | Train/0682.PNG None 0 176 | Train/1057.PNG None 0 177 | Train/1102.PNG None 0 178 | Train/0751.PNG None 0 179 | Train/0994.PNG None 0 180 | Train/0719.PNG None 0 181 | Train/0746.PNG None 0 182 | Train/0901.PNG None 0 183 | Train/1017.PNG None 0 184 | Train/0827.PNG None 0 185 | Train/0945.PNG None 0 186 | Train/0679.PNG None 0 187 | Train/1056.PNG None 0 188 | Train/0995.PNG None 0 189 | Train/0585.PNG None 0 190 | Train/1126.PNG None 0 191 | Train/0818.PNG None 0 192 | Train/1076.PNG None 0 193 | Train/0933.PNG None 0 194 | Train/0792.PNG None 0 195 | Train/0813.PNG None 0 196 | Train/0829.PNG None 0 197 | Train/0737.PNG None 0 198 | Train/0850.PNG None 0 199 | Train/1083.PNG None 0 200 | Train/0938.PNG None 0 201 | Train/0708.PNG None 0 202 | Train/1022.PNG None 0 203 | Train/0696.PNG None 0 204 | Train/0985.PNG None 0 205 | Train/1074.PNG None 0 206 | Train/0758.PNG None 0 207 | Train/0975.PNG None 0 208 | Train/1027.PNG None 0 209 | Train/0769.PNG None 0 210 | Train/0587.PNG None 0 211 | Train/0611.PNG None 0 212 | Train/0841.PNG None 0 213 | Train/0912.PNG None 0 214 | Train/1008.PNG None 0 215 | Train/0687.PNG None 0 216 | Train/0644.PNG None 0 217 | Train/0783.PNG None 0 218 | Train/0851.PNG None 0 219 | Train/0911.PNG None 0 220 | Train/1002.PNG None 0 221 | Train/0876.PNG None 0 222 | Train/0622.PNG None 0 223 | Train/0673.PNG None 0 224 | Train/0620.PNG None 0 225 | Train/0992.PNG None 0 226 | Train/1081.PNG None 0 227 | Train/0987.PNG None 0 228 | Train/0669.PNG None 0 229 | Train/0913.PNG None 0 230 | Train/0609.PNG None 0 231 | Train/0680.PNG None 0 232 | Train/1072.PNG None 0 233 | Train/0688.PNG None 0 234 | Train/0773.PNG None 0 235 | Train/0887.PNG None 0 236 | Train/0790.PNG None 0 237 | Train/0742.PNG None 0 238 | Train/1005.PNG None 0 239 | Train/0753.PNG None 0 240 | Train/0590.PNG None 0 241 | Train/0598.PNG None 0 242 | Train/1079.PNG None 0 243 | Train/0699.PNG None 0 244 | Train/0785.PNG None 0 245 | Train/0922.PNG None 0 246 | Train/0997.PNG None 0 247 | Train/0748.PNG None 0 248 | Train/0798.PNG None 0 249 | Train/0840.PNG Train/Label/0840_label.PNG 1 250 | Train/0892.PNG Train/Label/0892_label.PNG 1 251 | Train/1132.PNG Train/Label/1132_label.PNG 1 252 | Train/0781.PNG Train/Label/0781_label.PNG 1 253 | Train/0860.PNG Train/Label/0860_label.PNG 1 254 | Train/0635.PNG Train/Label/0635_label.PNG 1 255 | Train/1140.PNG Train/Label/1140_label.PNG 1 256 | Train/0690.PNG Train/Label/0690_label.PNG 1 257 | Train/0672.PNG Train/Label/0672_label.PNG 1 258 | Train/0973.PNG Train/Label/0973_label.PNG 1 259 | Train/1129.PNG Train/Label/1129_label.PNG 1 260 | Train/1036.PNG Train/Label/1036_label.PNG 1 261 | Train/0784.PNG Train/Label/0784_label.PNG 1 262 | Train/0634.PNG Train/Label/0634_label.PNG 1 263 | Train/0914.PNG Train/Label/0914_label.PNG 1 264 | Train/0902.PNG Train/Label/0902_label.PNG 1 265 | Train/0755.PNG Train/Label/0755_label.PNG 1 266 | Train/1143.PNG Train/Label/1143_label.PNG 1 267 | Train/1114.PNG Train/Label/1114_label.PNG 1 268 | Train/0626.PNG Train/Label/0626_label.PNG 1 269 | Train/0822.PNG Train/Label/0822_label.PNG 1 270 | Train/0761.PNG Train/Label/0761_label.PNG 1 271 | Train/1087.PNG Train/Label/1087_label.PNG 1 272 | Train/1130.PNG Train/Label/1130_label.PNG 1 273 | Train/0760.PNG Train/Label/0760_label.PNG 1 274 | Train/1141.PNG Train/Label/1141_label.PNG 1 275 | Train/0837.PNG Train/Label/0837_label.PNG 1 276 | Train/1133.PNG Train/Label/1133_label.PNG 1 277 | Train/0741.PNG Train/Label/0741_label.PNG 1 278 | Train/0740.PNG Train/Label/0740_label.PNG 1 279 | Train/1131.PNG Train/Label/1131_label.PNG 1 280 | Train/0873.PNG Train/Label/0873_label.PNG 1 281 | Train/1018.PNG Train/Label/1018_label.PNG 1 282 | Train/0729.PNG Train/Label/0729_label.PNG 1 283 | Train/1142.PNG Train/Label/1142_label.PNG 1 284 | Train/0846.PNG Train/Label/0846_label.PNG 1 285 | Train/1148.PNG Train/Label/1148_label.PNG 1 286 | Train/0658.PNG Train/Label/0658_label.PNG 1 287 | Train/1034.PNG Train/Label/1034_label.PNG 1 288 | Train/0897.PNG Train/Label/0897_label.PNG 1 -------------------------------------------------------------------------------- /WSLDatasets/configs/class4_train5valid5/validation.txt: -------------------------------------------------------------------------------- 1 | Train/0684.PNG Train/Label/0684_label.PNG 1 2 | Train/1077.PNG Train/Label/1077_label.PNG 1 3 | Train/0836.PNG Train/Label/0836_label.PNG 1 4 | Train/0975.PNG Train/Label/0975_label.PNG 1 5 | Train/0945.PNG Train/Label/0945_label.PNG 1 6 | Train/0685.PNG Train/Label/0685_label.PNG 1 7 | Train/0796.PNG Train/Label/0796_label.PNG 1 8 | Train/1135.PNG Train/Label/1135_label.PNG 1 9 | Train/0741.PNG Train/Label/0741_label.PNG 1 10 | Train/0587.PNG Train/Label/0587_label.PNG 1 11 | Train/0826.PNG Train/Label/0826_label.PNG 1 12 | Train/1105.PNG Train/Label/1105_label.PNG 1 13 | Train/1058.PNG Train/Label/1058_label.PNG 1 14 | Train/1020.PNG Train/Label/1020_label.PNG 1 15 | Train/0597.PNG Train/Label/0597_label.PNG 1 16 | Train/1108.PNG Train/Label/1108_label.PNG 1 17 | Train/1144.PNG Train/Label/1144_label.PNG 1 18 | Train/0961.PNG Train/Label/0961_label.PNG 1 19 | Train/1147.PNG Train/Label/1147_label.PNG 1 20 | Train/1120.PNG Train/Label/1120_label.PNG 1 21 | Train/1041.PNG Train/Label/1041_label.PNG 1 22 | Train/1101.PNG Train/Label/1101_label.PNG 1 23 | Train/0704.PNG Train/Label/0704_label.PNG 1 24 | Train/1146.PNG Train/Label/1146_label.PNG 1 25 | Train/1145.PNG Train/Label/1145_label.PNG 1 26 | Train/0635.PNG Train/Label/0635_label.PNG 1 27 | Train/1143.PNG Train/Label/1143_label.PNG 1 28 | Train/0736.PNG Train/Label/0736_label.PNG 1 29 | Train/0689.PNG Train/Label/0689_label.PNG 1 30 | Train/1139.PNG Train/Label/1139_label.PNG 1 31 | Train/0628.PNG Train/Label/0628_label.PNG 1 32 | Train/0971.PNG Train/Label/0971_label.PNG 1 33 | Train/0748.PNG Train/Label/0748_label.PNG 1 34 | Train/0774.PNG Train/Label/0774_label.PNG 1 35 | Train/1100.PNG Train/Label/1100_label.PNG 1 36 | Train/0639.PNG Train/Label/0639_label.PNG 1 37 | Train/1138.PNG Train/Label/1138_label.PNG 1 38 | Train/1049.PNG Train/Label/1049_label.PNG 1 39 | Train/1097.PNG Train/Label/1097_label.PNG 1 40 | Train/0912.PNG Train/Label/0912_label.PNG 1 41 | Train/0934.PNG Train/Label/0934_label.PNG 1 42 | Train/0792.PNG None 0 43 | Train/1031.PNG None 0 44 | Train/0858.PNG None 0 45 | Train/0871.PNG None 0 46 | Train/0984.PNG None 0 47 | Train/0856.PNG None 0 48 | Train/1114.PNG None 0 49 | Train/0758.PNG None 0 50 | Train/0993.PNG None 0 51 | Train/0829.PNG None 0 52 | Train/0828.PNG None 0 53 | Train/0866.PNG None 0 54 | Train/1078.PNG None 0 55 | Train/1076.PNG None 0 56 | Train/0727.PNG None 0 57 | Train/1072.PNG None 0 58 | Train/1025.PNG None 0 59 | Train/1012.PNG None 0 60 | Train/0761.PNG None 0 61 | Train/0653.PNG None 0 62 | Train/0911.PNG None 0 63 | Train/0834.PNG None 0 64 | Train/1056.PNG None 0 65 | Train/0620.PNG None 0 66 | Train/1123.PNG None 0 67 | Train/0622.PNG None 0 68 | Train/0929.PNG None 0 69 | Train/0603.PNG None 0 70 | Train/0662.PNG None 0 71 | Train/0632.PNG None 0 72 | Train/0898.PNG None 0 73 | Train/0771.PNG None 0 74 | Train/0948.PNG None 0 75 | Train/1066.PNG None 0 76 | Train/0815.PNG None 0 77 | Train/0825.PNG None 0 78 | Train/0648.PNG None 0 79 | Train/0617.PNG None 0 80 | Train/0645.PNG None 0 81 | Train/0671.PNG None 0 82 | Train/0809.PNG None 0 83 | Train/0676.PNG None 0 84 | Train/0946.PNG None 0 85 | Train/0750.PNG None 0 86 | Train/0884.PNG None 0 87 | Train/0878.PNG None 0 88 | Train/0873.PNG None 0 89 | Train/0672.PNG None 0 90 | Train/0724.PNG None 0 91 | Train/1106.PNG None 0 92 | Train/0810.PNG None 0 93 | Train/0923.PNG None 0 94 | Train/0875.PNG None 0 95 | Train/0733.PNG None 0 96 | Train/1103.PNG None 0 97 | Train/0601.PNG None 0 98 | Train/0881.PNG None 0 99 | Train/0977.PNG None 0 100 | Train/1112.PNG None 0 101 | Train/1037.PNG None 0 102 | Train/0956.PNG None 0 103 | Train/0668.PNG None 0 104 | Train/0751.PNG None 0 105 | Train/0782.PNG None 0 106 | Train/0734.PNG None 0 107 | Train/0842.PNG None 0 108 | Train/0608.PNG None 0 109 | Train/0698.PNG None 0 110 | Train/1003.PNG None 0 111 | Train/1092.PNG None 0 112 | Train/1035.PNG None 0 113 | Train/0952.PNG None 0 114 | Train/1015.PNG None 0 115 | Train/0699.PNG None 0 116 | Train/0818.PNG None 0 117 | Train/1116.PNG None 0 118 | Train/0778.PNG None 0 119 | Train/0711.PNG None 0 120 | Train/0656.PNG None 0 121 | Train/1080.PNG None 0 122 | Train/0730.PNG None 0 123 | Train/0588.PNG None 0 124 | Train/0805.PNG None 0 125 | Train/0888.PNG None 0 126 | Train/0957.PNG None 0 127 | Train/1082.PNG None 0 128 | Train/0833.PNG None 0 129 | Train/0747.PNG None 0 130 | Train/0902.PNG None 0 131 | Train/0660.PNG None 0 132 | Train/0930.PNG None 0 133 | Train/1044.PNG None 0 134 | Train/0599.PNG None 0 135 | Train/0702.PNG None 0 136 | Train/0797.PNG None 0 137 | Train/0996.PNG None 0 138 | Train/0821.PNG None 0 139 | Train/0802.PNG None 0 140 | Train/0830.PNG None 0 141 | Train/0723.PNG None 0 142 | Train/0725.PNG None 0 143 | Train/0780.PNG None 0 144 | Train/1102.PNG None 0 145 | Train/1019.PNG None 0 146 | Train/0807.PNG None 0 147 | Train/0619.PNG None 0 148 | Train/0994.PNG None 0 149 | Train/0922.PNG None 0 150 | Train/0637.PNG None 0 151 | Train/0701.PNG None 0 152 | Train/0703.PNG None 0 153 | Train/0627.PNG None 0 154 | Train/0583.PNG None 0 155 | Train/0578.PNG None 0 156 | Train/0715.PNG None 0 157 | Train/1038.PNG None 0 158 | Train/1000.PNG None 0 159 | Train/0963.PNG None 0 160 | Train/0799.PNG None 0 161 | Train/0642.PNG None 0 162 | Train/0759.PNG None 0 163 | Train/0989.PNG None 0 164 | Train/0882.PNG None 0 165 | Train/1045.PNG None 0 166 | Train/0709.PNG None 0 167 | Train/0816.PNG None 0 168 | Train/1131.PNG None 0 169 | Train/0770.PNG None 0 170 | Train/0686.PNG None 0 171 | Train/0861.PNG None 0 172 | Train/0793.PNG None 0 173 | Train/0737.PNG None 0 174 | Train/0863.PNG None 0 175 | Train/0819.PNG None 0 176 | Train/0824.PNG None 0 177 | Train/0616.PNG None 0 178 | Train/1052.PNG None 0 179 | Train/0855.PNG None 0 180 | Train/0710.PNG None 0 181 | Train/0966.PNG None 0 182 | Train/0714.PNG None 0 183 | Train/0909.PNG None 0 184 | Train/0641.PNG None 0 185 | Train/1107.PNG None 0 186 | Train/0955.PNG None 0 187 | Train/1054.PNG None 0 188 | Train/0595.PNG None 0 189 | Train/0853.PNG None 0 190 | Train/1029.PNG None 0 191 | Train/0914.PNG None 0 192 | Train/0661.PNG None 0 193 | Train/1001.PNG None 0 194 | Train/0735.PNG None 0 195 | Train/0854.PNG None 0 196 | Train/0784.PNG None 0 197 | Train/0868.PNG None 0 198 | Train/0757.PNG None 0 199 | Train/0786.PNG None 0 200 | Train/0883.PNG None 0 201 | Train/0651.PNG None 0 202 | Train/0673.PNG None 0 203 | Train/1067.PNG None 0 204 | Train/0669.PNG None 0 205 | Train/0752.PNG None 0 206 | Train/0921.PNG None 0 207 | Train/0933.PNG None 0 208 | Train/1011.PNG None 0 209 | Train/0728.PNG None 0 210 | Train/1115.PNG None 0 211 | Train/0596.PNG None 0 212 | Train/0745.PNG None 0 213 | Train/0894.PNG None 0 214 | Train/1030.PNG None 0 215 | Train/0625.PNG None 0 216 | Train/0667.PNG None 0 217 | Train/1121.PNG None 0 218 | Train/0990.PNG None 0 219 | Train/0915.PNG None 0 220 | Train/0928.PNG None 0 221 | Train/0960.PNG None 0 222 | Train/0972.PNG None 0 223 | Train/0880.PNG None 0 224 | Train/0679.PNG None 0 225 | Train/0720.PNG None 0 226 | Train/1118.PNG None 0 227 | Train/1086.PNG None 0 228 | Train/1111.PNG None 0 229 | Train/0794.PNG None 0 230 | Train/0982.PNG None 0 231 | Train/1050.PNG None 0 232 | Train/0650.PNG None 0 233 | Train/0772.PNG None 0 234 | Train/0584.PNG None 0 235 | Train/0939.PNG None 0 236 | Train/0986.PNG None 0 237 | Train/0823.PNG None 0 238 | Train/0775.PNG None 0 239 | Train/0936.PNG None 0 240 | Train/1023.PNG None 0 241 | Train/1024.PNG None 0 242 | Train/0609.PNG None 0 243 | Train/0754.PNG None 0 244 | Train/0636.PNG None 0 245 | Train/0841.PNG None 0 246 | Train/0920.PNG None 0 247 | Train/0721.PNG None 0 248 | Train/0582.PNG None 0 249 | Train/1026.PNG None 0 250 | Train/0852.PNG None 0 251 | Train/1063.PNG None 0 252 | Train/0692.PNG None 0 253 | Train/0766.PNG None 0 254 | Train/0579.PNG None 0 255 | Train/0600.PNG None 0 256 | Train/0700.PNG None 0 257 | Train/1071.PNG None 0 258 | Train/0773.PNG None 0 259 | Train/0717.PNG None 0 260 | Train/0722.PNG None 0 261 | Train/0891.PNG None 0 262 | Train/0670.PNG None 0 263 | Train/0712.PNG None 0 264 | Train/0999.PNG None 0 265 | Train/0731.PNG None 0 266 | Train/0822.PNG None 0 267 | Train/1085.PNG None 0 268 | Train/0978.PNG None 0 269 | Train/1088.PNG None 0 270 | Train/1128.PNG None 0 271 | Train/0941.PNG None 0 272 | Train/0760.PNG None 0 273 | Train/0705.PNG None 0 274 | Train/0630.PNG None 0 275 | Train/0769.PNG None 0 276 | Train/1093.PNG None 0 277 | Train/0716.PNG None 0 278 | Train/0983.PNG None 0 279 | Train/0893.PNG None 0 280 | Train/0862.PNG None 0 281 | Train/1034.PNG None 0 282 | Train/1081.PNG None 0 283 | Train/1002.PNG None 0 284 | Train/0682.PNG None 0 285 | Train/0765.PNG None 0 286 | Train/0749.PNG None 0 287 | Train/0585.PNG None 0 288 | Train/0767.PNG None 0 -------------------------------------------------------------------------------- /WSLDatasets/configs/class6_train5valid5/validation.txt: -------------------------------------------------------------------------------- 1 | Train/0602.PNG None 0 2 | Train/0851.PNG None 0 3 | Train/0996.PNG None 0 4 | Train/1058.PNG None 0 5 | Train/1138.PNG None 0 6 | Train/1011.PNG None 0 7 | Train/0699.PNG None 0 8 | Train/0900.PNG None 0 9 | Train/1025.PNG None 0 10 | Train/1047.PNG None 0 11 | Train/0626.PNG None 0 12 | Train/0847.PNG None 0 13 | Train/0590.PNG None 0 14 | Train/0874.PNG None 0 15 | Train/1081.PNG None 0 16 | Train/0804.PNG None 0 17 | Train/0709.PNG None 0 18 | Train/0724.PNG None 0 19 | Train/0610.PNG None 0 20 | Train/0651.PNG None 0 21 | Train/0831.PNG None 0 22 | Train/0704.PNG None 0 23 | Train/0672.PNG None 0 24 | Train/0753.PNG None 0 25 | Train/0650.PNG None 0 26 | Train/0781.PNG None 0 27 | Train/0689.PNG None 0 28 | Train/0968.PNG None 0 29 | Train/0944.PNG None 0 30 | Train/0779.PNG None 0 31 | Train/1141.PNG None 0 32 | Train/0922.PNG None 0 33 | Train/1046.PNG None 0 34 | Train/0706.PNG None 0 35 | Train/0703.PNG None 0 36 | Train/0637.PNG None 0 37 | Train/0600.PNG None 0 38 | Train/0757.PNG None 0 39 | Train/1085.PNG None 0 40 | Train/0661.PNG None 0 41 | Train/0803.PNG None 0 42 | Train/0624.PNG None 0 43 | Train/1074.PNG None 0 44 | Train/1110.PNG None 0 45 | Train/1064.PNG None 0 46 | Train/1102.PNG None 0 47 | Train/0627.PNG None 0 48 | Train/0722.PNG None 0 49 | Train/0786.PNG None 0 50 | Train/1055.PNG None 0 51 | Train/0788.PNG None 0 52 | Train/0578.PNG None 0 53 | Train/0762.PNG None 0 54 | Train/0730.PNG None 0 55 | Train/0605.PNG None 0 56 | Train/0952.PNG None 0 57 | Train/0876.PNG None 0 58 | Train/0744.PNG None 0 59 | Train/0842.PNG None 0 60 | Train/0859.PNG None 0 61 | Train/1131.PNG None 0 62 | Train/1039.PNG None 0 63 | Train/0966.PNG None 0 64 | Train/0894.PNG None 0 65 | Train/0959.PNG None 0 66 | Train/0745.PNG None 0 67 | Train/0883.PNG None 0 68 | Train/1050.PNG None 0 69 | Train/1012.PNG None 0 70 | Train/0608.PNG None 0 71 | Train/0609.PNG None 0 72 | Train/0658.PNG None 0 73 | Train/1109.PNG None 0 74 | Train/0789.PNG None 0 75 | Train/1094.PNG None 0 76 | Train/0675.PNG None 0 77 | Train/1149.PNG None 0 78 | Train/0733.PNG None 0 79 | Train/0765.PNG None 0 80 | Train/0929.PNG None 0 81 | Train/1068.PNG None 0 82 | Train/0751.PNG None 0 83 | Train/0871.PNG None 0 84 | Train/1123.PNG None 0 85 | Train/0947.PNG None 0 86 | Train/1118.PNG None 0 87 | Train/1142.PNG None 0 88 | Train/0593.PNG None 0 89 | Train/0877.PNG None 0 90 | Train/0841.PNG None 0 91 | Train/0969.PNG None 0 92 | Train/1093.PNG None 0 93 | Train/0824.PNG None 0 94 | Train/0678.PNG None 0 95 | Train/0865.PNG None 0 96 | Train/1075.PNG None 0 97 | Train/1111.PNG None 0 98 | Train/0712.PNG None 0 99 | Train/1105.PNG None 0 100 | Train/1121.PNG None 0 101 | Train/0774.PNG None 0 102 | Train/0693.PNG None 0 103 | Train/0880.PNG None 0 104 | Train/0735.PNG None 0 105 | Train/0958.PNG None 0 106 | Train/1013.PNG None 0 107 | Train/0787.PNG None 0 108 | Train/0888.PNG None 0 109 | Train/1009.PNG None 0 110 | Train/0879.PNG None 0 111 | Train/0955.PNG None 0 112 | Train/0790.PNG None 0 113 | Train/1147.PNG None 0 114 | Train/0681.PNG None 0 115 | Train/0869.PNG None 0 116 | Train/1023.PNG None 0 117 | Train/0777.PNG None 0 118 | Train/1063.PNG None 0 119 | Train/0956.PNG None 0 120 | Train/0692.PNG None 0 121 | Train/1069.PNG None 0 122 | Train/0631.PNG None 0 123 | Train/0583.PNG None 0 124 | Train/0677.PNG None 0 125 | Train/0707.PNG None 0 126 | Train/0905.PNG None 0 127 | Train/0700.PNG None 0 128 | Train/0731.PNG None 0 129 | Train/0764.PNG None 0 130 | Train/0948.PNG None 0 131 | Train/0977.PNG None 0 132 | Train/1045.PNG None 0 133 | Train/0964.PNG None 0 134 | Train/0597.PNG None 0 135 | Train/0870.PNG None 0 136 | Train/0665.PNG None 0 137 | Train/0585.PNG None 0 138 | Train/0945.PNG None 0 139 | Train/0810.PNG None 0 140 | Train/0793.PNG None 0 141 | Train/0808.PNG None 0 142 | Train/0588.PNG None 0 143 | Train/0686.PNG None 0 144 | Train/0873.PNG None 0 145 | Train/0819.PNG None 0 146 | Train/0584.PNG None 0 147 | Train/1127.PNG None 0 148 | Train/0987.PNG None 0 149 | Train/0942.PNG None 0 150 | Train/1037.PNG None 0 151 | Train/1022.PNG None 0 152 | Train/0976.PNG None 0 153 | Train/0884.PNG None 0 154 | Train/0736.PNG None 0 155 | Train/0685.PNG None 0 156 | Train/0782.PNG None 0 157 | Train/0813.PNG None 0 158 | Train/1145.PNG None 0 159 | Train/0737.PNG None 0 160 | Train/0937.PNG None 0 161 | Train/1133.PNG None 0 162 | Train/0783.PNG None 0 163 | Train/0965.PNG None 0 164 | Train/0822.PNG None 0 165 | Train/0695.PNG None 0 166 | Train/0660.PNG None 0 167 | Train/0934.PNG None 0 168 | Train/0816.PNG None 0 169 | Train/0997.PNG None 0 170 | Train/0815.PNG None 0 171 | Train/0738.PNG None 0 172 | Train/1117.PNG None 0 173 | Train/1014.PNG None 0 174 | Train/1042.PNG None 0 175 | Train/0743.PNG None 0 176 | Train/0970.PNG None 0 177 | Train/0664.PNG None 0 178 | Train/0628.PNG None 0 179 | Train/0680.PNG None 0 180 | Train/0984.PNG None 0 181 | Train/0604.PNG None 0 182 | Train/0849.PNG None 0 183 | Train/0645.PNG None 0 184 | Train/0732.PNG None 0 185 | Train/1004.PNG None 0 186 | Train/0655.PNG None 0 187 | Train/0683.PNG None 0 188 | Train/0690.PNG None 0 189 | Train/0853.PNG None 0 190 | Train/0924.PNG None 0 191 | Train/0960.PNG None 0 192 | Train/1113.PNG None 0 193 | Train/0746.PNG None 0 194 | Train/0882.PNG None 0 195 | Train/0930.PNG None 0 196 | Train/0836.PNG None 0 197 | Train/0632.PNG None 0 198 | Train/0860.PNG None 0 199 | Train/0713.PNG None 0 200 | Train/0657.PNG None 0 201 | Train/0827.PNG None 0 202 | Train/1033.PNG None 0 203 | Train/1018.PNG None 0 204 | Train/0802.PNG None 0 205 | Train/0982.PNG None 0 206 | Train/0867.PNG None 0 207 | Train/0684.PNG None 0 208 | Train/0647.PNG None 0 209 | Train/0852.PNG None 0 210 | Train/0986.PNG None 0 211 | Train/0611.PNG None 0 212 | Train/0656.PNG None 0 213 | Train/0670.PNG None 0 214 | Train/0629.PNG None 0 215 | Train/1134.PNG None 0 216 | Train/0748.PNG None 0 217 | Train/1092.PNG None 0 218 | Train/1006.PNG None 0 219 | Train/0586.PNG None 0 220 | Train/0740.PNG None 0 221 | Train/1106.PNG None 0 222 | Train/0912.PNG None 0 223 | Train/0625.PNG None 0 224 | Train/1148.PNG None 0 225 | Train/0902.PNG None 0 226 | Train/0794.PNG None 0 227 | Train/0875.PNG None 0 228 | Train/0796.PNG None 0 229 | Train/0715.PNG None 0 230 | Train/0636.PNG None 0 231 | Train/0823.PNG None 0 232 | Train/1017.PNG None 0 233 | Train/0897.PNG None 0 234 | Train/0866.PNG None 0 235 | Train/1146.PNG None 0 236 | Train/0673.PNG None 0 237 | Train/0594.PNG None 0 238 | Train/0971.PNG None 0 239 | Train/1026.PNG None 0 240 | Train/0973.PNG None 0 241 | Train/0904.PNG None 0 242 | Train/0992.PNG None 0 243 | Train/0993.PNG None 0 244 | Train/0784.PNG None 0 245 | Train/0705.PNG None 0 246 | Train/0697.PNG None 0 247 | Train/0770.PNG Train/Label/0770_label.PNG 1 248 | Train/0607.PNG Train/Label/0607_label.PNG 1 249 | Train/0599.PNG Train/Label/0599_label.PNG 1 250 | Train/0856.PNG Train/Label/0856_label.PNG 1 251 | Train/0766.PNG Train/Label/0766_label.PNG 1 252 | Train/0848.PNG Train/Label/0848_label.PNG 1 253 | Train/0640.PNG Train/Label/0640_label.PNG 1 254 | Train/1056.PNG Train/Label/1056_label.PNG 1 255 | Train/1062.PNG Train/Label/1062_label.PNG 1 256 | Train/0990.PNG Train/Label/0990_label.PNG 1 257 | Train/1124.PNG Train/Label/1124_label.PNG 1 258 | Train/0933.PNG Train/Label/0933_label.PNG 1 259 | Train/0908.PNG Train/Label/0908_label.PNG 1 260 | Train/0752.PNG Train/Label/0752_label.PNG 1 261 | Train/0649.PNG Train/Label/0649_label.PNG 1 262 | Train/0979.PNG Train/Label/0979_label.PNG 1 263 | Train/0728.PNG Train/Label/0728_label.PNG 1 264 | Train/0838.PNG Train/Label/0838_label.PNG 1 265 | Train/0581.PNG Train/Label/0581_label.PNG 1 266 | Train/0974.PNG Train/Label/0974_label.PNG 1 267 | Train/0579.PNG Train/Label/0579_label.PNG 1 268 | Train/0891.PNG Train/Label/0891_label.PNG 1 269 | Train/0749.PNG Train/Label/0749_label.PNG 1 270 | Train/0806.PNG Train/Label/0806_label.PNG 1 271 | Train/1112.PNG Train/Label/1112_label.PNG 1 272 | Train/0981.PNG Train/Label/0981_label.PNG 1 273 | Train/0889.PNG Train/Label/0889_label.PNG 1 274 | Train/0750.PNG Train/Label/0750_label.PNG 1 275 | Train/0896.PNG Train/Label/0896_label.PNG 1 276 | Train/1000.PNG Train/Label/1000_label.PNG 1 277 | Train/1038.PNG Train/Label/1038_label.PNG 1 278 | Train/0698.PNG Train/Label/0698_label.PNG 1 279 | Train/0618.PNG Train/Label/0618_label.PNG 1 280 | Train/0985.PNG Train/Label/0985_label.PNG 1 281 | Train/0846.PNG Train/Label/0846_label.PNG 1 282 | Train/0907.PNG Train/Label/0907_label.PNG 1 283 | Train/0616.PNG Train/Label/0616_label.PNG 1 284 | Train/0938.PNG Train/Label/0938_label.PNG 1 285 | Train/0648.PNG Train/Label/0648_label.PNG 1 286 | Train/0755.PNG Train/Label/0755_label.PNG 1 287 | Train/0925.PNG Train/Label/0925_label.PNG 1 288 | Train/0696.PNG Train/Label/0696_label.PNG 1 -------------------------------------------------------------------------------- /utils/cv_utiles.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | def sub_bp_MOG2(img): 7 | """ 8 | :param img: 输入图像 9 | :return:返回 10 | """ 11 | mog = cv2.createBackgroundSubtractorMOG2() 12 | return mog.apply(img,None,0.01) 13 | 14 | def cv_equalizeHist(img): 15 | """ 16 | :param img: 输入图像 17 | :return: 直方图均衡后的图像 18 | """ 19 | return cv2.equalizeHist(img) 20 | 21 | 22 | def high_pass_fft(img,filter_size=None,power_thred=None): 23 | assert filter_size!=None or power_thred!=None 24 | if(filter_size !=None and power_thred !=None): 25 | raise Exception("filter_size and power_thred are incompatible!") 26 | img_float32 = np.float32(img) 27 | dft = cv2.dft(img_float32, flags=cv2.DFT_COMPLEX_OUTPUT) 28 | # 将低频信息转换至图像中心 29 | dft_shift = np.fft.fftshift(dft) 30 | if power_thred !=None: 31 | # # 获取图像尺寸 与 中心坐标 32 | features = cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1])/np.sqrt(img.shape[0]*img.shape[1]) 33 | mask = np.where(features > power_thred, 1, 0)[:, :, np.newaxis] 34 | if filter_size!=None: 35 | crow, ccol = int(img.shape[0] / 2), int(img.shape[1] / 2) # 求得图像的中心点位置 36 | mask = np.zeros((img.shape[0], img.shape[1], 2), np.uint8) 37 | mask[crow-filter_size:crow+filter_size, ccol-filter_size:ccol+filter_size] = 1 38 | # 掩码与傅里叶图像按位相乘 去除低频区域 39 | fshift = dft_shift * mask# 40 | # 之前把低频转换到了图像中间,现在需要重新转换回去 41 | f_ishift = np.fft.ifftshift(fshift) 42 | # 傅里叶逆变换 43 | img_back = cv2.idft(f_ishift) 44 | img_back = cv2.magnitude(img_back[:, :, 0], img_back[:, :, 1]) 45 | img_back=(img_back-np.min(img_back))/(np.max(img_back)-np.min(img_back))*255 46 | return mask[:, :, 0],img_back 47 | def split(image,num,axis=1,offset=0): 48 | assert axis==0 or axis==1 49 | h, w = image.shape 50 | if axis==0: 51 | h=(h+offset)//num-offset 52 | return [image[i*(h+offset):i*(h+offset)+h,:] for i in range(num)] 53 | if axis==1: 54 | w =(w+offset)//num-offset 55 | 56 | return [image[:,i * (w+offset):i * (w+offset)+w] for i in range(num)] 57 | 58 | def mask2rect(mask): 59 | cols,rows=np.where(mask>0) 60 | col1=np.amin(cols) 61 | col2 = np.amax(cols) 62 | row1=np.amin(rows) 63 | row2 = np.amax(rows) 64 | return (row1,col1,row2-row1,col2-col1) 65 | 66 | 67 | 68 | def cv_dilate(mask, ksize=5, struct="ellipse"): 69 | assert struct in ["rect", "ellipse"] 70 | if struct == "rect": struct = cv2.MORPH_RECT 71 | if struct == "ellipse": struct = cv2.MORPH_ELLIPSE 72 | elment = cv2.getStructuringElement(struct, (ksize, ksize)) 73 | mask_dilate = cv2.morphologyEx(mask, cv2.MORPH_DILATE, elment) 74 | return mask_dilate 75 | 76 | def show_cams_on_images(img_batch, mask_batch,filenames,save_dirs): 77 | if img_batch.ndim!=4:img_batch=img_batch.unsqueeze(0) 78 | if img_batch.shape[-1] != 3: raise Exception("image[{}] must be RGB!".format(img_batch.shape)) 79 | mask_batch=mask_batch.squeeze(1) 80 | batch= len(img_batch) if isinstance(img_batch, list) else img_batch.shape[0] 81 | img_height,img_width=img_batch.shape[1:3] 82 | save_dirs=[save_dirs]*batch if not isinstance(save_dirs, list) else save_dirs 83 | for i, filename in enumerate(filenames): 84 | save_dir=save_dirs[i] 85 | if not os.path.exists(save_dir): 86 | os.makedirs(save_dir) 87 | #filename = str(filename).split("'")[-2].replace("/","_") 88 | filename=filename.decode("utf-8") if not isinstance(filename, str) else filename 89 | heatmap = cv2.applyColorMap(np.uint8(255 * mask_batch[i]), cv2.COLORMAP_JET) 90 | heatmap=cv2.resize(heatmap,(img_width,img_height)) 91 | heatmap = np.float32(heatmap) / 255 92 | #img_show=cv2.cvtColor( np.uint8(255 * img_batch[i]), cv2.COLOR_GRAY2BGR) 93 | img_show =np.uint8(255 * img_batch[i]) 94 | cam = heatmap + np.float32(img_show)/255 95 | #cam=np.float32(img_show) / 255 96 | cam = cam / np.max(cam) 97 | cam=np.uint8(255 * cam) 98 | visualization_path = os.path.join(save_dir,filename) 99 | print("write to {}".format(visualization_path)) 100 | cv2.imwrite(visualization_path, cam) 101 | 102 | def grub_cut_on_mask(img,mask,thred=0.2,n_iter=1): 103 | mask=np.where(mask>thred,3,0).astype(np.uint8) 104 | # if mask.sum()==0: 105 | # return mask 106 | rect=(0,0,0,0) 107 | # rect=mask2rect(mask) 108 | #assert mask.sum()>0 109 | mode = cv2.GC_INIT_WITH_MASK 110 | # mode=cv2.GC_INIT_WITH_RECT 111 | bgdModel = np.zeros((1, 65), np.float64) 112 | fgdModel = np.zeros((1, 65), np.float64) 113 | try: 114 | cv2.grabCut(img, mask, rect, bgdModel, fgdModel, n_iter, mode=mode) 115 | except Exception as error: 116 | pass 117 | #print("错误:{}".format(error)) 118 | mask = np.where((mask == 2) | (mask == 0), 0, 1).astype("uint8") 119 | return mask 120 | 121 | 122 | def cv_resize(img,size=None,fxy=None): 123 | """ 124 | :param img: 输入图像 125 | :param size: 目标尺寸 (x,y) 126 | :param fxy: 放缩比例 (fx,fy) 和 size参数互斥 127 | :return: resize 之后的图片 128 | """ 129 | if size==None and fxy!=None: 130 | assert isinstance(fxy, tuple) and len(fxy) == 2 131 | return cv2.resize(img, (0, 0), fx=fxy[0], fy=fxy[1]) 132 | elif fxy==None and size!=None: 133 | assert isinstance(size, tuple) and len(fxy) == 2 134 | return cv2.resize(img, size) 135 | else: 136 | raise Exception("custom error!") 137 | 138 | def cv_imread(file_path, flag=-1): 139 | """ 140 | 解决cv包含中文路径的问题 141 | :param file_path: 路径 142 | :param flag: 143 | :return: 144 | """ 145 | cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) 146 | return cv_img 147 | 148 | def show_cams_on_images(img_batch, mask_batch,filenames,save_dirs): 149 | if img_batch.ndim!=4:img_batch=img_batch.unsqueeze(0) 150 | if img_batch.shape[-1] != 3: raise Exception("image[{}] must be RGB!".format(img_batch.shape)) 151 | mask_batch=mask_batch.squeeze(1) 152 | batch= len(img_batch) if isinstance(img_batch, list) else img_batch.shape[0] 153 | img_height,img_width=img_batch.shape[1:3] 154 | save_dirs=[save_dirs]*batch if not isinstance(save_dirs, list) else save_dirs 155 | for i, filename in enumerate(filenames): 156 | save_dir=save_dirs[i] 157 | if not os.path.exists(save_dir): 158 | os.makedirs(save_dir) 159 | #filename = str(filename).split("'")[-2].replace("/","_") 160 | filename=filename.decode("utf-8") if not isinstance(filename, str) else filename 161 | heatmap = cv2.applyColorMap(np.uint8(255 * mask_batch[i]), cv2.COLORMAP_JET) 162 | heatmap=cv2.resize(heatmap,(img_width,img_height)) 163 | heatmap = np.float32(heatmap) / 255 164 | #img_show=cv2.cvtColor( np.uint8(255 * img_batch[i]), cv2.COLOR_GRAY2BGR) 165 | img_show =np.uint8(255 * img_batch[i]) 166 | cam = heatmap + np.float32(img_show)/255 167 | #cam=np.float32(img_show) / 255 168 | cam = cam / np.max(cam) 169 | cam=np.uint8(255 * cam) 170 | visualization_path = os.path.join(save_dir,filename) 171 | print("write to {}".format(visualization_path)) 172 | cv2.imwrite(visualization_path, cam) 173 | 174 | 175 | def Ostu(array): 176 | array = np.array(array * 255, dtype=np.uint8) 177 | best_threshold, binary_output = cv2.threshold(array, 100, 1, cv2.THRESH_BINARY) # cv2.THRESH_OTSU 178 | area = np.sum(np.array(binary_output)) 179 | predict =(area > 1) 180 | return predict 181 | def cv_open(mask,ksize=5,struct="ellipse"): 182 | assert struct in ["rect","ellipse"] 183 | if struct=="rect": struct=cv2.MORPH_RECT 184 | if struct=="ellipse": struct=cv2.MORPH_ELLIPSE 185 | elment=cv2.getStructuringElement(struct, (ksize, ksize)) 186 | mask_open = cv2.morphologyEx(mask, cv2.MORPH_OPEN, elment) 187 | return mask_open 188 | 189 | def cv_close(mask,ksize=5,struct="ellipse"): 190 | assert struct in ["rect","ellipse"] 191 | if struct=="rect": struct=cv2.MORPH_RECT 192 | if struct=="ellipse": struct=cv2.MORPH_ELLIPSE 193 | elment=cv2.getStructuringElement(struct, (ksize, ksize)) 194 | mask_open = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, elment) 195 | return mask_open 196 | def cv_dyn_threshold(img,thred,ksize=15): 197 | img_blur=cv2.blur(img,ksize=(ksize,ksize)) 198 | arr_blur=np.array(img_blur,dtype=np.float) 199 | arr=np.array(img,dtype=np.float) 200 | mask=np.where(np.abs(arr-arr_blur)>thred,1,0) 201 | return mask.astype(np.uint8) 202 | 203 | 204 | def origin_LBP(img): 205 | dst = np.zeros(img.shape, dtype=img.dtype) 206 | h, w = img.shape 207 | for i in range(1, h - 1): 208 | for j in range(1, w - 1): 209 | center = img[i][j] 210 | code = 0 211 | code |= (img[i - 1][j - 1] >= center) << (np.uint8)(7) 212 | code |= (img[i - 1][j] >= center) << (np.uint8)(6) 213 | code |= (img[i - 1][j + 1] >= center) << (np.uint8)(5) 214 | code |= (img[i][j + 1] >= center) << (np.uint8)(4) 215 | code |= (img[i + 1][j + 1] >= center) << (np.uint8)(3) 216 | code |= (img[i + 1][j] >= center) << (np.uint8)(2) 217 | code |= (img[i + 1][j - 1] >= center) << (np.uint8)(1) 218 | code |= (img[i][j - 1] >= center) << (np.uint8)(0) 219 | 220 | dst[i - 1][j - 1] = code 221 | return dst 222 | 223 | 224 | def match(img1,img2,filter_sz=3): 225 | dst = np.zeros(img1.shape, dtype=img1.dtype) 226 | h, w = img1.shape 227 | for i in range(0, h ): 228 | for j in range(0, w ): 229 | val=255 230 | for k in range(filter_sz): 231 | y=i+k-(filter_sz//2) 232 | if y<0 or y>h-1: 233 | continue 234 | for l in range(filter_sz): 235 | x = j + l - (filter_sz // 2) 236 | if x < 0 or x > w - 1: 237 | continue 238 | val=min(val,abs(img1[i,j]-img2[y,x])) 239 | dst[i,j]=val 240 | dst=dst.astype(np.float) 241 | dst=(dst-np.min(dst))/(np.max(dst)-np.min(dst)) 242 | return (dst*255).astype(np.uint8) 243 | 244 | if __name__=="__main__": 245 | dir=r"C:\Datasets\KolektorSDD" 246 | -------------------------------------------------------------------------------- /WSLDatasets/wsl_dataset.py: -------------------------------------------------------------------------------- 1 | from utils.data.base_dataset import * 2 | from utils.cv_utiles import cv_imread 3 | from utils.data import my_transforms 4 | from utils.param import Param 5 | import utils 6 | from utils import plt_utils 7 | from torchvision import transforms 8 | import cv2 9 | import os 10 | import numpy as np 11 | from random import shuffle 12 | from torch.utils.data import DataLoader 13 | 14 | """ 15 | ## dataloader 将数据打包为batch 16 | 1. 自己写也是可以,锻炼下 17 | 2, 数据读取是在cpu上运行的, 训练在GPU上运行的 (木桶定理) 18 | 3. 官方提供的接口:有多线程的功能。 19 | """ 20 | 21 | Trans = { 22 | "numpy":my_transforms.ComposeJoint([ 23 | [transforms.ToTensor(), transforms.ToTensor()], #ToTensor 转化为【0-1】 24 | # [transforms.Normalize(*mean_std), None], 25 | my_transforms.Tensor2Numpy(), 26 | [my_transforms.ToFloatNumpy(), my_transforms.ToMask(0.2)] 27 | ]), 28 | "torch":my_transforms.ComposeJoint([ 29 | [transforms.ToTensor(), transforms.ToTensor()], 30 | # [transforms.Normalize(*mean_std), None], 31 | [None, my_transforms.ToMask(0.2)] 32 | ]) 33 | } 34 | 35 | 36 | class WSLDataset(BDataset): 37 | def __init__(self,root,transform_PIL=None,return_numpy=True): 38 | self.root=root 39 | 40 | 41 | #把所有图片地址导入内存中 42 | self.list_dir(self.root,use_absPath=False) 43 | 44 | 45 | 46 | #把数据打包为样本: ( 图片地址, 像素标签的地址, 类别) 47 | self.make_dataset() 48 | 49 | 50 | # if transform_PIL is None: 51 | # self.transform_PIL = my_transforms.ComposeJoint([ 52 | # my_transforms.ToPIL(), 53 | # my_transforms.GroupRandomHorizontalFlip(), 54 | # my_transforms.GroupRandomVerticalFlip(), 55 | # my_transforms.GroupResize(size=(512,512)), 56 | # ]) 57 | # else: 58 | self.transform_PIL=transform_PIL 59 | 60 | self.transform_array = Trans["numpy"] if return_numpy else Trans["torch"] 61 | 62 | def get_label(self,path): 63 | img=cv_imread(path) 64 | if img is None: 65 | raise Exception("read image wrong") 66 | mask=np.where(img>0,1,0).astype(np.uint8) 67 | target = 1 if np.sum(mask) > 10 else 0 68 | return target 69 | def make_dataset(self): 70 | print("生成数据集......") 71 | samples = [] 72 | # 添加样本列表 73 | for img in sorted(self.imgs): 74 | label_pixel=os.path.join(os.path.dirname(img),"Label",os.path.basename(img).replace(".","_label.")) 75 | #print(label_pixel) 76 | # 过滤没有语义标签的图片 77 | label= 0 if label_pixel not in self.imgs_pixel else 1 78 | label_pixel=label_pixel if label_pixel in self.imgs_pixel else None 79 | # 过滤不符合类别的图片 80 | item = (img, label_pixel, label) 81 | samples.append(item) 82 | print(" 总样本数:{}".format(len(samples))) 83 | cls_dict=self.split_dataset_by_cls(samples,2) 84 | self.samples, self.cls_dict= samples, cls_dict 85 | for sample in self.samples: 86 | print(sample) 87 | 88 | def list_dir(self,root,use_absPath=False, func=None): 89 | def is_train(path): 90 | return True if "Train" in path else False 91 | def is_test(path): 92 | return True if "Test" in path else False 93 | def is_img(path): 94 | return True if path.endswith(".PNG") and "_label.PNG" not in path else False 95 | def is_imgPixel(path): 96 | return True if path.endswith("_label.PNG") else False 97 | self.imgs=super(WSLDataset,self).list_dir(root,use_absPath,is_img) 98 | # for img in self.imgs : print( img) 99 | self.imgs_pixel=super(WSLDataset, self).list_dir(root, use_absPath, is_imgPixel) 100 | # for img in self.imgs_pixel : print( img) 101 | self.imgs_train=[img_path for img_path in self.imgs if is_train(img_path)] 102 | self.imgs_test = [img_path for img_path in self.imgs if is_test(img_path)] 103 | self.imgs_pixel_train=[img_path for img_path in self.imgs_pixel if is_train(img_path)] 104 | self.imgs_pixel_test = [img_path for img_path in self.imgs_pixel if is_test(img_path)] 105 | 106 | def gen_a_sample(self,sample): 107 | """ 108 | 过程: 1. 数据增强 3. 读图片 2. 把数据转化为tensor 109 | :param sample: 110 | :return: 111 | """ 112 | 113 | file_basename_image, file_basename_label, label =sample 114 | # 1. 读图 115 | image_path = os.path.join(self.root, file_basename_image) 116 | image=cv_imread(image_path,-1) 117 | #print(image.shape) 118 | image = np.array(image).astype(np.uint8) 119 | if file_basename_label is not None: 120 | label_path = os.path.join(self.root, file_basename_label) 121 | pixel_label = cv_imread(label_path, -1) 122 | label_pixel = np.array(pixel_label).astype(np.uint8) 123 | else: 124 | label_pixel=np.zeros_like(image).astype(np.uint8) 125 | # 2. 数据增强 126 | if self.transform_PIL is not None: 127 | image,label_pixel=self.transform_PIL([image,label_pixel]) 128 | 129 | # 3. 数据格式转化 130 | # utils.plt_utils.plt_show_imgs([image, label_pixel]) 131 | image, label_pixel = self.transform_array([image, label_pixel]) 132 | # utils.plt_utils.plt_show_imgs([image.squeeze(), label_pixel.squeeze()]) 133 | return image, label_pixel, int(label), file_basename_image 134 | 135 | 136 | class WSLDataset_train(WSLDataset): 137 | 138 | def make_dataset(self): 139 | imgs=self.imgs_train 140 | imgs_pixel=self.imgs_pixel_train 141 | #print("生成数据集......") 142 | samples = [] 143 | # 添加样本列表 144 | for img in sorted(imgs): 145 | label_pixel=os.path.join(os.path.dirname(img),"Label",os.path.basename(img).replace(".","_label.")) 146 | # 过滤没有语义标签的图片 147 | label= 0 if label_pixel not in imgs_pixel else 1 148 | label_pixel=label_pixel if label_pixel in imgs_pixel else None 149 | # 过滤不符合类别的图片 150 | item = (img, label_pixel, label) 151 | samples.append(item) 152 | print(" 总样本数:{}".format(len(samples))) 153 | cls_dict=self.split_dataset_by_cls(samples,2) 154 | self.samples, self.cls_dict= samples, cls_dict 155 | 156 | class WSLDataset_test(WSLDataset): 157 | def make_dataset(self): 158 | imgs=self.imgs_test 159 | imgs_pixel=self.imgs_pixel_test 160 | print("生成数据集......") 161 | samples = [] 162 | # 添加样本列表 163 | for img in sorted(imgs): 164 | label_pixel=os.path.join(os.path.dirname(img),"Label",os.path.basename(img).replace(".","_label.")) 165 | # 过滤没有语义标签的图片 166 | label= 0 if label_pixel not in imgs_pixel else 1 167 | label_pixel=label_pixel if label_pixel in imgs_pixel else None 168 | # 过滤不符合类别的图片 169 | item = (img, label_pixel, label) 170 | samples.append(item) 171 | print(" 总样本数:{}".format(len(samples))) 172 | cls_dict=self.split_dataset_by_cls(samples,2) 173 | self.samples, self.cls_dict= samples, cls_dict 174 | 175 | 176 | class WSLDataset_split(WSLDataset): 177 | def __init__(self,root,data_config,phase="training",transform_PIL=None,return_numpy=True,**kwargs): 178 | self.root=root 179 | self.data_config=data_config 180 | self.phase=phase 181 | self.transform_PIL=transform_PIL 182 | self.transform_array = Trans["numpy"] if return_numpy else Trans["torch"] 183 | self.make_dataset() 184 | 185 | def make_dataset(self): 186 | from WSLDatasets import get_cur_path 187 | config_path=os.path.join(get_cur_path(),"configs",self.data_config) 188 | self.samples=read_txt(config_path,self.phase) 189 | for sample in self.samples: sample[2]=int(sample[2]) 190 | # print(" 总样本数:{}".format(len(self.samples))) 191 | self.cls_dict=self.split_dataset_by_cls(self.samples,loc=2) 192 | 193 | #####从新排列样本 194 | sample=[] 195 | max_len=max([ len(self.cls_dict[i]) for i in range(len(self.cls_dict)) ]) 196 | for i in range(max_len): 197 | for j in range(len(self.cls_dict)): 198 | if iself.thred,torch.full_like(img,1),torch.full_like(img,0)).int() 104 | return np.where(np.array(img)>self.thred,1,0).astype(np.int8) 105 | if isinstance(imgs, collections.Iterable): 106 | for idx in range(len(imgs)): 107 | imgs[idx]=func(imgs[idx]) 108 | else: 109 | imgs=func(imgs) 110 | return imgs 111 | 112 | class ToPIL(object): 113 | def __call__(self, imgs): 114 | 115 | def func(img): 116 | if isinstance(img, np.ndarray) and img.ndim == 2: 117 | # if 2D image, add channel dimension (HWC) 118 | img = np.expand_dims(img, 2) 119 | return F.to_pil_image(img) 120 | if isinstance(imgs, collections.Iterable): 121 | for idx in range(len(imgs)): 122 | imgs[idx]=func(imgs[idx]) 123 | else: 124 | imgs=func(imgs) 125 | # for img in imgs: 126 | # #print(img.size) 127 | # print(np.array(img).shape) 128 | return imgs 129 | 130 | 131 | 132 | class ToLong(object): 133 | def __call__(self, x): 134 | return torch.LongTensor(np.asarray(x)) 135 | 136 | class GroupResize(transforms.Resize): 137 | """Resize the input PIL Image to the given size. 138 | 139 | Args: 140 | size (sequence or int): Desired output size. If size is a sequence like 141 | (h, w), output size will be matched to this. If size is an int, 142 | smaller edge of the image will be matched to this number. 143 | i.e, if height > width, then image will be rescaled to 144 | (size * height / width, size) 145 | interpolation (int, optional): Desired interpolation. Default is 146 | ``PIL.Image.BILINEAR`` 147 | """ 148 | 149 | def __call__(self, imgs): 150 | """ 151 | Args: 152 | img (PIL Image): Image to be scaled. 153 | 154 | Returns: 155 | PIL Image: Rescaled image. 156 | """ 157 | def func(img): 158 | return F.resize(img, self.size, self.interpolation) 159 | if isinstance(imgs, collections.Iterable): 160 | for idx in range(len(imgs)): 161 | imgs[idx]=func(imgs[idx]) 162 | else: 163 | imgs=func(imgs) 164 | return imgs 165 | 166 | class GropuRandomCropAndScale(transforms.RandomResizedCrop): 167 | def __init__(self,area_limit=1e+8, size_ratio=(3. / 4., 4. / 3.), crop_scale=(0.5, 1.0), crop_ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 168 | 169 | self.interpolation = interpolation 170 | self.crop_scale = crop_scale 171 | self.crop_ratio = crop_ratio 172 | self.size_ratio=size_ratio 173 | self.area_limit=area_limit 174 | 175 | @staticmethod 176 | def get_params(img, scale, ratio): 177 | """Get parameters for ``crop`` for a random sized crop. 178 | 179 | Args: 180 | img (PIL Image): Image to be cropped. 181 | scale (tuple): range of size of the origin size cropped 182 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 183 | 184 | Returns: 185 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 186 | sized crop. 187 | """ 188 | #宽高比 189 | img_ratio=img.size[0]/img.size[1] 190 | img_area=img.size[0]*img.size[1] 191 | for attempt in range(10): 192 | target_area = random.uniform(*scale) * img_area 193 | aspect_ratio = random.uniform(*ratio) * img_ratio 194 | 195 | w = int(round(math.sqrt(target_area * aspect_ratio))) 196 | h = int(round(math.sqrt(target_area / aspect_ratio))) 197 | 198 | 199 | if w <= img.size[0] and h <= img.size[1]: 200 | i = random.randint(0, img.size[1] - h) 201 | j = random.randint(0, img.size[0] - w) 202 | return i, j, h, w 203 | 204 | # Fallback 205 | # w = min(img.size[0], img.size[1]) 206 | # i = (img.size[1] - w) // 2 207 | # j = (img.size[0] - w) // 2 208 | return 0,0,img.size[0],img.size[1] 209 | 210 | def get_output_size(self,input_size,size_ratio,area_limit): 211 | img_ratio = input_size[0] / input_size[1] 212 | img_area = input_size[0] * input_size[1] 213 | target_area = random.uniform(*size_ratio) * img_area 214 | target_area = min(target_area,area_limit) 215 | w = int(round(math.sqrt(target_area * img_ratio))) 216 | h = int(round(math.sqrt(target_area / img_ratio))) 217 | out_size=(w,h) 218 | return out_size 219 | def __call__(self, imgs): 220 | """ 221 | Args: 222 | img (PIL Image): Image to be cropped and resized. 223 | 224 | Returns: 225 | PIL Image: Randomly cropped and resized image. 226 | """ 227 | i, j, h, w = self.get_params(imgs[0], self.crop_scale, self.crop_ratio) 228 | out_size=self.get_output_size((h,w),self.size_ratio,self.area_limit) 229 | imgs=[F.resized_crop(img, i, j, h, w, out_size, self.interpolation) for img in imgs] 230 | 231 | return imgs 232 | 233 | class GroupRandomRotation(transforms.RandomRotation): 234 | def __init__(self, degrees=[0,90,180,270], resample=False, expand=True, center=None): 235 | if not isinstance(degrees, list): 236 | raise ValueError("it must be a list.") 237 | self.degrees = degrees 238 | self.resample = resample 239 | self.expand = expand 240 | self.center = center 241 | @staticmethod 242 | def get_params(degrees): 243 | """Get parameters for ``rotate`` for a random rotation. 244 | 245 | Returns: 246 | sequence: params to be passed to ``rotate`` for random rotation. 247 | """ 248 | rand = random.randint(0, len(degrees)-1) 249 | angle=degrees[rand] 250 | return angle 251 | def __call__(self, img): 252 | """ 253 | Args: 254 | img (PIL Image): Image to be rotated. 255 | 256 | Returns: 257 | PIL Image: Rotated image. 258 | """ 259 | angle =self.get_params(self.degrees) 260 | if isinstance(img,list): 261 | return [F.rotate(item, angle, self.resample, self.expand, self.center) 262 | for item in img ] 263 | else: 264 | return F.rotate(img, angle, self.resample, self.expand, self.center) 265 | 266 | class GroupRandomHorizontalFlip(transforms.RandomHorizontalFlip): 267 | def __init__(self, p=0.5): 268 | super(GroupRandomHorizontalFlip,self).__init__(p) 269 | 270 | def __call__(self, img): 271 | """ 272 | Args: 273 | img (PIL Image): Image to be flipped. 274 | 275 | Returns: 276 | PIL Image: Randomly flipped image. 277 | """ 278 | rand= random.random() 279 | if isinstance(img,list): 280 | img= [F.hflip(item) if rand< self.p else item for item in img ] 281 | else: 282 | img = F.hflip(img) if rand < self.p else img 283 | return img 284 | 285 | 286 | class GroupRandomVerticalFlip(transforms.RandomVerticalFlip): 287 | def __init__(self, p=0.5): 288 | super(GroupRandomVerticalFlip, self).__init__(p) 289 | 290 | def __call__(self, img): 291 | """ 292 | Args: 293 | img (PIL Image): Image to be flipped. 294 | 295 | Returns: 296 | PIL Image: Randomly flipped image. 297 | """ 298 | rand = random.random() 299 | if isinstance(img, list): 300 | img = [F.vflip(item) if rand < self.p else item for item in img] 301 | else: 302 | img = F.vflip(img) if rand < self.p else img 303 | return img 304 | 305 | 306 | class ComposeJoint(object): 307 | def __init__(self, transforms): 308 | self.transforms = transforms 309 | 310 | def __call__(self, x): 311 | assert isinstance(x, collections.Iterable) 312 | for transform in self.transforms: 313 | x = self._iterate_transforms(transform, x) 314 | 315 | return x 316 | 317 | # def _iterate_transforms(self, transforms, x): 318 | # if isinstance(transforms, collections.Iterable): 319 | # for i, transform in enumerate(transforms): 320 | # x[i] = self._iterate_transforms(transform, x[i]) 321 | # else: 322 | # if transforms is not None: 323 | # x = transforms(x) 324 | # return x 325 | def _iterate_transforms(self, transforms, x): 326 | 327 | if isinstance(transforms, collections.Iterable): 328 | x=[ self._iterate_transforms(transforms[i], x_i) for i, x_i in enumerate(x) ] 329 | else: 330 | if transforms is not None: 331 | x = transforms(x) 332 | return x -------------------------------------------------------------------------------- /utils/data/data_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import shuffle 3 | 4 | import os 5 | from random import shuffle 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | from torchvision import transforms 10 | from utils import utils 11 | from utils.data import my_transforms 12 | from utils.data.gen_defect import DefectiveGenerator 13 | 14 | class DataManager(object): 15 | def __init__(self, dataList,param,shuffle=True): 16 | """ 17 | """ 18 | self.shuffle=shuffle 19 | self.data_list=dataList 20 | self.data_size=len(dataList) 21 | self.sample_dict =self.splitData(self.data_list) 22 | self.data_dir=param["data_dir"] 23 | self.epochs_num=param["epochs_num"] 24 | self.batch_size = param["batch_size"] 25 | self.image_scale=param["image_scale"] 26 | self.image_size =param["image_size"] 27 | self.with_RGB=image_size =param["with_RGB"] 28 | self.image_size = [self.image_size[0] // self.image_scale, self.image_size[1] // self.image_scale] 29 | self.set() 30 | def set(self): 31 | self.number_batch =len(self.data_list)//self.batch_size 32 | self.next_batch=self.get_next() 33 | 34 | def get_next(self): 35 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32,tf.int32, tf.string)) 36 | dataset = dataset.repeat(self.epochs_num) 37 | if self.shuffle: 38 | dataset = dataset.shuffle(self.batch_size*3+200) 39 | dataset = dataset.batch(self.batch_size) 40 | iterator = dataset.make_one_shot_iterator() 41 | out_batch = iterator.get_next() 42 | return out_batch 43 | 44 | def generator(self): 45 | while True: 46 | for index in range(len(self.data_list)): 47 | yield self.get_one_sample(self.data_list[index]) 48 | 49 | def get_one_sample(self,sample): 50 | file_basename_image, file_basename_label, label =sample 51 | image_path = os.path.join(self.data_dir, file_basename_image) 52 | label_path = os.path.join(self.data_dir, file_basename_label) 53 | image = self.read_data(image_path) 54 | label_pixel = self.read_data(label_path) 55 | label_pixel = self.label_preprocess(label_pixel) 56 | if not self.with_RGB: 57 | image = (np.array(image[:, :, np.newaxis])) 58 | label_pixel = (np.array(label_pixel[:, :, np.newaxis])) 59 | image = utils.transform(image) 60 | return image, label_pixel, int(label), file_basename_image 61 | 62 | def __iter__(self): 63 | for index in range(self.number_batch): 64 | next_batch=SESSION.run(self.next_batch) 65 | yield next_batch 66 | 67 | def read_data(self, data_name): 68 | flag=1 if self.with_RGB else 0 69 | img = cv2.imread(data_name, flag) # /255.#read the gray image 70 | img = cv2.resize(img, (int(self.image_size[1]), int(self.image_size[0]))) 71 | return img 72 | 73 | def label_preprocess(self,label): 74 | #label = cv2.resize(label, (int(self.image_size[1]/8), int(self.image_size[0]/8))) 75 | label_pixel=self.ImageBinarization(label) 76 | return label_pixel 77 | 78 | def ImageBinarization(self,img, threshold=1): 79 | img = np.array(img) 80 | image = np.where(img > threshold, 1, 0) 81 | return image 82 | 83 | def splitData(self,data): 84 | """ 85 | 把数据列表按照类别分开 86 | :param data: 87 | :return: 88 | """ 89 | dict={} 90 | for item in data: 91 | key=int(item[2]) 92 | if key not in dict.keys(): 93 | dict[key]=[] 94 | dict[key].append(item) 95 | return dict 96 | 97 | class DataManager_balance(DataManager): 98 | def __init__(self, dataList,param,shuffle=True): 99 | super(DataManager_balance,self).__init__(dataList,param,shuffle) 100 | self.with_transform=param["with_transform"] 101 | transform_train=[ 102 | my_transforms.GroupRandomHorizontalFlip(), 103 | my_transforms.GroupRandomVerticalFlip(), 104 | #transforms.RandomResizedCrop 105 | # transforms.RandomHorizontalFlip(), 106 | # transforms.RandomVerticalFlip(), 107 | # transforms.RandomResizedCrop(size=[], scale=(0.5, 1.0),ratio=()), 108 | # transforms.RandomCrop() 109 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | ] 111 | if "with_rotate" in param.keys() and param["with_rotate"]:transform_train.append(my_transforms.GroupRandomRotation()) 112 | self.transform = {"train":transforms.Compose(transform_train), "val": None} 113 | def set(self): 114 | self.next_batch = self.get_next() 115 | self.number_batch=int(np.floor(len(self.sample_dict[1])))*2//self.batch_size 116 | 117 | def get_next(self): 118 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32, tf.int32, tf.string)) 119 | dataset = dataset.batch(self.batch_size) 120 | iterator = dataset.make_one_shot_iterator() 121 | out_batch = iterator.get_next() 122 | return out_batch 123 | 124 | def generator(self): 125 | step=0 126 | while(True): 127 | for cls, sample_list, in self.sample_dict.items(): 128 | sample_num=len(sample_list) 129 | index=step%sample_num 130 | if index==0 and self.shuffle: 131 | shuffle(sample_list) 132 | yield self.get_one_sample(sample_list[index]) 133 | step+=1 134 | #添加transform 135 | def get_one_sample(self,sample): 136 | file_basename_image, file_basename_label, label =sample 137 | image_path = os.path.join(self.data_dir, file_basename_image) 138 | label_path = os.path.join(self.data_dir, file_basename_label) 139 | image = self.read_data(image_path) 140 | label_pixel = self.read_data(label_path) 141 | if self.with_transform: 142 | image,label_pixel=self.transform_sample(image,label_pixel) 143 | label_pixel = self.label_preprocess(label_pixel) 144 | image = utils.transform(image) 145 | if not self.with_RGB: 146 | image = (np.array(image[:, :, np.newaxis])) 147 | label_pixel = (np.array(label_pixel[:, :, np.newaxis])) 148 | return image, label_pixel, int(label), file_basename_image 149 | 150 | def transform_sample(self,image,label): 151 | image = Image.fromarray( np.uint8(image)) 152 | label = Image.fromarray(np.uint8(label)) 153 | ouput=self.transform["train"]([image,label]) 154 | image=np.array(ouput[0]) 155 | label = np.array(ouput[1]) 156 | return image,label 157 | 158 | class DataManager_normal(DataManager): 159 | def __init__(self, dataList,param,shuffle=True): 160 | super(DataManager_normal,self).__init__(dataList,param,shuffle) 161 | self.transform = {"train": 162 | transforms.Compose([ 163 | transforms.RandomHorizontalFlip(), 164 | transforms.RandomVerticalFlip(), 165 | # transforms.RandomResizedCrop(size=[], scale=(0.5, 1.0),ratio=()), 166 | # transforms.RandomCrop() 167 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 168 | ]), 169 | "val": None} 170 | def set(self): 171 | self.next_batch = self.get_next() 172 | self.number_batch=int(np.floor(len(self.sample_dict[1]))) 173 | 174 | def get_next(self): 175 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32, tf.int32, tf.string)) 176 | dataset = dataset.batch(self.batch_size) 177 | iterator = dataset.make_one_shot_iterator() 178 | out_batch = iterator.get_next() 179 | return out_batch 180 | 181 | def generator(self): 182 | step=0 183 | while(True): 184 | for cls, sample_list, in self.sample_dict.items(): 185 | if cls==0: 186 | sample_num=len(sample_list) 187 | index=step%sample_num 188 | if index==0 and self.shuffle: 189 | shuffle(sample_list) 190 | yield self.get_one_sample(sample_list[index]) 191 | step+=1 192 | #添加transform 193 | def get_one_sample(self,sample): 194 | file_basename_image, file_basename_label, label =sample 195 | image_path = os.path.join(self.data_dir, file_basename_image) 196 | label_path = os.path.join(self.data_dir, file_basename_label) 197 | image = self.read_data(image_path) 198 | 199 | label_pixel = self.read_data(label_path) 200 | label_pixel = self.label_preprocess(label_pixel) 201 | image = (np.array(image[:, :, np.newaxis])) 202 | label_pixel = (np.array(label_pixel[:, :, np.newaxis])) 203 | image,label_pixel=self.transform_sample(image,label_pixel) 204 | image = utils.transform(image) 205 | return image, label_pixel, int(label), file_basename_image 206 | 207 | def transform_sample(self,image,label): 208 | image= np.uint8(image) 209 | label = np.uint8(label) 210 | img=np.concatenate((image,label),2) 211 | img = Image.fromarray(img) 212 | img=self.transform["train"](img) 213 | img=np.array(img) 214 | image=img[:,:,0][:, :, np.newaxis] 215 | label =img[:, :, 1][:, :, np.newaxis] 216 | return image,label 217 | 218 | 219 | class DataManager_faker(DataManager): 220 | def __init__(self, dataList,param,dir_DefectsDir,shuffle=True): 221 | super(DataManager_faker,self).__init__(dataList,param,shuffle) 222 | self.defectGenerator=DefectiveGenerator(dir_DefectsDir,self.image_size,[0,10000]) 223 | def set(self): 224 | self.next_batch = self.get_next() 225 | self.number_batch=len(self.sample_dict[0])//self.batch_size*2 226 | 227 | def get_next(self): 228 | dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32, tf.int32, tf.string)) 229 | dataset = dataset.batch(self.batch_size) 230 | iterator = dataset.make_one_shot_iterator() 231 | out_batch = iterator.get_next() 232 | return out_batch 233 | 234 | def generator(self): 235 | step=0 236 | while(True): 237 | sample_list = self.sample_dict[0] 238 | sample_num = len(sample_list) 239 | index = step % sample_num 240 | if index == 0 and self.shuffle: 241 | shuffle(sample_list) 242 | sample = self.get_one_sample(sample_list[index]) 243 | step += 1 244 | for cls in range(2): 245 | if cls==0: 246 | yield sample 247 | if cls==1: 248 | yield self.draw_one_sample(sample) 249 | 250 | def draw_one_sample(self,sample): 251 | image, label_pixel, label, file_basename_image=sample 252 | image=image.squeeze(2) 253 | image_draw, label_pixel_draw = self.defectGenerator.genDefect(image) 254 | image_draw = (np.array(image_draw[:, :, np.newaxis])) 255 | label_pixel_draw = (np.array(label_pixel_draw[:, :, np.newaxis])) 256 | filename = str(file_basename_image).split(".")[-2]+"_faker."+str(file_basename_image).split(".")[-1] 257 | return image_draw, label_pixel_draw, int(1), filename 258 | 259 | 260 | class DataManager_class(DataManager): 261 | def __init__(self, dataList,param,dir_DefectsDir,shuffle=True): 262 | super(DataManager_class,self).__init__(dataList,param,shuffle) 263 | self.defectGenerator=DefectiveGenerator(dir_DefectsDir,self.image_size,[0,10000]) 264 | def set(self): 265 | # self.number_batch_negative =int(np.floor(len(self.sample_dict[0])/self.batch_size)) 266 | # self.number_batch_positive =int(np.floor(len(self.sample_dict[1])/self.batch_size)) 267 | self.number_batch = int(np.floor(len(self.sample_dict[1]) / self.batch_size)) 268 | self.next_batch_positive=self.get_next_positive() 269 | self.next_batch_negative=self.get_next_negative() 270 | def get_next_positive(self): 271 | dataset = tf.data.Dataset.from_generator(self.generator_positive, (tf.float32, tf.int32,tf.int32, tf.string)) 272 | dataset = dataset.repeat(self.epochs_num*3) 273 | if self.shuffle: 274 | dataset = dataset.shuffle(self.batch_size*3+200) 275 | dataset = dataset.batch(self.batch_size) 276 | iterator = dataset.make_one_shot_iterator() 277 | out_batch = iterator.get_next() 278 | return out_batch 279 | def generator_negative(self): 280 | while True: 281 | data_list=self.sample_dict[0] 282 | for index in range(len(data_list)): 283 | yield self.get_one_sample(data_list[index]) 284 | 285 | def get_next_negative(self): 286 | dataset = tf.data.Dataset.from_generator(self.generator_negative, (tf.float32, tf.int32,tf.int32, tf.string)) 287 | dataset = dataset.repeat(self.epochs_num*3) 288 | if self.shuffle: 289 | dataset = dataset.shuffle(self.batch_size*3+200) 290 | dataset = dataset.batch(self.batch_size) 291 | iterator = dataset.make_one_shot_iterator() 292 | out_batch = iterator.get_next() 293 | return out_batch 294 | 295 | def generator_positive(self): 296 | while True: 297 | data_list=self.sample_dict[0] 298 | for index in range(len(data_list)): 299 | yield self.draw_one_sample(self.get_one_sample(data_list[index])) 300 | 301 | def draw_one_sample(self,sample): 302 | image, label_pixel, label, file_basename_image=sample 303 | image=image.squeeze(2) 304 | image_draw, label_pixel_draw = self.defectGenerator.genDefect(image) 305 | image_draw = (np.array(image_draw[:, :, np.newaxis])) 306 | label_pixel_draw = (np.array(label_pixel_draw[:, :, np.newaxis])) 307 | return image_draw, label_pixel_draw, int(1), file_basename_image 308 | 309 | if __name__=="__main__": 310 | 311 | kolektorSDD_Patch_config="../config/kolektorSDD_config1" 312 | --------------------------------------------------------------------------------