├── .gitignore ├── MeanTeacher.py ├── README.md ├── altrasound.zip ├── dataset.py ├── main.py ├── ramps.py └── unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /MeanTeacher.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from ramps import exp_rampup 6 | 7 | class Trainer: 8 | 9 | def __init__(self, model, ema_model, optimizer, device): 10 | self.model = model 11 | self.ema_model = ema_model 12 | self.optimizer = optimizer 13 | self.ce_loss = torch.nn.BCEWithLogitsLoss() 14 | self.usp_weight = 30.0 15 | self.ema_decay = 0.97 16 | self.rampup = exp_rampup(30) 17 | self.device = device 18 | self.global_step= 0 19 | 20 | def cons_loss(self, logit1, logit2): 21 | assert logit1.size() == logit2.size() 22 | return F.mse_loss(logit1, logit2) 23 | 24 | def train_iteration(self, data_loader_labeled, data_loader_unlabeled): 25 | 26 | # === training with label === 27 | for x, y in data_loader_labeled: 28 | stduent_input = x.to(self.device) 29 | teacher_input = x.to(self.device) 30 | targets = y.to(self.device) 31 | self.global_step = self.global_step + 1 32 | 33 | # === forward === 34 | outputs = self.model(stduent_input) 35 | loss = self.ce_loss(outputs, targets) 36 | print("labeled_loss: %0.3f" % loss.item()) 37 | 38 | # === Semi-supervised Training === 39 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 40 | # consistency loss 41 | with torch.no_grad(): 42 | ema_outputs = self.ema_model(teacher_input) 43 | ema_outputs = ema_outputs.detach() 44 | cons_loss = self.cons_loss(outputs, ema_outputs) 45 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 46 | loss += cons_loss 47 | print("consistent_loss: %0.3f" % cons_loss.item()) 48 | 49 | # backward 50 | self.optimizer.zero_grad() 51 | loss.backward() 52 | self.optimizer.step() 53 | 54 | # === training without label === 55 | for x in data_loader_unlabeled: 56 | stduent_input = x.to(self.device) 57 | teacher_input = x.to(self.device) 58 | 59 | # === forward === 60 | outputs = self.model(stduent_input) 61 | 62 | # === Semi-supervised Training === 63 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 64 | with torch.no_grad(): 65 | ema_outputs = self.ema_model(teacher_input) 66 | ema_outputs = ema_outputs.detach() 67 | # === consistency loss === 68 | cons_loss = self.cons_loss(outputs, ema_outputs) 69 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 70 | print("unlabeled_consistent_loss: %0.3f" % cons_loss.item()) 71 | 72 | # backward 73 | self.optimizer.zero_grad() 74 | cons_loss.backward() 75 | self.optimizer.step() 76 | return self.model, self.ema_model 77 | 78 | def train(self, data_loader_labeled, data_loader_unlabeled): 79 | self.model.train() 80 | self.ema_model.train() 81 | with torch.enable_grad(): 82 | return self.train_iteration(data_loader_labeled, data_loader_unlabeled) 83 | 84 | def test(self, model, ema_model, stu_ckpt, t_ckpt, test_data): 85 | step = 1 86 | model.load_state_dict(torch.load(stu_ckpt, map_location='cpu')) 87 | ema_model.load_state_dict(torch.load(t_ckpt, map_location='cpu')) 88 | for x, y in test_data: 89 | print("----- img %d -----" % step) 90 | stduent_input = x.to(self.device) 91 | teacher_input = x.to(self.device) 92 | targets = y.to(self.device) 93 | outputs = self.model(stduent_input) 94 | student_test_loss = self.ce_loss(outputs, targets) 95 | print("student_test_loss: %0.3f" % student_test_loss.item()) 96 | outputs = self.ema_model(teacher_input) 97 | teacher_test_loss = self.ce_loss(outputs, targets) 98 | print("teacher_test_loss: %0.3f" % teacher_test_loss.item()) 99 | step = step + 1 100 | 101 | def loop_train(self, epochs, train_data_labeled, train_data_unlabeled, scheduler=None): 102 | for ep in range(epochs): 103 | self.epoch = ep 104 | print("------ Training epochs: {} ------".format(ep)) 105 | model, ema_model = self.train(train_data_labeled, train_data_unlabeled) 106 | if scheduler is not None: 107 | scheduler.step() 108 | torch.save(model.state_dict(), 'student_weights_%d.pth' % ep) 109 | torch.save(ema_model.state_dict(), 'teacher_weights_%d.pth' % ep) 110 | # save model 111 | print("Model is saved!") 112 | 113 | def update_ema(self, model, ema_model, alpha, global_step): 114 | alpha = min(1 - 1 / (global_step + 1), alpha) 115 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 116 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # unet-master 2 | Unet network with mean teacher for altrasound image segmentation 3 | ## data preparation 4 | structure of project 5 | ``` 6 | --project 7 | main.py 8 | unet.py 9 | dataset.py 10 | altrasound.zip 11 | 12 | all dataset you can access by email:1901684@stu.neu.edu.cn 13 | ``` 14 | ## training 15 | ``` 16 | main.py: 17 | if __name__ == '__main__': 18 | batch_size = 4 19 | train(batch_size) 20 | 21 | ``` 22 | ## testing 23 | ``` 24 | main.py: 25 | if __name__ == '__main__': 26 | student_ckpt = "student_weight path" 27 | teacher_ckpt = "teacher_weight path" 28 | test(student_ckpt, teacher_ckpt) 29 | 30 | -------------------------------------------------------------------------------- /altrasound.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/Simple-U-net-with-mean-teacher/e32c97cc191c7d738db9b82ba750ca0f49d4419b/altrasound.zip -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import PIL.Image as Image 3 | import os 4 | 5 | 6 | 7 | def make_dataset(root): 8 | imgs=[] 9 | #训练集中的n张图片 10 | n = 10 11 | for i in range(n): 12 | root1 = root + '/img' 13 | root2 = root + '/gt' 14 | img = os.path.join(root1, "%d.jpg" % i) 15 | mask = os.path.join(root2, "%d.png" % i) 16 | imgs.append((img, mask)) 17 | return imgs 18 | 19 | def make_test_dataset(root): 20 | imgs=[] 21 | #训练集中的n张图片 22 | n = 5 23 | for i in range(n): 24 | root1 = root + '/img' 25 | root2 = root + '/gt' 26 | img = os.path.join(root1, "%d.jpg" % i) 27 | mask = os.path.join(root2, "%d.png" % i) 28 | imgs.append((img, mask)) 29 | return imgs 30 | 31 | def make_dataset_unlabeled(root): 32 | imgs = [] 33 | n = 40 34 | for i in range(n): 35 | root1 = root + '/img' 36 | img = os.path.join(root1, "%d.jpg" % (i + 10)) 37 | imgs.append(img) 38 | return imgs 39 | 40 | class MyDataset(Dataset): 41 | def __init__(self, root, transform=None, target_transform=None): 42 | imgs = make_dataset(root) 43 | self.imgs = imgs 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | 47 | def __getitem__(self, index): 48 | x_path, y_path = self.imgs[index] 49 | img_x = Image.open(x_path) 50 | img_y = Image.open(y_path) 51 | if self.transform is not None: 52 | img_x = self.transform(img_x) 53 | if self.target_transform is not None: 54 | img_y = self.target_transform(img_y) 55 | 56 | return img_x, img_y 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | class MyDataset_unlabeled(Dataset): 62 | def __init__(self, root, transform=None): 63 | imgs = make_dataset_unlabeled(root) 64 | self.imgs = imgs 65 | self.transform = transform 66 | 67 | def __getitem__(self, index): 68 | x_path = self.imgs[index] 69 | img_x = Image.open(x_path) 70 | if self.transform is not None: 71 | img_x = self.transform(img_x) 72 | 73 | return img_x 74 | 75 | def __len__(self): 76 | return len(self.imgs) 77 | 78 | class MyDataset_test(Dataset): 79 | def __init__(self, root, transform=None, target_transform=None): 80 | imgs = make_test_dataset(root) 81 | self.imgs = imgs 82 | self.transform = transform 83 | self.target_transform = target_transform 84 | 85 | def __getitem__(self, index): 86 | x_path, y_path = self.imgs[index] 87 | img_x = Image.open(x_path) 88 | img_y = Image.open(y_path) 89 | if self.transform is not None: 90 | img_x = self.transform(img_x) 91 | if self.target_transform is not None: 92 | img_y = self.target_transform(img_y) 93 | 94 | return img_x, img_y 95 | 96 | def __len__(self): 97 | return len(self.imgs) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch import nn, optim 3 | from torchvision.transforms import transforms 4 | from torch.optim import lr_scheduler 5 | 6 | from unet import Unet 7 | from dataset import MyDataset 8 | from dataset import MyDataset_unlabeled 9 | from dataset import MyDataset_test 10 | from MeanTeacher import * 11 | 12 | 13 | # 是否使用cuda 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | # 数据转换 17 | # image转换 18 | x_transforms = transforms.Compose([ 19 | transforms.Resize([256,256]), 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 22 | ]) 23 | 24 | # mask转换 25 | y_transforms = transforms.Compose([ 26 | transforms.Resize([256,256]), 27 | transforms.ToTensor() 28 | ]) 29 | 30 | def train(batch_size): 31 | student_model = Unet(3, 1).to(device) 32 | teacher_model = Unet(3,1).to(device) 33 | epochs = 50 34 | min_lr = 1e-4 35 | batch_size = batch_size 36 | optimizer = optim.Adam(student_model.parameters()) 37 | # 学习率调节 38 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 39 | T_max=epochs, 40 | eta_min=min_lr) 41 | liver_dataset_labeled = MyDataset("altrasound/train_labeled",transform=x_transforms,target_transform=y_transforms) 42 | liver_dataset_unlabeled = MyDataset_unlabeled("altrasound/train_unlabeled",transform=x_transforms) 43 | dataloaders_labeled = DataLoader(liver_dataset_labeled, batch_size=batch_size, shuffle=True, num_workers=0) 44 | dataloaders_unlabeled = DataLoader(liver_dataset_unlabeled, batch_size=batch_size, shuffle=True, num_workers=0) 45 | trainer = Trainer(student_model, teacher_model, optimizer, device) 46 | trainer.loop_train(epochs, dataloaders_labeled, dataloaders_unlabeled, scheduler) 47 | 48 | def check(stu_ckpt, t_ckpt): 49 | student_model = Unet(3, 1).to(device) 50 | teacher_model = Unet(3,1).to(device) 51 | liver_dataset_test = MyDataset_test("altrasound/val",transform=x_transforms,target_transform=y_transforms) 52 | dataloaders_test = DataLoader(liver_dataset_test, batch_size=1) 53 | optimizer = optim.Adam(student_model.parameters()) 54 | trainer = Trainer(student_model, teacher_model, optimizer, device) 55 | trainer.test(student_model, teacher_model, stu_ckpt, t_ckpt, dataloaders_test) 56 | 57 | if __name__ == '__main__': 58 | # 设置训练时的batch_size 59 | batch_size = 4 60 | train(batch_size) 61 | 62 | # 训练后进行测试 63 | check_times = 50 64 | for t in range(check_times): 65 | stu_ckpt = 'student_weights_%d.pth' % t 66 | t_ckpt = 'teacher_weights_%d.pth' % t 67 | print('##### check iteration %d #####' % (t + 1)) 68 | check(stu_ckpt, t_ckpt) 69 | print("") 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /ramps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def pseudo_rampup(T1, T2): 4 | def warpper(epoch): 5 | if epoch > T1: 6 | alpha = (epoch-T1) / (T2-T1) 7 | if epoch > T2: 8 | alpha = 1.0 9 | else: 10 | alpha = 0.0 11 | return alpha 12 | return warpper 13 | 14 | 15 | def exp_rampup(rampup_length): 16 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 17 | def warpper(epoch): 18 | if epoch < rampup_length: 19 | epoch = np.clip(epoch, 0.0, rampup_length) 20 | phase = 1.0 - epoch / rampup_length 21 | return float(np.exp(-5.0 * phase * phase)) 22 | else: 23 | return 1.0 24 | return warpper 25 | 26 | 27 | def linear_rampup(rampup_length): 28 | """Linear rampup""" 29 | def warpper(epoch): 30 | if epoch < rampup_length: 31 | return epoch / rampup_length 32 | else: 33 | return 1.0 34 | return warpper 35 | 36 | 37 | def exp_rampdown(rampdown_length, num_epochs): 38 | """Exponential rampdown from https://arxiv.org/abs/1610.02242""" 39 | def warpper(epoch): 40 | if epoch >= (num_epochs - rampdown_length): 41 | ep = .5* (epoch - (num_epochs - rampdown_length)) 42 | return float(np.exp(-(ep * ep) / rampdown_length)) 43 | else: 44 | return 1.0 45 | return warpper 46 | 47 | 48 | def cosine_rampdown(rampdown_length, num_epochs): 49 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 50 | def warpper(epoch): 51 | if epoch >= (num_epochs - rampdown_length): 52 | ep = .5* (epoch - (num_epochs - rampdown_length)) 53 | return float(.5 * (np.cos(np.pi * ep / rampdown_length) + 1)) 54 | else: 55 | return 1.0 56 | return warpper 57 | 58 | 59 | def exp_warmup(rampup_length, rampdown_length, num_epochs): 60 | rampup = exp_rampup(rampup_length) 61 | rampdown = exp_rampdown(rampdown_length, num_epochs) 62 | def warpper(epoch): 63 | return rampup(epoch)*rampdown(epoch) 64 | return warpper 65 | 66 | 67 | def test_warmup(): 68 | warmup = exp_warmup(80, 50, 500) 69 | for ep in range(500): 70 | print(warmup(ep)) 71 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class DoubleConv(nn.Module): 5 | def __init__(self, in_ch, out_ch): 6 | super(DoubleConv, self).__init__() 7 | self.conv = nn.Sequential( 8 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 9 | nn.BatchNorm2d(out_ch), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 12 | nn.BatchNorm2d(out_ch), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def forward(self, input): 17 | return self.conv(input) 18 | 19 | 20 | class Unet(nn.Module): 21 | def __init__(self,in_ch,out_ch): 22 | super(Unet, self).__init__() 23 | 24 | self.conv1 = DoubleConv(in_ch, 64) 25 | self.pool1 = nn.MaxPool2d(2) 26 | self.conv2 = DoubleConv(64, 128) 27 | self.pool2 = nn.MaxPool2d(2) 28 | self.conv3 = DoubleConv(128, 256) 29 | self.pool3 = nn.MaxPool2d(2) 30 | self.conv4 = DoubleConv(256, 512) 31 | self.pool4 = nn.MaxPool2d(2) 32 | self.conv5 = DoubleConv(512, 1024) 33 | self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) 34 | self.conv6 = DoubleConv(1024, 512) 35 | self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) 36 | self.conv7 = DoubleConv(512, 256) 37 | self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) 38 | self.conv8 = DoubleConv(256, 128) 39 | self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) 40 | self.conv9 = DoubleConv(128, 64) 41 | self.conv10 = nn.Conv2d(64,out_ch, 1) 42 | 43 | def forward(self,x): 44 | c1=self.conv1(x) 45 | p1=self.pool1(c1) 46 | c2=self.conv2(p1) 47 | p2=self.pool2(c2) 48 | c3=self.conv3(p2) 49 | p3=self.pool3(c3) 50 | c4=self.conv4(p3) 51 | p4=self.pool4(c4) 52 | c5=self.conv5(p4) 53 | up_6= self.up6(c5) 54 | merge6 = torch.cat([up_6, c4], dim=1) 55 | c6=self.conv6(merge6) 56 | up_7=self.up7(c6) 57 | merge7 = torch.cat([up_7, c3], dim=1) 58 | c7=self.conv7(merge7) 59 | up_8=self.up8(c7) 60 | merge8 = torch.cat([up_8, c2], dim=1) 61 | c8=self.conv8(merge8) 62 | up_9=self.up9(c8) 63 | merge9=torch.cat([up_9,c1],dim=1) 64 | c9=self.conv9(merge9) 65 | c10=self.conv10(c9) 66 | #out = nn.Sigmoid()(c10) 67 | return c10 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | --------------------------------------------------------------------------------