├── WTNet ├── __init__.py ├── __pycache__ │ ├── config.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── dataloader.cpython-38.pyc │ └── train_and_eval.cpython-38.pyc ├── config.py ├── dataloader.py └── train_and_eval.py ├── networks └── WTNet │ ├── __init__.py │ ├── model_data │ └── readme.md │ ├── __pycache__ │ ├── vgg.cpython-38.pyc │ ├── WTNet.cpython-38.pyc │ └── __init__.cpython-38.pyc │ ├── vgg.py │ └── WTNet.py ├── logo.jpg ├── Application.pdf ├── LICENSE ├── README.md ├── train_WTNet.py └── modles └── WTNet.py /WTNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/WTNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/logo.jpg -------------------------------------------------------------------------------- /networks/WTNet/model_data/readme.md: -------------------------------------------------------------------------------- 1 | put the pretrained model here 2 | -------------------------------------------------------------------------------- /Application.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/Application.pdf -------------------------------------------------------------------------------- /WTNet/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /WTNet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /WTNet/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /networks/WTNet/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /WTNet/__pycache__/train_and_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/WTNet/__pycache__/train_and_eval.cpython-38.pyc -------------------------------------------------------------------------------- /networks/WTNet/__pycache__/WTNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/WTNet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/WTNet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nkicsl/NKUT/HEAD/networks/WTNet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 nkicsl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /WTNet/config.py: -------------------------------------------------------------------------------- 1 | class config: 2 | queue_length = 300 3 | samples_per_volume = 30 4 | patch_size = 64, 64, 64 5 | epoch = 5 6 | epochs_per_val = 1 7 | input_channel = 1 8 | num_classes = 4 9 | batch_size = 2 10 | learning_rate = 0.001 11 | # crop_or_pad_size = 512, 512, 32 12 | input_train_image_dir = 'D:/TOOTH/Datasets/new/patches/256_256_16/Image' 13 | input_train_label_dir = 'D:/TOOTH/Datasets/new/patches/256_256_16/label' 14 | input_val_image_dir = 'C:/Users/zhouzhenhuan/Desktop/DATA/Val/Image' 15 | input_val_label_dir = 'C:/Users/zhouzhenhuan/Desktop/DATA/Val/Label' 16 | # input_test_image_dir = '' 17 | # input_test_label_dir = '' 18 | output_logs_dir = 'E:/PycharmProjects/NKUT_Tooth/logs' 19 | devices = [0, 1] 20 | step_size = 10 21 | gamma = 0.8 22 | latest_output_dir = 'E:/PycharmProjects/NKUT_Tooth/result/latest_output_dir/latest_result.pt' 23 | latest_checkpoint_file = 'E:/PycharmProjects/NKUT_Tooth/result/latest_checkpoint_dir/latest_checkpoint.pt' 24 | best_model_path = 'E:/PycharmProjects/NKUT_Tooth/result/best_model/best_model.pt' 25 | epochs_per_checkpoint = 10 26 | -------------------------------------------------------------------------------- /WTNet/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import SimpleITK as sitk 5 | import torchio as tio 6 | from config import config 7 | from torchio.data import UniformSampler 8 | from torchio.transforms import ( 9 | RandomFlip, 10 | RandomAffine, 11 | RandomElasticDeformation, 12 | RandomNoise, 13 | RandomMotion, 14 | RandomBiasField, 15 | RescaleIntensity, 16 | Resample, 17 | ToCanonical, 18 | ZNormalization, 19 | CropOrPad, 20 | HistogramStandardization, 21 | OneOf, 22 | Compose, 23 | OneHot, 24 | Resize 25 | ) 26 | 27 | 28 | def Tooth_Dataset(images_dir, labels_dir, train=True): 29 | subjects_list = [] 30 | images_list = os.listdir(images_dir) 31 | 32 | labels_binary_dir = os.path.join(labels_dir, 'binary') 33 | labels_tooth_dir = os.path.join(labels_dir, 'tooth') 34 | labels_bone_dir = os.path.join(labels_dir, 'bone') 35 | 36 | labels_binary_list = os.listdir(labels_binary_dir) 37 | labels_tooth_list = os.listdir(labels_tooth_dir) 38 | labels_bone_list = os.listdir(labels_bone_dir) 39 | 40 | # queue_length = config.queue_length 41 | # samples_per_volume = config.samples_per_volume 42 | # patch_size = config.patch_size 43 | 44 | training_transform = Compose([ 45 | RandomFlip(), 46 | RandomNoise(), 47 | RandomMotion(), 48 | Resize(target_shape=64) 49 | ]) 50 | for image, labels_binary, labels_tooth, labels_bone in zip(images_list, labels_binary_list, labels_tooth_list, labels_bone_list): 51 | subject = tio.Subject( 52 | image=tio.ScalarImage(os.path.join(images_dir, image)), 53 | labels_binary=tio.LabelMap(os.path.join(labels_binary_dir, labels_binary)), 54 | labels_tooth=tio.LabelMap(os.path.join(labels_tooth_dir, labels_tooth)), 55 | labels_bone=tio.LabelMap(os.path.join(labels_bone_dir, labels_bone)), 56 | ) 57 | subjects_list.append(subject) 58 | 59 | if train: 60 | subject_dataset = tio.SubjectsDataset(subjects_list, transform=training_transform) 61 | # queue_dataset = tio.Queue( 62 | # subject_dataset, 63 | # max_length=queue_length, 64 | # samples_per_volume=samples_per_volume, 65 | # sampler=tio.LabelSampler(patch_size=patch_size, label_name=None, 66 | # label_probabilities={0: 0, 1: 5, 2: 5, 3: 2}), 67 | # ) 68 | 69 | return subject_dataset 70 | 71 | else: 72 | subject_dataset = tio.SubjectsDataset(subjects_list, transform=Resize(target_shape=64)) 73 | return subject_dataset 74 | 75 | 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NKUT: Dataset and Benchmark for Pediatric Mandibular Wisdom Teeth Segmentation 2 | ![NKUT_logo](./logo.jpg) 3 | 4 | ## News 5 | * `March. 23th, 2024`: Our paper was accepted by IEEE Journal of Biomedical and Health Informatics (JBHI), congratulations!🎉🎉🎉🎉
6 | * `April. 8th, 2024`: We released the NKUT dataset. Now, researchers can apply to obtain the dataset.🎉🎉🎉🎉
7 | * `May. 15th, 2024`: We released the 2D and 3D WTNet model. 🎉🎉🎉🎉
8 | * `Dec. 26th, 2024`: We released the training codes. 🎉🎉🎉🎉
happy new year! 9 | 10 | ## To Do List 11 | - [X] NKUT Dataset release 12 | - [X] WTNet 2D model code release 13 | - [X] WTNet 3D model code release 14 | - [X] Training code release 15 | 16 | ## Request for NKUT Dataset 17 | ### If you wish to use the NKUT dataset in your own research, you need to complete the following steps: 18 | * 1. Download and fill in the `Application.pdf` PDF file in the repository. Please note that all items in the file need to be filled in completely and cannot be left blank, otherwise it may affect the acquisition of the dataset. 19 | * 2. Send an email to `aics@nankai.edu.cn` and copy to `zzh_nkcs@mail.nankai.edu.cn`. The subject of the email should be "NKUT Dataset Request" and briefly describe your name, contact information and institution or organization in the content of the email. Remember to upload the PDF completed in last step as an attachment of your email. 20 | * 3. We will review your application and notify you via email whether your application has been approved or if further submission of materials is required within two weeks. Please arrange your time reasonably. 21 | * 4. For researchers who pass the application, we will attach a link to obtain the dataset with the email. You will get about 30 cases of NKUT dataset and their corresponding pixel-level expert annotations, a doc file recording the details of each data will also be included. 22 | 23 | ## Model 24 | ### WTNet_2D Model 25 | The 2D WTNet model is in ./networks/WTNet/WTNet.py
26 | 27 | ### WTNet_3D Model 28 | The 3D WTNet model is in ./modles/WTNet.py
29 | 30 | ## Training 31 | Adjust the parameters in the final part of train_WTNet.py according to your situation and run train_WTNet.py. 32 | 33 | ## Citation 34 | If you used NKUT in your own research, please give us a star and cite our paper below: 35 | 36 | ``` 37 | @ARTICLE{10485282, 38 | author={Zhou, Zhenhuan and Chen, Yuzhu and He, Along and Que, Xitao and Wang, Kai and Yao, Rui and Li, Tao}, 39 | journal={IEEE Journal of Biomedical and Health Informatics}, 40 | title={NKUT: Dataset and Benchmark for Pediatric Mandibular Wisdom Teeth Segmentation}, 41 | year={2024}, 42 | volume={28}, 43 | number={6}, 44 | pages={3523-3533}, 45 | keywords={Teeth;Dentistry;Image segmentation;Task analysis;Bones;Annotations;Three-dimensional displays;CBCT dataset;pediatric wisdom teeth segmentation;pediatric germectomy;multi-scale feature fusion}, 46 | doi={10.1109/JBHI.2024.3383222}} 47 | ``` 48 | 49 | ## Acknowledgment 50 | Code can only be used for ACADEMIC PURPOSES. NO COMERCIAL USE is allowed. Copyright © College of Computer Science, Nankai University. All rights reserved. 51 | 52 | 53 | [![Star History Chart](https://api.star-history.com/svg?repos=nkicsl/NKUT&type=Date)](https://star-history.com/#nkicsl/NKUT&Date) 54 | 55 | -------------------------------------------------------------------------------- /networks/WTNet/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | import torch 4 | 5 | 6 | class VGG(nn.Module): 7 | def __init__(self, features, num_classes=1000): 8 | super(VGG, self).__init__() 9 | self.features = features 10 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 11 | self.classifier = nn.Sequential( 12 | nn.Linear(512 * 7 * 7, 4096), 13 | nn.ReLU(True), 14 | nn.Dropout(), 15 | nn.Linear(4096, 4096), 16 | nn.ReLU(True), 17 | nn.Dropout(), 18 | nn.Linear(4096, num_classes), 19 | ) 20 | self._initialize_weights() 21 | 22 | def forward(self, x): 23 | # x = self.features(x) 24 | # x = self.avgpool(x) 25 | # x = torch.flatten(x, 1) 26 | # x = self.classifier(x) 27 | feat1 = self.features[:4](x) 28 | feat2 = self.features[4:9](feat1) 29 | feat3 = self.features[9:16](feat2) 30 | feat4 = self.features[16:23](feat3) 31 | feat5 = self.features[23:-1](feat4) 32 | 33 | # print(self.features[:4]) 34 | # print(self.features[4:9]) 35 | # print(self.features[9:16]) 36 | # print(self.features[16:23]) 37 | # print(self.features[23:-1]) 38 | 39 | return [feat1, feat2, feat3, feat4, feat5] 40 | 41 | 42 | def _initialize_weights(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 46 | if m.bias is not None: 47 | nn.init.constant_(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | nn.init.constant_(m.weight, 1) 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.Linear): 52 | nn.init.normal_(m.weight, 0, 0.01) 53 | nn.init.constant_(m.bias, 0) 54 | 55 | 56 | def make_layers(cfg, batch_norm=False, in_channels=3): 57 | layers = [] 58 | for v in cfg: 59 | if v == 'M': 60 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 61 | else: 62 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 63 | if batch_norm: 64 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 65 | else: 66 | layers += [conv2d, nn.ReLU(inplace=True)] 67 | in_channels = v 68 | return nn.Sequential(*layers) 69 | 70 | 71 | # 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256 72 | # 64,64,512 -> 32,32,512 -> 32,32,512 73 | cfgs = { 74 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 75 | } 76 | 77 | 78 | def VGG16(pretrained, in_channels=3, **kwargs): 79 | model = VGG(make_layers(cfgs["D"], batch_norm=False, in_channels=in_channels), **kwargs) 80 | if pretrained: 81 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", 82 | model_dir="./model_data") 83 | model.load_state_dict(state_dict) 84 | 85 | del model.avgpool 86 | del model.classifier 87 | return model 88 | 89 | 90 | # model = VGG16(pretrained=True, in_channels=3) 91 | # print(model) 92 | # a = torch.rand(size=(1, 3, 256, 256)) 93 | # a, b, c, d, e = model(a) 94 | # print(a.shape) 95 | # print(b.shape) 96 | # print(c.shape) 97 | # print(d.shape) 98 | # print(e.shape) 99 | -------------------------------------------------------------------------------- /WTNet/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from tqdm import tqdm 4 | import os 5 | import torchio as tio 6 | 7 | 8 | def train_one_epoch(model, diceloss, celoss, optimizer, dataloader, device, epoch, arg): 9 | model.train() 10 | dice_loss = diceloss 11 | ce_loss = celoss 12 | 13 | loss_sum = 0 14 | iteration = 0 15 | 16 | with tqdm(enumerate(dataloader), total=len(dataloader)) as loop: 17 | for i, batch in loop: 18 | data = batch['image'][tio.DATA] 19 | labels_binary = batch['labels_binary'][tio.DATA] 20 | labels_tooth = batch['labels_tooth'][tio.DATA] 21 | labels_bone = batch['labels_bone'][tio.DATA] 22 | 23 | data = data.float() 24 | labels_binary = labels_binary.long() 25 | labels_tooth = labels_tooth.long() 26 | labels_bone = labels_bone.long() 27 | 28 | data = torch.transpose(data, 2, 4) 29 | labels_binary = torch.transpose(labels_binary, 2, 4) 30 | labels_tooth = torch.transpose(labels_tooth, 2, 4) 31 | labels_bone = torch.transpose(labels_bone, 2, 4) 32 | 33 | data = data.to(device) 34 | labels_binary = labels_binary.to(device) 35 | labels_tooth = labels_tooth.to(device) 36 | labels_bone = labels_bone.to(device) 37 | 38 | Binary_out, out_tooth_last, out_bone_last = model(data) 39 | Dice_Binary = dice_loss(Binary_out, labels_binary) 40 | Dice_tooth = dice_loss(out_tooth_last, labels_tooth) 41 | Dice_bone = dice_loss(out_bone_last, labels_bone) 42 | 43 | CE_Binary = ce_loss(Binary_out, labels_binary.squeeze(1)) 44 | CE_tooth = ce_loss(out_tooth_last, labels_tooth.squeeze(1)) 45 | CE_bone = ce_loss(out_bone_last, labels_bone.squeeze(1)) 46 | 47 | loss = Dice_Binary+CE_Binary+Dice_tooth+CE_tooth+Dice_bone+CE_bone 48 | 49 | loss_sum += loss.item() 50 | optimizer.zero_grad() 51 | loss.backward() 52 | optimizer.step() 53 | iteration += 1 54 | 55 | loop.set_description(f'Epoch {epoch}') 56 | loop.set_postfix(lr=optimizer.state_dict()['param_groups'][0]['lr'], total_loss=loss_sum / iteration) 57 | 58 | torch.save(model, arg.latest_output_dir) 59 | 60 | return loss_sum / iteration, model 61 | 62 | 63 | def eval(model_path, dataloader, device, diceloss, celoss): 64 | model = torch.load(model_path) 65 | model.to(device) 66 | model.eval() 67 | iteration = 0 68 | 69 | dice_loss = diceloss.to(device) 70 | ce_loss = celoss.to(device) 71 | val_loss_sum = 0 72 | 73 | with torch.no_grad(): 74 | with tqdm(enumerate(dataloader)) as loop_val: 75 | for i, batch in loop_val: 76 | data = batch['image'][tio.DATA] 77 | labels_binary = batch['labels_binary'][tio.DATA] 78 | labels_tooth = batch['labels_tooth'][tio.DATA] 79 | labels_bone = batch['labels_bone'][tio.DATA] 80 | 81 | data = data.float() 82 | labels_binary = labels_binary.long() 83 | labels_tooth = labels_tooth.long() 84 | labels_bone = labels_bone.long() 85 | 86 | data = torch.transpose(data, 2, 4) 87 | labels_binary = torch.transpose(labels_binary, 2, 4) 88 | labels_tooth = torch.transpose(labels_tooth, 2, 4) 89 | labels_bone = torch.transpose(labels_bone, 2, 4) 90 | 91 | data = data.to(device) 92 | labels_binary = labels_binary.to(device) 93 | labels_tooth = labels_tooth.to(device) 94 | labels_bone = labels_bone.to(device) 95 | 96 | Binary_out, out_tooth_last, out_bone_last = model(data) 97 | Dice_Binary = dice_loss(Binary_out, labels_binary) 98 | Dice_tooth = dice_loss(out_tooth_last, labels_tooth) 99 | Dice_bone = dice_loss(out_bone_last, labels_bone) 100 | 101 | CE_Binary = ce_loss(Binary_out, labels_binary.squeeze(1)) 102 | CE_tooth = ce_loss(out_tooth_last, labels_tooth.squeeze(1)) 103 | CE_bone = ce_loss(out_bone_last, labels_bone.squeeze(1)) 104 | 105 | loss = Dice_Binary + CE_Binary + Dice_tooth + CE_tooth + Dice_bone + CE_bone 106 | 107 | val_loss_sum += loss.item() 108 | iteration += 1 109 | 110 | return val_loss_sum / iteration 111 | 112 | 113 | def save_checkpoint(model, optim, scheduler, epoch, save_fre, checkpoint_dir): 114 | if epoch % save_fre == 0: 115 | torch.save( 116 | { 117 | "model": model.state_dict(), 118 | "optim": optim.state_dict(), 119 | "scheduler": scheduler.state_dict(), 120 | "epoch": epoch, 121 | }, 122 | os.path.join(checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch)) 123 | ) 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /train_WTNet.py: -------------------------------------------------------------------------------- 1 | import torchio as tio 2 | import os 3 | import argparse 4 | from shutil import copy 5 | import torch 6 | from torch.nn.modules.loss import CrossEntropyLoss 7 | from torch.optim.lr_scheduler import StepLR 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | from modles.WTNet import WTNet 12 | from WTNet.config import config 13 | from WTNet.dataloader import Tooth_Dataset 14 | from torch.nn.functional import softmax 15 | from monai.losses.dice import DiceLoss 16 | from WTNet.train_and_eval import train_one_epoch, eval, save_checkpoint 17 | 18 | def create_model(): 19 | model = WTNet() 20 | return model 21 | 22 | 23 | def main(args, fold): 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | batch_size = args.batch_size 26 | num_workers = args.num_workers 27 | best_val_loss = 1000 28 | best_epoch = 0 29 | count = 0 30 | fold = fold 31 | 32 | train_dataset = Tooth_Dataset(images_dir=args.input_image_dir_train, labels_dir=args.label_dir_train, train=True) 33 | val_dataset = Tooth_Dataset(images_dir=args.input_image_dir_val, labels_dir=args.label_dir_val, train=False) 34 | NKUT_Train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 35 | NKUT_Val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 36 | 37 | if args.resume: 38 | model = create_model() 39 | model.to(device) 40 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 41 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 42 | mode='min', 43 | factor=0.8, 44 | patience=3, 45 | verbose=True, 46 | ) 47 | 48 | checkpoint = torch.load('/data/dataset/zzh/ckeckpoint/fold{}/checkpoint_epoch70.pth'.format(fold)) 49 | model.load_state_dict(checkpoint['model']) 50 | optimizer.load_state_dict(checkpoint['optim']) 51 | ckpt_epoch = checkpoint['epoch'] 52 | scheduler.load_state_dict(checkpoint['scheduler']) 53 | 54 | else: 55 | model = create_model() 56 | model.to(device) 57 | 58 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 59 | # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 60 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 61 | mode='min', 62 | factor=0.8, 63 | patience=3, 64 | verbose=True) 65 | ckpt_epoch = args.start_epoch 66 | 67 | writer = SummaryWriter(log_dir=args.log_dir) 68 | diceloss = DiceLoss(to_onehot_y=True, softmax=True) 69 | celoss = CrossEntropyLoss() 70 | 71 | for epoch in range(ckpt_epoch, args.epochs + 1): 72 | loss_result, new_model = train_one_epoch(model=model, diceloss=diceloss, celoss=celoss, 73 | optimizer=optimizer, 74 | dataloader=NKUT_Train_loader, device=device, arg=args, epoch=epoch) 75 | writer.add_scalar('Training Loss', loss_result, epoch) 76 | 77 | save_checkpoint(model=new_model, optim=optimizer, scheduler=scheduler, checkpoint_dir=args.ckpt_dir, 78 | epoch=epoch, save_fre=args.epochs_per_checkpoint) 79 | 80 | val_loss_sum = eval(model_path='./result/WTNet/fold{}/latest_output.pth'.format(fold), 81 | dataloader=NKUT_Val_loader, device=device, diceloss=diceloss, 82 | celoss=celoss) 83 | 84 | scheduler.step(loss_result) 85 | writer.add_scalar('Val Loss', val_loss_sum, epoch) 86 | 87 | if best_val_loss > val_loss_sum: 88 | copy(src=args.latest_output_dir, dst=args.best_model_path) 89 | best_val_loss = val_loss_sum 90 | best_epoch = epoch 91 | count = 0 92 | else: 93 | count += 1 94 | 95 | print('The total val loss is {}, best is {}, in Epoch {}'.format(val_loss_sum, best_val_loss, best_epoch)) 96 | model = new_model 97 | 98 | if count == args.early_stop: 99 | print("early stop") 100 | break 101 | 102 | def parse_args(fold): 103 | 104 | parser = argparse.ArgumentParser(description='NKUT Wisdom Tooth Segmentation') 105 | parser.add_argument('-epochs', type=int, default=500, help='Numbers of epochs to train') 106 | parser.add_argument('-batch_size', type=int, default=3, help='batch size') 107 | parser.add_argument('-input_image_dir_train', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold{}/Train/Image'.format(fold)) 108 | parser.add_argument('-label_dir_train', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold{}/Train/Label'.format(fold)) 109 | parser.add_argument('-input_image_dir_val', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold4/Val/Image') 110 | parser.add_argument('-label_dir_val', type=str, default='/data/dataset/zzh/NKUT/patch/64_64_64/fold4/Val/Label') 111 | 112 | parser.add_argument('-epochs-per-checkpoint', type=int, default=5, help='Number of epochs per checkpoint') 113 | parser.add_argument('-log_dir', '-output_logs_dir', type=str, default='./logs', help='Where to save the train logs') 114 | parser.add_argument('-lr', '-learning rate', type=float, default=0.00001, help='learning rate') 115 | parser.add_argument('-latest_output_dir', type=str, default='./result/WTNet/fold{}/latest_output.pth'.format(fold), 116 | help='where to store the latest model') 117 | parser.add_argument('-best_model_path', type=str, default='./result/WTNet/fold{}/best_result.pth'.format(fold), 118 | help='where to save the best val model') 119 | parser.add_argument('-ckpt_dir', type=str, default='/data/dataset/zzh/ckeckpoint/fold{}'.format(fold), 120 | help='where to save the latest checkpoint') 121 | parser.add_argument('-epochs_per_checkpoint', type=int, default=5, help='epoch to store a checkpoint') 122 | parser.add_argument('-resume', action='store_true', help='continue training') 123 | parser.add_argument('-early_stop', type=int, default=200, help='early stop') 124 | parser.add_argument('-num_workers', type=int, default=8, help='num_workers') 125 | parser.add_argument('-start_epoch', type=int, default=1, help='num_workers') 126 | 127 | args = parser.parse_args() 128 | return args 129 | 130 | 131 | if __name__ == '__main__': 132 | fold = 3 133 | args = parse_args(fold) 134 | main(args, fold) 135 | -------------------------------------------------------------------------------- /networks/WTNet/WTNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch 4 | from networks.WTNet.vgg import * 5 | import math 6 | from networks.Unet.unet import Unet 7 | from thop import profile 8 | 9 | 10 | class Decoder_block(nn.Module): 11 | def __init__(self, in_channel, out_channel, attention=False): 12 | super(Decoder_block, self).__init__() 13 | self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1) 14 | self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1) 15 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.eca = ECABlock(channels=out_channel) 18 | self.Spatial = SpatialAttention() 19 | self.attention = attention 20 | 21 | def forward(self, inputs1, inputs2): 22 | if self.attention: 23 | inputs1 = self.eca(inputs1) 24 | Spatial_map = self.Spatial(inputs1) 25 | inputs1 = inputs1 * Spatial_map 26 | outputs = torch.cat([inputs1, self.up(inputs2)], 1) 27 | else: 28 | outputs = torch.cat([inputs1, self.up(inputs2)], 1) 29 | outputs = self.conv1(outputs) 30 | outputs = self.relu(outputs) 31 | outputs = self.conv2(outputs) 32 | outputs = self.relu(outputs) 33 | return outputs 34 | 35 | 36 | class Encoder(nn.Module): 37 | """ 38 | for input size of (B, 3, 256, 256) 39 | output size is: feat1, feat2, feat3, feat4, feat5 40 | 41 | torch.Size([1, 64, 256, 256]) 42 | torch.Size([1, 128, 128, 128]) 43 | torch.Size([1, 256, 64, 64]) 44 | torch.Size([1, 512, 32, 32]) 45 | torch.Size([1, 512, 16, 16]) 46 | """ 47 | 48 | def __init__(self, in_channel): 49 | super(Encoder, self).__init__() 50 | self.backbone = VGG16(pretrained=True, in_channels=in_channel) 51 | 52 | def forward(self, x): 53 | feat1, feat2, feat3, feat4, feat5 = self.backbone(x) 54 | 55 | return feat1, feat2, feat3, feat4, feat5 56 | 57 | 58 | class ECABlock(nn.Module): 59 | def __init__(self, channels, gamma=2, bias=1): 60 | super(ECABlock, self).__init__() 61 | 62 | # 设计自适应卷积核,便于后续做1*1卷积 63 | kernel_size = int(abs((math.log(channels, 2) + bias) / gamma)) 64 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 65 | # 全局平局池化 66 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 67 | # 基于1*1卷积学习通道之间的信息 68 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 69 | # 激活函数 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | # 首先,空间维度做全局平局池化,[b,c,h,w]==>[b,c,1,1] 74 | v = self.avg_pool(x) 75 | # 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核 76 | v = self.conv(v.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 77 | # 最终,经过sigmoid 激活函数处理 78 | v = self.sigmoid(v) 79 | return x * v 80 | 81 | 82 | class Tooth_multi_scale(nn.Module): 83 | """ 84 | all feature map are sampling to 128*128, then concat in channel dimension 85 | finally, execute channel attention to all channels 86 | """ 87 | 88 | def __init__(self): 89 | super(Tooth_multi_scale, self).__init__() 90 | self.input1_down = nn.MaxPool2d(kernel_size=2, stride=2) 91 | self.input2_out = nn.Identity() 92 | self.input3_up = nn.UpsamplingBilinear2d(scale_factor=2) 93 | self.input4_up = nn.UpsamplingBilinear2d(scale_factor=4) 94 | self.channel_atten = ECABlock(channels=960) 95 | 96 | def forward(self, input1, input2, input3, input4): 97 | out1 = self.input1_down(input1) 98 | out2 = self.input2_out(input2) 99 | out3 = self.input3_up(input3) 100 | out4 = self.input4_up(input4) 101 | out = torch.cat([out1, out2, out3, out4], dim=1) 102 | channel_atten_out = self.channel_atten(out) 103 | return channel_atten_out 104 | 105 | 106 | class Bone_multi_scale(nn.Module): 107 | """ 108 | all feature map are sampling to 64*64, then concat in channel dimension 109 | finally, execute channel attention to all channels 110 | """ 111 | 112 | def __init__(self): 113 | super(Bone_multi_scale, self).__init__() 114 | self.input1_down = nn.MaxPool2d(kernel_size=4, stride=4) 115 | self.input2_down = nn.MaxPool2d(kernel_size=2, stride=2) 116 | self.input3_out = nn.Identity() 117 | self.input4_up = nn.UpsamplingBilinear2d(scale_factor=2) 118 | self.channel_atten = ECABlock(channels=960) 119 | 120 | def forward(self, input1, input2, input3, input4): 121 | out1 = self.input1_down(input1) 122 | out2 = self.input2_down(input2) 123 | out3 = self.input3_out(input3) 124 | out4 = self.input4_up(input4) 125 | out = torch.cat([out1, out2, out3, out4], dim=1) 126 | channel_atten_out = self.channel_atten(out) 127 | return channel_atten_out 128 | 129 | 130 | class SpatialAttention(nn.Module): # Spatial Attention Module 131 | def __init__(self): 132 | super(SpatialAttention, self).__init__() 133 | self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False) 134 | self.sigmoid = nn.Sigmoid() 135 | 136 | def forward(self, x): 137 | avg_out = torch.mean(x, dim=1, keepdim=True) 138 | max_out, _ = torch.max(x, dim=1, keepdim=True) 139 | out = torch.cat([avg_out, max_out], dim=1) 140 | out = self.conv1(out) 141 | out = self.sigmoid(out) 142 | return out 143 | 144 | 145 | class multi_scale_feature(nn.Module): 146 | def __init__(self, zoom=None, input_feature_size=None, in_multi_size=None, channel=None): 147 | super(multi_scale_feature, self).__init__() 148 | self.input_feature_size = input_feature_size 149 | self.in_multi_size = in_multi_size 150 | self.zoom = zoom 151 | self.channel = channel 152 | self.conv = nn.Conv2d(in_channels=960, out_channels=channel, kernel_size=1, stride=1) 153 | self.atten = SpatialAttention() 154 | 155 | if self.zoom == 'UP': 156 | self.k = input_feature_size / in_multi_size 157 | self.up = nn.UpsamplingBilinear2d(scale_factor=self.k) 158 | elif self.zoom == 'DOWN': 159 | self.avg = nn.AdaptiveAvgPool2d(self.input_feature_size) 160 | elif self.zoom == "None": 161 | self.none = nn.Identity() 162 | 163 | def forward(self, input_feature, in_multi): 164 | if self.zoom == 'UP': 165 | out_up = self.up(in_multi) 166 | out_adjust_channel = self.conv(out_up) 167 | x_add = torch.add(out_adjust_channel, input_feature) 168 | spatial_attention_map = self.atten(x_add) 169 | out = torch.mul(spatial_attention_map, input_feature) 170 | return out 171 | 172 | if self.zoom == 'DOWN': 173 | out_down = self.avg(in_multi) 174 | out_adjust_channel = self.conv(out_down) 175 | x_add = torch.add(out_adjust_channel, input_feature) 176 | spatial_attention_map = self.atten(x_add) 177 | out = torch.mul(spatial_attention_map, input_feature) 178 | return out 179 | if self.zoom == 'None': 180 | out_none = self.none(in_multi) 181 | out_adjust_channel = self.conv(out_none) 182 | x_add = torch.add(out_adjust_channel, input_feature) 183 | spatial_attention_map = self.atten(x_add) 184 | out = torch.mul(spatial_attention_map, input_feature) 185 | return out 186 | 187 | 188 | class Binary_mask(nn.Module): 189 | def __init__(self, num_classes=2): 190 | super(Binary_mask, self).__init__() 191 | self.num_classes = num_classes 192 | self.encoder = Encoder(in_channel=3) 193 | self.decoder4 = Decoder_block(in_channel=1024, out_channel=512) 194 | self.decoder3 = Decoder_block(in_channel=768, out_channel=256) 195 | self.decoder2 = Decoder_block(in_channel=384, out_channel=128) 196 | self.decoder1 = Decoder_block(in_channel=192, out_channel=64) 197 | self.final = nn.Conv2d(64, self.num_classes, 1) 198 | 199 | def forward(self, x): 200 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x) 201 | out4 = self.decoder4(feat4, feat5) 202 | out3 = self.decoder3(feat3, out4) 203 | out2 = self.decoder2(feat2, out3) 204 | out1 = self.decoder1(feat1, out2) 205 | out_last = self.final(out1) 206 | return out_last 207 | 208 | class input_enhancement(nn.Module): 209 | def __init__(self): 210 | super(input_enhancement, self).__init__() 211 | self.conv = nn.Conv2d(9, 3, kernel_size=1, stride=1, padding=0) 212 | self.relu = nn.ReLU(inplace=True) 213 | 214 | def forward(self, origin, binary_mask): 215 | x1 = torch.mul(origin, binary_mask) 216 | out = torch.add(x1, origin) 217 | out = torch.cat([x1, origin, out], dim=1) 218 | out = self.conv(out) 219 | # out = self.relu(out) 220 | return out 221 | 222 | 223 | class Tooth_bone_separation(nn.Module): 224 | def __init__(self): 225 | super(Tooth_bone_separation, self).__init__() 226 | self.encoder = Encoder(in_channel=3) 227 | 228 | self.Tdecoder = nn.ModuleList( 229 | [Decoder_block(in_channel=1024, out_channel=512), 230 | Decoder_block(in_channel=768, out_channel=256), 231 | Decoder_block(in_channel=384, out_channel=128), 232 | Decoder_block(in_channel=192, out_channel=64)] 233 | ) 234 | 235 | self.Bdecoder = nn.ModuleList( 236 | [Decoder_block(in_channel=1024, out_channel=512), 237 | Decoder_block(in_channel=768, out_channel=256), 238 | Decoder_block(in_channel=384, out_channel=128), 239 | Decoder_block(in_channel=192, out_channel=64)] 240 | ) 241 | 242 | self.Tmulti = nn.ModuleList( 243 | [ 244 | multi_scale_feature(zoom='UP', input_feature_size=256, in_multi_size=128, channel=64), 245 | multi_scale_feature(zoom='None', input_feature_size=128, in_multi_size=128, channel=128), 246 | multi_scale_feature(zoom='DOWN', input_feature_size=64, in_multi_size=128, channel=256), 247 | multi_scale_feature(zoom='DOWN', input_feature_size=32, in_multi_size=128, channel=512) 248 | ] 249 | ) 250 | 251 | self.Bmulti = nn.ModuleList( 252 | [ 253 | multi_scale_feature(zoom='UP', input_feature_size=256, in_multi_size=64, channel=64), 254 | multi_scale_feature(zoom='UP', input_feature_size=128, in_multi_size=64, channel=128), 255 | multi_scale_feature(zoom='None', input_feature_size=64, in_multi_size=64, channel=256), 256 | multi_scale_feature(zoom='DOWN', input_feature_size=32, in_multi_size=64, channel=512) 257 | ] 258 | ) 259 | 260 | self.Tooth_multi_scale = Tooth_multi_scale() 261 | self.Bone_multi_scale = Bone_multi_scale() 262 | self.Tfinal = nn.Conv2d(64, 3, 1) # background, WT, SM 263 | self.Bfinal = nn.Conv2d(64, 2, 1) # background, AB 264 | 265 | def forward(self, x): 266 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x) 267 | 268 | Tooth_multi = self.Tooth_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 128, 128) 269 | Tooth_feat1 = self.Tmulti[0](input_feature=feat1, in_multi=Tooth_multi) 270 | Tooth_feat2 = self.Tmulti[1](input_feature=feat2, in_multi=Tooth_multi) 271 | Tooth_feat3 = self.Tmulti[2](input_feature=feat3, in_multi=Tooth_multi) 272 | Tooth_feat4 = self.Tmulti[3](input_feature=feat4, in_multi=Tooth_multi) 273 | Tout4 = self.Tdecoder[0](Tooth_feat4, feat5) 274 | Tout3 = self.Tdecoder[1](Tooth_feat3, Tout4) 275 | Tout2 = self.Tdecoder[2](Tooth_feat2, Tout3) 276 | Tout1 = self.Tdecoder[3](Tooth_feat1, Tout2) 277 | 278 | out_tooth_last = self.Tfinal(Tout1) 279 | 280 | Bone_multi = self.Bone_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 64, 64) 281 | Bone_feat1 = self.Bmulti[0](input_feature=feat1, in_multi=Bone_multi) 282 | Bone_feat2 = self.Bmulti[1](input_feature=feat2, in_multi=Bone_multi) 283 | Bone_feat3 = self.Bmulti[2](input_feature=feat3, in_multi=Bone_multi) 284 | Bone_feat4 = self.Bmulti[3](input_feature=feat4, in_multi=Bone_multi) 285 | Bout4 = self.Bdecoder[0](Bone_feat4, feat5) 286 | Bout3 = self.Bdecoder[1](Bone_feat3, Bout4) 287 | Bout2 = self.Bdecoder[2](Bone_feat2, Bout3) 288 | Bout1 = self.Bdecoder[3](Bone_feat1, Bout2) 289 | 290 | out_bone_last = self.Bfinal(Bout1) 291 | 292 | return out_tooth_last, out_bone_last 293 | 294 | 295 | class WTNet(nn.Module): 296 | def __init__(self): 297 | super(WTNet, self).__init__() 298 | self.Binary = Binary_mask() 299 | self.input_enhancement = input_enhancement() 300 | self.TBS = Tooth_bone_separation() 301 | 302 | def forward(self, x): 303 | Binary_out = self.Binary(x) 304 | Binary_map = torch.nn.functional.softmax(Binary_out, dim=1) 305 | Binary_map = torch.argmax(Binary_map, dim=1, keepdim=True) 306 | enhancement = self.input_enhancement(x, Binary_map) 307 | out_tooth_last, out_bone_last = self.TBS(enhancement) 308 | return Binary_out, out_tooth_last, out_bone_last 309 | 310 | 311 | if __name__ == '__main__': 312 | model = WTNet() 313 | a = torch.rand(size=(1, 3, 256, 256)) 314 | b, c, d = model(a) 315 | print(b.shape, c.shape, d.shape) 316 | 317 | -------------------------------------------------------------------------------- /modles/WTNet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Decoder_block(nn.Module): 8 | def __init__(self, num_classes=2, init_features=64): 9 | super(Decoder_block, self).__init__() 10 | features = init_features 11 | out_channels = num_classes 12 | 13 | self.upconv4 = nn.ConvTranspose3d( 14 | features * 8, features * 8, kernel_size=2, stride=2 15 | ) 16 | self.decoder4 = Decoder_block._block((features * 8) * 2, features * 8, name="dec4") 17 | self.upconv3 = nn.ConvTranspose3d( 18 | features * 8, features * 4, kernel_size=2, stride=2 19 | ) 20 | self.decoder3 = Decoder_block._block((features * 4) * 2, features * 4, name="dec3") 21 | self.upconv2 = nn.ConvTranspose3d( 22 | features * 4, features * 2, kernel_size=2, stride=2 23 | ) 24 | self.decoder2 = Decoder_block._block((features * 2) * 2, features * 2, name="dec2") 25 | self.upconv1 = nn.ConvTranspose3d( 26 | features * 2, features, kernel_size=2, stride=2 27 | ) 28 | self.decoder1 = Decoder_block._block(features * 2, features, name="dec1") 29 | 30 | self.conv = nn.Conv3d( 31 | in_channels=features, out_channels=out_channels, kernel_size=1 32 | ) 33 | 34 | def forward(self, fea1, fea2, fea3, fea4, fea5): 35 | dec4 = self.upconv4(fea5) 36 | dec4 = torch.cat((dec4, fea4), dim=1) 37 | dec4 = self.decoder4(dec4) 38 | dec3 = self.upconv3(dec4) 39 | dec3 = torch.cat((dec3, fea3), dim=1) 40 | dec3 = self.decoder3(dec3) 41 | dec2 = self.upconv2(dec3) 42 | dec2 = torch.cat((dec2, fea2), dim=1) 43 | dec2 = self.decoder2(dec2) 44 | dec1 = self.upconv1(dec2) 45 | dec1 = torch.cat((dec1, fea1), dim=1) 46 | dec1 = self.decoder1(dec1) 47 | outputs = self.conv(dec1) 48 | return outputs 49 | 50 | @staticmethod 51 | def _block(in_channels, features, name): 52 | return nn.Sequential( 53 | OrderedDict( # 有序字典 54 | [ 55 | ( 56 | name + "conv1", 57 | nn.Conv3d( 58 | in_channels=in_channels, 59 | out_channels=features, 60 | kernel_size=3, 61 | padding=1, 62 | bias=True, 63 | ), 64 | ), 65 | (name + "norm1", nn.BatchNorm3d(num_features=features)), 66 | (name + "relu1", nn.ReLU(inplace=True)), 67 | ( 68 | name + "conv2", 69 | nn.Conv3d( 70 | in_channels=features, 71 | out_channels=features, 72 | kernel_size=3, 73 | padding=1, 74 | bias=True, 75 | ), 76 | ), 77 | (name + "norm2", nn.BatchNorm3d(num_features=features)), 78 | (name + "relu2", nn.ReLU(inplace=True)), 79 | ] 80 | ) 81 | ) 82 | 83 | 84 | class Encoder(nn.Module): 85 | """ 86 | for input size of (B, 1, 64, 64, 64) 87 | output size is: feat1, feat2, feat3, feat4, feat5 88 | 89 | torch.Size([1, 64, 256, 256]) 90 | torch.Size([1, 128, 128, 128]) 91 | torch.Size([1, 256, 64, 64]) 92 | torch.Size([1, 512, 32, 32]) 93 | torch.Size([1, 512, 16, 16]) 94 | """ 95 | 96 | def __init__(self, in_channels=1, init_features=64): 97 | super(Encoder, self).__init__() 98 | 99 | features = init_features 100 | self.encoder1 = Encoder._block(in_channels, features, name="enc1") 101 | self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) 102 | self.encoder2 = Encoder._block(features, features * 2, name="enc2") 103 | self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) 104 | self.encoder3 = Encoder._block(features * 2, features * 4, name="enc3") 105 | self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) 106 | self.encoder4 = Encoder._block(features * 4, features * 8, name="enc4") 107 | self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) 108 | 109 | self.bottleneck = Encoder._block(features * 8, features * 8, name="bottleneck") 110 | 111 | def forward(self, x): 112 | feat1 = self.encoder1(x) 113 | feat2 = self.encoder2(self.pool1(feat1)) 114 | feat3 = self.encoder3(self.pool2(feat2)) 115 | feat4 = self.encoder4(self.pool3(feat3)) 116 | feat5 = self.bottleneck(self.pool4(feat4)) 117 | 118 | return feat1, feat2, feat3, feat4, feat5 119 | 120 | @staticmethod 121 | def _block(in_channels, features, name): 122 | return nn.Sequential( 123 | OrderedDict( # 有序字典 124 | [ 125 | ( 126 | name + "conv1", 127 | nn.Conv3d( 128 | in_channels=in_channels, 129 | out_channels=features, 130 | kernel_size=3, 131 | padding=1, 132 | bias=True, 133 | ), 134 | ), 135 | (name + "norm1", nn.BatchNorm3d(num_features=features)), 136 | (name + "relu1", nn.ReLU(inplace=True)), 137 | ( 138 | name + "conv2", 139 | nn.Conv3d( 140 | in_channels=features, 141 | out_channels=features, 142 | kernel_size=3, 143 | padding=1, 144 | bias=True, 145 | ), 146 | ), 147 | (name + "norm2", nn.BatchNorm3d(num_features=features)), 148 | (name + "relu2", nn.ReLU(inplace=True)), 149 | ] 150 | ) 151 | ) 152 | 153 | 154 | class ECABlock(nn.Module): 155 | def __init__(self, channels, gamma=2, bias=1): 156 | super(ECABlock, self).__init__() 157 | 158 | # 设计自适应卷积核,便于后续做1*1卷积 159 | kernel_size = int(abs((math.log(channels, 2) + bias) / gamma)) 160 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 161 | # 全局平局池化 162 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 163 | # 基于1*1卷积学习通道之间的信息 164 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 165 | # 激活函数 166 | self.sigmoid = nn.Sigmoid() 167 | 168 | def forward(self, x): 169 | # 首先,空间维度做全局平局池化,[b,c,h,w,d]==>[b,c,1,1,1] 170 | v = self.avg_pool(x) 171 | # 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核 172 | v = self.conv(v.squeeze(-1).squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1).unsqueeze(-1) 173 | # 最终,经过sigmoid 激活函数处理 174 | v = self.sigmoid(v) 175 | return x * v 176 | 177 | 178 | class Tooth_multi_scale(nn.Module): 179 | """ 180 | all feature map are sampling to 32*32*32, then concat in channel dimension 181 | finally, execute channel attention to all channels 182 | """ 183 | 184 | def __init__(self): 185 | super(Tooth_multi_scale, self).__init__() 186 | self.input1_down = nn.MaxPool3d(kernel_size=2, stride=2) 187 | self.input2_out = nn.Identity() 188 | self.input3_up = nn.Upsample(scale_factor=2) 189 | self.input4_up = nn.Upsample(scale_factor=4) 190 | self.channel_atten = ECABlock(channels=960) 191 | 192 | def forward(self, input1, input2, input3, input4): 193 | out1 = self.input1_down(input1) 194 | out2 = self.input2_out(input2) 195 | out3 = self.input3_up(input3) 196 | out4 = self.input4_up(input4) 197 | out = torch.cat([out1, out2, out3, out4], dim=1) 198 | channel_atten_out = self.channel_atten(out) 199 | return channel_atten_out 200 | 201 | 202 | class Bone_multi_scale(nn.Module): 203 | """ 204 | all feature map are sampling to 16*16*16, then concat in channel dimension 205 | finally, execute channel attention to all channels 206 | """ 207 | 208 | def __init__(self): 209 | super(Bone_multi_scale, self).__init__() 210 | self.input1_down = nn.MaxPool3d(kernel_size=4, stride=4) 211 | self.input2_down = nn.MaxPool3d(kernel_size=2, stride=2) 212 | self.input3_out = nn.Identity() 213 | self.input4_up = nn.Upsample(scale_factor=2) 214 | self.channel_atten = ECABlock(channels=960) 215 | 216 | def forward(self, input1, input2, input3, input4): 217 | out1 = self.input1_down(input1) 218 | out2 = self.input2_down(input2) 219 | out3 = self.input3_out(input3) 220 | out4 = self.input4_up(input4) 221 | out = torch.cat([out1, out2, out3, out4], dim=1) 222 | channel_atten_out = self.channel_atten(out) 223 | return channel_atten_out 224 | 225 | 226 | class SpatialAttention(nn.Module): # Spatial Attention Module 227 | def __init__(self): 228 | super(SpatialAttention, self).__init__() 229 | self.conv1 = nn.Conv3d(2, 1, kernel_size=7, padding=3, bias=False) 230 | self.sigmoid = nn.Sigmoid() 231 | 232 | def forward(self, x): 233 | avg_out = torch.mean(x, dim=1, keepdim=True) 234 | max_out, _ = torch.max(x, dim=1, keepdim=True) 235 | out = torch.cat([avg_out, max_out], dim=1) 236 | out = self.conv1(out) 237 | out = self.sigmoid(out) 238 | return out 239 | 240 | 241 | class multi_scale_feature(nn.Module): 242 | def __init__(self, zoom=None, input_feature_size=None, in_multi_size=None, channel=None): 243 | super(multi_scale_feature, self).__init__() 244 | self.input_feature_size = input_feature_size 245 | self.in_multi_size = in_multi_size 246 | self.zoom = zoom 247 | self.channel = channel 248 | self.conv = nn.Conv3d(in_channels=960, out_channels=channel, kernel_size=1, stride=1) 249 | self.atten = SpatialAttention() 250 | 251 | if self.zoom == 'UP': 252 | self.k = input_feature_size / in_multi_size 253 | self.up = nn.Upsample(scale_factor=self.k) 254 | elif self.zoom == 'DOWN': 255 | self.avg = nn.AdaptiveAvgPool3d(self.input_feature_size) 256 | elif self.zoom == "None": 257 | self.none = nn.Identity() 258 | 259 | def forward(self, input_feature, in_multi): 260 | if self.zoom == 'UP': 261 | out_up = self.up(in_multi) 262 | out_adjust_channel = self.conv(out_up) 263 | x_add = torch.add(out_adjust_channel, input_feature) 264 | spatial_attention_map = self.atten(x_add) 265 | out = torch.mul(spatial_attention_map, input_feature) 266 | return out 267 | 268 | if self.zoom == 'DOWN': 269 | out_down = self.avg(in_multi) 270 | out_adjust_channel = self.conv(out_down) 271 | x_add = torch.add(out_adjust_channel, input_feature) 272 | spatial_attention_map = self.atten(x_add) 273 | out = torch.mul(spatial_attention_map, input_feature) 274 | return out 275 | if self.zoom == 'None': 276 | out_none = self.none(in_multi) 277 | out_adjust_channel = self.conv(out_none) 278 | x_add = torch.add(out_adjust_channel, input_feature) 279 | spatial_attention_map = self.atten(x_add) 280 | out = torch.mul(spatial_attention_map, input_feature) 281 | return out 282 | 283 | 284 | class Binary_mask(nn.Module): 285 | def __init__(self, num_classes=2): 286 | super(Binary_mask, self).__init__() 287 | self.num_classes = num_classes 288 | self.encoder = Encoder(in_channels=1, init_features=64) 289 | self.decoder = Decoder_block(num_classes=2, init_features=64) 290 | 291 | def forward(self, x): 292 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x) 293 | out = self.decoder(feat1, feat2, feat3, feat4, feat5) 294 | return out 295 | 296 | 297 | class input_enhancement(nn.Module): 298 | def __init__(self): 299 | super(input_enhancement, self).__init__() 300 | self.conv = nn.Conv3d(3, 1, kernel_size=1, stride=1, padding=0) 301 | self.relu = nn.ReLU(inplace=True) 302 | 303 | def forward(self, origin, binary_mask): 304 | x1 = torch.mul(origin, binary_mask) 305 | out = torch.add(x1, origin) 306 | out = torch.cat([x1, origin, out], dim=1) 307 | out = self.conv(out) 308 | # out = self.relu(out) 309 | return out 310 | 311 | 312 | class Tooth_bone_separation(nn.Module): 313 | def __init__(self): 314 | super(Tooth_bone_separation, self).__init__() 315 | self.encoder = Encoder() 316 | 317 | self.Tdecoder = Decoder_block() 318 | self.Bdecoder = Decoder_block() 319 | 320 | self.Tmulti = nn.ModuleList( 321 | [ 322 | multi_scale_feature(zoom='UP', input_feature_size=64, in_multi_size=32, channel=64), 323 | multi_scale_feature(zoom='None', input_feature_size=32, in_multi_size=32, channel=128), 324 | multi_scale_feature(zoom='DOWN', input_feature_size=16, in_multi_size=32, channel=256), 325 | multi_scale_feature(zoom='DOWN', input_feature_size=8, in_multi_size=32, channel=512) 326 | ] 327 | ) 328 | 329 | self.Bmulti = nn.ModuleList( 330 | [ 331 | multi_scale_feature(zoom='UP', input_feature_size=64, in_multi_size=16, channel=64), 332 | multi_scale_feature(zoom='UP', input_feature_size=32, in_multi_size=16, channel=128), 333 | multi_scale_feature(zoom='None', input_feature_size=16, in_multi_size=16, channel=256), 334 | multi_scale_feature(zoom='DOWN', input_feature_size=8, in_multi_size=16, channel=512) 335 | ] 336 | ) 337 | 338 | self.Tooth_multi_scale = Tooth_multi_scale() 339 | self.Bone_multi_scale = Bone_multi_scale() 340 | self.Tfinal = nn.Conv3d(2, 3, 1) # background, WT, SM 341 | self.Bfinal = nn.Conv3d(2, 2, 1) # background, AB 342 | 343 | def forward(self, x): 344 | feat1, feat2, feat3, feat4, feat5 = self.encoder(x) 345 | Tooth_multi = self.Tooth_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 32, 32, 32) 346 | Tooth_feat1 = self.Tmulti[0](input_feature=feat1, in_multi=Tooth_multi) 347 | Tooth_feat2 = self.Tmulti[1](input_feature=feat2, in_multi=Tooth_multi) 348 | Tooth_feat3 = self.Tmulti[2](input_feature=feat3, in_multi=Tooth_multi) 349 | Tooth_feat4 = self.Tmulti[3](input_feature=feat4, in_multi=Tooth_multi) 350 | 351 | Tout1 = self.Tdecoder(Tooth_feat1, Tooth_feat2, Tooth_feat3, Tooth_feat4, feat5) 352 | out_tooth_last = self.Tfinal(Tout1) 353 | 354 | Bone_multi = self.Bone_multi_scale(feat1, feat2, feat3, feat4) # (B, 960, 16, 16, 16) 355 | Bone_feat1 = self.Bmulti[0](input_feature=feat1, in_multi=Bone_multi) 356 | Bone_feat2 = self.Bmulti[1](input_feature=feat2, in_multi=Bone_multi) 357 | Bone_feat3 = self.Bmulti[2](input_feature=feat3, in_multi=Bone_multi) 358 | Bone_feat4 = self.Bmulti[3](input_feature=feat4, in_multi=Bone_multi) 359 | 360 | Bout1 = self.Bdecoder(Bone_feat1, Bone_feat2, Bone_feat3, Bone_feat4, feat5) 361 | out_bone_last = self.Bfinal(Bout1) 362 | 363 | return out_tooth_last, out_bone_last 364 | 365 | 366 | class WTNet(nn.Module): 367 | def __init__(self): 368 | super(WTNet, self).__init__() 369 | self.Binary = Binary_mask() 370 | self.input_enhancement = input_enhancement() 371 | self.TBS = Tooth_bone_separation() 372 | 373 | def forward(self, x): 374 | Binary_out = self.Binary(x) 375 | Binary_map = torch.nn.functional.softmax(Binary_out, dim=1) 376 | Binary_map = torch.argmax(Binary_map, dim=1, keepdim=True) 377 | enhancement = self.input_enhancement(x, Binary_map) 378 | out_tooth_last, out_bone_last = self.TBS(enhancement) 379 | return Binary_out, out_tooth_last, out_bone_last 380 | 381 | 382 | if __name__ == '__main__': 383 | a = torch.rand(size=(2, 1, 64, 64, 64)) 384 | model = WTNet() 385 | Binary_out, out_tooth_last, out_bone_last = model(a) 386 | print(Binary_out.shape) 387 | print(out_tooth_last.shape) 388 | print(out_bone_last.shape) 389 | # print(feat4.shape) 390 | # print(feat5.shape) 391 | # 392 | # eca = SpatialAttention() 393 | # out = eca(feat1) 394 | # print(out.shape) 395 | --------------------------------------------------------------------------------