├── .gitignore ├── LICENSE ├── README.md ├── dataloader ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── dataloader.cpython-37.pyc └── dataloader.py ├── docs ├── F3结果.png ├── FaultSeg3D.png └── 合成数据结果.png ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── faultseg3d.cpython-37.pyc │ ├── unet_3d_PyramidPool.cpython-37.pyc │ ├── unet_3d_PyramidPool_Half.cpython-37.pyc │ ├── unet_3d_longpool.cpython-37.pyc │ └── unet_3d_longpool_T.cpython-37.pyc └── faultseg3d.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── dice_loss.cpython-37.pyc ├── test.cpython-37.pyc ├── test_new.cpython-37.pyc ├── tools.cpython-37.pyc └── train.cpython-37.pyc ├── dice_loss.py ├── test.py ├── tools.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data_*/ 2 | EXP/ 3 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ifjmww 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaultSeg3D_pytorch 2 | FaultSeg3D的pytorch版本(个人复现结果,如与FaultSeg3D原版代码有出入,以原版代码为主) 3 | 4 | ### [原文链接](http://cig.ustc.edu.cn/_upload/tpl/05/cd/1485/template1485/papers/wu2019FaultSeg3D.pdf) 5 | 6 | ![FaultSeg3D网络结构图](/docs/FaultSeg3D.png "FaultSeg3D") 7 | 8 | ## 运行 9 | ### 配置环境 10 | * [requirements.txt](./requirements.txt) 11 | #### Train(默认参数设置与文中相同) 12 | ```angular2html 13 | python main.py --mode train --exp [experiment_name] --train_path [train_dataset_path] --valid_path [valid_dataset_path] 14 | ``` 15 | #### Valid_Only(需要有预训练模型) 16 | ```angular2html 17 | python main.py --mode valid_only --exp [experiment_name] --valid_path [valid_dataset_path] 18 | ``` 19 | #### Prediction(需要有预训练模型) 20 | ```angular2html 21 | python main.py --mode pred --exp [experiment_name] --pretrained_model_name [FaultSeg3D_BEST.pth] --pred_data_name [pretrained_model_name] 22 | ``` 23 | 24 | ### 训练集、验证集、预测集 25 | * 上述数据均已做预处理: (1) dat->npy; (2) 正则化(减均值除标准差); 26 | * 训练集及验证集-200个数据--[百度网盘链接](https://pan.baidu.com/s/10o848E2vMmjmi21xZBFRiw?pwd=i4mo)-提取码:i4mo 27 | * 训练集及验证集-800个数据(数据增强)--[百度网盘链接](https://pan.baidu.com/s/1PzsmRt9drnZI9J5GFOk9rw?pwd=zwqf)-提取码:zwqf 28 | * 预测集-f3数据--[百度网盘链接](https://pan.baidu.com/s/1iBnW94Yn2U0GQQF3-3pXOA?pwd=0b2j)-提取码:0b2j 29 | 30 | 31 | ## 实验结果 32 | ### 合成地震数据断层分割结果 33 | ![合成地震数据断层分割结果](/docs/合成数据结果.png "合成地震数据断层分割结果") 34 | ### 荷兰F3真实地震数据断层分割结果 35 | ![荷兰F3真实地震数据断层分割结果](/docs/F3结果.png "荷兰F3真实地震数据断层分割结果") 36 | 37 | ## 归属声明 / Attribution Statement : 38 | 39 | 如果您在您的项目中使用或参考了本项目(FaultSeg3D_pytorch)的代码,我们要求并感激您在项目文档或代码中包含以下归属声明: 40 | ```commandline 41 | 本项目使用了Ifjmww在GitHub上的FaultSeg3D_pytorch项目的代码,特此致谢。原项目链接:https://github.com/Ifjmww/FaultSeg3D_pytorch 42 | ``` 43 | 我们欣赏并鼓励开源社区成员之间的相互尊重和学习,感谢您的合作与支持。 44 | 45 |   46 | 47 | If you use or reference the code from this project (FaultSeg3D_pytorch) in your project, we require and appreciate an attribution statement in your project documentation or code as follows: 48 | ```commandline 49 | Parts of this code are based on modifications of Ifjmww's FaultSeg3D_pytorch. Original project link: https://github.com/Ifjmww/FaultSeg3D_pytorch 50 | ``` 51 | We value and encourage mutual respect and learning among members of the open source community. Thank you for your cooperation and support. -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/dataloader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/dataloader/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | # 数据加载部分,transform在这里统一设置为None,数据增强在训练前单独完成,训练、验证、预测的数据都是经过(x-mean)/std正则化后的; 2 | import numpy as np 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FaultDataset(Dataset): 9 | """ 10 | Load_Dataset 11 | """ 12 | 13 | def __init__(self, path, mode='train', transform=None): 14 | self.path = path 15 | self.transform = transform 16 | self.mode = mode 17 | 18 | self.image_list, self.label_list = self.load_data() 19 | 20 | def __getitem__(self, index): 21 | image = np.load(self.image_list[index]) 22 | if len(self.label_list) == 0: 23 | label = np.zeros(image.shape) 24 | else: 25 | label = np.load(self.label_list[index]) 26 | 27 | img = image 28 | if len(img.shape) == 3: 29 | img = img.reshape((1, img.shape[0], img.shape[1], img.shape[2])) 30 | 31 | x = torch.from_numpy(img) 32 | y = torch.from_numpy(label) 33 | 34 | data = {'x': x.float(), 'y': y.float()} 35 | 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.image_list) 40 | 41 | def load_data(self): 42 | """ 43 | 44 | :return: 45 | """ 46 | img_list = [] 47 | label_list = [] 48 | label_pred_list = [] 49 | img_path = os.path.join(self.path, 'x/') 50 | label_path = os.path.join(self.path, 'y/') 51 | for item in os.listdir(img_path): 52 | img_list.append(os.path.join(img_path, item)) 53 | # 由于x和y的文件名一样,所以用一步加载进来 54 | label_list.append(os.path.join(label_path, item)) 55 | if self.mode != 'pred': 56 | return img_list, label_list 57 | else: 58 | return img_list, label_pred_list 59 | -------------------------------------------------------------------------------- /docs/F3结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/docs/F3结果.png -------------------------------------------------------------------------------- /docs/FaultSeg3D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/docs/FaultSeg3D.png -------------------------------------------------------------------------------- /docs/合成数据结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/docs/合成数据结果.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Fault Segmentation Based on Pytorch 2 | import os 3 | import argparse 4 | 5 | from utils.train import train, valid 6 | from utils.test import pred_Gaussian 7 | from utils.tools import save_args_info 8 | 9 | 10 | def add_args(): 11 | parser = argparse.ArgumentParser(description="FaultSeg3D_pytorch") 12 | 13 | parser.add_argument("--exp", default="test", type=str, help="Name of each run") 14 | parser.add_argument("--device", default='cuda:0', type=str, help="GPU id for training") 15 | parser.add_argument("--mode", default='train', choices=['train', 'valid_only', 'pred'], type=str, help='network run mode') 16 | parser.add_argument("--batch_size", default=2, type=int, help="number of batch size") 17 | parser.add_argument("--batch_size_not_train", default=1, type=int, help="number of batch size when not training") 18 | parser.add_argument("--epochs", default=25, type=int, help="max number of training epochs") 19 | parser.add_argument("--train_path", default="/data/train/", type=str, help="dataset directory") 20 | parser.add_argument("--valid_path", default="/data/valid/", type=str, help="dataset directory") 21 | parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") 22 | parser.add_argument("--out_channels", default=2, type=int, help="number of output channels") 23 | parser.add_argument("--loss_func", default="cross_with_weight", choices=['dice', 'cross_with_weight'], type=str, help="choose loss function") 24 | parser.add_argument("--val_every", default=10, type=int, help="validation frequency") 25 | parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") 26 | parser.add_argument("--workers", default=0, type=int, help="number of workers") 27 | parser.add_argument("--pretrained_model_name", default="FaultSeg3D_BEST.pth", type=str, help="pretrained model name") 28 | parser.add_argument("--pred_data_name", default="f3", choices=['f3', 'kerry'], type=str, help="pretrained data name") 29 | parser.add_argument('--overlap', default=0.25, type=int, help='pred‘s overlap') 30 | parser.add_argument('--threshold', default=0.5, type=float, help='Classification threshold') 31 | parser.add_argument('--sigma', default=0.0, type=float, help='Gaussian filter sigma') 32 | 33 | args = parser.parse_args() 34 | 35 | print() 36 | print(">>>============= args ====================<<<") 37 | print() 38 | print(args) # print command line args 39 | print() 40 | print(">>>=======================================<<<") 41 | 42 | return args 43 | 44 | 45 | def main(args): 46 | if args.mode == 'train': 47 | train(args) 48 | elif args.mode == 'valid_only': 49 | valid(args) 50 | elif args.mode == 'pred': 51 | pred_Gaussian(args) 52 | else: 53 | raise ValueError("Only ['train', 'valid_only', 'pred'] mode is supported.") 54 | save_args_info(args) 55 | 56 | 57 | if __name__ == "__main__": 58 | args = add_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/faultseg3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/faultseg3d.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d_PyramidPool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/unet_3d_PyramidPool.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d_PyramidPool_Half.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/unet_3d_PyramidPool_Half.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d_longpool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/unet_3d_longpool.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d_longpool_T.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/models/__pycache__/unet_3d_longpool_T.cpython-37.pyc -------------------------------------------------------------------------------- /models/faultseg3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsummary import summary 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DoubleConv(nn.Module): 8 | def __init__(self, in_channels, out_channels, mid_channels=None): 9 | super().__init__() 10 | if not mid_channels: 11 | mid_channels = out_channels 12 | self.double_conv = nn.Sequential( 13 | nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), 14 | nn.BatchNorm3d(mid_channels), 15 | nn.ReLU(inplace=True), 16 | nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm3d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | super().__init__() 28 | self.maxpool_conv = nn.Sequential( 29 | nn.MaxPool3d(2), 30 | DoubleConv(in_channels, out_channels) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.maxpool_conv(x) 35 | 36 | 37 | class Up(nn.Module): 38 | def __init__(self, in_channels, out_channels): 39 | super().__init__() 40 | 41 | self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=True) 42 | self.conv = DoubleConv(in_channels, out_channels) 43 | 44 | def forward(self, x1, x2): 45 | x1 = self.up(x1) 46 | diffZ = x2.size()[2] - x1.size()[2] 47 | diffY = x2.size()[3] - x1.size()[3] 48 | diffX = x2.size()[4] - x1.size()[4] 49 | x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, 50 | diffY // 2, diffY - diffY // 2, 51 | diffZ // 2, diffZ - diffZ // 2]) 52 | x = torch.cat([x2, x1], dim=1) 53 | return self.conv(x) 54 | 55 | 56 | class OutConv(nn.Module): 57 | def __init__(self, in_channels, out_channels): 58 | super(OutConv, self).__init__() 59 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) 60 | 61 | def forward(self, x): 62 | return self.conv(x) 63 | 64 | 65 | class FaultSeg3D(nn.Module): 66 | def __init__(self, n_channels, n_classes): 67 | super(FaultSeg3D, self).__init__() 68 | self.n_channels = n_channels 69 | self.n_classes = n_classes 70 | 71 | self.inc = DoubleConv(n_channels, 16) 72 | self.down1 = Down(16, 32) 73 | self.down2 = Down(32, 64) 74 | self.down3 = Down(64, 128) 75 | 76 | self.up2 = Up(192, 64) 77 | self.up3 = Up(96, 32) 78 | self.up4 = Up(48, 16) 79 | self.outc = OutConv(16, n_classes) 80 | self.softmax = nn.Softmax(dim=1) 81 | 82 | def forward(self, x): 83 | # encoder部分 84 | x1 = self.inc(x) 85 | x2 = self.down1(x1) 86 | x3 = self.down2(x2) 87 | x4 = self.down3(x3) 88 | 89 | # decoder部分 90 | x = self.up2(x4, x3) 91 | x = self.up3(x, x2) 92 | x = self.up4(x, x1) 93 | logits = self.outc(x) 94 | outputs = self.softmax(logits) 95 | return outputs 96 | 97 | 98 | if __name__ == '__main__': 99 | # 查看网络参数量 100 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 101 | net = FaultSeg3D(1, 2).to(device) 102 | summary(net, input_size=(1, 128, 128, 128)) 103 | 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.2 2 | numpy==1.21.2 3 | openpyxl==3.0.10 4 | pandas==1.3.5 5 | scikit_learn==0.24.2 6 | scipy==1.7.1 7 | torch==1.9.1 8 | torchsummary==1.5.1 9 | tqdm==4.62.2 10 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dice_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/dice_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/test_new.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ifjmww/FaultSeg3D_pytorch/c31d612e48bef7e97aa1fdb7af3715ce28b11d03/utils/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /utils/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, epsilon=1e-5): 7 | super(DiceLoss, self).__init__() 8 | self.epsilon = epsilon 9 | 10 | def forward(self, pred, target): 11 | # 将预测结果和目标标签转换为二进制形式 12 | pred = pred[:, 1, :, :] # 取第二个通道的预测结果 13 | target = target.float() 14 | 15 | # 计算Dice系数的分子和分母 16 | intersection = (pred * target).sum() 17 | dice_coefficient = (2. * intersection + self.epsilon) / (pred.sum() + target.sum() + self.epsilon) 18 | 19 | # 计算Dice Loss 20 | loss = 1 - dice_coefficient 21 | return loss 22 | -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.tools import save_pred_picture, load_pred_data 3 | from models.faultseg3d import FaultSeg3D 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | from scipy.ndimage import gaussian_filter 8 | 9 | 10 | def sliding_window_prediction(input_data, block_size, overlap, model, args): 11 | # 输入数据的尺寸 12 | input_shape = input_data.shape 13 | # 切块大小和步长 14 | block_shape = np.array(block_size) 15 | step = (1 - overlap) * block_shape 16 | 17 | # 计算需要切割成的块数 18 | num_blocks = np.ceil(input_shape / step).astype(int) 19 | 20 | # 初始化预测结果和权重矩阵 21 | sliding_shape = np.array(((num_blocks[0] - 1) * step[0] + block_shape[0], 22 | (num_blocks[1] - 1) * step[1] + block_shape[1], 23 | (num_blocks[2] - 1) * step[2] + block_shape[2])).astype(int) 24 | 25 | sliding_data = np.zeros(sliding_shape) 26 | 27 | sliding_data[0:input_shape[0], 0:input_shape[1], 0:input_shape[2]] = input_data 28 | 29 | output = np.zeros(sliding_shape) 30 | weight_map = np.zeros(sliding_shape) 31 | 32 | total_iterations = num_blocks[0] * num_blocks[1] * num_blocks[2] 33 | progress_bar = tqdm(total=total_iterations, desc='[Pred]', unit='it') 34 | 35 | # 滑动窗口切块和预测 36 | for i in range(num_blocks[0]): 37 | for j in range(num_blocks[1]): 38 | for k in range(num_blocks[2]): 39 | # 计算当前块的起始和结束位置 40 | start = (step * np.array([i, j, k])).astype(int) 41 | end = (start + block_shape).astype(int) 42 | # 裁剪当前块的数据 43 | block = sliding_data[start[0]:end[0], start[1]:end[1], start[2]:end[2]] 44 | block = block.reshape((1, 1, block.shape[0], block.shape[1], block.shape[2])) 45 | 46 | block_mean = np.mean(block) 47 | block_std = np.std(block) 48 | block_normal = (block - block_mean) / block_std 49 | 50 | input_block = torch.from_numpy(block_normal).to(args.device).float() 51 | 52 | block_prediction = model(input_block) 53 | 54 | block_prediction = block_prediction[:, 1, :, :, :] 55 | # block_prediction = block_prediction.argmax(axis=1) 56 | block_prediction = block_prediction.detach().cpu().numpy() 57 | block_prediction = np.squeeze(block_prediction) 58 | 59 | # 计算当前块的权重矩阵 60 | weight_map[start[0]:end[0], start[1]:end[1], start[2]:end[2]] += 1 61 | 62 | # 将当前块的预测结果叠加到输出中 63 | output[start[0]:end[0], start[1]:end[1], start[2]:end[2]] += block_prediction 64 | progress_bar.update(1) 65 | progress_bar.close() 66 | 67 | # 根据权重矩阵对预测结果进行归一化 68 | output /= weight_map 69 | 70 | # 使用高斯滤波器对边界进行平滑 71 | smoothed_output = gaussian_filter(output, sigma=args.sigma) 72 | 73 | return smoothed_output[0:input_shape[0], 0:input_shape[1], 0:input_shape[2]] 74 | 75 | 76 | 77 | def pred_Gaussian(args): 78 | print("============================== pred_Gaussian ==============================") 79 | input_data = load_pred_data(args) # 输入数据 80 | block_size = (128, 128, 128) # 切块大小 81 | overlap = args.overlap # 重叠率 82 | 83 | # 使用训练好的模型进行预测 84 | model = FaultSeg3D(args.in_channels, args.out_channels).to(args.device) 85 | model_path = './EXP/' + args.exp + '/models/' + args.pretrained_model_name 86 | model.load_state_dict(torch.load(model_path)) 87 | print("Loaded model from disk") 88 | model.eval() 89 | # 调用滑动窗口预测函数 90 | output_data = sliding_window_prediction(input_data, block_size, overlap, model, args) 91 | 92 | threshold = args.threshold 93 | output_data[output_data > threshold] = 1 94 | output_data[output_data <= threshold] = 0 95 | 96 | print("---Start Save results ······") 97 | save_path = './EXP/' + args.exp + '/results/pred/' + args.pred_data_name + '/' 98 | if not os.path.exists(save_path + '/numpy/'): 99 | os.makedirs(save_path + '/numpy/') 100 | if not os.path.exists(save_path + '/picture/'): 101 | os.makedirs(save_path + '/picture/') 102 | np.save(save_path + '/numpy/' + args.pred_data_name + '.npy', output_data) 103 | 104 | save_pred_picture(input_data, output_data, save_path + '/picture/', args.pred_data_name) 105 | print("Finish!!!") 106 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from dataloader.dataloader import FaultDataset 4 | from torch.utils.data import DataLoader 5 | import torch.nn as nn 6 | from sklearn.metrics import confusion_matrix 7 | import numpy as np 8 | import pandas as pd 9 | from matplotlib import pyplot as plt 10 | 11 | from utils.dice_loss import DiceLoss 12 | 13 | 14 | def save_args_info(args): 15 | # save args to config.txt 16 | argsDict = args.__dict__ 17 | result_path = './EXP/' + '/' + args.exp + '/' 18 | 19 | if not os.path.exists(result_path): 20 | os.makedirs(result_path) 21 | if args.mode == 'train': 22 | with open(result_path + 'config.txt', 'w') as f: 23 | f.writelines('------------------ start ------------------' + '\n') 24 | for eachArg, value in argsDict.items(): 25 | f.writelines(eachArg + ' : ' + str(value) + '\n') 26 | f.writelines('------------------- end -------------------') 27 | elif args.mode == 'valid_only': 28 | with open(result_path + 'config_valid_only.txt', 'w') as f: 29 | f.writelines('------------------ start ------------------' + '\n') 30 | for eachArg, value in argsDict.items(): 31 | f.writelines(eachArg + ' : ' + str(value) + '\n') 32 | f.writelines('------------------- end -------------------') 33 | elif args.mode == 'pred': 34 | with open(result_path + 'config_pred.txt', 'w') as f: 35 | f.writelines('------------------ start ------------------' + '\n') 36 | for eachArg, value in argsDict.items(): 37 | f.writelines(eachArg + ' : ' + str(value) + '\n') 38 | f.writelines('------------------- end -------------------') 39 | 40 | 41 | def load_data(args): 42 | # args.mode=['train', 'valid_only', 'pred'] 43 | if args.mode == 'train': 44 | # 训练时的训练集 45 | train_dataset = FaultDataset(args.train_path, args.mode, transform=None) 46 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) 47 | 48 | valid_dataset = FaultDataset(args.valid_path, args.mode, transform=None) 49 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size_not_train, shuffle=True, num_workers=args.workers, drop_last=True) 50 | 51 | print("--- create train dataloader ---") 52 | print(len(train_dataset), ", train dataset created") 53 | print(len(train_dataloader), ", train dataloader created") 54 | 55 | print("--- create valid dataloader ---") 56 | print(len(valid_dataset), ", valid dataset created") 57 | print(len(valid_dataloader), ", valid dataloaders created") 58 | 59 | return train_dataloader, valid_dataloader 60 | 61 | elif args.mode == 'valid_only': 62 | dataset = FaultDataset(args.valid_path, args.mode, transform=None) 63 | dataloader = DataLoader(dataset, batch_size=args.batch_size_not_train, shuffle=True, num_workers=args.workers, drop_last=True) 64 | 65 | print("--- create valid dataloader ---") 66 | print(len(dataset), ", valid dataset created") 67 | print(len(dataloader), ", valid dataloaders created") 68 | 69 | return dataloader 70 | 71 | else: # args.mode=='pred' 72 | dataset = FaultDataset(args.pred_path, args.mode, transform=None) 73 | dataloader = DataLoader(dataset, batch_size=args.batch_size_not_train, shuffle=False, num_workers=args.workers, drop_last=True) 74 | print("--- create prediction dataloader ---") 75 | print(len(dataset), ", prediction dataset created") 76 | print(len(dataloader), ", prediction dataloaders created") 77 | return dataloader 78 | 79 | 80 | def compute_loss(outputs, labels, args): 81 | if args.loss_func == 'dice': 82 | criterion = DiceLoss().to(args.device) 83 | loss = criterion(outputs, labels) 84 | 85 | return loss 86 | 87 | elif args.loss_func == 'cross_with_weight': 88 | neg = (1 - labels).sum() # 算有多少个0 89 | pos = labels.sum() # 算有多少个1 90 | beta = neg / (neg + pos) 91 | 92 | weight = torch.tensor([1 - beta, beta]).to(args.device) 93 | 94 | loss = nn.CrossEntropyLoss(weight=weight, reduction='mean')(outputs, labels.long()) 95 | 96 | return loss 97 | else: 98 | raise ValueError("Only ['DiceLoss', 'CrossEntropyLoss'] loss is supported.") 99 | 100 | 101 | def con_matrix(outputs, labels, args): 102 | y_pred = outputs.detach().cpu().numpy() 103 | y_true = labels.detach().cpu().numpy() 104 | 105 | y_pred = y_pred.argmax(axis=1).flatten() 106 | y_true = y_true.flatten() 107 | 108 | num_class = args.out_channels 109 | current = confusion_matrix(y_true, y_pred, labels=range(num_class)) # confusion_matrix混淆矩阵,计算把xxx预测成xxx的次数 110 | 111 | # compute mean iou 112 | intersection = np.diag(current) 113 | # 一维数组的形式返回混淆矩阵的对角线元素 114 | ground_truth_set = current.sum(axis=1) 115 | # 按行求和 116 | predicted_set = current.sum(axis=0) 117 | # 按列求和 118 | union = ground_truth_set + predicted_set - intersection + 1e-7 119 | IoU = intersection / union.astype(np.float32) 120 | union_dice = ground_truth_set + predicted_set + 1e-7 121 | DICE = 2 * intersection / union_dice.astype(np.float32) 122 | 123 | return np.mean(IoU), np.mean(DICE) 124 | 125 | 126 | def save_train_info(args, train_RESULT, val_RESULT): 127 | if not os.path.exists('./EXP/' + args.exp + '/results/train/'): 128 | os.makedirs('./EXP/' + args.exp + '/results/train/') 129 | 130 | data_df = pd.DataFrame(train_RESULT) 131 | data_df.columns = ['train_loss', 'train_iou', 'train_dice'] 132 | data_df.index = np.arange(0, args.epochs, 1) 133 | writer = pd.ExcelWriter('./EXP/' + args.exp + '/results/train/train_result.xlsx') 134 | data_df.to_excel(writer, 'page_1', float_format='%.5f') 135 | writer.save() 136 | writer.close() 137 | 138 | data_df_val = pd.DataFrame(val_RESULT) 139 | data_df_val.columns = ['val_loss', 'val_iou', 'val_dice'] 140 | data_df_val.index = np.arange(0, args.epochs, 1) 141 | writer_val = pd.ExcelWriter('./EXP/' + args.exp + '/results/train/val_result.xlsx') 142 | data_df_val.to_excel(writer_val, 'page_1', float_format='%.5f') 143 | writer_val.save() 144 | 145 | 146 | def save_result(args, segs, inputs, gts, val_loss, val_iou, val_dice): 147 | result_path = './EXP/' + args.exp + '/results/valid/' 148 | if not os.path.exists(result_path): 149 | os.makedirs(result_path) 150 | 151 | with open(result_path + "valid_final_result.txt", 'a+') as f: 152 | f.write('valid loss:\t' + str(val_loss) + '\n') 153 | f.write('valid iou:\t' + str(val_iou) + '\n') 154 | f.write('valid dice:\t' + str(val_dice) + '\n') 155 | 156 | if not os.path.exists(result_path + '/numpy/'): 157 | os.makedirs(result_path + '/numpy/') 158 | if not os.path.exists(result_path + '/picture/'): 159 | os.makedirs(result_path + '/picture/') 160 | 161 | for i in range(len(inputs)): 162 | 163 | seg = segs[i].argmax(axis=1) 164 | img = inputs[i] 165 | gt = gts[i] 166 | seg = np.squeeze(seg) 167 | img = np.squeeze(img) 168 | gt = np.squeeze(gt) 169 | # save output 170 | np.save(result_path + '/numpy/' + str(i) + '_seg.npy', seg) 171 | np.save(result_path + '/numpy/' + str(i) + '_img.npy', img) 172 | np.save(result_path + '/numpy/' + str(i) + '_gt.npy', gt) 173 | # save picture 174 | 175 | index = np.arange(0, 128, 50) 176 | for idx in index: 177 | # dim 0 178 | plt.subplot(1, 3, 1) 179 | plt.imshow(img[idx, :, :]) 180 | plt.axis('off') 181 | plt.title('Image') 182 | 183 | plt.subplot(1, 3, 2) 184 | plt.imshow(gt[idx, :, :]) 185 | plt.axis('off') 186 | plt.title('Ground Truth') 187 | 188 | plt.subplot(1, 3, 3) 189 | plt.imshow(seg[idx, :, :]) 190 | plt.axis('off') 191 | plt.title('Segmentation') 192 | 193 | plt.savefig(result_path + '/picture/No_' + str(i) + '_idx_' + str(idx) + '_dim_0.png') 194 | plt.close() 195 | # dim 1 196 | plt.subplot(1, 3, 1) 197 | plt.imshow(img[:, idx, :]) 198 | plt.axis('off') 199 | plt.title('Image') 200 | 201 | plt.subplot(1, 3, 2) 202 | plt.imshow(gt[:, idx, :]) 203 | plt.axis('off') 204 | plt.title('Ground Truth') 205 | 206 | plt.subplot(1, 3, 3) 207 | plt.imshow(seg[:, idx, :]) 208 | plt.axis('off') 209 | plt.title('Segmentation') 210 | 211 | plt.savefig(result_path + '/picture/No_' + str(i) + '_idx_' + str(idx) + '_dim_1.png') 212 | plt.close() 213 | # dim 2 214 | plt.subplot(1, 3, 1) 215 | plt.imshow(img[:, :, idx]) 216 | plt.axis('off') 217 | plt.title('Image') 218 | 219 | plt.subplot(1, 3, 2) 220 | plt.imshow(gt[:, :, idx]) 221 | plt.axis('off') 222 | plt.title('Ground Truth') 223 | 224 | plt.subplot(1, 3, 3) 225 | plt.imshow(seg[:, :, idx]) 226 | plt.axis('off') 227 | plt.title('Segmentation') 228 | 229 | plt.savefig(result_path + '/picture/No_' + str(i) + '_idx_' + str(idx) + '_dim_2.png') 230 | plt.close() 231 | 232 | 233 | def load_pred_data(args): 234 | if args.pred_data_name == 'f3': 235 | print("Data use f3.") 236 | data = np.load('f3_data_path') 237 | return data 238 | elif args.pred_data_name == 'kerry': 239 | print("Data use kerry.") 240 | data = np.load('kerry_data_path') 241 | return data 242 | else: 243 | raise ValueError("Only ['f3', 'kerry'] mode is supported.") 244 | 245 | 246 | def save_pred_picture(gx, gy, save_path, pred_data_name): 247 | k1, k2, k3 = 80, 80, 80 248 | gx1 = gx[k1, :, :] 249 | gy1 = gy[k1, :, :] 250 | gx2 = gx[:, k2, :] 251 | gy2 = gy[:, k2, :] 252 | gx3 = gx[:, :, k3] 253 | gy3 = gy[:, :, k3] 254 | 255 | # xline slice 256 | plt.subplot(1, 2, 1) 257 | plt.imshow(gx1, cmap='gray') 258 | 259 | plt.subplot(1, 2, 2) 260 | plt.imshow(gy1, cmap='gray') 261 | 262 | plt.savefig(save_path + pred_data_name + '_dim_0.png', dpi=600) 263 | 264 | # inline slice 265 | plt.subplot(1, 2, 1) 266 | plt.imshow(gx2, cmap='gray') 267 | 268 | plt.subplot(1, 2, 2) 269 | plt.imshow(gy2, cmap='gray') 270 | 271 | plt.savefig(save_path + pred_data_name + '_dim_1.png', dpi=600) 272 | 273 | # time slice 274 | plt.subplot(1, 2, 1) 275 | plt.imshow(gx3, cmap='gray') 276 | 277 | plt.subplot(1, 2, 2) 278 | plt.imshow(gy3, cmap='gray') 279 | 280 | plt.savefig(save_path + pred_data_name + '_dim_2.png', dpi=600) 281 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from utils.tools import load_data, compute_loss, con_matrix, save_train_info, save_result 5 | import torch.optim as optim 6 | from models.faultseg3d import FaultSeg3D 7 | import numpy as np 8 | 9 | 10 | def train(args): 11 | # set device 12 | device = torch.device(args.device) 13 | print("---") 14 | print('Device is :', device) 15 | # Load data 16 | print("---") 17 | print("Loading data ... ") 18 | train_loader, val_loader = load_data(args) 19 | print('Create model...') 20 | model = FaultSeg3D(args.in_channels, args.out_channels).to(args.device) 21 | # Initialize optimizer 22 | print("---") 23 | print("Define optimizer ... ") 24 | 25 | optimizer = optim.Adam(model.parameters(), lr=args.optim_lr) 26 | 27 | # Set model save path ./EXP//models/ 28 | model_path = './EXP/' + args.exp + '/models/' 29 | print("---") 30 | print("The model is saved in : ", model_path) 31 | 32 | if not os.path.exists(model_path): 33 | os.makedirs(model_path) 34 | 35 | # start training 36 | print("---") 37 | print("Start training ... ") 38 | 39 | train_RESULT = [] 40 | val_RESULT = [] 41 | 42 | best_iou = 0.0 43 | 44 | for epoch in range(args.epochs): 45 | 46 | model.train() 47 | # 训练模式 48 | train_loss = 0.0 49 | train_iou = 0.0 50 | train_dice = 0.0 51 | 52 | for step, data in enumerate(tqdm(train_loader, desc='[Train] Epoch' + str(epoch + 1) + '/' + str(args.epochs))): 53 | inputs, labels = data['x'].to('cuda'), data['y'].to('cuda') 54 | 55 | optimizer.zero_grad() 56 | 57 | outputs = model(inputs) 58 | loss = compute_loss(outputs, labels, args) 59 | iou, dice = con_matrix(outputs, labels, args) 60 | 61 | loss.backward() 62 | optimizer.step() 63 | 64 | train_loss += loss.item() 65 | train_iou += iou 66 | train_dice += dice 67 | 68 | model.eval() 69 | val_loss = 0.0 70 | val_iou = 0.0 71 | val_dice = 0.0 72 | 73 | with torch.no_grad(): 74 | for step, data in enumerate(tqdm(val_loader, desc='[VALID] Valid ')): 75 | inputs = data['x'].to('cuda') 76 | labels = data['y'].to('cuda') 77 | outputs = model(inputs) 78 | loss = compute_loss(outputs, labels, args) 79 | iou, dice = con_matrix(outputs, labels, args) 80 | 81 | val_loss += loss.item() 82 | val_iou += iou 83 | val_dice += dice 84 | print( 85 | " train loss: {:.4f}".format(train_loss / len(train_loader)), 86 | " train iou: {:.4f}".format(train_iou / len(train_loader)), 87 | " train dice:{:.4f}".format(train_dice / len(train_loader)), 88 | " val loss: {:.4f}".format(val_loss / len(val_loader)), 89 | " val iou: {:.4f}".format(val_iou / len(val_loader)), 90 | " val dice:{:.4f}".format(val_dice / len(val_loader)) 91 | ) 92 | 93 | train_result = np.append(train_loss / len(train_loader), [train_iou / len(train_loader), train_dice / len(train_loader)]) 94 | train_RESULT.append(train_result) 95 | 96 | val_result = np.append(val_loss / len(val_loader), [val_iou / len(val_loader), val_dice / len(val_loader)]) 97 | val_RESULT.append(val_result) 98 | 99 | if (val_iou / len(val_loader)) > best_iou: 100 | print("new best ({:.6f} --> {:.6f}). ".format(best_iou, val_iou / len(val_loader))) 101 | best_iou = val_iou / len(val_loader) 102 | best_model_name = 'FaultSeg3D_BEST.pth'.format(epoch + 1, val_iou / len(val_loader)) 103 | torch.save(model.state_dict(), model_path + best_model_name) 104 | 105 | if (epoch + 1) % args.val_every == 0: 106 | model_name = 'FaultSeg3D_epoch_{}_iou_{:.4f}_CP.pth'.format(epoch + 1, val_iou / len(val_loader)) # CP means checkpoints 107 | torch.save(model.state_dict(), model_path + model_name) 108 | 109 | # Save training information 110 | 111 | print("---") 112 | print("Save training information ... ") 113 | save_train_info(args, train_RESULT, val_RESULT) 114 | print("---") 115 | print("Train Finish ! ") 116 | print("---") 117 | print("---") 118 | print("Last validation ... ") 119 | valid(args, val_loader) 120 | 121 | return 0 122 | 123 | 124 | def valid(args, val_loader=None): 125 | 126 | device = torch.device(args.device) 127 | print("---") 128 | print('Device is :', device) 129 | # Load data 130 | print("---") 131 | print("Loading data ... ") 132 | if args.mode == 'valid_only': 133 | val_loader = load_data(args) 134 | # Load Model 135 | print("---") 136 | print("Loading Model ... ") 137 | model = FaultSeg3D(args.in_channels, args.out_channels).to(args.device) 138 | 139 | model_path = './EXP/' + args.exp + '/models/' + args.pretrained_model_name 140 | 141 | model.load_state_dict(torch.load(model_path)) 142 | 143 | segs = [] 144 | inputs = [] 145 | gts = [] 146 | 147 | print("---") 148 | print("Start validation ... ") 149 | 150 | val_loss = 0.0 151 | val_iou = 0.0 152 | val_dice = 0.0 153 | 154 | model.eval() 155 | with torch.no_grad(): 156 | for step, data in enumerate(tqdm(val_loader, desc='[Valid] Valid')): 157 | x = data['x'].to(args.device) 158 | y = data['y'].to(args.device) 159 | 160 | outputs = model(x) 161 | loss = compute_loss(outputs, y, args) 162 | iou, dice = con_matrix(outputs, y, args) 163 | 164 | val_loss += loss.item() 165 | val_iou += iou 166 | val_dice += dice 167 | 168 | segs.append(outputs.detach().cpu().numpy()) 169 | inputs.append(x.detach().cpu().numpy()) 170 | gts.append(y.detach().cpu().numpy()) 171 | 172 | print( 173 | " val loss: {:.4f}".format(val_loss / len(val_loader)), 174 | " val iou: {:.4f}".format(val_iou / len(val_loader)), 175 | " val dice:{:.4f}".format(val_dice / len(val_loader)), 176 | ) 177 | 178 | print("---") 179 | print("Save result of validation ... ") 180 | 181 | save_result(args, segs, inputs, gts, val_loss / len(val_loader), val_iou / len(val_loader), val_dice / len(val_loader)) 182 | 183 | print("---") 184 | print("Save Finished ! ") 185 | --------------------------------------------------------------------------------