├── WTNet
├── __init__.py
├── __pycache__
│ ├── config.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── dataloader.cpython-38.pyc
│ └── train_and_eval.cpython-38.pyc
├── config.py
├── dataloader.py
└── train_and_eval.py
├── networks
└── WTNet
│ ├── __init__.py
│ ├── model_data
│ └── readme.md
│ ├── __pycache__
│ ├── vgg.cpython-38.pyc
│ ├── WTNet.cpython-38.pyc
│ └── __init__.cpython-38.pyc
│ ├── vgg.py
│ └── WTNet.py
├── logo.jpg
├── Application.pdf
├── LICENSE
├── README.md
├── train_WTNet.py
└── modles
└── WTNet.py
/WTNet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/networks/WTNet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/logo.jpg
--------------------------------------------------------------------------------
/networks/WTNet/model_data/readme.md:
--------------------------------------------------------------------------------
1 | put the pretrained model here
2 |
--------------------------------------------------------------------------------
/Application.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/Application.pdf
--------------------------------------------------------------------------------
/WTNet/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/WTNet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/WTNet/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/WTNet/__pycache__/vgg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/vgg.cpython-38.pyc
--------------------------------------------------------------------------------
/WTNet/__pycache__/train_and_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/train_and_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/WTNet/__pycache__/WTNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/WTNet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/WTNet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 nkicsl
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/WTNet/config.py:
--------------------------------------------------------------------------------
1 | class config:
2 | queue_length = 300
3 | samples_per_volume = 30
4 | patch_size = 64, 64, 64
5 | epoch = 5
6 | epochs_per_val = 1
7 | input_channel = 1
8 | num_classes = 4
9 | batch_size = 2
10 | learning_rate = 0.001
11 | # crop_or_pad_size = 512, 512, 32
12 | input_train_image_dir = 'D:/TOOTH/Datasets/new/patches/256_256_16/Image'
13 | input_train_label_dir = 'D:/TOOTH/Datasets/new/patches/256_256_16/label'
14 | input_val_image_dir = 'C:/Users/zhouzhenhuan/Desktop/DATA/Val/Image'
15 | input_val_label_dir = 'C:/Users/zhouzhenhuan/Desktop/DATA/Val/Label'
16 | # input_test_image_dir = ''
17 | # input_test_label_dir = ''
18 | output_logs_dir = 'E:/PycharmProjects/NKUT_Tooth/logs'
19 | devices = [0, 1]
20 | step_size = 10
21 | gamma = 0.8
22 | latest_output_dir = 'E:/PycharmProjects/NKUT_Tooth/result/latest_output_dir/latest_result.pt'
23 | latest_checkpoint_file = 'E:/PycharmProjects/NKUT_Tooth/result/latest_checkpoint_dir/latest_checkpoint.pt'
24 | best_model_path = 'E:/PycharmProjects/NKUT_Tooth/result/best_model/best_model.pt'
25 | epochs_per_checkpoint = 10
26 |
--------------------------------------------------------------------------------
/WTNet/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | import SimpleITK as sitk
5 | import torchio as tio
6 | from config import config
7 | from torchio.data import UniformSampler
8 | from torchio.transforms import (
9 | RandomFlip,
10 | RandomAffine,
11 | RandomElasticDeformation,
12 | RandomNoise,
13 | RandomMotion,
14 | RandomBiasField,
15 | RescaleIntensity,
16 | Resample,
17 | ToCanonical,
18 | ZNormalization,
19 | CropOrPad,
20 | HistogramStandardization,
21 | OneOf,
22 | Compose,
23 | OneHot,
24 | Resize
25 | )
26 |
27 |
28 | def Tooth_Dataset(images_dir, labels_dir, train=True):
29 | subjects_list = []
30 | images_list = os.listdir(images_dir)
31 |
32 | labels_binary_dir = os.path.join(labels_dir, 'binary')
33 | labels_tooth_dir = os.path.join(labels_dir, 'tooth')
34 | labels_bone_dir = os.path.join(labels_dir, 'bone')
35 |
36 | labels_binary_list = os.listdir(labels_binary_dir)
37 | labels_tooth_list = os.listdir(labels_tooth_dir)
38 | labels_bone_list = os.listdir(labels_bone_dir)
39 |
40 | # queue_length = config.queue_length
41 | # samples_per_volume = config.samples_per_volume
42 | # patch_size = config.patch_size
43 |
44 | training_transform = Compose([
45 | RandomFlip(),
46 | RandomNoise(),
47 | RandomMotion(),
48 | Resize(target_shape=64)
49 | ])
50 | for image, labels_binary, labels_tooth, labels_bone in zip(images_list, labels_binary_list, labels_tooth_list, labels_bone_list):
51 | subject = tio.Subject(
52 | image=tio.ScalarImage(os.path.join(images_dir, image)),
53 | labels_binary=tio.LabelMap(os.path.join(labels_binary_dir, labels_binary)),
54 | labels_tooth=tio.LabelMap(os.path.join(labels_tooth_dir, labels_tooth)),
55 | labels_bone=tio.LabelMap(os.path.join(labels_bone_dir, labels_bone)),
56 | )
57 | subjects_list.append(subject)
58 |
59 | if train:
60 | subject_dataset = tio.SubjectsDataset(subjects_list, transform=training_transform)
61 | # queue_dataset = tio.Queue(
62 | # subject_dataset,
63 | # max_length=queue_length,
64 | # samples_per_volume=samples_per_volume,
65 | # sampler=tio.LabelSampler(patch_size=patch_size, label_name=None,
66 | # label_probabilities={0: 0, 1: 5, 2: 5, 3: 2}),
67 | # )
68 |
69 | return subject_dataset
70 |
71 | else:
72 | subject_dataset = tio.SubjectsDataset(subjects_list, transform=Resize(target_shape=64))
73 | return subject_dataset
74 |
75 |
76 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NKUT: Dataset and Benchmark for Pediatric Mandibular Wisdom Teeth Segmentation
2 | 
3 |
4 | ## News
5 | * `March. 23th, 2024`: Our paper was accepted by IEEE Journal of Biomedical and Health Informatics (JBHI), congratulations!🎉🎉🎉🎉
6 | * `April. 8th, 2024`: We released the NKUT dataset. Now, researchers can apply to obtain the dataset.🎉🎉🎉🎉
7 | * `May. 15th, 2024`: We released the 2D and 3D WTNet model. 🎉🎉🎉🎉
8 | * `Dec. 26th, 2024`: We released the training codes. 🎉🎉🎉🎉
happy new year!
9 |
10 | ## To Do List
11 | - [X] NKUT Dataset release
12 | - [X] WTNet 2D model code release
13 | - [X] WTNet 3D model code release
14 | - [X] Training code release
15 |
16 | ## Request for NKUT Dataset
17 | ### If you wish to use the NKUT dataset in your own research, you need to complete the following steps:
18 | * 1. Download and fill in the `Application.pdf` PDF file in the repository. Please note that all items in the file need to be filled in completely and cannot be left blank, otherwise it may affect the acquisition of the dataset.
19 | * 2. Send an email to `aics@nankai.edu.cn` and copy to `zzh_nkcs@mail.nankai.edu.cn`. The subject of the email should be "NKUT Dataset Request" and briefly describe your name, contact information and institution or organization in the content of the email. Remember to upload the PDF completed in last step as an attachment of your email.
20 | * 3. We will review your application and notify you via email whether your application has been approved or if further submission of materials is required within two weeks. Please arrange your time reasonably.
21 | * 4. For researchers who pass the application, we will attach a link to obtain the dataset with the email. You will get about 30 cases of NKUT dataset and their corresponding pixel-level expert annotations, a doc file recording the details of each data will also be included.
22 |
23 | ## Model
24 | ### WTNet_2D Model
25 | The 2D WTNet model is in ./networks/WTNet/WTNet.py
26 |
27 | ### WTNet_3D Model
28 | The 3D WTNet model is in ./modles/WTNet.py
29 |
30 | ## Training
31 | Adjust the parameters in the final part of train_WTNet.py according to your situation and run train_WTNet.py.
32 |
33 | ## Citation
34 | If you used NKUT in your own research, please give us a star and cite our paper below:
35 |
36 | ```
37 | @ARTICLE{10485282,
38 | author={Zhou, Zhenhuan and Chen, Yuzhu and He, Along and Que, Xitao and Wang, Kai and Yao, Rui and Li, Tao},
39 | journal={IEEE Journal of Biomedical and Health Informatics},
40 | title={NKUT: Dataset and Benchmark for Pediatric Mandibular Wisdom Teeth Segmentation},
41 | year={2024},
42 | volume={28},
43 | number={6},
44 | pages={3523-3533},
45 | keywords={Teeth;Dentistry;Image segmentation;Task analysis;Bones;Annotations;Three-dimensional displays;CBCT dataset;pediatric wisdom teeth segmentation;pediatric germectomy;multi-scale feature fusion},
46 | doi={10.1109/JBHI.2024.3383222}}
47 | ```
48 |
49 | ## Acknowledgment
50 | Code can only be used for ACADEMIC PURPOSES. NO COMERCIAL USE is allowed. Copyright © College of Computer Science, Nankai University. All rights reserved.
51 |
52 |
53 | [](https://star-history.com/#nkicsl/NKUT&Date)
54 |
55 |
--------------------------------------------------------------------------------
/networks/WTNet/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.hub import load_state_dict_from_url
3 | import torch
4 |
5 |
6 | class VGG(nn.Module):
7 | def __init__(self, features, num_classes=1000):
8 | super(VGG, self).__init__()
9 | self.features = features
10 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
11 | self.classifier = nn.Sequential(
12 | nn.Linear(512 * 7 * 7, 4096),
13 | nn.ReLU(True),
14 | nn.Dropout(),
15 | nn.Linear(4096, 4096),
16 | nn.ReLU(True),
17 | nn.Dropout(),
18 | nn.Linear(4096, num_classes),
19 | )
20 | self._initialize_weights()
21 |
22 | def forward(self, x):
23 | # x = self.features(x)
24 | # x = self.avgpool(x)
25 | # x = torch.flatten(x, 1)
26 | # x = self.classifier(x)
27 | feat1 = self.features[:4](x)
28 | feat2 = self.features[4:9](feat1)
29 | feat3 = self.features[9:16](feat2)
30 | feat4 = self.features[16:23](feat3)
31 | feat5 = self.features[23:-1](feat4)
32 |
33 | # print(self.features[:4])
34 | # print(self.features[4:9])
35 | # print(self.features[9:16])
36 | # print(self.features[16:23])
37 | # print(self.features[23:-1])
38 |
39 | return [feat1, feat2, feat3, feat4, feat5]
40 |
41 |
42 | def _initialize_weights(self):
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
46 | if m.bias is not None:
47 | nn.init.constant_(m.bias, 0)
48 | elif isinstance(m, nn.BatchNorm2d):
49 | nn.init.constant_(m.weight, 1)
50 | nn.init.constant_(m.bias, 0)
51 | elif isinstance(m, nn.Linear):
52 | nn.init.normal_(m.weight, 0, 0.01)
53 | nn.init.constant_(m.bias, 0)
54 |
55 |
56 | def make_layers(cfg, batch_norm=False, in_channels=3):
57 | layers = []
58 | for v in cfg:
59 | if v == 'M':
60 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
61 | else:
62 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
63 | if batch_norm:
64 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
65 | else:
66 | layers += [conv2d, nn.ReLU(inplace=True)]
67 | in_channels = v
68 | return nn.Sequential(*layers)
69 |
70 |
71 | # 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256
72 | # 64,64,512 -> 32,32,512 -> 32,32,512
73 | cfgs = {
74 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
75 | }
76 |
77 |
78 | def VGG16(pretrained, in_channels=3, **kwargs):
79 | model = VGG(make_layers(cfgs["D"], batch_norm=False, in_channels=in_channels), **kwargs)
80 | if pretrained:
81 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth",
82 | model_dir="./model_data")
83 | model.load_state_dict(state_dict)
84 |
85 | del model.avgpool
86 | del model.classifier
87 | return model
88 |
89 |
90 | # model = VGG16(pretrained=True, in_channels=3)
91 | # print(model)
92 | # a = torch.rand(size=(1, 3, 256, 256))
93 | # a, b, c, d, e = model(a)
94 | # print(a.shape)
95 | # print(b.shape)
96 | # print(c.shape)
97 | # print(d.shape)
98 | # print(e.shape)
99 |
--------------------------------------------------------------------------------
/WTNet/train_and_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from tqdm import tqdm
4 | import os
5 | import torchio as tio
6 |
7 |
8 | def train_one_epoch(model, diceloss, celoss, optimizer, dataloader, device, epoch, arg):
9 | model.train()
10 | dice_loss = diceloss
11 | ce_loss = celoss
12 |
13 | loss_sum = 0
14 | iteration = 0
15 |
16 | with tqdm(enumerate(dataloader), total=len(dataloader)) as loop:
17 | for i, batch in loop:
18 | data = batch['image'][tio.DATA]
19 | labels_binary = batch['labels_binary'][tio.DATA]
20 | labels_tooth = batch['labels_tooth'][tio.DATA]
21 | labels_bone = batch['labels_bone'][tio.DATA]
22 |
23 | data = data.float()
24 | labels_binary = labels_binary.long()
25 | labels_tooth = labels_tooth.long()
26 | labels_bone = labels_bone.long()
27 |
28 | data = torch.transpose(data, 2, 4)
29 | labels_binary = torch.transpose(labels_binary, 2, 4)
30 | labels_tooth = torch.transpose(labels_tooth, 2, 4)
31 | labels_bone = torch.transpose(labels_bone, 2, 4)
32 |
33 | data = data.to(device)
34 | labels_binary = labels_binary.to(device)
35 | labels_tooth = labels_tooth.to(device)
36 | labels_bone = labels_bone.to(device)
37 |
38 | Binary_out, out_tooth_last, out_bone_last = model(data)
39 | Dice_Binary = dice_loss(Binary_out, labels_binary)
40 | Dice_tooth = dice_loss(out_tooth_last, labels_tooth)
41 | Dice_bone = dice_loss(out_bone_last, labels_bone)
42 |
43 | CE_Binary = ce_loss(Binary_out, labels_binary.squeeze(1))
44 | CE_tooth = ce_loss(out_tooth_last, labels_tooth.squeeze(1))
45 | CE_bone = ce_loss(out_bone_last, labels_bone.squeeze(1))
46 |
47 | loss = Dice_Binary+CE_Binary+Dice_tooth+CE_tooth+Dice_bone+CE_bone
48 |
49 | loss_sum += loss.item()
50 | optimizer.zero_grad()
51 | loss.backward()
52 | optimizer.step()
53 | iteration += 1
54 |
55 | loop.set_description(f'Epoch {epoch}')
56 | loop.set_postfix(lr=optimizer.state_dict()['param_groups'][0]['lr'], total_loss=loss_sum / iteration)
57 |
58 | torch.save(model, arg.latest_output_dir)
59 |
60 | return loss_sum / iteration, model
61 |
62 |
63 | def eval(model_path, dataloader, device, diceloss, celoss):
64 | model = torch.load(model_path)
65 | model.to(device)
66 | model.eval()
67 | iteration = 0
68 |
69 | dice_loss = diceloss.to(device)
70 | ce_loss = celoss.to(device)
71 | val_loss_sum = 0
72 |
73 | with torch.no_grad():
74 | with tqdm(enumerate(dataloader)) as loop_val:
75 | for i, batch in loop_val:
76 | data = batch['image'][tio.DATA]
77 | labels_binary = batch['labels_binary'][tio.DATA]
78 | labels_tooth = batch['labels_tooth'][tio.DATA]
79 | labels_bone = batch['labels_bone'][tio.DATA]
80 |
81 | data = data.float()
82 | labels_binary = labels_binary.long()
83 | labels_tooth = labels_tooth.long()
84 | labels_bone = labels_bone.long()
85 |
86 | data = torch.transpose(data, 2, 4)
87 | labels_binary = torch.transpose(labels_binary, 2, 4)
88 | labels_tooth = torch.transpose(labels_tooth, 2, 4)
89 | labels_bone = torch.transpose(labels_bone, 2, 4)
90 |
91 | data = data.to(device)
92 | labels_binary = labels_binary.to(device)
93 | labels_tooth = labels_tooth.to(device)
94 | labels_bone = labels_bone.to(device)
95 |
96 | Binary_out, out_tooth_last, out_bone_last = model(data)
97 | Dice_Binary = dice_loss(Binary_out, labels_binary)
98 | Dice_tooth = dice_loss(out_tooth_last, labels_tooth)
99 | Dice_bone = dice_loss(out_bone_last, labels_bone)
100 |
101 | CE_Binary = ce_loss(Binary_out, labels_binary.squeeze(1))
102 | CE_tooth = ce_loss(out_tooth_last, labels_tooth.squeeze(1))
103 | CE_bone = ce_loss(out_bone_last, labels_bone.squeeze(1))
104 |
105 | loss = Dice_Binary + CE_Binary + Dice_tooth + CE_tooth + Dice_bone + CE_bone
106 |
107 | val_loss_sum += loss.item()
108 | iteration += 1
109 |
110 | return val_loss_sum / iteration
111 |
112 |
113 | def save_checkpoint(model, optim, scheduler, epoch, save_fre, checkpoint_dir):
114 | if epoch % save_fre == 0:
115 | torch.save(
116 | {
117 | "model": model.state_dict(),
118 | "optim": optim.state_dict(),
119 | "scheduler": scheduler.state_dict(),
120 | "epoch": epoch,
121 | },
122 | os.path.join(checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch))
123 | )
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/train_WTNet.py:
--------------------------------------------------------------------------------
1 | import torchio as tio
2 | import os
3 | import argparse
4 | from shutil import copy
5 | import torch
6 | from torch.nn.modules.loss import CrossEntropyLoss
7 | from torch.optim.lr_scheduler import StepLR
8 | from torch.utils.data import DataLoader
9 | from torch.utils.tensorboard import SummaryWriter
10 | from tqdm import tqdm
11 | from modles.WTNet import WTNet
12 | from WTNet.config import config
13 | from WTNet.dataloader import Tooth_Dataset
14 | from torch.nn.functional import softmax
15 | from monai.losses.dice import DiceLoss
16 | from WTNet.train_and_eval import train_one_epoch, eval, save_checkpoint
17 |
18 | def create_model():
19 | model = WTNet()
20 | return model
21 |
22 |
23 | def main(args, fold):
24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25 | batch_size = args.batch_size
26 | num_workers = args.num_workers
27 | best_val_loss = 1000
28 | best_epoch = 0
29 | count = 0
30 | fold = fold
31 |
32 | train_dataset = Tooth_Dataset(images_dir=args.input_image_dir_train, labels_dir=args.label_dir_train, train=True)
33 | val_dataset = Tooth_Dataset(images_dir=args.input_image_dir_val, labels_dir=args.label_dir_val, train=False)
34 | NKUT_Train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
35 | NKUT_Val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
36 |
37 | if args.resume:
38 | model = create_model()
39 | model.to(device)
40 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
41 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
42 | mode='min',
43 | factor=0.8,
44 | patience=3,
45 | verbose=True,
46 | )
47 |
48 | checkpoint = torch.load('/data/dataset/zzh/ckeckpoint/fold{}/checkpoint_epoch70.pth'.format(fold))
49 | model.load_state_dict(checkpoint['model'])
50 | optimizer.load_state_dict(checkpoint['optim'])
51 | ckpt_epoch = checkpoint['epoch']
52 | scheduler.load_state_dict(checkpoint['scheduler'])
53 |
54 | else:
55 | model = create_model()
56 | model.to(device)
57 |
58 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
59 | # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
60 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
61 | mode='min',
62 | factor=0.8,
63 | patience=3,
64 | verbose=True)
65 | ckpt_epoch = args.start_epoch
66 |
67 | writer = SummaryWriter(log_dir=args.log_dir)
68 | diceloss = DiceLoss(to_onehot_y=True, softmax=True)
69 | celoss = CrossEntropyLoss()
70 |
71 | for epoch in range(ckpt_epoch, args.epochs + 1):
72 | loss_result, new_model = train_one_epoch(model=model, diceloss=diceloss, celoss=celoss,
73 | optimizer=optimizer,
74 | dataloader=NKUT_Train_loader, device=device, arg=args, epoch=epoch)
75 | writer.add_scalar('Training Loss', loss_result, epoch)
76 |
77 | save_checkpoint(model=new_model, optim=optimizer, scheduler=scheduler, checkpoint_dir=args.ckpt_dir,
78 | epoch=epoch, save_fre=args.epochs_per_checkpoint)
79 |
80 | val_loss_sum = eval(model_path='./result/WTNet/fold{}/latest_output.pth'.format(fold),
81 | dataloader=NKUT_Val_loader, device=device, diceloss=diceloss,
82 | celoss=celoss)
83 |
84 | scheduler.step(loss_result)
85 | writer.add_scalar('Val Loss', val_loss_sum, epoch)
86 |
87 | if best_val_loss > val_loss_sum:
88 | copy(src=args.latest_output_dir, dst=args.best_model_path)
89 | best_val_loss = val_loss_sum
90 | best_epoch = epoch
91 | count = 0
92 | else:
93 | count += 1
94 |
95 | print('The total val loss is {}, best is {}, in Epoch {}'.format(val_loss_sum, best_val_loss, best_epoch))
96 | model = new_model
97 |
98 | if count == args.early_stop:
99 | print("early stop")
100 | break
101 |
102 | def parse_args(fold):
103 |
104 | parser = argparse.ArgumentParser(description='NKUT Wisdom Tooth Segmentation')
105 | parser.add_argument('-epochs', type=int, default=500, help='Numbers of epochs to train')
106 | parser.add_argument('-batch_size', type=int, default=3, help='batch size')
107 | parser.add_argument('-input_image_dir_train', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold{}/Train/Image'.format(fold))
108 | parser.add_argument('-label_dir_train', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold{}/Train/Label'.format(fold))
109 | parser.add_argument('-input_image_dir_val', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold4/Val/Image')
110 | parser.add_argument('-label_dir_val', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold4/Val/Label')
111 |
112 | parser.add_argument('-epochs-per-checkpoint', type=int, default=5, help='Number of epochs per checkpoint')
113 | parser.add_argument('-log_dir', '-output_logs_dir', type=str, default='./logs', help='Where to save the train logs')
114 | parser.add_argument('-lr', '-learning rate', type=float, default=0.00001, help='learning rate')
115 | parser.add_argument('-latest_output_dir', type=str, default='./result/WTNet/fold{}/latest_output.pth'.format(fold),
116 | help='where to store the latest model')
117 | parser.add_argument('-best_model_path', type=str, default='./result/WTNet/fold{}/best_result.pth'.format(fold),
118 | help='where to save the best val model')
119 | parser.add_argument('-ckpt_dir', type=str, default='/data/dataset/zzh/ckeckpoint/fold{}'.format(fold),
120 | help='where to save the latest checkpoint')
121 | parser.add_argument('-epochs_per_checkpoint', type=int, default=5, help='epoch to store a checkpoint')
122 | parser.add_argument('-resume', action='store_true', help='continue training')
123 | parser.add_argument('-early_stop', type=int, default=200, help='early stop')
124 | parser.add_argument('-num_workers', type=int, default=8, help='num_workers')
125 | parser.add_argument('-start_epoch', type=int, default=1, help='num_workers')
126 |
127 | args = parser.parse_args()
128 | return args
129 |
130 |
131 | if __name__ == '__main__':
132 | fold = 3
133 | args = parse_args(fold)
134 | main(args, fold)
135 |
--------------------------------------------------------------------------------
/networks/WTNet/WTNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch
4 | from networks.WTNet.vgg import *
5 | import math
6 | from networks.Unet.unet import Unet
7 | from thop import profile
8 |
9 |
10 | class Decoder_block(nn.Module):
11 | def __init__(self, in_channel, out_channel, attention=False):
12 | super(Decoder_block, self).__init__()
13 | self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
14 | self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
15 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.eca = ECABlock(channels=out_channel)
18 | self.Spatial = SpatialAttention()
19 | self.attention = attention
20 |
21 | def forward(self, inputs1, inputs2):
22 | if self.attention:
23 | inputs1 = self.eca(inputs1)
24 | Spatial_map = self.Spatial(inputs1)
25 | inputs1 = inputs1 * Spatial_map
26 | outputs = torch.cat([inputs1, self.up(inputs2)], 1)
27 | else:
28 | outputs = torch.cat([inputs1, self.up(inputs2)], 1)
29 | outputs = self.conv1(outputs)
30 | outputs = self.relu(outputs)
31 | outputs = self.conv2(outputs)
32 | outputs = self.relu(outputs)
33 | return outputs
34 |
35 |
36 | class Encoder(nn.Module):
37 | """
38 | for input size of (B, 3, 256, 256)
39 | output size is: feat1, feat2, feat3, feat4, feat5
40 |
41 | torch.Size([1, 64, 256, 256])
42 | torch.Size([1, 128, 128, 128])
43 | torch.Size([1, 256, 64, 64])
44 | torch.Size([1, 512, 32, 32])
45 | torch.Size([1, 512, 16, 16])
46 | """
47 |
48 | def __init__(self, in_channel):
49 | super(Encoder, self).__init__()
50 | self.backbone = VGG16(pretrained=True, in_channels=in_channel)
51 |
52 | def forward(self, x):
53 | feat1, feat2, feat3, feat4, feat5 = self.backbone(x)
54 |
55 | return feat1, feat2, feat3, feat4, feat5
56 |
57 |
58 | class ECABlock(nn.Module):
59 | def __init__(self, channels, gamma=2, bias=1):
60 | super(ECABlock, self).__init__()
61 |
62 | # 设计自适应卷积核,便于后续做1*1卷积
63 | kernel_size = int(abs((math.log(channels, 2) + bias) / gamma))
64 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
65 | # 全局平局池化
66 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
67 | # 基于1*1卷积学习通道之间的信息
68 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
69 | # 激活函数
70 | self.sigmoid = nn.Sigmoid()
71 |
72 | def forward(self, x):
73 | # 首先,空间维度做全局平局池化,[b,c,h,w]==>[b,c,1,1]
74 | v = self.avg_pool(x)
75 | # 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核
76 | v = self.conv(v.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
77 | # 最终,经过sigmoid 激活函数处理
78 | v = self.sigmoid(v)
79 | return x * v
80 |
81 |
82 | class Tooth_multi_scale(nn.Module):
83 | """
84 | all feature map are sampling to 128*128, then concat in channel dimension
85 | finally, execute channel attention to all channels
86 | """
87 |
88 | def __init__(self):
89 | super(Tooth_multi_scale, self).__init__()
90 | self.input1_down = nn.MaxPool2d(kernel_size=2, stride=2)
91 | self.input2_out = nn.Identity()
92 | self.input3_up = nn.UpsamplingBilinear2d(scale_factor=2)
93 | self.input4_up = nn.UpsamplingBilinear2d(scale_factor=4)
94 | self.channel_atten = ECABlock(channels=960)
95 |
96 | def forward(self, input1, input2, input3, input4):
97 | out1 = self.input1_down(input1)
98 | out2 = self.input2_out(input2)
99 | out3 = self.input3_up(input3)
100 | out4 = self.input4_up(input4)
101 | out = torch.cat([out1, out2, out3, out4], dim=1)
102 | channel_atten_out = self.channel_atten(out)
103 | return channel_atten_out
104 |
105 |
106 | class Bone_multi_scale(nn.Module):
107 | """
108 | all feature map are sampling to 64*64, then concat in channel dimension
109 | finally, execute channel attention to all channels
110 | """
111 |
112 | def __init__(self):
113 | super(Bone_multi_scale, self).__init__()
114 | self.input1_down = nn.MaxPool2d(kernel_size=4, stride=4)
115 | self.input2_down = nn.MaxPool2d(kernel_size=2, stride=2)
116 | self.input3_out = nn.Identity()
117 | self.input4_up = nn.UpsamplingBilinear2d(scale_factor=2)
118 | self.channel_atten = ECABlock(channels=960)
119 |
120 | def forward(self, input1, input2, input3, input4):
121 | out1 = self.input1_down(input1)
122 | out2 = self.input2_down(input2)
123 | out3 = self.input3_out(input3)
124 | out4 = self.input4_up(input4)
125 | out = torch.cat([out1, out2, out3, out4], dim=1)
126 | channel_atten_out = self.channel_atten(out)
127 | return channel_atten_out
128 |
129 |
130 | class SpatialAttention(nn.Module): # Spatial Attention Module
131 | def __init__(self):
132 | super(SpatialAttention, self).__init__()
133 | self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
134 | self.sigmoid = nn.Sigmoid()
135 |
136 | def forward(self, x):
137 | avg_out = torch.mean(x, dim=1, keepdim=True)
138 | max_out, _ = torch.max(x, dim=1, keepdim=True)
139 | out = torch.cat([avg_out, max_out], dim=1)
140 | out = self.conv1(out)
141 | out = self.sigmoid(out)
142 | return out
143 |
144 |
145 | class multi_scale_feature(nn.Module):
146 | def __init__(self, zoom=None, input_feature_size=None, in_multi_size=None, channel=None):
147 | super(multi_scale_feature, self).__init__()
148 | self.input_feature_size = input_feature_size
149 | self.in_multi_size = in_multi_size
150 | self.zoom = zoom
151 | self.channel = channel
152 | self.conv = nn.Conv2d(in_channels=960, out_channels=channel, kernel_size=1, stride=1)
153 | self.atten = SpatialAttention()
154 |
155 | if self.zoom == 'UP':
156 | self.k = input_feature_size / in_multi_size
157 | self.up = nn.UpsamplingBilinear2d(scale_factor=self.k)
158 | elif self.zoom == 'DOWN':
159 | self.avg = nn.AdaptiveAvgPool2d(self.input_feature_size)
160 | elif self.zoom == "None":
161 | self.none = nn.Identity()
162 |
163 | def forward(self, input_feature, in_multi):
164 | if self.zoom == 'UP':
165 | out_up = self.up(in_multi)
166 | out_adjust_channel = self.conv(out_up)
167 | x_add = torch.add(out_adjust_channel, input_feature)
168 | spatial_attention_map = self.atten(x_add)
169 | out = torch.mul(spatial_attention_map, input_feature)
170 | return out
171 |
172 | if self.zoom == 'DOWN':
173 | out_down = self.avg(in_multi)
174 | out_adjust_channel = self.conv(out_down)
175 | x_add = torch.add(out_adjust_channel, input_feature)
176 | spatial_attention_map = self.atten(x_add)
177 | out = torch.mul(spatial_attention_map, input_feature)
178 | return out
179 | if self.zoom == 'None':
180 | out_none = self.none(in_multi)
181 | out_adjust_channel = self.conv(out_none)
182 | x_add = torch.add(out_adjust_channel, input_feature)
183 | spatial_attention_map = self.atten(x_add)
184 | out = torch.mul(spatial_attention_map, input_feature)
185 | return out
186 |
187 |
188 | class Binary_mask(nn.Module):
189 | def __init__(self, num_classes=2):
190 | super(Binary_mask, self).__init__()
191 | self.num_classes = num_classes
192 | self.encoder = Encoder(in_channel=3)
193 | self.decoder4 = Decoder_block(in_channel=1024, out_channel=512)
194 | self.decoder3 = Decoder_block(in_channel=768, out_channel=256)
195 | self.decoder2 = Decoder_block(in_channel=384, out_channel=128)
196 | self.decoder1 = Decoder_block(in_channel=192, out_channel=64)
197 | self.final = nn.Conv2d(64, self.num_classes, 1)
198 |
199 | def forward(self, x):
200 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x)
201 | out4 = self.decoder4(feat4, feat5)
202 | out3 = self.decoder3(feat3, out4)
203 | out2 = self.decoder2(feat2, out3)
204 | out1 = self.decoder1(feat1, out2)
205 | out_last = self.final(out1)
206 | return out_last
207 |
208 | class input_enhancement(nn.Module):
209 | def __init__(self):
210 | super(input_enhancement, self).__init__()
211 | self.conv = nn.Conv2d(9, 3, kernel_size=1, stride=1, padding=0)
212 | self.relu = nn.ReLU(inplace=True)
213 |
214 | def forward(self, origin, binary_mask):
215 | x1 = torch.mul(origin, binary_mask)
216 | out = torch.add(x1, origin)
217 | out = torch.cat([x1, origin, out], dim=1)
218 | out = self.conv(out)
219 | # out = self.relu(out)
220 | return out
221 |
222 |
223 | class Tooth_bone_separation(nn.Module):
224 | def __init__(self):
225 | super(Tooth_bone_separation, self).__init__()
226 | self.encoder = Encoder(in_channel=3)
227 |
228 | self.Tdecoder = nn.ModuleList(
229 | [Decoder_block(in_channel=1024, out_channel=512),
230 | Decoder_block(in_channel=768, out_channel=256),
231 | Decoder_block(in_channel=384, out_channel=128),
232 | Decoder_block(in_channel=192, out_channel=64)]
233 | )
234 |
235 | self.Bdecoder = nn.ModuleList(
236 | [Decoder_block(in_channel=1024, out_channel=512),
237 | Decoder_block(in_channel=768, out_channel=256),
238 | Decoder_block(in_channel=384, out_channel=128),
239 | Decoder_block(in_channel=192, out_channel=64)]
240 | )
241 |
242 | self.Tmulti = nn.ModuleList(
243 | [
244 | multi_scale_feature(zoom='UP', input_feature_size=256, in_multi_size=128, channel=64),
245 | multi_scale_feature(zoom='None', input_feature_size=128, in_multi_size=128, channel=128),
246 | multi_scale_feature(zoom='DOWN', input_feature_size=64, in_multi_size=128, channel=256),
247 | multi_scale_feature(zoom='DOWN', input_feature_size=32, in_multi_size=128, channel=512)
248 | ]
249 | )
250 |
251 | self.Bmulti = nn.ModuleList(
252 | [
253 | multi_scale_feature(zoom='UP', input_feature_size=256, in_multi_size=64, channel=64),
254 | multi_scale_feature(zoom='UP', input_feature_size=128, in_multi_size=64, channel=128),
255 | multi_scale_feature(zoom='None', input_feature_size=64, in_multi_size=64, channel=256),
256 | multi_scale_feature(zoom='DOWN', input_feature_size=32, in_multi_size=64, channel=512)
257 | ]
258 | )
259 |
260 | self.Tooth_multi_scale = Tooth_multi_scale()
261 | self.Bone_multi_scale = Bone_multi_scale()
262 | self.Tfinal = nn.Conv2d(64, 3, 1) # background, WT, SM
263 | self.Bfinal = nn.Conv2d(64, 2, 1) # background, AB
264 |
265 | def forward(self, x):
266 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x)
267 |
268 | Tooth_multi = self.Tooth_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 128, 128)
269 | Tooth_feat1 = self.Tmulti[0](input_feature=feat1, in_multi=Tooth_multi)
270 | Tooth_feat2 = self.Tmulti[1](input_feature=feat2, in_multi=Tooth_multi)
271 | Tooth_feat3 = self.Tmulti[2](input_feature=feat3, in_multi=Tooth_multi)
272 | Tooth_feat4 = self.Tmulti[3](input_feature=feat4, in_multi=Tooth_multi)
273 | Tout4 = self.Tdecoder[0](Tooth_feat4, feat5)
274 | Tout3 = self.Tdecoder[1](Tooth_feat3, Tout4)
275 | Tout2 = self.Tdecoder[2](Tooth_feat2, Tout3)
276 | Tout1 = self.Tdecoder[3](Tooth_feat1, Tout2)
277 |
278 | out_tooth_last = self.Tfinal(Tout1)
279 |
280 | Bone_multi = self.Bone_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 64, 64)
281 | Bone_feat1 = self.Bmulti[0](input_feature=feat1, in_multi=Bone_multi)
282 | Bone_feat2 = self.Bmulti[1](input_feature=feat2, in_multi=Bone_multi)
283 | Bone_feat3 = self.Bmulti[2](input_feature=feat3, in_multi=Bone_multi)
284 | Bone_feat4 = self.Bmulti[3](input_feature=feat4, in_multi=Bone_multi)
285 | Bout4 = self.Bdecoder[0](Bone_feat4, feat5)
286 | Bout3 = self.Bdecoder[1](Bone_feat3, Bout4)
287 | Bout2 = self.Bdecoder[2](Bone_feat2, Bout3)
288 | Bout1 = self.Bdecoder[3](Bone_feat1, Bout2)
289 |
290 | out_bone_last = self.Bfinal(Bout1)
291 |
292 | return out_tooth_last, out_bone_last
293 |
294 |
295 | class WTNet(nn.Module):
296 | def __init__(self):
297 | super(WTNet, self).__init__()
298 | self.Binary = Binary_mask()
299 | self.input_enhancement = input_enhancement()
300 | self.TBS = Tooth_bone_separation()
301 |
302 | def forward(self, x):
303 | Binary_out = self.Binary(x)
304 | Binary_map = torch.nn.functional.softmax(Binary_out, dim=1)
305 | Binary_map = torch.argmax(Binary_map, dim=1, keepdim=True)
306 | enhancement = self.input_enhancement(x, Binary_map)
307 | out_tooth_last, out_bone_last = self.TBS(enhancement)
308 | return Binary_out, out_tooth_last, out_bone_last
309 |
310 |
311 | if __name__ == '__main__':
312 | model = WTNet()
313 | a = torch.rand(size=(1, 3, 256, 256))
314 | b, c, d = model(a)
315 | print(b.shape, c.shape, d.shape)
316 |
317 |
--------------------------------------------------------------------------------
/modles/WTNet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class Decoder_block(nn.Module):
8 | def __init__(self, num_classes=2, init_features=64):
9 | super(Decoder_block, self).__init__()
10 | features = init_features
11 | out_channels = num_classes
12 |
13 | self.upconv4 = nn.ConvTranspose3d(
14 | features * 8, features * 8, kernel_size=2, stride=2
15 | )
16 | self.decoder4 = Decoder_block._block((features * 8) * 2, features * 8, name="dec4")
17 | self.upconv3 = nn.ConvTranspose3d(
18 | features * 8, features * 4, kernel_size=2, stride=2
19 | )
20 | self.decoder3 = Decoder_block._block((features * 4) * 2, features * 4, name="dec3")
21 | self.upconv2 = nn.ConvTranspose3d(
22 | features * 4, features * 2, kernel_size=2, stride=2
23 | )
24 | self.decoder2 = Decoder_block._block((features * 2) * 2, features * 2, name="dec2")
25 | self.upconv1 = nn.ConvTranspose3d(
26 | features * 2, features, kernel_size=2, stride=2
27 | )
28 | self.decoder1 = Decoder_block._block(features * 2, features, name="dec1")
29 |
30 | self.conv = nn.Conv3d(
31 | in_channels=features, out_channels=out_channels, kernel_size=1
32 | )
33 |
34 | def forward(self, fea1, fea2, fea3, fea4, fea5):
35 | dec4 = self.upconv4(fea5)
36 | dec4 = torch.cat((dec4, fea4), dim=1)
37 | dec4 = self.decoder4(dec4)
38 | dec3 = self.upconv3(dec4)
39 | dec3 = torch.cat((dec3, fea3), dim=1)
40 | dec3 = self.decoder3(dec3)
41 | dec2 = self.upconv2(dec3)
42 | dec2 = torch.cat((dec2, fea2), dim=1)
43 | dec2 = self.decoder2(dec2)
44 | dec1 = self.upconv1(dec2)
45 | dec1 = torch.cat((dec1, fea1), dim=1)
46 | dec1 = self.decoder1(dec1)
47 | outputs = self.conv(dec1)
48 | return outputs
49 |
50 | @staticmethod
51 | def _block(in_channels, features, name):
52 | return nn.Sequential(
53 | OrderedDict( # 有序字典
54 | [
55 | (
56 | name + "conv1",
57 | nn.Conv3d(
58 | in_channels=in_channels,
59 | out_channels=features,
60 | kernel_size=3,
61 | padding=1,
62 | bias=True,
63 | ),
64 | ),
65 | (name + "norm1", nn.BatchNorm3d(num_features=features)),
66 | (name + "relu1", nn.ReLU(inplace=True)),
67 | (
68 | name + "conv2",
69 | nn.Conv3d(
70 | in_channels=features,
71 | out_channels=features,
72 | kernel_size=3,
73 | padding=1,
74 | bias=True,
75 | ),
76 | ),
77 | (name + "norm2", nn.BatchNorm3d(num_features=features)),
78 | (name + "relu2", nn.ReLU(inplace=True)),
79 | ]
80 | )
81 | )
82 |
83 |
84 | class Encoder(nn.Module):
85 | """
86 | for input size of (B, 1, 64, 64, 64)
87 | output size is: feat1, feat2, feat3, feat4, feat5
88 |
89 | torch.Size([1, 64, 256, 256])
90 | torch.Size([1, 128, 128, 128])
91 | torch.Size([1, 256, 64, 64])
92 | torch.Size([1, 512, 32, 32])
93 | torch.Size([1, 512, 16, 16])
94 | """
95 |
96 | def __init__(self, in_channels=1, init_features=64):
97 | super(Encoder, self).__init__()
98 |
99 | features = init_features
100 | self.encoder1 = Encoder._block(in_channels, features, name="enc1")
101 | self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
102 | self.encoder2 = Encoder._block(features, features * 2, name="enc2")
103 | self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
104 | self.encoder3 = Encoder._block(features * 2, features * 4, name="enc3")
105 | self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
106 | self.encoder4 = Encoder._block(features * 4, features * 8, name="enc4")
107 | self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
108 |
109 | self.bottleneck = Encoder._block(features * 8, features * 8, name="bottleneck")
110 |
111 | def forward(self, x):
112 | feat1 = self.encoder1(x)
113 | feat2 = self.encoder2(self.pool1(feat1))
114 | feat3 = self.encoder3(self.pool2(feat2))
115 | feat4 = self.encoder4(self.pool3(feat3))
116 | feat5 = self.bottleneck(self.pool4(feat4))
117 |
118 | return feat1, feat2, feat3, feat4, feat5
119 |
120 | @staticmethod
121 | def _block(in_channels, features, name):
122 | return nn.Sequential(
123 | OrderedDict( # 有序字典
124 | [
125 | (
126 | name + "conv1",
127 | nn.Conv3d(
128 | in_channels=in_channels,
129 | out_channels=features,
130 | kernel_size=3,
131 | padding=1,
132 | bias=True,
133 | ),
134 | ),
135 | (name + "norm1", nn.BatchNorm3d(num_features=features)),
136 | (name + "relu1", nn.ReLU(inplace=True)),
137 | (
138 | name + "conv2",
139 | nn.Conv3d(
140 | in_channels=features,
141 | out_channels=features,
142 | kernel_size=3,
143 | padding=1,
144 | bias=True,
145 | ),
146 | ),
147 | (name + "norm2", nn.BatchNorm3d(num_features=features)),
148 | (name + "relu2", nn.ReLU(inplace=True)),
149 | ]
150 | )
151 | )
152 |
153 |
154 | class ECABlock(nn.Module):
155 | def __init__(self, channels, gamma=2, bias=1):
156 | super(ECABlock, self).__init__()
157 |
158 | # 设计自适应卷积核,便于后续做1*1卷积
159 | kernel_size = int(abs((math.log(channels, 2) + bias) / gamma))
160 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
161 | # 全局平局池化
162 | self.avg_pool = nn.AdaptiveAvgPool3d(1)
163 | # 基于1*1卷积学习通道之间的信息
164 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
165 | # 激活函数
166 | self.sigmoid = nn.Sigmoid()
167 |
168 | def forward(self, x):
169 | # 首先,空间维度做全局平局池化,[b,c,h,w,d]==>[b,c,1,1,1]
170 | v = self.avg_pool(x)
171 | # 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核
172 | v = self.conv(v.squeeze(-1).squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1).unsqueeze(-1)
173 | # 最终,经过sigmoid 激活函数处理
174 | v = self.sigmoid(v)
175 | return x * v
176 |
177 |
178 | class Tooth_multi_scale(nn.Module):
179 | """
180 | all feature map are sampling to 32*32*32, then concat in channel dimension
181 | finally, execute channel attention to all channels
182 | """
183 |
184 | def __init__(self):
185 | super(Tooth_multi_scale, self).__init__()
186 | self.input1_down = nn.MaxPool3d(kernel_size=2, stride=2)
187 | self.input2_out = nn.Identity()
188 | self.input3_up = nn.Upsample(scale_factor=2)
189 | self.input4_up = nn.Upsample(scale_factor=4)
190 | self.channel_atten = ECABlock(channels=960)
191 |
192 | def forward(self, input1, input2, input3, input4):
193 | out1 = self.input1_down(input1)
194 | out2 = self.input2_out(input2)
195 | out3 = self.input3_up(input3)
196 | out4 = self.input4_up(input4)
197 | out = torch.cat([out1, out2, out3, out4], dim=1)
198 | channel_atten_out = self.channel_atten(out)
199 | return channel_atten_out
200 |
201 |
202 | class Bone_multi_scale(nn.Module):
203 | """
204 | all feature map are sampling to 16*16*16, then concat in channel dimension
205 | finally, execute channel attention to all channels
206 | """
207 |
208 | def __init__(self):
209 | super(Bone_multi_scale, self).__init__()
210 | self.input1_down = nn.MaxPool3d(kernel_size=4, stride=4)
211 | self.input2_down = nn.MaxPool3d(kernel_size=2, stride=2)
212 | self.input3_out = nn.Identity()
213 | self.input4_up = nn.Upsample(scale_factor=2)
214 | self.channel_atten = ECABlock(channels=960)
215 |
216 | def forward(self, input1, input2, input3, input4):
217 | out1 = self.input1_down(input1)
218 | out2 = self.input2_down(input2)
219 | out3 = self.input3_out(input3)
220 | out4 = self.input4_up(input4)
221 | out = torch.cat([out1, out2, out3, out4], dim=1)
222 | channel_atten_out = self.channel_atten(out)
223 | return channel_atten_out
224 |
225 |
226 | class SpatialAttention(nn.Module): # Spatial Attention Module
227 | def __init__(self):
228 | super(SpatialAttention, self).__init__()
229 | self.conv1 = nn.Conv3d(2, 1, kernel_size=7, padding=3, bias=False)
230 | self.sigmoid = nn.Sigmoid()
231 |
232 | def forward(self, x):
233 | avg_out = torch.mean(x, dim=1, keepdim=True)
234 | max_out, _ = torch.max(x, dim=1, keepdim=True)
235 | out = torch.cat([avg_out, max_out], dim=1)
236 | out = self.conv1(out)
237 | out = self.sigmoid(out)
238 | return out
239 |
240 |
241 | class multi_scale_feature(nn.Module):
242 | def __init__(self, zoom=None, input_feature_size=None, in_multi_size=None, channel=None):
243 | super(multi_scale_feature, self).__init__()
244 | self.input_feature_size = input_feature_size
245 | self.in_multi_size = in_multi_size
246 | self.zoom = zoom
247 | self.channel = channel
248 | self.conv = nn.Conv3d(in_channels=960, out_channels=channel, kernel_size=1, stride=1)
249 | self.atten = SpatialAttention()
250 |
251 | if self.zoom == 'UP':
252 | self.k = input_feature_size / in_multi_size
253 | self.up = nn.Upsample(scale_factor=self.k)
254 | elif self.zoom == 'DOWN':
255 | self.avg = nn.AdaptiveAvgPool3d(self.input_feature_size)
256 | elif self.zoom == "None":
257 | self.none = nn.Identity()
258 |
259 | def forward(self, input_feature, in_multi):
260 | if self.zoom == 'UP':
261 | out_up = self.up(in_multi)
262 | out_adjust_channel = self.conv(out_up)
263 | x_add = torch.add(out_adjust_channel, input_feature)
264 | spatial_attention_map = self.atten(x_add)
265 | out = torch.mul(spatial_attention_map, input_feature)
266 | return out
267 |
268 | if self.zoom == 'DOWN':
269 | out_down = self.avg(in_multi)
270 | out_adjust_channel = self.conv(out_down)
271 | x_add = torch.add(out_adjust_channel, input_feature)
272 | spatial_attention_map = self.atten(x_add)
273 | out = torch.mul(spatial_attention_map, input_feature)
274 | return out
275 | if self.zoom == 'None':
276 | out_none = self.none(in_multi)
277 | out_adjust_channel = self.conv(out_none)
278 | x_add = torch.add(out_adjust_channel, input_feature)
279 | spatial_attention_map = self.atten(x_add)
280 | out = torch.mul(spatial_attention_map, input_feature)
281 | return out
282 |
283 |
284 | class Binary_mask(nn.Module):
285 | def __init__(self, num_classes=2):
286 | super(Binary_mask, self).__init__()
287 | self.num_classes = num_classes
288 | self.encoder = Encoder(in_channels=1, init_features=64)
289 | self.decoder = Decoder_block(num_classes=2, init_features=64)
290 |
291 | def forward(self, x):
292 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x)
293 | out = self.decoder(feat1, feat2, feat3, feat4, feat5)
294 | return out
295 |
296 |
297 | class input_enhancement(nn.Module):
298 | def __init__(self):
299 | super(input_enhancement, self).__init__()
300 | self.conv = nn.Conv3d(3, 1, kernel_size=1, stride=1, padding=0)
301 | self.relu = nn.ReLU(inplace=True)
302 |
303 | def forward(self, origin, binary_mask):
304 | x1 = torch.mul(origin, binary_mask)
305 | out = torch.add(x1, origin)
306 | out = torch.cat([x1, origin, out], dim=1)
307 | out = self.conv(out)
308 | # out = self.relu(out)
309 | return out
310 |
311 |
312 | class Tooth_bone_separation(nn.Module):
313 | def __init__(self):
314 | super(Tooth_bone_separation, self).__init__()
315 | self.encoder = Encoder()
316 |
317 | self.Tdecoder = Decoder_block()
318 | self.Bdecoder = Decoder_block()
319 |
320 | self.Tmulti = nn.ModuleList(
321 | [
322 | multi_scale_feature(zoom='UP', input_feature_size=64, in_multi_size=32, channel=64),
323 | multi_scale_feature(zoom='None', input_feature_size=32, in_multi_size=32, channel=128),
324 | multi_scale_feature(zoom='DOWN', input_feature_size=16, in_multi_size=32, channel=256),
325 | multi_scale_feature(zoom='DOWN', input_feature_size=8, in_multi_size=32, channel=512)
326 | ]
327 | )
328 |
329 | self.Bmulti = nn.ModuleList(
330 | [
331 | multi_scale_feature(zoom='UP', input_feature_size=64, in_multi_size=16, channel=64),
332 | multi_scale_feature(zoom='UP', input_feature_size=32, in_multi_size=16, channel=128),
333 | multi_scale_feature(zoom='None', input_feature_size=16, in_multi_size=16, channel=256),
334 | multi_scale_feature(zoom='DOWN', input_feature_size=8, in_multi_size=16, channel=512)
335 | ]
336 | )
337 |
338 | self.Tooth_multi_scale = Tooth_multi_scale()
339 | self.Bone_multi_scale = Bone_multi_scale()
340 | self.Tfinal = nn.Conv3d(2, 3, 1) # background, WT, SM
341 | self.Bfinal = nn.Conv3d(2, 2, 1) # background, AB
342 |
343 | def forward(self, x):
344 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x)
345 | Tooth_multi = self.Tooth_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 32, 32, 32)
346 | Tooth_feat1 = self.Tmulti[0](input_feature=feat1, in_multi=Tooth_multi)
347 | Tooth_feat2 = self.Tmulti[1](input_feature=feat2, in_multi=Tooth_multi)
348 | Tooth_feat3 = self.Tmulti[2](input_feature=feat3, in_multi=Tooth_multi)
349 | Tooth_feat4 = self.Tmulti[3](input_feature=feat4, in_multi=Tooth_multi)
350 |
351 | Tout1 = self.Tdecoder(Tooth_feat1, Tooth_feat2, Tooth_feat3, Tooth_feat4, feat5)
352 | out_tooth_last = self.Tfinal(Tout1)
353 |
354 | Bone_multi = self.Bone_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 16, 16, 16)
355 | Bone_feat1 = self.Bmulti[0](input_feature=feat1, in_multi=Bone_multi)
356 | Bone_feat2 = self.Bmulti[1](input_feature=feat2, in_multi=Bone_multi)
357 | Bone_feat3 = self.Bmulti[2](input_feature=feat3, in_multi=Bone_multi)
358 | Bone_feat4 = self.Bmulti[3](input_feature=feat4, in_multi=Bone_multi)
359 |
360 | Bout1 = self.Bdecoder(Bone_feat1, Bone_feat2, Bone_feat3, Bone_feat4, feat5)
361 | out_bone_last = self.Bfinal(Bout1)
362 |
363 | return out_tooth_last, out_bone_last
364 |
365 |
366 | class WTNet(nn.Module):
367 | def __init__(self):
368 | super(WTNet, self).__init__()
369 | self.Binary = Binary_mask()
370 | self.input_enhancement = input_enhancement()
371 | self.TBS = Tooth_bone_separation()
372 |
373 | def forward(self, x):
374 | Binary_out = self.Binary(x)
375 | Binary_map = torch.nn.functional.softmax(Binary_out, dim=1)
376 | Binary_map = torch.argmax(Binary_map, dim=1, keepdim=True)
377 | enhancement = self.input_enhancement(x, Binary_map)
378 | out_tooth_last, out_bone_last = self.TBS(enhancement)
379 | return Binary_out, out_tooth_last, out_bone_last
380 |
381 |
382 | if __name__ == '__main__':
383 | a = torch.rand(size=(2, 1, 64, 64, 64))
384 | model = WTNet()
385 | Binary_out, out_tooth_last, out_bone_last = model(a)
386 | print(Binary_out.shape)
387 | print(out_tooth_last.shape)
388 | print(out_bone_last.shape)
389 | # print(feat4.shape)
390 | # print(feat5.shape)
391 | #
392 | # eca = SpatialAttention()
393 | # out = eca(feat1)
394 | # print(out.shape)
395 |
--------------------------------------------------------------------------------