├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── data ├── train │ ├── 000.png │ └── 000_mask.png └── val │ ├── 000.png │ └── 000_mask.png ├── dataset.py ├── main.py └── unet.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 PJ-Javis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # unet liver 2 | Unet network for liver CT image segmentation 3 | ## data preparation 4 | structure of project 5 | ``` 6 | --project 7 | main.py 8 | --data 9 | --train 10 | --val 11 | ``` 12 | data and trained weight link: https://pan.baidu.com/s/1dgGnsfoSmL1lbOUwyItp6w code: 17yr 13 | 14 | all dataset you can access from: https://competitions.codalab.org/competitions/15595 15 | 16 | ## training 17 | ``` 18 | python main.py train 19 | ``` 20 | 21 | ## testing 22 | load the last saved weight 23 | ``` 24 | python main.py test --ckpt=weights_19.pth 25 | ``` 26 | ---- 27 | 28 | ## 数据准备 29 | 项目文件分布如下 30 | ``` 31 | --project 32 | main.py 33 | --data 34 | --train 35 | --val 36 | ``` 37 | 38 | 数据和权重可以使用百度云下载 链接: 39 | 40 | 链接: https://pan.baidu.com/s/1dgGnsfoSmL1lbOUwyItp6w 提取码: 17yr 41 | 42 | 全部数据集: https://competitions.codalab.org/competitions/15595 43 | 44 | ## 模型训练 45 | ``` 46 | python main.py train 47 | ``` 48 | 49 | ## 测试模型训练 50 | 加载权重,默认保存最后一个权重 51 | ``` 52 | python main.py test --ckpt=weights_19.pth 53 | ``` 54 | ## 多类别 55 | 修改2个地方即可:unet最后一层的通道数设置为类别数;损失函数使用CrossEntropyLoss 56 | ```python 57 | bath_size,img_size,num_classes=2,3,4 58 | #model = Unet(3, num_classes) 59 | criterion = nn.CrossEntropyLoss() 60 | #assume the pred is the output of the model 61 | pred=torch.rand(bath_size,num_classes,img_size,img_size) 62 | target=torch.randint(num_classes,(bath_size,img_size,img_size)) 63 | loss=criterion(pred,target) 64 | ``` 65 | 66 | ## Demo 67 | ![liver](https://img-blog.csdn.net/20180508083935908) 68 | -------------------------------------------------------------------------------- /data/train/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavisPeng/u_net_liver/a1b9553d8ba8c6e5a3d4c5fabd387e130e60a072/data/train/000.png -------------------------------------------------------------------------------- /data/train/000_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavisPeng/u_net_liver/a1b9553d8ba8c6e5a3d4c5fabd387e130e60a072/data/train/000_mask.png -------------------------------------------------------------------------------- /data/val/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavisPeng/u_net_liver/a1b9553d8ba8c6e5a3d4c5fabd387e130e60a072/data/val/000.png -------------------------------------------------------------------------------- /data/val/000_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavisPeng/u_net_liver/a1b9553d8ba8c6e5a3d4c5fabd387e130e60a072/data/val/000_mask.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import PIL.Image as Image 3 | import os 4 | 5 | 6 | def make_dataset(root): 7 | imgs=[] 8 | n=len(os.listdir(root))//2 9 | for i in range(n): 10 | img=os.path.join(root,"%03d.png"%i) 11 | mask=os.path.join(root,"%03d_mask.png"%i) 12 | imgs.append((img,mask)) 13 | return imgs 14 | 15 | 16 | class LiverDataset(Dataset): 17 | def __init__(self, root, transform=None, target_transform=None): 18 | imgs = make_dataset(root) 19 | self.imgs = imgs 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | 23 | def __getitem__(self, index): 24 | x_path, y_path = self.imgs[index] 25 | img_x = Image.open(x_path) 26 | img_y = Image.open(y_path) 27 | if self.transform is not None: 28 | img_x = self.transform(img_x) 29 | if self.target_transform is not None: 30 | img_y = self.target_transform(img_y) 31 | return img_x, img_y 32 | 33 | def __len__(self): 34 | return len(self.imgs) 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from torch.utils.data import DataLoader 4 | from torch import nn, optim 5 | from torchvision.transforms import transforms 6 | from unet import Unet 7 | from dataset import LiverDataset 8 | 9 | 10 | # 是否使用cuda 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | x_transforms = transforms.Compose([ 14 | transforms.ToTensor(), 15 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 16 | ]) 17 | 18 | # mask只需要转换为tensor 19 | y_transforms = transforms.ToTensor() 20 | 21 | def train_model(model, criterion, optimizer, dataload, num_epochs=20): 22 | for epoch in range(num_epochs): 23 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 24 | print('-' * 10) 25 | dt_size = len(dataload.dataset) 26 | epoch_loss = 0 27 | step = 0 28 | for x, y in dataload: 29 | step += 1 30 | inputs = x.to(device) 31 | labels = y.to(device) 32 | # zero the parameter gradients 33 | optimizer.zero_grad() 34 | # forward 35 | outputs = model(inputs) 36 | loss = criterion(outputs, labels) 37 | loss.backward() 38 | optimizer.step() 39 | epoch_loss += loss.item() 40 | print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item())) 41 | print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step)) 42 | torch.save(model.state_dict(), 'weights_%d.pth' % epoch) 43 | return model 44 | 45 | #训练模型 46 | def train(args): 47 | model = Unet(3, 1).to(device) 48 | batch_size = args.batch_size 49 | criterion = nn.BCEWithLogitsLoss() 50 | optimizer = optim.Adam(model.parameters()) 51 | liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms) 52 | dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 53 | train_model(model, criterion, optimizer, dataloaders) 54 | 55 | #显示模型的输出结果 56 | def test(args): 57 | model = Unet(3, 1) 58 | model.load_state_dict(torch.load(args.ckpt,map_location='cpu')) 59 | liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms) 60 | dataloaders = DataLoader(liver_dataset, batch_size=1) 61 | model.eval() 62 | import matplotlib.pyplot as plt 63 | plt.ion() 64 | with torch.no_grad(): 65 | for x, _ in dataloaders: 66 | y=model(x).sigmoid() 67 | img_y=torch.squeeze(y).numpy() 68 | plt.imshow(img_y) 69 | plt.pause(0.01) 70 | plt.show() 71 | 72 | 73 | if __name__ == '__main__': 74 | #参数解析 75 | parse=argparse.ArgumentParser() 76 | parse = argparse.ArgumentParser() 77 | parse.add_argument("action", type=str, help="train or test") 78 | parse.add_argument("--batch_size", type=int, default=8) 79 | parse.add_argument("--ckpt", type=str, help="the path of model weight file") 80 | args = parse.parse_args() 81 | 82 | if args.action=="train": 83 | train(args) 84 | elif args.action=="test": 85 | test(args) 86 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class DoubleConv(nn.Module): 5 | def __init__(self, in_ch, out_ch): 6 | super(DoubleConv, self).__init__() 7 | self.conv = nn.Sequential( 8 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 9 | nn.BatchNorm2d(out_ch), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 12 | nn.BatchNorm2d(out_ch), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def forward(self, input): 17 | return self.conv(input) 18 | 19 | 20 | class Unet(nn.Module): 21 | def __init__(self,in_ch,out_ch): 22 | super(Unet, self).__init__() 23 | 24 | self.conv1 = DoubleConv(in_ch, 64) 25 | self.pool1 = nn.MaxPool2d(2) 26 | self.conv2 = DoubleConv(64, 128) 27 | self.pool2 = nn.MaxPool2d(2) 28 | self.conv3 = DoubleConv(128, 256) 29 | self.pool3 = nn.MaxPool2d(2) 30 | self.conv4 = DoubleConv(256, 512) 31 | self.pool4 = nn.MaxPool2d(2) 32 | self.conv5 = DoubleConv(512, 1024) 33 | self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) 34 | self.conv6 = DoubleConv(1024, 512) 35 | self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) 36 | self.conv7 = DoubleConv(512, 256) 37 | self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) 38 | self.conv8 = DoubleConv(256, 128) 39 | self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) 40 | self.conv9 = DoubleConv(128, 64) 41 | self.conv10 = nn.Conv2d(64,out_ch, 1) 42 | 43 | def forward(self,x): 44 | c1=self.conv1(x) 45 | p1=self.pool1(c1) 46 | c2=self.conv2(p1) 47 | p2=self.pool2(c2) 48 | c3=self.conv3(p2) 49 | p3=self.pool3(c3) 50 | c4=self.conv4(p3) 51 | p4=self.pool4(c4) 52 | c5=self.conv5(p4) 53 | up_6= self.up6(c5) 54 | merge6 = torch.cat([up_6, c4], dim=1) 55 | c6=self.conv6(merge6) 56 | up_7=self.up7(c6) 57 | merge7 = torch.cat([up_7, c3], dim=1) 58 | c7=self.conv7(merge7) 59 | up_8=self.up8(c7) 60 | merge8 = torch.cat([up_8, c2], dim=1) 61 | c8=self.conv8(merge8) 62 | up_9=self.up9(c8) 63 | merge9=torch.cat([up_9,c1],dim=1) 64 | c9=self.conv9(merge9) 65 | c10=self.conv10(c9) 66 | #out = nn.Sigmoid()(c10) 67 | return c10 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | --------------------------------------------------------------------------------