├── __init__.py ├── utils1 ├── __init__.py ├── __pycache__ │ ├── loss.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── ramps.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── aceloss.cpython-38.pyc │ ├── statistic.cpython-38.pyc │ ├── visualize.cpython-38.pyc │ ├── StochSegLoss.cpython-38.pyc │ └── distributions.cpython-38.pyc ├── distributions.py ├── ramps.py ├── ResampleLoss.py ├── losses.py ├── loss.py └── statistic.py ├── dataset ├── __init__.py ├── __pycache__ │ ├── LeftAtrium.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── pancreas.cpython-38.pyc │ └── make_dataset.cpython-38.pyc ├── make_dataset.py ├── LeftAtrium.py └── pancreas.py ├── preprocess ├── __init__.py ├── pancreas_preprocess.py ├── preprocess_utils.py └── io_.py ├── data_lists ├── pancreas │ ├── train_lab.txt │ ├── test.txt │ ├── train_unlab.txt │ └── train_whole.txt └── LA_dataset │ ├── train_lab.txt │ ├── test.txt │ ├── train_unlab.txt │ └── test_whole.list ├── README.md ├── aleatoric.py ├── test_model_panc.py ├── test_model_LA.py ├── test_util.py ├── vnet.py ├── train_panc.py └── train_LA.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils1/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/ramps.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/ramps.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/aceloss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/aceloss.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/LeftAtrium.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/dataset/__pycache__/LeftAtrium.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/pancreas.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/dataset/__pycache__/pancreas.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/statistic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/statistic.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/visualize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/visualize.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/make_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/dataset/__pycache__/make_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/StochSegLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/StochSegLoss.cpython-38.pyc -------------------------------------------------------------------------------- /utils1/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grant-jpg/FUSSNet/HEAD/utils1/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /data_lists/pancreas/train_lab.txt: -------------------------------------------------------------------------------- 1 | data0001 2 | data0002 3 | data0003 4 | data0004 5 | data0005 6 | data0006 7 | data0007 8 | data0008 9 | data0009 10 | data0010 11 | data0011 12 | data0012 13 | -------------------------------------------------------------------------------- /data_lists/pancreas/test.txt: -------------------------------------------------------------------------------- 1 | data0064 2 | data0065 3 | data0066 4 | data0067 5 | data0068 6 | data0069 7 | data0071 8 | data0072 9 | data0073 10 | data0074 11 | data0075 12 | data0076 13 | data0077 14 | data0078 15 | data0079 16 | data0080 17 | data0081 18 | data0082 19 | -------------------------------------------------------------------------------- /data_lists/LA_dataset/train_lab.txt: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W -------------------------------------------------------------------------------- /data_lists/LA_dataset/test.txt: -------------------------------------------------------------------------------- 1 | UPT6DX9IQY9JAZ7HJKA7 2 | UTBUJIWZMKP64E3N73YC 3 | ULHWPWKKLTE921LQLH1P 4 | V0MZOWJ6MU3RMRCV9EXR 5 | VDOF02M8ZHEAADFMS6NP 6 | VG4C826RAAKVMV9BQLVD 7 | VIXBEFTNVHZWKAKURJBN 8 | VQ2L3WM8KEVF6L44E6G9 9 | WBG9WYZ1B25WDT5WAT8T 10 | WMDG2EFA6L2SNDZXIRU0 11 | WNPKE0W404QE9AELX1LR 12 | WSJB9P4JCXUVHBOYFVWL 13 | WW8F5CO4S4K5IM5Z7EXX 14 | X18LU5AOBNNDMLTA0JZL 15 | XYDLYJ5CS19FDBVLJIPI 16 | Y7ZU0B2APPF54WG6PDMF 17 | YDKD1HVHSME6NVMA8I39 18 | Z9GMG63CJLL0VW893BB1 19 | ZIJLJAVQV3FJ6JSQOH1E 20 | ZQPMJ4XEC5A4BISD45P1 21 | -------------------------------------------------------------------------------- /data_lists/pancreas/train_unlab.txt: -------------------------------------------------------------------------------- 1 | data0013 2 | data0014 3 | data0015 4 | data0016 5 | data0017 6 | data0018 7 | data0019 8 | data0020 9 | data0021 10 | data0022 11 | data0023 12 | data0024 13 | data0026 14 | data0027 15 | data0028 16 | data0029 17 | data0030 18 | data0031 19 | data0032 20 | data0033 21 | data0034 22 | data0035 23 | data0036 24 | data0037 25 | data0038 26 | data0039 27 | data0040 28 | data0041 29 | data0042 30 | data0043 31 | data0044 32 | data0045 33 | data0046 34 | data0047 35 | data0048 36 | data0049 37 | data0050 38 | data0051 39 | data0052 40 | data0053 41 | data0054 42 | data0055 43 | data0056 44 | data0057 45 | data0058 46 | data0059 47 | data0060 48 | data0061 49 | data0062 50 | data0063 51 | -------------------------------------------------------------------------------- /data_lists/pancreas/train_whole.txt: -------------------------------------------------------------------------------- 1 | data0001 2 | data0002 3 | data0003 4 | data0004 5 | data0005 6 | data0006 7 | data0007 8 | data0008 9 | data0009 10 | data0010 11 | data0011 12 | data0012 13 | data0013 14 | data0014 15 | data0015 16 | data0016 17 | data0017 18 | data0018 19 | data0019 20 | data0020 21 | data0021 22 | data0022 23 | data0023 24 | data0024 25 | data0026 26 | data0027 27 | data0028 28 | data0029 29 | data0030 30 | data0031 31 | data0032 32 | data0033 33 | data0034 34 | data0035 35 | data0036 36 | data0037 37 | data0038 38 | data0039 39 | data0040 40 | data0041 41 | data0042 42 | data0043 43 | data0044 44 | data0045 45 | data0046 46 | data0047 47 | data0048 48 | data0049 49 | data0050 50 | data0051 51 | data0052 52 | data0053 53 | data0054 54 | data0055 55 | data0056 56 | data0057 57 | data0058 58 | data0059 59 | data0060 60 | data0061 61 | data0062 62 | data0063 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | This is the code repository for paper "A Novel Semi-supervised Training Framework Guided by Two Sources of Uncertainties for Medical Image Segmentation" 4 | 5 | Training details about the model can be found in `train_panc.py` file in `train` function. 6 | 7 | Folder `data_lists` specifies the train test split for pancreas dataset and left atrium dataset. 8 | 9 | Folder `preprocess` contains code used to preprocess pancreas dataset. If you are using raw pancreas dataset, you may need to preprocess data first by running `pancreas_preprocess.py`. 10 | 11 | Folder `trained_models` contains the model whose results are presented in our paper. 12 | 13 | Files with suffix "panc" means it uses pancreas dataset while suffix "LA" means left atrium dataset 14 | 15 | If you'd like to train the model from scratch, you can run either `train_panc.py` or `train_LA.py`. You may need uncomment the line invoking pretrain function to get a pretrained model first and prepare corresponding datasets and modify the dataset path in the code. 16 | 17 | Trained models on pancreas and left atrium dataset is available on [google drive](https://drive.google.com/drive/folders/1lwbHrgltbqhMf0WEHt25c8MtmX0ENa2h?usp=sharing) 18 | -------------------------------------------------------------------------------- /dataset/make_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import h5py, os 3 | import torch, cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | from dataset.pancreas import * 8 | 9 | 10 | class make_data_3d(Dataset): 11 | def __init__(self, imgs, plab1, mask1, labs, crop_size = (96, 96, 96)): 12 | self.img = [img.cpu().squeeze().numpy() for img in imgs] 13 | self.plab1 = [np.squeeze(lab.cpu().numpy()) for lab in plab1] 14 | self.mask1 = [np.squeeze(mask.cpu().numpy()) for mask in mask1] 15 | self.lab = [np.squeeze(lab.cpu().numpy()) for lab in labs] 16 | self.num = len(self.img) 17 | self.tr_transform = Compose([ 18 | # RandomRotFlip(), 19 | CenterCrop(crop_size), 20 | # RandomNoise(), 21 | ToTensor() 22 | ]) 23 | 24 | def __getitem__(self, idx): 25 | samples = self.img[idx], self.plab1[idx], self.mask1[idx], self.lab[idx] 26 | samples = self.tr_transform(samples) 27 | imgs, plab1, mask1, labs = samples 28 | return imgs, plab1.long(), mask1.float(), labs.long() 29 | 30 | def __len__(self): 31 | return self.num 32 | -------------------------------------------------------------------------------- /utils1/distributions.py: -------------------------------------------------------------------------------- 1 | import torch.distributions as td 2 | import torch 3 | from typing import Tuple 4 | 5 | 6 | class ReshapedDistribution(td.Distribution): 7 | def __init__(self, base_distribution: td.Distribution, new_event_shape: Tuple[int, ...]): 8 | super().__init__(batch_shape=base_distribution.batch_shape, event_shape=new_event_shape) 9 | self.base_distribution = base_distribution 10 | self.new_shape = base_distribution.batch_shape + new_event_shape 11 | 12 | @property 13 | def support(self): 14 | return self.base_distribution.support 15 | 16 | @property 17 | def arg_constraints(self): 18 | return self.base_distribution.arg_constraints() 19 | 20 | @property 21 | def mean(self): 22 | return self.base_distribution.mean.view(self.new_shape) 23 | 24 | @property 25 | def variance(self): 26 | return self.base_distribution.variance.view(self.new_shape) 27 | 28 | def rsample(self, sample_shape=torch.Size()): 29 | return self.base_distribution.rsample(sample_shape).view(sample_shape + self.new_shape) 30 | 31 | def log_prob(self, value): 32 | return self.base_distribution.log_prob(value.view(self.batch_shape + (-1,))) 33 | 34 | def entropy(self): 35 | return self.base_distribution.entropy() 36 | -------------------------------------------------------------------------------- /utils1/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /data_lists/LA_dataset/train_unlab.txt: -------------------------------------------------------------------------------- 1 | 5FKQL4K14KCB72Y8YMC2 2 | 5HH0WPWIY06DLAFOBQ4M 3 | 5QFK2PMHNX7UALK52NNA 4 | 5UB5KFD2PK38Z4LS6W80 5 | 6799D6LEBH3NSRV1KH27 6 | 78NJ5YFQF72BGC8RO51C 7 | 7FUCNXB39F78WTOP5K71 8 | 8GYK8A9MBRC9TV0FVSRA 9 | 8M99G0JLAXG9GLPV0O8G 10 | 8RE90C8H5DKF4V6HO8UU 11 | 8ZG2TRZ81MAWHZPN9KKG 12 | 9DCM2IB45SK6YKQNYUQY 13 | 9DHWWP5Y66VDMPXISZ13 14 | 9DQYTIU00I4JC0OEOKQQ 15 | A11O45O3NAXWM7T2H8CH 16 | A4R1S23KR0KU2WSYHK2X 17 | A5RNNK0A891WUSC2V624 18 | AT5CRO5JUDBWD4RUPXSQ 19 | BNK95S2SJXEGSW7VAKYU 20 | BXJWOUYP2J3EN4U92517 21 | BYSRSI3H4YTWKMM3MADP 22 | BZUFJX66T0W6ZPVTL9DU 23 | CB5P5W7X310NIIVU7UZV 24 | CBIJFVZ5L9BS0LKWE8YL 25 | CCGAKN4EDT72KC8TTJ76 26 | CLXFYOBQDCVXQ9P7YC07 27 | CMPXO4J23G58J53Q98SZ 28 | CZPMV6KWZ4I7IJJP9FOK 29 | DLKXBV73A55ZTSZ0QQI2 30 | DQ5UYBGR5QP6L692QSG6 31 | DYXSCIWHLSUOZIDDSZ40 32 | E2ZMO66WGS74UKXTZPPQ 33 | EJ5V7SPR4961JWD6SS8V 34 | FGM5NIWN3URY4HF4WNUW 35 | GSC9KNY0VEZXFSGWNF25 36 | HVE7DR3CUA2IM3RC6OMA 37 | HZZ4O0BRKF8S0YX3NNF7 38 | I2VZ7N8H9QYNYT7ZZF1Y 39 | IDWWHGWJ5STOQXSDT6GU 40 | IIY6TYJMTJIZRIZLB9YW 41 | IJJY51YW3W4YJJ7DTVTK 42 | IQYKPTWXVV9H0IHB8YXC 43 | JEC6HJ7SQJXBKVREX03F 44 | JGFOLWJF7YCYD8DPHQNH 45 | K32FD6LRSUSSXGS1YUOX 46 | KM5RYAMP4P4ZP6XWP3Q2 47 | KSNYHUBHHUJTYJ14UQZR 48 | LH4FVU3TQDEC87YGN6FL 49 | LJSDNMND9SHKM7Q4IRHJ 50 | MFTDVMBWFNQ3F5KHBRDR 51 | MJHV7F65TB2A76CQLOC3 52 | MVKIPGBKTNSENNP1S4HB 53 | O5TSIKRD4AIB8K84WIR9 54 | OIRDLE32TXZX942FVZMM 55 | P1OTI3IWJUIB5NRLULLH 56 | PVNXUK681N9BY14K4Z86 57 | Q0MEX9ZIKAGJORSPLQ3Y 58 | Q7J0WYM695R9MA285ZW0 59 | QZC1W0FNR19KJFLOCFLH 60 | R8ER97O9UUN77C02VE2J 61 | RSZY41MT2FGDKHWWL5L2 62 | SN4LF8SGBSRQUPTDSX78 63 | TDDI6L3Y0L9VVFP9MNFS 64 | UZUZZT2W9IUSHL6ASOX3 65 | -------------------------------------------------------------------------------- /utils1/ResampleLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class CrossEntropyLoss(nn.CrossEntropyLoss): 8 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): 9 | super().__init__(weight, size_average, ignore_index, reduce, reduction) 10 | 11 | def forward(self, logits: torch.tensor, target: torch.tensor, **kwargs): 12 | return super().forward(logits, target) 13 | 14 | 15 | class ResampleLossMCIntegral(nn.Module): 16 | def __init__(self, num_mc_samples: int = 1): 17 | super().__init__() 18 | self.num_mc_samples = num_mc_samples 19 | 20 | @staticmethod 21 | def fixed_re_parametrization_trick(dist, num_samples): 22 | assert num_samples % 2 == 0 23 | samples = dist.rsample((num_samples // 2,)) 24 | mean = dist.mean.unsqueeze(0) 25 | samples = samples - mean 26 | return torch.cat([samples, -samples]) + mean 27 | 28 | def forward(self, logits, target, distribution, **kwargs): 29 | batch_size = logits.shape[0] 30 | num_classes = logits.shape[1] 31 | assert num_classes >= 2 # not implemented for binary case with implied background 32 | # logit_sample = distribution.rsample((self.num_mc_samples,)) 33 | logit_sample = self.fixed_re_parametrization_trick(distribution, self.num_mc_samples) 34 | target = target.unsqueeze(1) 35 | target = target.expand((self.num_mc_samples,) + target.shape) 36 | 37 | flat_size = self.num_mc_samples * batch_size 38 | logit_sample = logit_sample.view((flat_size, num_classes, -1)) 39 | target = target.reshape((flat_size, -1)) 40 | 41 | #log_prob = F.cross_entropy(logit_sample, target, reduction='none').view((self.num_mc_samples, batch_size, -1)) 42 | #loglikelihood = torch.mean(torch.logsumexp(torch.sum(log_prob, dim=-1), dim=0) - math.log(self.num_mc_samples)) 43 | #loss = -loglikelihood 44 | log_prob = F.cross_entropy(logit_sample, target) 45 | return log_prob -------------------------------------------------------------------------------- /data_lists/LA_dataset/test_whole.list: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W 17 | 5FKQL4K14KCB72Y8YMC2 18 | 5HH0WPWIY06DLAFOBQ4M 19 | 5QFK2PMHNX7UALK52NNA 20 | 5UB5KFD2PK38Z4LS6W80 21 | 6799D6LEBH3NSRV1KH27 22 | 78NJ5YFQF72BGC8RO51C 23 | 7FUCNXB39F78WTOP5K71 24 | 8GYK8A9MBRC9TV0FVSRA 25 | 8M99G0JLAXG9GLPV0O8G 26 | 8RE90C8H5DKF4V6HO8UU 27 | 8ZG2TRZ81MAWHZPN9KKG 28 | 9DCM2IB45SK6YKQNYUQY 29 | 9DHWWP5Y66VDMPXISZ13 30 | 9DQYTIU00I4JC0OEOKQQ 31 | A11O45O3NAXWM7T2H8CH 32 | A4R1S23KR0KU2WSYHK2X 33 | A5RNNK0A891WUSC2V624 34 | AT5CRO5JUDBWD4RUPXSQ 35 | BNK95S2SJXEGSW7VAKYU 36 | BXJWOUYP2J3EN4U92517 37 | BYSRSI3H4YTWKMM3MADP 38 | BZUFJX66T0W6ZPVTL9DU 39 | CB5P5W7X310NIIVU7UZV 40 | CBIJFVZ5L9BS0LKWE8YL 41 | CCGAKN4EDT72KC8TTJ76 42 | CLXFYOBQDCVXQ9P7YC07 43 | CMPXO4J23G58J53Q98SZ 44 | CZPMV6KWZ4I7IJJP9FOK 45 | DLKXBV73A55ZTSZ0QQI2 46 | DQ5UYBGR5QP6L692QSG6 47 | DYXSCIWHLSUOZIDDSZ40 48 | E2ZMO66WGS74UKXTZPPQ 49 | EJ5V7SPR4961JWD6SS8V 50 | FGM5NIWN3URY4HF4WNUW 51 | GSC9KNY0VEZXFSGWNF25 52 | HVE7DR3CUA2IM3RC6OMA 53 | HZZ4O0BRKF8S0YX3NNF7 54 | I2VZ7N8H9QYNYT7ZZF1Y 55 | IDWWHGWJ5STOQXSDT6GU 56 | IIY6TYJMTJIZRIZLB9YW 57 | IJJY51YW3W4YJJ7DTVTK 58 | IQYKPTWXVV9H0IHB8YXC 59 | JEC6HJ7SQJXBKVREX03F 60 | JGFOLWJF7YCYD8DPHQNH 61 | K32FD6LRSUSSXGS1YUOX 62 | KM5RYAMP4P4ZP6XWP3Q2 63 | KSNYHUBHHUJTYJ14UQZR 64 | LH4FVU3TQDEC87YGN6FL 65 | LJSDNMND9SHKM7Q4IRHJ 66 | MFTDVMBWFNQ3F5KHBRDR 67 | MJHV7F65TB2A76CQLOC3 68 | MVKIPGBKTNSENNP1S4HB 69 | O5TSIKRD4AIB8K84WIR9 70 | OIRDLE32TXZX942FVZMM 71 | P1OTI3IWJUIB5NRLULLH 72 | PVNXUK681N9BY14K4Z86 73 | Q0MEX9ZIKAGJORSPLQ3Y 74 | Q7J0WYM695R9MA285ZW0 75 | QZC1W0FNR19KJFLOCFLH 76 | R8ER97O9UUN77C02VE2J 77 | RSZY41MT2FGDKHWWL5L2 78 | SN4LF8SGBSRQUPTDSX78 79 | TDDI6L3Y0L9VVFP9MNFS 80 | UZUZZT2W9IUSHL6ASOX3 81 | UPT6DX9IQY9JAZ7HJKA7 82 | UTBUJIWZMKP64E3N73YC 83 | ULHWPWKKLTE921LQLH1P 84 | V0MZOWJ6MU3RMRCV9EXR 85 | VDOF02M8ZHEAADFMS6NP 86 | VG4C826RAAKVMV9BQLVD 87 | VIXBEFTNVHZWKAKURJBN 88 | VQ2L3WM8KEVF6L44E6G9 89 | WBG9WYZ1B25WDT5WAT8T 90 | WMDG2EFA6L2SNDZXIRU0 91 | WNPKE0W404QE9AELX1LR 92 | WSJB9P4JCXUVHBOYFVWL 93 | WW8F5CO4S4K5IM5Z7EXX 94 | X18LU5AOBNNDMLTA0JZL 95 | XYDLYJ5CS19FDBVLJIPI 96 | Y7ZU0B2APPF54WG6PDMF 97 | YDKD1HVHSME6NVMA8I39 98 | Z9GMG63CJLL0VW893BB1 99 | ZIJLJAVQV3FJ6JSQOH1E 100 | ZQPMJ4XEC5A4BISD45P1 101 | -------------------------------------------------------------------------------- /aleatoric.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.distributions as td 4 | import torch.nn.functional as F 5 | from utils1.distributions import ReshapedDistribution 6 | 7 | 8 | 9 | # FEATURE_MAPS = (30, 30, 40, 40, 40, 40, 50, 50) 10 | # f_out is the number of channels for the input 11 | f_out = 16 12 | 13 | 14 | # 看一下这个feature map的影响 15 | 16 | class StochasticDeepMedic(nn.Module): 17 | def __init__(self, 18 | num_classes, 19 | rank: int = 10, 20 | epsilon=1e-5, 21 | diagonal=False, 22 | dim=3): 23 | super().__init__() 24 | self.dim = dim 25 | conv_fn = nn.Conv3d if self.dim == 3 else nn.Conv2d 26 | self.rank = rank 27 | self.num_classes = num_classes 28 | self.epsilon = epsilon 29 | self.diagonal = diagonal # whether to use only the diagonal (independent normals) 30 | self.mean_l = conv_fn(f_out, num_classes, kernel_size=(1, ) * self.dim) 31 | self.log_cov_diag_l = conv_fn(f_out, num_classes, kernel_size=(1, ) * self.dim) 32 | self.cov_factor_l = conv_fn(f_out, num_classes * rank, kernel_size=(1, ) * self.dim) 33 | 34 | def forward(self, input, sampling_mask): 35 | logits = F.relu(input) 36 | batch_size = logits.shape[0] 37 | event_shape = (self.num_classes,) + logits.shape[2:] 38 | 39 | mean = self.mean_l(logits) 40 | cov_diag = self.log_cov_diag_l(logits).exp() + self.epsilon 41 | mean = mean.view((batch_size, -1)) 42 | cov_diag = cov_diag.view((batch_size, -1)) 43 | 44 | cov_factor = self.cov_factor_l(logits) 45 | cov_factor = cov_factor.view((batch_size, self.rank, self.num_classes, -1)) 46 | cov_factor = cov_factor.flatten(2, 3) 47 | cov_factor = cov_factor.transpose(1, 2) 48 | 49 | # covariance in the background tens to blow up to infinity, hence set to 0 outside the ROI 50 | #mask = kwargs['sampling_mask'] 51 | mask = sampling_mask 52 | mask = mask.unsqueeze(1).expand((batch_size, self.num_classes) + mask.shape[1:]).reshape(batch_size, -1) 53 | cov_factor = cov_factor * mask.unsqueeze(-1) 54 | cov_diag = cov_diag * mask + self.epsilon 55 | 56 | if self.diagonal: 57 | base_distribution = td.Independent(td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1) 58 | else: 59 | try: 60 | base_distribution = td.LowRankMultivariateNormal(loc=mean, cov_factor=cov_factor, cov_diag=cov_diag) 61 | except: 62 | print('Covariance became not invertible using independent normals for this batch!') 63 | base_distribution = td.Independent(td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1) 64 | 65 | distribution = ReshapedDistribution(base_distribution, event_shape) 66 | 67 | shape = (batch_size,) + event_shape 68 | logit_mean = mean.view(shape) 69 | cov_diag_view = cov_diag.view(shape).detach() 70 | cov_factor_view = cov_factor.transpose(2, 1).view((batch_size, self.num_classes * self.rank) + event_shape[1:]).detach() 71 | 72 | output_dict = {'logit_mean': logit_mean.detach(), 73 | 'cov_diag': cov_diag_view, 74 | 'cov_factor': cov_factor_view, 75 | 'distribution': distribution} 76 | 77 | return logit_mean, output_dict 78 | -------------------------------------------------------------------------------- /preprocess/pancreas_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import traceback 4 | from pathlib import Path 5 | import sys 6 | 7 | sys.path.append("..") 8 | from preprocess import io_ 9 | from preprocess.preprocess_utils import * 10 | import numpy as np 11 | 12 | 13 | def split_file_path(root, data_list_path, train_lab_num, train_unlab_num, test_num, shuffle=False): 14 | Path(data_list_path).mkdir(exist_ok=True) 15 | 16 | def save_path(paths, filename): 17 | with open(filename, 'w') as f: 18 | for p in paths: 19 | f.write(p + '\n') 20 | return filename 21 | 22 | root, data_list_path = Path(root), Path(data_list_path) 23 | files = [] 24 | for f in root.iterdir(): 25 | files.append(f.name) 26 | 27 | files.sort(key=lambda x: int(x.split('.')[0])) 28 | 29 | if test_num is not None: 30 | assert len(files) == (train_lab_num + train_unlab_num + test_num), 'Total_files : {}, current files : {}'.format( 31 | len(files), train_lab_num + train_unlab_num + test_num) 32 | else: 33 | test_num = len(files) - (train_lab_num + train_unlab_num) 34 | 35 | if shuffle: 36 | np.random.shuffle(files) 37 | 38 | train_lab_paths = files[:train_lab_num] 39 | train_unlab_paths = files[train_lab_num:train_lab_num + train_unlab_num] 40 | test_paths = files[-test_num:] 41 | print('Generated labeled train {}, unlabeled train {}, test {}'.format(len(train_lab_paths), len(train_unlab_paths), len(test_paths))) 42 | return save_path(train_lab_paths, data_list_path / 'train_lab.txt'), \ 43 | save_path(train_unlab_paths, data_list_path / 'train_unlab.txt'), \ 44 | save_path(test_paths, data_list_path / 'test.txt') 45 | 46 | 47 | def normalize(data): 48 | # normalized_data = (data - data.mean()) / (data.std() + 1e-10) 49 | normalized_data = (data - data.min()) / (data.max() - data.min()) 50 | normalized_data = normalized_data # * 2 - 1 51 | return normalized_data 52 | 53 | 54 | def save_to_h5(img, mask, filename): 55 | hf = h5py.File(filename, 'w') 56 | hf.create_dataset('image', data=img) 57 | hf.create_dataset('label', data=mask) 58 | hf.close() 59 | 60 | 61 | root = 'data/' 62 | save_to = root 63 | DCM_data = True 64 | 65 | 66 | def process_case(case_folders): 67 | try: 68 | print("yes") 69 | for case_folder in case_folders: 70 | print("path: ", case_folder) 71 | if DCM_data: # if downloaed DCM data 72 | if not case_folder.is_dir(): 73 | return 74 | img = [] 75 | for inner_folder in case_folder.iterdir(): 76 | for folder in inner_folder.iterdir(): 77 | folders = list(folder.iterdir()) 78 | folders.sort() 79 | for slice_path in folders: 80 | slice, spacing, affine_pre = io_.read_nii(slice_path) 81 | img.append(slice) 82 | img = np.concatenate(img) # depth x H x W 83 | print(case_folder) 84 | case_idx = str(case_folder)[-4:] 85 | label_path = root / 'Pancreas-CT-Label' / ('label' + case_idx + '.nii.gz') 86 | img = img.swapaxes(2, 1).swapaxes(1, 0).swapaxes(1, 2) # make depth last 87 | mask = io_.read_img(label_path) 88 | else: # if downloaed nii data 89 | img, spacing, affine_pre = io_.read_nii(case_folder) 90 | print(case_folder) 91 | case_idx = str(case_folder).split('.')[0][-4:] 92 | label_path = root / 'label' / ('label' + case_idx + '.nii.gz') 93 | mask, _, _ = io_.read_nii(label_path) 94 | 95 | assert mask.shape == img.shape, "{}, {}".format(mask.shape, img.shape) 96 | 97 | # show_graphs(img[:, :, 100:116].clip(-125, 275), figsize=(20, 20)), show_graphs(mask[100:116], figsize=(20, 20)) 98 | 99 | # resample to [1, 1, 1] 100 | target_spacing = (1, 1, 1) 101 | # change spacing of depth 102 | spacing = (spacing[1], spacing[1], spacing[1]) 103 | affine_pre = io_.make_affine2(spacing) 104 | resampled_img, affine = resample_volume_nib(img, affine_pre, spacing, target_spacing, mask=False) 105 | resampled_mask, affine = resample_volume_nib(mask, affine_pre, spacing, target_spacing, mask=True) 106 | # resampled_img, resampled_mask = img, mask 107 | 108 | # clip to [-125, 275] 109 | min_clip, max_clip = -125, 275 110 | resampled_img = resampled_img.clip(min_clip, max_clip) 111 | resampled_img = normalize(resampled_img) 112 | 113 | # crop image 114 | bbox = get_bbox_3d(resampled_mask) 115 | offset = 25 116 | bbox = expand_bbox(resampled_img, bbox, expand_size=(offset, offset, offset), min_crop_size=(96, 96, 96)) 117 | cropped_img = crop_img(resampled_img, bbox, min_crop_size=(96, 96, 96)) 118 | cropped_mask = crop_img(resampled_mask, bbox, min_crop_size=(96, 96, 96)) 119 | 120 | # show_graphs(cropped_img[100:116], figsize=(10, 10)), show_graphs(cropped_mask[100:116], figsize=(10, 10), filename='mask.png') 121 | save_to_h5(cropped_img, cropped_mask, save_to + case_idx + '.h5') 122 | print('saved : {}, resampled shape : {}, cropped shape : {}'.format(case_idx, resampled_img.shape, cropped_img.shape)) 123 | except Exception as e: 124 | print("No") 125 | print(e) 126 | # traceback.print_tb(e) 127 | traceback.print_exc() 128 | 129 | 130 | def generate_h5_data(original_pancreas_path, save_path): 131 | global root, save_to 132 | root = Path(original_pancreas_path) 133 | path = Path(root) / 'Pancreas-CT' 134 | save_to = save_path 135 | paths = list(path.iterdir()) 136 | paths.sort() 137 | print(paths) 138 | Path(save_path).mkdir(exist_ok=True) 139 | # io_.multiprocess_task(process_case, paths, cpu_num=16) 140 | process_case(paths) 141 | 142 | 143 | if __name__ == '__main__': 144 | path_to_save_generated_data = 'data' 145 | generate_h5_data('../../pancreas', path_to_save_generated_data) 146 | -------------------------------------------------------------------------------- /preprocess/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import gaussian_filter 3 | import skimage.measure as skmeasure 4 | import scipy.ndimage as ndi 5 | import torch 6 | 7 | 8 | class BBoxException(Exception): 9 | pass 10 | 11 | 12 | def get_non_empty_min_max_idx_along_axis(mask, axis): 13 | """ 14 | Get non zero min and max index along given axis. 15 | :param mask: 16 | :param axis: 17 | :return: 18 | """ 19 | if isinstance(mask, torch.Tensor): 20 | # pytorch is the axis you want to get 21 | nonzero_idx = (mask != 0).nonzero() 22 | if len(nonzero_idx) == 0: 23 | min = max = 0 24 | else: 25 | max = nonzero_idx[:, axis].max() 26 | min = nonzero_idx[:, axis].min() 27 | elif isinstance(mask, np.ndarray): 28 | nonzero_idx = (mask != 0).nonzero() 29 | if len(nonzero_idx[axis]) == 0: 30 | min = max = 0 31 | else: 32 | max = nonzero_idx[axis].max() 33 | min = nonzero_idx[axis].min() 34 | else: 35 | raise BBoxException("Wrong type") 36 | max += 1 37 | return min, max 38 | 39 | 40 | def get_bbox_3d(mask): 41 | """ Input : [D, H, W] , output : ((min_x, max_x), (min_y, max_y), (min_z, max_z)) 42 | Return non zero value's min and max index for a mask 43 | If no value exists, an array of all zero returns 44 | :param mask: numpy of [D, H, W] 45 | :return: 46 | """ 47 | assert len(mask.shape) == 3 48 | min_z, max_z = get_non_empty_min_max_idx_along_axis(mask, 0) 49 | min_y, max_y = get_non_empty_min_max_idx_along_axis(mask, 1) 50 | min_x, max_x = get_non_empty_min_max_idx_along_axis(mask, 2) 51 | 52 | return np.array(((min_x, max_x + 1), 53 | (min_y, max_y + 1), 54 | (min_z, max_z + 1))) 55 | 56 | 57 | def pad_bbox(bbox, min_bbox, max_img): 58 | """ 59 | :param bbox: ndarray ((min_x, max_x), (min_y, max_y), (min_z, max_z)) 60 | :param min_bbox: list (d, h, w) 61 | :param max_img: list (d, h, w), image shape 62 | :return: 63 | """ 64 | min_bbox = list(min_bbox) 65 | change_min_bbox = False 66 | for i, (min_x, max_img_x) in enumerate(zip(min_bbox, max_img)): 67 | if min_x > max_img_x: 68 | min_bbox[i] = max_img[i] 69 | change_min_bbox = True 70 | 71 | if change_min_bbox: 72 | print('min box {} is larger than max image size {}'.format(min_bbox, max_img)) 73 | 74 | # z first 75 | bbox = np.array(bbox)[::-1, :] 76 | result_bbox = [] 77 | for (min_x, max_x), min_size, max_size in zip(bbox, min_bbox, max_img): 78 | width = max_x - min_x 79 | if width < min_size: 80 | padding = min_size - width 81 | padding_left = padding // 2 82 | padding_right = padding - padding_left 83 | 84 | # find a best place to pad img 85 | while True: 86 | if (min_x - padding_left) < 0 and (max_x + padding_right) > max_size: 87 | # pad to img size 88 | padding_left = min_x 89 | padding_right = max_size - max_x 90 | break 91 | elif (min_x - padding_left) < 0: 92 | # right shift pad 93 | padding_left -= 1 94 | padding_right += 1 95 | elif (max_x + padding_right) > max_size: 96 | # left shift pad 97 | padding_left += 1 98 | padding_right -= 1 99 | else: 100 | # no operation to pad 101 | break 102 | min_x -= padding_left 103 | max_x += padding_right 104 | result_bbox.append((min_x, max_x)) 105 | # x first 106 | return np.array(result_bbox)[::-1, :] 107 | 108 | 109 | def expand_bbox(img, bbox, expand_size, min_crop_size): 110 | img_z, img_y, img_x = img.shape 111 | 112 | # expand [[154 371 15] [439 499 68]] 113 | bbox[:, 0] -= expand_size[::-1] # min (x, y, z) 114 | bbox[:, 1] += expand_size[::-1] # max (x, y, z) 115 | # prevent out of range 116 | bbox[0, :] = np.clip(bbox[0, :], 0, img_x) 117 | bbox[1, :] = np.clip(bbox[1, :], 0, img_y) 118 | bbox[2, :] = np.clip(bbox[2, :], 0, img_z) 119 | 120 | # expand, then pad 121 | bbox = pad_bbox(bbox, min_crop_size, img.shape) 122 | return bbox 123 | 124 | 125 | 126 | def crop_img(img, bbox, min_crop_size): 127 | """ Crop image with expanded bbox. 128 | :param img: ndarray (D, H, W) 129 | :param bbox: ndarray ((min_x, max_x), (min_y, max_y), (min_z, max_z)) 130 | :param min_crop_size: list (d, h ,w) 131 | :return: 132 | """ 133 | 134 | # extract coords 135 | (min_x, max_x), (min_y, max_y), (min_z, max_z) = bbox 136 | 137 | # crop 138 | cropped_img = img[min_z:max_z, min_y:max_y, min_x:max_x] 139 | 140 | padding = [] 141 | for i, (cropped_width, min_width) in enumerate(zip(cropped_img.shape, min_crop_size)): 142 | if cropped_width < min_width: 143 | padding.append((0, min_width - cropped_width)) 144 | else: 145 | padding.append((0, 0)) 146 | padding = np.array(padding).astype(np.int) 147 | cropped_img = np.pad(cropped_img, padding, mode='constant', constant_values=0) 148 | return cropped_img 149 | 150 | 151 | from dipy.align.reslice import reslice 152 | def resample_volume_nib(np_data, affine, spacing_old, spacing_new=(1., 1., 1.), mask=False): 153 | """Resample 3D image(trilinear) and mask(nearest) to (1., 1., 1.) spacing. 154 | It seems works better than the method above, seen from generated image. 155 | 156 | :param np_data: ndarray, channel first 157 | :param affine: the affine returned from nibabel 158 | :param spacing_old: current spacing 159 | :param spacing_new: target spacing, default is (1., 1., 1.) 160 | :param mask: if set True, use nearest instead of trilinear interpolation 161 | :return: 162 | resampled data : ndarray 163 | affine : the modified affine. 164 | """ 165 | if not mask: 166 | # trilinear 167 | resampled_data, affine = reslice(np_data, affine, spacing_old, spacing_new, order=1) 168 | else: 169 | # nearest 170 | resampled_data, affine = reslice(np_data, affine, spacing_old, spacing_new, order=0) 171 | return resampled_data, affine -------------------------------------------------------------------------------- /test_model_panc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | import logging 15 | import utils1.loss 16 | from dataset.make_dataset import make_data_3d 17 | from dataset.pancreas import Pancreas 18 | from test_util import test_calculate_metric 19 | from utils1 import statistic, ramps 20 | from vnet import VNet 21 | from aleatoric import StochasticDeepMedic 22 | import logging 23 | import sys 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--al_weight', type=float, default=0.5, help='the weight of aleatoric uncertainty loss') 28 | parser.add_argument('--gpu', type=str, default='1', help='GPU to use') 29 | 30 | args = parser.parse_args() 31 | 32 | al_weight = args.al_weight 33 | 34 | res_dir = 'test_result/' 35 | 36 | if not os.path.exists(res_dir): 37 | os.makedirs(res_dir) 38 | 39 | logging.basicConfig(filename=res_dir + "log.txt", level=logging.INFO, 40 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 41 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 42 | logging.info('New Exp :') 43 | 44 | # 2,1 45 | # 因为加入了后面 aleatoric loss 的部分 gpu设置为多个点话会有问题 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | # Parameters 48 | num_class = 2 49 | base_dim = 8 50 | 51 | batch_size = 2 52 | lr = 1e-3 53 | beta1, beta2 = 0.5, 0.999 54 | 55 | # log settings & test 56 | pretraining_epochs = 60 57 | self_training_epochs = 301 58 | thres = 0.5 59 | pretrain_save_step = 5 60 | st_save_step = 10 61 | pred_step = 10 62 | 63 | r18 = False 64 | dataset_name = 'pancreas' 65 | data_root = '../pancreas/Pancreas-processed' 66 | 67 | cost_num = 3 68 | 69 | alpha = 0.99 70 | consistency = 1 71 | consistency_rampup = 40 72 | 73 | 74 | class AverageMeter(object): 75 | """Computes and stores the average and current value""" 76 | 77 | def __init__(self): 78 | self.reset() 79 | 80 | def reset(self): 81 | self.val = 0 82 | self.avg = 0 83 | self.sum = 0 84 | self.count = 0 85 | return self 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | return self 93 | 94 | 95 | def set_random_seed(seed): 96 | random.seed(seed) 97 | np.random.seed(seed) 98 | torch.manual_seed(seed) 99 | torch.cuda.manual_seed(seed) 100 | 101 | 102 | def get_current_consistency_weight(epoch): 103 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 104 | return ramps.sigmoid_rampup(epoch, consistency_rampup) 105 | 106 | 107 | def update_ema_variables(model, ema_model, alpha, global_step): 108 | # Use the true average until the exponential average is more correct 109 | alpha = min(1 - 1 / (global_step + 1), alpha) 110 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 111 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) 112 | 113 | 114 | def create_model(ema=False): 115 | net = nn.DataParallel(VNet(n_branches=4)) 116 | model = net.cuda() 117 | if ema: 118 | for param in model.parameters(): 119 | param.detach_() 120 | return model 121 | 122 | 123 | def get_model_and_dataloader(): 124 | """Net & optimizer""" 125 | net = create_model() 126 | ema_net = create_model(ema=True).cuda() 127 | optimizer = optim.Adam(net.parameters(), lr=lr, betas=(beta1, beta2)) 128 | # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-4) 129 | 130 | """Loading Dataset""" 131 | logging.info("loading dataset") 132 | 133 | trainset_lab = Pancreas(data_root, dataset_name, split='train_lab', require_mask=True) 134 | lab_loader = DataLoader(trainset_lab, batch_size=batch_size, shuffle=False, num_workers=0) 135 | 136 | trainset_unlab = Pancreas(data_root, dataset_name, split='train_unlab', no_crop=True) 137 | unlab_loader = DataLoader(trainset_unlab, batch_size=1, shuffle=False, num_workers=0) 138 | 139 | testset = Pancreas(data_root, dataset_name, split='test') 140 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 141 | return net, ema_net, optimizer, lab_loader, unlab_loader, test_loader 142 | 143 | 144 | def save_net_opt(net, optimizer, path, epoch): 145 | state = { 146 | 'net': net.state_dict(), 147 | 'opt': optimizer.state_dict(), 148 | 'epoch': epoch, 149 | } 150 | torch.save(state, str(path)) 151 | 152 | 153 | def load_net_opt(net, optimizer, path): 154 | state = torch.load(str(path)) 155 | net.load_state_dict(state['net']) 156 | optimizer.load_state_dict(state['opt']) 157 | logging.info('Loaded from {}'.format(path)) 158 | 159 | 160 | 161 | def count_param(model): 162 | param_count = 0 163 | for param in model.parameters(): 164 | param_count += param.view(-1).size()[0] 165 | return param_count 166 | 167 | 168 | 169 | 170 | @torch.no_grad() 171 | def test(net, val_loader, maxdice=0, save_result=False, test_save_path='./save'): 172 | metrics = test_calculate_metric(net, val_loader.dataset, save_result=save_result, test_save_path=test_save_path) 173 | val_dice = metrics[0] 174 | 175 | if val_dice > maxdice: 176 | maxdice = val_dice 177 | max_flag = True 178 | else: 179 | max_flag = False 180 | logging.info('Evaluation : val_dice: %.4f, val_maxdice: %.4f' % (val_dice, maxdice)) 181 | return val_dice, maxdice, max_flag 182 | 183 | 184 | if __name__ == '__main__': 185 | # set_random_seed(1337) 186 | net, ema_net, optimizer, lab_loader, unlab_loader, test_loader = get_model_and_dataloader() 187 | # load model 188 | # net.load_state_dict(torch.load(res_dir + '/model/best.pth')) 189 | model_path = Path('/home/xiangjinyi/semi_supervised/alnet/result/single_loss_ce_12/pretrain_con_5.0') 190 | 191 | load_net_opt(net, optimizer, model_path / '55.pth') 192 | #load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 193 | # pretrain(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader, start_epoch=1) 194 | test(net, test_loader) 195 | 196 | #t_train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader) 197 | 198 | logging.info(count_param(net)) 199 | -------------------------------------------------------------------------------- /test_model_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | import logging 15 | import utils1.loss 16 | from dataset.make_dataset import make_data_3d 17 | from dataset.LeftAtrium import LAHeart 18 | from test_util import test_calculate_metric_LA 19 | from utils1 import statistic, ramps 20 | from vnet import VNet 21 | from aleatoric import StochasticDeepMedic 22 | import logging 23 | import sys 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--al_weight', type=float, default=0.5, help='the weight of aleatoric uncertainty loss') 28 | parser.add_argument('--gpu', type=str, default='1', help='GPU to use') 29 | 30 | args = parser.parse_args() 31 | 32 | al_weight = args.al_weight 33 | 34 | res_dir = 'test_result/' 35 | 36 | if not os.path.exists(res_dir): 37 | os.makedirs(res_dir) 38 | 39 | logging.basicConfig(filename=res_dir + "log.txt", level=logging.INFO, 40 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 41 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 42 | logging.info('New Exp :') 43 | 44 | # 2,1 45 | # 因为加入了后面 aleatoric loss 的部分 gpu设置为多个点话会有问题 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | # Parameters 48 | num_class = 2 49 | base_dim = 8 50 | 51 | batch_size = 2 52 | lr = 1e-3 53 | beta1, beta2 = 0.5, 0.999 54 | 55 | # log settings & test 56 | pretraining_epochs = 60 57 | self_training_epochs = 301 58 | thres = 0.5 59 | pretrain_save_step = 5 60 | st_save_step = 10 61 | pred_step = 10 62 | 63 | r18 = False 64 | split_name = 'LA_dataset' 65 | data_root = '../LA_dataset' 66 | 67 | cost_num = 3 68 | 69 | alpha = 0.99 70 | consistency = 1 71 | consistency_rampup = 40 72 | 73 | 74 | class AverageMeter(object): 75 | """Computes and stores the average and current value""" 76 | 77 | def __init__(self): 78 | self.reset() 79 | 80 | def reset(self): 81 | self.val = 0 82 | self.avg = 0 83 | self.sum = 0 84 | self.count = 0 85 | return self 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | return self 93 | 94 | 95 | def set_random_seed(seed): 96 | random.seed(seed) 97 | np.random.seed(seed) 98 | torch.manual_seed(seed) 99 | torch.cuda.manual_seed(seed) 100 | 101 | 102 | def get_current_consistency_weight(epoch): 103 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 104 | return ramps.sigmoid_rampup(epoch, consistency_rampup) 105 | 106 | 107 | def update_ema_variables(model, ema_model, alpha, global_step): 108 | # Use the true average until the exponential average is more correct 109 | alpha = min(1 - 1 / (global_step + 1), alpha) 110 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 111 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) 112 | 113 | 114 | def create_model(ema=False): 115 | net = nn.DataParallel(VNet(n_branches=4)) 116 | model = net.cuda() 117 | if ema: 118 | for param in model.parameters(): 119 | param.detach_() 120 | return model 121 | 122 | 123 | def get_model_and_dataloader(): 124 | """Net & optimizer""" 125 | net = create_model() 126 | ema_net = create_model(ema=True).cuda() 127 | optimizer = optim.Adam(net.parameters(), lr=lr, betas=(beta1, beta2)) 128 | # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-4) 129 | 130 | """Loading Dataset""" 131 | logging.info("loading dataset") 132 | 133 | trainset_lab = LAHeart(data_root, split_name, split='train_lab', require_mask=True) 134 | lab_loader = DataLoader(trainset_lab, batch_size=batch_size, shuffle=False, num_workers=0) 135 | 136 | trainset_unlab = LAHeart(data_root, split_name, split='train_unlab', no_crop=True) 137 | unlab_loader = DataLoader(trainset_unlab, batch_size=1, shuffle=False, num_workers=0) 138 | 139 | testset = LAHeart(data_root, split_name, split='test') 140 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 141 | return net, ema_net, optimizer, lab_loader, unlab_loader, test_loader 142 | 143 | 144 | def save_net_opt(net, optimizer, path, epoch): 145 | state = { 146 | 'net': net.state_dict(), 147 | 'opt': optimizer.state_dict(), 148 | 'epoch': epoch, 149 | } 150 | torch.save(state, str(path)) 151 | 152 | 153 | def load_net_opt(net, optimizer, path): 154 | state = torch.load(str(path)) 155 | net.load_state_dict(state['net']) 156 | optimizer.load_state_dict(state['opt']) 157 | logging.info('Loaded from {}'.format(path)) 158 | 159 | 160 | 161 | def count_param(model): 162 | param_count = 0 163 | for param in model.parameters(): 164 | param_count += param.view(-1).size()[0] 165 | return param_count 166 | 167 | 168 | 169 | 170 | @torch.no_grad() 171 | def test(net, val_loader, maxdice=0, save_result=False, test_save_path='./save'): 172 | metrics = test_calculate_metric_LA(net, val_loader.dataset, save_result=save_result, test_save_path=test_save_path) 173 | val_dice = metrics[0] 174 | 175 | if val_dice > maxdice: 176 | maxdice = val_dice 177 | max_flag = True 178 | else: 179 | max_flag = False 180 | logging.info('Evaluation : val_dice: %.4f, val_maxdice: %.4f' % (val_dice, maxdice)) 181 | return val_dice, maxdice, max_flag 182 | 183 | 184 | if __name__ == '__main__': 185 | # set_random_seed(1337) 186 | net, ema_net, optimizer, lab_loader, unlab_loader, test_loader = get_model_and_dataloader() 187 | # load model 188 | # net.load_state_dict(torch.load(res_dir + '/model/best.pth')) 189 | model_path = Path('/home/xiangjinyi/semi_supervised/alnet/LA_result/LA_0.8_al/con_5.0_consistency_1_VNet3') 190 | 191 | load_net_opt(net, optimizer, model_path / 'best.pth') 192 | #load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 193 | # pretrain(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader, start_epoch=1) 194 | 195 | test(net, test_loader, save_result=True, test_save_path="/home/xiangjinyi/semi_supervised/alnet/LA_image_result_2/") 196 | 197 | #t_train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader) 198 | 199 | logging.info(count_param(net)) 200 | -------------------------------------------------------------------------------- /utils1/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def dice_loss(score, target): 9 | target = target.float() 10 | smooth = 1e-5 11 | intersect = torch.sum(score * target) 12 | y_sum = torch.sum(target * target) 13 | z_sum = torch.sum(score * score) 14 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 15 | loss = 1 - loss 16 | return loss 17 | 18 | 19 | def dice_loss1(score, target): 20 | target = target.float() 21 | smooth = 1e-5 22 | intersect = torch.sum(score * target) 23 | y_sum = torch.sum(target) 24 | z_sum = torch.sum(score) 25 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 26 | loss = 1 - loss 27 | return loss 28 | 29 | 30 | def entropy_loss(p, C=2): 31 | # p N*C*W*H*D 32 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / \ 33 | torch.tensor(np.log(C)).cuda() 34 | ent = torch.mean(y1) 35 | 36 | return ent 37 | 38 | 39 | def softmax_dice_loss(input_logits, target_logits): 40 | """Takes softmax on both sides and returns MSE loss 41 | 42 | Note: 43 | - Returns the sum over all examples. Divide by the batch size afterwards 44 | if you want the mean. 45 | - Sends gradients to inputs but not the targets. 46 | """ 47 | assert input_logits.size() == target_logits.size() 48 | input_softmax = F.softmax(input_logits, dim=1) 49 | target_softmax = F.softmax(target_logits, dim=1) 50 | n = input_logits.shape[1] 51 | dice = 0 52 | for i in range(0, n): 53 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 54 | mean_dice = dice / n 55 | 56 | return mean_dice 57 | 58 | 59 | def entropy_loss_map(p, C=2): 60 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 61 | keepdim=True)/torch.tensor(np.log(C)).cuda() 62 | return ent 63 | 64 | 65 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 66 | """Takes softmax on both sides and returns MSE loss 67 | 68 | Note: 69 | - Returns the sum over all examples. Divide by the batch size afterwards 70 | if you want the mean. 71 | - Sends gradients to inputs but not the targets. 72 | """ 73 | assert input_logits.size() == target_logits.size() 74 | if sigmoid: 75 | input_softmax = torch.sigmoid(input_logits) 76 | target_softmax = torch.sigmoid(target_logits) 77 | else: 78 | input_softmax = F.softmax(input_logits, dim=1) 79 | target_softmax = F.softmax(target_logits, dim=1) 80 | 81 | mse_loss = (input_softmax-target_softmax)**2 82 | return mse_loss 83 | 84 | 85 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 86 | """Takes softmax on both sides and returns KL divergence 87 | 88 | Note: 89 | - Returns the sum over all examples. Divide by the batch size afterwards 90 | if you want the mean. 91 | - Sends gradients to inputs but not the targets. 92 | """ 93 | assert input_logits.size() == target_logits.size() 94 | if sigmoid: 95 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 96 | target_softmax = torch.sigmoid(target_logits) 97 | else: 98 | input_log_softmax = F.log_softmax(input_logits, dim=1) 99 | target_softmax = F.softmax(target_logits, dim=1) 100 | 101 | # return F.kl_div(input_log_softmax, target_softmax) 102 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 103 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 104 | return kl_div 105 | 106 | 107 | def symmetric_mse_loss(input1, input2): 108 | """Like F.mse_loss but sends gradients to both directions 109 | 110 | Note: 111 | - Returns the sum over all examples. Divide by the batch size afterwards 112 | if you want the mean. 113 | - Sends gradients to both input1 and input2. 114 | """ 115 | assert input1.size() == input2.size() 116 | return torch.mean((input1 - input2)**2) 117 | 118 | 119 | class FocalLoss(nn.Module): 120 | def __init__(self, gamma=2, alpha=None, size_average=True): 121 | super(FocalLoss, self).__init__() 122 | self.gamma = gamma 123 | self.alpha = alpha 124 | if isinstance(alpha, (float, int)): 125 | self.alpha = torch.Tensor([alpha, 1-alpha]) 126 | if isinstance(alpha, list): 127 | self.alpha = torch.Tensor(alpha) 128 | self.size_average = size_average 129 | 130 | def forward(self, input, target): 131 | if input.dim() > 2: 132 | # N,C,H,W => N,C,H*W 133 | input = input.view(input.size(0), input.size(1), -1) 134 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 135 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 136 | target = target.view(-1, 1) 137 | 138 | logpt = F.log_softmax(input, dim=1) 139 | logpt = logpt.gather(1, target) 140 | logpt = logpt.view(-1) 141 | pt = Variable(logpt.data.exp()) 142 | 143 | if self.alpha is not None: 144 | if self.alpha.type() != input.data.type(): 145 | self.alpha = self.alpha.type_as(input.data) 146 | at = self.alpha.gather(0, target.data.view(-1)) 147 | logpt = logpt * Variable(at) 148 | 149 | loss = -1 * (1-pt)**self.gamma * logpt 150 | if self.size_average: 151 | return loss.mean() 152 | else: 153 | return loss.sum() 154 | 155 | 156 | class DiceLoss(nn.Module): 157 | def __init__(self, n_classes): 158 | super(DiceLoss, self).__init__() 159 | self.n_classes = n_classes 160 | 161 | def _one_hot_encoder(self, input_tensor): 162 | tensor_list = [] 163 | for i in range(self.n_classes): 164 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 165 | tensor_list.append(temp_prob) 166 | output_tensor = torch.cat(tensor_list, dim=1) 167 | return output_tensor.float() 168 | 169 | def _dice_loss(self, score, target): 170 | target = target.float() 171 | smooth = 1e-5 172 | intersect = torch.sum(score * target) 173 | y_sum = torch.sum(target * target) 174 | z_sum = torch.sum(score * score) 175 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 176 | loss = 1 - loss 177 | return loss 178 | 179 | def forward(self, inputs, target, weight=None, softmax=False): 180 | if softmax: 181 | inputs = torch.softmax(inputs, dim=1) 182 | target = self._one_hot_encoder(target) 183 | if weight is None: 184 | weight = [1] * self.n_classes 185 | assert inputs.size() == target.size(), 'predict & target shape do not match' 186 | class_wise_dice = [] 187 | loss = 0.0 188 | for i in range(0, self.n_classes): 189 | dice = self._dice_loss(inputs[:, i], target[:, i]) 190 | class_wise_dice.append(1.0 - dice.item()) 191 | loss += dice * weight[i] 192 | return loss / self.n_classes 193 | 194 | 195 | def entropy_minmization(p): 196 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) 197 | ent = torch.mean(y1) 198 | 199 | return ent 200 | 201 | 202 | def entropy_map(p): 203 | ent_map = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 204 | keepdim=True) 205 | return ent_map 206 | 207 | -------------------------------------------------------------------------------- /test_util.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import nibabel as nib 4 | import numpy as np 5 | from medpy import metric 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from dataset.pancreas import Pancreas 11 | import os 12 | 13 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, 14 | preproc_fn=None): 15 | total_metric = 0.0 16 | cnt = 0 17 | for image_path in tqdm(image_list): 18 | 19 | id = image_path.split('/')[-1] 20 | h5f = h5py.File(image_path, 'r') 21 | image = h5f['image'][:] 22 | label = h5f['label'][:] 23 | 24 | if preproc_fn is not None: 25 | image = preproc_fn(image) 26 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 27 | 28 | if np.sum(prediction) == 0: 29 | single_metric = (0, 0, 0, 0) 30 | else: 31 | single_metric = calculate_metric_percase(prediction, label[:]) 32 | # print(single_metric) 33 | total_metric += np.asarray(single_metric) 34 | # print(str(cnt) + ", {}, {}, {}. {}".format(single_metric[0], single_metric[1], single_metric[2], single_metric[3])) 35 | 36 | 37 | if save_result: 38 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + str(cnt) + "_pred.nii.gz") 39 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + str(cnt) + "_img.nii.gz") 40 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + str(cnt) + "_gt.nii.gz") 41 | cnt += 1 42 | avg_metric = total_metric / len(image_list) 43 | 44 | print('average metric is {}'.format(avg_metric)) 45 | 46 | return avg_metric 47 | 48 | 49 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 50 | w, h, d = image.shape 51 | 52 | # if the size of image is less than patch_size, then padding it 53 | add_pad = False 54 | if w < patch_size[0]: 55 | w_pad = patch_size[0] - w 56 | add_pad = True 57 | else: 58 | w_pad = 0 59 | if h < patch_size[1]: 60 | h_pad = patch_size[1] - h 61 | add_pad = True 62 | else: 63 | h_pad = 0 64 | if d < patch_size[2]: 65 | d_pad = patch_size[2] - d 66 | add_pad = True 67 | else: 68 | d_pad = 0 69 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 70 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 71 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 72 | if add_pad: 73 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 74 | ww, hh, dd = image.shape 75 | 76 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 77 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 78 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 79 | # print("{}, {}, {}".format(sx, sy, sz)) 80 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 81 | cnt = np.zeros(image.shape).astype(np.float32) 82 | 83 | for x in range(0, sx): 84 | xs = min(stride_xy * x, ww - patch_size[0]) 85 | for y in range(0, sy): 86 | ys = min(stride_xy * y, hh - patch_size[1]) 87 | for z in range(0, sz): 88 | zs = min(stride_z * z, dd - patch_size[2]) 89 | test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] 90 | test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32) 91 | test_patch = torch.from_numpy(test_patch).cuda() 92 | y1 = net(test_patch)[0] 93 | y = F.softmax(y1, dim=1) 94 | y = y.cpu().data.numpy() 95 | y = y[0, :, :, :, :] 96 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 97 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 98 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 99 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 100 | score_map = score_map / np.expand_dims(cnt, axis=0) 101 | label_map = np.argmax(score_map, axis=0) 102 | if add_pad: 103 | label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 104 | score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 105 | return label_map, score_map 106 | 107 | 108 | def cal_dice(prediction, label, num=2): 109 | total_dice = np.zeros(num - 1) 110 | for i in range(1, num): 111 | prediction_tmp = (prediction == i) 112 | label_tmp = (label == i) 113 | prediction_tmp = prediction_tmp.astype(np.float) 114 | label_tmp = label_tmp.astype(np.float) 115 | 116 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 117 | total_dice[i - 1] += dice 118 | 119 | return total_dice 120 | 121 | 122 | def calculate_metric_percase(pred, gt): 123 | dice = metric.binary.dc(pred, gt) 124 | jc = metric.binary.jc(pred, gt) 125 | hd = metric.binary.hd95(pred, gt) 126 | asd = metric.binary.asd(pred, gt) 127 | 128 | return dice, jc, hd, asd 129 | 130 | 131 | def test_calculate_metric(net, test_dataset, num_classes=2, save_result=False, test_save_path='./save'): 132 | net.eval() 133 | image_list = test_dataset.image_list 134 | 135 | if save_result: 136 | test_save_path = Path(test_save_path) 137 | test_save_path.mkdir(exist_ok=True) 138 | 139 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 140 | patch_size=(96, 96, 96), stride_xy=16, stride_z=4, 141 | save_result=save_result, test_save_path=str(test_save_path) + '/') 142 | return avg_metric 143 | 144 | 145 | def test_calculate_metric_LA(net, test_dataset, num_classes=2, save_result=False, test_save_path='./save'): 146 | net.eval() 147 | # with open("/home/xiangjinyi/semi_supervised/alnet/data_lists_cora/LA_dataset/test_whole.list", 'r') as f: 148 | # image_list = f.readlines() 149 | # image_list = [item.replace('\n', '') for item in image_list] 150 | # image_list = [os.path.join("../LA_dataset", item, "mri_norm2.h5") for item in image_list] 151 | 152 | image_list = test_dataset.image_list 153 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 154 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 155 | save_result=save_result, test_save_path=test_save_path) 156 | return avg_metric 157 | 158 | 159 | if __name__ == '__main__': 160 | import os 161 | from train_panc import get_model_and_dataloader, load_net_opt, test 162 | 163 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 164 | num_classes = 2 165 | res_dir = './result/pancreas_3d_VNet/st_model' 166 | 167 | split_name = 'pancreas' 168 | data_root = '/data/DataSets/pancreas_pad25' 169 | net, ema_net, optimizer, lab_loader, unlab_loader, test_loader = get_model_and_dataloader() 170 | dataset = Pancreas(data_root, split_name, split='test') 171 | load_net_opt(net, optimizer, res_dir + '/best.pth') 172 | metric = test_calculate_metric(net, dataset) 173 | print(metric) 174 | -------------------------------------------------------------------------------- /preprocess/io_.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from torch import multiprocessing 8 | import nibabel as nib 9 | 10 | """ 11 | IO tools 12 | """ 13 | 14 | 15 | def read_mhd_affine(path): 16 | with Path(path).open('r') as f: 17 | lines = f.readlines() 18 | matrix_str = [float(s) for s in lines[5].strip().split('=')[-1].strip().split(' ')] 19 | offset = [float(s) for s in lines[6].strip().split('=')[-1].strip().split(' ')] 20 | matrix = np.array(matrix_str).reshape(3, 3) 21 | offset = np.array(offset).reshape(3, 1) 22 | res = np.concatenate([np.concatenate([matrix, offset], 1), np.array((0, 0, 0, 1)).reshape(1, 4)], 0) 23 | return res 24 | 25 | 26 | def make_affine(sitkImg): 27 | # get affine transform in LPS 28 | rot = [sitkImg.TransformContinuousIndexToPhysicalPoint(p) 29 | for p in ((1, 0, 0), 30 | (0, 1, 0), 31 | (0, 0, 1), 32 | (0, 0, 0))] 33 | rot = np.array(rot) 34 | affine = np.concatenate([ 35 | np.concatenate([rot[0:3] - rot[3:], rot[3:]], axis=0), 36 | [[0.], [0.], [0.], [1.]] 37 | ], axis=1) 38 | affine = np.transpose(affine) 39 | # convert to RAS to match nibabel 40 | affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine) 41 | return affine 42 | 43 | 44 | def make_affine2(spacing): 45 | affine = np.array(((0, 0, -1, 0), 46 | (0, -1, 0, 0), 47 | (-1, 0, 0, 0), 48 | (0, 0, 0, 1))) 49 | spacing = np.diag(list(spacing) + [1]) 50 | return np.matmul(affine, spacing) 51 | 52 | 53 | def read_nii(path, method='nib'): 54 | """ 55 | Read ".nii.gz" data 56 | :param path: path to image 57 | :param method: method to read data, only support ('nib', 'sitk') 58 | :returns: 59 | data : numpy data [channel, x, y] 60 | spacing : (x_spacing, y_spacing, z_spacing) 61 | 62 | """ 63 | import SimpleITK as sitk 64 | path = str(path) 65 | method = method.lower() 66 | if method == 'nib': 67 | from nibabel.filebasedimages import ImageFileError 68 | try: 69 | img = nib.load(path) 70 | data = img.get_data() 71 | spacing = img.header.get_zooms()[:3] 72 | affine = img.affine 73 | return data, spacing, affine 74 | except ImageFileError as e: 75 | method = 'sitk' 76 | 77 | if method == 'sitk': 78 | img = sitk.ReadImage(path) 79 | data = sitk.GetArrayFromImage(img) 80 | # channel first 81 | spacing = img.GetSpacing()[::-1] 82 | affine = make_affine2(spacing) 83 | return data, spacing, affine 84 | else: 85 | raise Exception("method only supports nib(nibabel) or sitk(SimpleITK)") 86 | 87 | 88 | def read_img(img_path): 89 | img_path = str(img_path) 90 | 91 | import skimage.io as skio 92 | if img_path.endswith('jpg') or img_path.endswith('png') or img_path.endswith('bmp'): 93 | return skio.imread(img_path) 94 | elif img_path.endswith('nii.gz') or img_path.endswith('.dcm'): 95 | return read_nii(img_path)[0] 96 | elif img_path.endswith('npy'): 97 | return np.load(img_path) 98 | elif img_path.endswith('.mhd'): 99 | return read_nii(img_path, method='sitk')[0] 100 | else: 101 | raise Exception("Error file format for {}, only support ['bmp', 'jpg', 'png', 'nii.gz', 'npy', 'mhd']".format(img_path)) 102 | 103 | 104 | def save_nii_with_sitk(np_data, path, origin=None, spacing=None): 105 | img = sitk.GetImageFromArray(np_data) 106 | if origin is not None: 107 | img.setOrigin(origin) 108 | if spacing is not None: 109 | img.setSpacing(spacing) 110 | sitk.WriteImage(img, path) 111 | 112 | 113 | def save_nii(np_data, affine, path): 114 | path = str(path) 115 | img = nib.Nifti1Image(np_data, affine) 116 | nib.save(img, path) 117 | 118 | 119 | def mkdir(path, level=2, create_self=True): 120 | """ Make directory for this path, 121 | level is how many parent folders should be created. 122 | create_self is whether create path(if it is a file, it should not be created) 123 | 124 | e.g. : mkdir('/home/parent1/parent2/folder', level=3, create_self=False), 125 | it will first create parent1, then parent2, then folder. 126 | 127 | :param path: string 128 | :param level: int 129 | :param create_self: True or False 130 | :return: 131 | """ 132 | p = Path(path) 133 | if create_self: 134 | paths = [p] 135 | else: 136 | paths = [] 137 | level -= 1 138 | while level != 0: 139 | p = p.parent 140 | paths.append(p) 141 | level -= 1 142 | 143 | for p in paths[::-1]: 144 | p.mkdir(exist_ok=True) 145 | 146 | 147 | def move_files(path): 148 | path = Path(path) 149 | pathes = list(path.iterdir()) 150 | for p in pathes: 151 | parent = p.parent 152 | name = p.name 153 | _, cid, number, suffix = name.split('_') 154 | case_dir = parent / cid 155 | case_dir.mkdir(exist_ok=True) 156 | target_p = case_dir / "{}_{}".format(number, suffix) 157 | shutil.move(str(p), str(target_p)) 158 | print('finished') 159 | 160 | 161 | def create_symlink(src, dst): 162 | try: 163 | # if exists, unlink first 164 | os.unlink(dst) 165 | except FileNotFoundError as e: 166 | pass 167 | 168 | try: 169 | # create link 170 | os.symlink(src, dst) 171 | except FileExistsError as e: 172 | pass 173 | 174 | 175 | """ 176 | Case file and folder tools 177 | """ 178 | 179 | 180 | def load_case(root, cid): 181 | img_p = get_nii_case_file(root, cid, 'imaging') 182 | mask_p = get_nii_case_file(root, cid, 'mask') 183 | img, spacing, affine = read_nii(img_p) 184 | mask, _, _ = read_nii(mask_p) 185 | return img, mask, spacing, affine 186 | 187 | 188 | def get_case_folder(root_path, name, cid, create_self=True): 189 | """ Make case folder 190 | :param root_path: 191 | :param cid: 192 | :return: 193 | """ 194 | img_folder = Path(root_path) / name / 'case_{:05d}'.format(cid) 195 | mkdir(img_folder, level=3, create_self=create_self) 196 | return img_folder 197 | 198 | 199 | def get_nii_case_file(root, cid, filename): 200 | path = Path(root) / ("case_{:05d}".format(cid)) / ("{}.nii.gz".format(str(filename))) 201 | mkdir(path, level=2, create_self=False) 202 | return path 203 | 204 | """ 205 | To process quickly, use multi-cpu to process functions 206 | """ 207 | 208 | 209 | def multiprocess_task(func, dynamic_args, static_args=(), split_func=np.array_split, ret=False, cpu_num=None): 210 | """ 211 | Process task with multi cpus. 212 | :param func: task to be processed, func must be the top level function 213 | :param dynamic_args: args to be split to assign to cpus, 214 | it is a list by default 215 | :param static_args: args doesn't need to be split 216 | :param split_func: function to split args, use the function 217 | to split a list args by default 218 | :return: 219 | """ 220 | start = time.time() 221 | if cpu_num is None: 222 | cpu_num = multiprocessing.cpu_count() // 2 223 | 224 | if cpu_num <= 1: 225 | ret = func(dynamic_args, *static_args) 226 | else: 227 | # split dynamic args with cpu num 228 | dynamic_args_splits = split_func(dynamic_args, cpu_num) 229 | workers = multiprocessing.Pool(processes=cpu_num) 230 | processes = [] 231 | for proc_id, dynamic_args in enumerate(dynamic_args_splits): 232 | # do processing, concat dynamic args and static args 233 | dynamic_args = list(dynamic_args) 234 | p = workers.apply_async(func, (dynamic_args, *static_args)) 235 | processes.append(p) 236 | workers.close() 237 | workers.join() 238 | 239 | duration = time.time() - start 240 | print('total time : {} min'.format(duration / 60.)) 241 | 242 | if ret: 243 | # collect results 244 | if cpu_num > 1: 245 | res = [] 246 | for p in processes: 247 | p = p.get() 248 | res.extend(p) 249 | else: 250 | res = ret 251 | return res 252 | 253 | 254 | def io_exception(func): 255 | def wrapper(): 256 | try: 257 | func() 258 | except Exception as e: 259 | print(e) 260 | return wrapper 261 | 262 | 263 | if __name__ == '__main__': 264 | pass 265 | -------------------------------------------------------------------------------- /utils1/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | 8 | 9 | def cross_entropy_2d(predict, target): 10 | """ 11 | Args: 12 | predict:(n, c, h, w) 13 | target:(n, h, w) 14 | """ 15 | assert not target.requires_grad 16 | assert predict.base_dim() == 4 17 | assert target.base_dim() == 4 18 | assert predict.size(0) == target.size(0), f"{predict.size(0)} vs {target.size(0)}" 19 | assert predict.size(2) == target.size(2), f"{predict.size(2)} vs {target.size(1)}" 20 | assert predict.size(3) == target.size(3), f"{predict.size(3)} vs {target.size(3)}" 21 | n, c, h, w = predict.size() 22 | target_mask = (target >= 0) * (target != 255) 23 | target = target[target_mask] 24 | if not target.data.base_dim(): 25 | return Variable(torch.zeros(1)) 26 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 27 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 28 | 29 | loss = F.cross_entropy(predict, target, size_average=True) 30 | return loss 31 | 32 | 33 | def entropy_loss(v): 34 | """ 35 | Entropy loss for probabilistic prediction vectors 36 | input: batch_size x channels x h x w 37 | output: batch_size x 1 x h x w 38 | """ 39 | assert v.base_dim() == 4 40 | n, c, h, w = v.size() 41 | return -torch.sum(torch.mul(v, torch.log2(v + 1e-30))) / (n * h * w * np.log2(c)) 42 | 43 | 44 | def to_one_hot(tensor, nClasses): 45 | """ Input tensor : Nx1xHxW 46 | :param tensor: 47 | :param nClasses: 48 | :return: 49 | """ 50 | assert tensor.max().item() < nClasses, 'one hot tensor.max() = {} < {}'.format(torch.max(tensor), nClasses) 51 | assert tensor.min().item() >= 0, 'one hot tensor.min() = {} < {}'.format(tensor.min(), 0) 52 | 53 | size = list(tensor.size()) 54 | assert size[1] == 1 55 | size[1] = nClasses 56 | one_hot = torch.zeros(*size) 57 | if tensor.is_cuda: 58 | one_hot = one_hot.cuda(tensor.device) 59 | one_hot = one_hot.scatter_(1, tensor, 1) 60 | return one_hot 61 | 62 | 63 | def get_probability(logits): 64 | """ Get probability from logits, if the channel of logits is 1 then use sigmoid else use softmax. 65 | :param logits: [N, C, H, W] or [N, C, D, H, W] 66 | :return: prediction and class num 67 | """ 68 | size = logits.size() 69 | # N x 1 x H x W 70 | if size[1] > 1: 71 | pred = F.softmax(logits, dim=1) 72 | nclass = size[1] 73 | else: 74 | pred = F.sigmoid(logits) 75 | pred = torch.cat([1 - pred, pred], 1) 76 | nclass = 2 77 | return pred, nclass 78 | 79 | 80 | class DiceLoss(nn.Module): 81 | def __init__(self, nclass, class_weights=None, smooth=1e-5): 82 | super(DiceLoss, self).__init__() 83 | self.smooth = smooth 84 | if class_weights is None: 85 | # default weight is all 1 86 | self.class_weights = nn.Parameter(torch.ones((1, nclass)).type(torch.float32), requires_grad=False) 87 | else: 88 | class_weights = np.array(class_weights) 89 | assert nclass == class_weights.shape[0] 90 | self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype=torch.float32), requires_grad=False) 91 | 92 | def prob_forward(self, pred, target, mask=None): 93 | size = pred.size() 94 | N, nclass = size[0], size[1] 95 | # N x C x H x W 96 | pred_one_hot = pred.view(N, nclass, -1) 97 | target = target.view(N, 1, -1) 98 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 99 | 100 | # N x C x H x W 101 | inter = pred_one_hot * target_one_hot 102 | union = pred_one_hot + target_one_hot 103 | 104 | if mask is not None: 105 | mask = mask.view(N, 1, -1) 106 | inter = (inter.view(N, nclass, -1)*mask).sum(2) 107 | union = (union.view(N, nclass, -1)*mask).sum(2) 108 | else: 109 | # N x C 110 | inter = inter.view(N, nclass, -1).sum(2) 111 | union = union.view(N, nclass, -1).sum(2) 112 | 113 | # smooth to prevent overfitting 114 | # [https://github.com/pytorch/pytorch/issues/1249] 115 | # NxC 116 | dice = (2 * inter + self.smooth) / (union + self.smooth) 117 | return 1 - dice.mean() 118 | 119 | def forward(self, logits, target, mask=None): 120 | size = logits.size() 121 | N, nclass = size[0], size[1] 122 | 123 | logits = logits.view(N, nclass, -1) 124 | target = target.view(N, 1, -1) 125 | 126 | pred, nclass = get_probability(logits) 127 | 128 | # N x C x H x W 129 | pred_one_hot = pred 130 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 131 | 132 | # N x C x H x W 133 | inter = pred_one_hot * target_one_hot 134 | union = pred_one_hot + target_one_hot 135 | 136 | if mask is not None: 137 | mask = mask.view(N, 1, -1) 138 | inter = (inter.view(N, nclass, -1)*mask).sum(2) 139 | union = (union.view(N, nclass, -1)*mask).sum(2) 140 | else: 141 | # N x C 142 | inter = inter.view(N, nclass, -1).sum(2) 143 | union = union.view(N, nclass, -1).sum(2) 144 | 145 | # smooth to prevent overfitting 146 | # [https://github.com/pytorch/pytorch/issues/1249] 147 | # NxC 148 | dice = (2 * inter + self.smooth) / (union + self.smooth) 149 | return 1 - dice.mean() 150 | 151 | 152 | def softmax_mse_loss(input_logits, target_logits): 153 | """Takes softmax on both sides and returns MSE loss 154 | Note: 155 | - Returns the sum over all examples. Divide by the batch size afterwards 156 | if you want the mean. 157 | - Sends gradients to inputs but not the targets. 158 | """ 159 | assert input_logits.size() == target_logits.size() 160 | input_softmax = F.softmax(input_logits, dim=1) 161 | target_softmax = F.softmax(target_logits, dim=1) 162 | 163 | mse_loss = (input_softmax-target_softmax)**2 164 | return mse_loss 165 | 166 | 167 | #针对多分类问题,二分类问题更简单一点 168 | class SoftIoULoss(nn.Module): 169 | def __init__(self, nclass, class_weights=None, smooth=1e-5): 170 | super(SoftIoULoss, self).__init__() 171 | self.smooth = smooth 172 | if class_weights is None: 173 | # default weight is all 1 174 | self.class_weights = nn.Parameter(torch.ones((1, nclass)).type(torch.float32), requires_grad=False) 175 | else: 176 | class_weights = np.array(class_weights) 177 | assert nclass == class_weights.shape[0] 178 | self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype=torch.float32), requires_grad=False) 179 | 180 | def prob_forward(self, pred, target, mask=None): 181 | size = pred.size() 182 | N, nclass = size[0], size[1] 183 | # N x C x H x W 184 | pred_one_hot = pred.view(N, nclass, -1) 185 | target = target.view(N, 1, -1) 186 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 187 | 188 | # N x C x H x W 189 | inter = pred_one_hot * target_one_hot 190 | union = pred_one_hot + target_one_hot 191 | 192 | if mask is not None: 193 | mask = mask.view(N, 1, -1) 194 | inter = (inter.view(N, nclass, -1)*mask).sum(2) 195 | union = (union.view(N, nclass, -1)*mask).sum(2) 196 | else: 197 | # N x C 198 | inter = inter.view(N, nclass, -1).sum(2) 199 | union = union.view(N, nclass, -1).sum(2) 200 | 201 | # smooth to prevent overfitting 202 | # [https://github.com/pytorch/pytorch/issues/1249] 203 | # NxC 204 | dice = (2 * inter + self.smooth) / (union + self.smooth) 205 | return 1 - dice.mean() 206 | 207 | def forward(self, logits, target, mask=None): 208 | size = logits.size() 209 | N, nclass = size[0], size[1] 210 | 211 | logits = logits.view(N, nclass, -1) 212 | target = target.view(N, 1, -1) 213 | 214 | pred, nclass = get_probability(logits) 215 | 216 | # N x C x H x W 217 | pred_one_hot = pred 218 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 219 | 220 | # N x C x H x W 221 | inter = pred_one_hot * target_one_hot 222 | union = pred_one_hot + target_one_hot - inter 223 | 224 | if mask is not None: 225 | mask = mask.view(N, 1, -1) 226 | inter = (inter.view(N, nclass, -1)*mask).sum(2) 227 | union = (union.view(N, nclass, -1)*mask).sum(2) 228 | else: 229 | # N x C 230 | inter = inter.view(N, nclass, -1).sum(2) 231 | union = union.view(N, nclass, -1).sum(2) 232 | 233 | # smooth to prevent overfitting 234 | # [https://github.com/pytorch/pytorch/issues/1249] 235 | # NxC 236 | dice = (1 * inter + self.smooth) / (union + self.smooth) 237 | return 1 - dice.mean() 238 | -------------------------------------------------------------------------------- /utils1/statistic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2,torch 3 | from scipy import ndimage 4 | from sklearn.metrics.pairwise import pairwise_distances 5 | 6 | def dice_loss(masks, labels, is_average=True): 7 | """ 8 | dice loss 9 | :param masks: 10 | :param labels: 11 | :return: 12 | """ 13 | num = labels.size(0) 14 | 15 | m1 = masks.view(num, -1) 16 | m2 = labels.view(num, -1) 17 | 18 | intersection = (m1 * m2) 19 | 20 | score = (2 * intersection.sum(1)) / (m1.sum(1) + m2.sum(1)+1.0) 21 | if is_average: 22 | return score.sum() / num 23 | else: 24 | return score 25 | def dice_ratio(masks, labels, is_average=True): 26 | """ 27 | dice ratio 28 | :param masks: 29 | :param labels: 30 | :return: 31 | """ 32 | masks = masks.cpu() 33 | labels = labels.cpu() 34 | 35 | m1 = masks.flatten() 36 | m2 = labels.flatten().float() 37 | 38 | intersection = m1 * m2 39 | score = (2 * intersection.sum()) / (m1.sum() + m2.sum()+1e-6) 40 | 41 | pre = intersection.sum() / np.max([m2.sum(), 1]) 42 | rec = intersection.sum() / np.max([m1.sum(), 1]) 43 | 44 | return score#, pre, rec 45 | def dice_mc(masks, labels, classes): 46 | 47 | 48 | num = labels.size(0) 49 | 50 | class_dice = torch.zeros(num) 51 | per_class_dice = torch.zeros(num,classes) 52 | per_class_cnt = torch.zeros(num,classes) 53 | 54 | total_insect = 0.0 55 | total_pred = 0.0 56 | total_labs = 0.0 57 | 58 | for i in range(num): 59 | for n in range(1,classes): 60 | if (labels[i]==n).sum(): 61 | pred = (masks[i]==n) 62 | labs = (labels[i]==n) 63 | insect = pred*labs 64 | per_class_dice[i,n-1] =(2 * insect.sum()).float() / (pred.sum() + labs.sum()).float() 65 | per_class_cnt[i,n-1] +=1 66 | 67 | total_insect += insect.sum() 68 | total_pred += pred.sum() 69 | total_labs += labs.sum() 70 | 71 | class_dice[i] = (2*total_insect).float()/ (total_pred + total_labs).float() 72 | 73 | aver_dice = class_dice.sum()/num 74 | per_class_dice = per_class_dice.sum(0)/(per_class_cnt.sum(0)+1e-5) 75 | return aver_dice,per_class_dice 76 | 77 | def dice_m(masks, labels, classes): 78 | 79 | 80 | num = labels.size(0) 81 | 82 | m1 = masks.view(num, -1) 83 | m2 = labels.view(num, -1) 84 | 85 | 86 | class_dice = torch.zeros(num) 87 | per_class_dice = torch.zeros(num,classes) 88 | m1_cnt = torch.zeros(num,classes) 89 | m2_cnt = torch.zeros(num,classes) 90 | insect_cnt = torch.zeros(num,classes) 91 | 92 | for i in range(num): 93 | for j in range(m1.shape[1]): 94 | if m1[i,j]!=0: 95 | if m1[i,j]==m2[i,j]: 96 | insect_cnt[i,m1[i,j]-1] += 1 97 | m1_cnt[i,m1[i,j]-1] += 1 98 | if m2[i,j]!=0: 99 | m2_cnt[i,m2[i,j]-1] += 1 100 | 101 | 102 | per_class_dice[i] =(2 * insect_cnt[i]) / (m1_cnt[i] + m2_cnt[i]) 103 | 104 | class_dice[i] = (2*insect_cnt[i].sum())/ (m1_cnt[i].sum() + m2_cnt[i].sum()) 105 | class_dice = class_dice.sum()/num 106 | per_class_dice = per_class_dice.sum(0)/num 107 | return class_dice,per_class_dice 108 | 109 | 110 | 111 | def hausdorff_mad_distance(set1, set2, max_ahd=np.inf): 112 | """ 113 | Compute the Averaged Hausdorff Distance function 114 | between two unordered sets of points (the function is symmetric). 115 | Batches are not supported, so squeeze your inputs first! 116 | :param set1: Array/list where each row/element is an N-dimensional point. 117 | :param set2: Array/list where each row/element is an N-dimensional point. 118 | :param max_ahd: Maximum AHD possible to return if any set is empty. Default: inf. 119 | :return: The Hausdorff Distance and Mean Absolute Distance between set1 and set2. 120 | """ 121 | 122 | if len(set1) == 0 or len(set2) == 0: 123 | return max_ahd 124 | 125 | set1 = np.array(set1.cpu()) 126 | set2 = np.array(set2.cpu()) 127 | 128 | assert set1.ndim == 2, 'got %s' % set1.ndim 129 | assert set2.ndim == 2, 'got %s' % set2.ndim 130 | 131 | assert set1.shape[1] == set2.shape[1], \ 132 | 'The points in both sets must have the same number of dimensions, got %s and %s.'\ 133 | % (set2.shape[1], set2.shape[1]) 134 | 135 | d2_matrix = pairwise_distances(set1, set2, metric='euclidean') 136 | 137 | d12 = np.min(d2_matrix, axis=0) 138 | d21 = np.min(d2_matrix, axis=1) 139 | #print(d12.size,d21.size) 140 | 141 | hd = np.max([np.max(d12),np.max(d21),0]) 142 | 143 | # sorted_d12 = np.sort(d12) 144 | # sorted_d21 = np.sort(d21) 145 | # num12 = np.int(np.round(np.size(d12) * 0.95)) 146 | # num21 = np.int(np.round(np.size(d21) * 0.95)) 147 | # 148 | # mhd = np.max([sorted_d12[num12], sorted_d21[num21]]) 149 | 150 | # mad = 0.5*(np.average(d12)+np.average(d21)) 151 | 152 | return hd#, mhd#, mad 153 | 154 | def acc(masks, labels): 155 | 156 | # labels = labels.cpu().numpy() 157 | # masks = masks.cpu().numpy() 158 | 159 | m1 = masks.flatten() 160 | m2 = labels.flatten() 161 | 162 | same = (m1 == m2).sum().float() 163 | diff = (m1 != m2).sum().float() 164 | 165 | intersection = m1 * m2 166 | same1 = intersection.sum() 167 | same0 = same - intersection.sum() 168 | acc = same/m2.size(0) 169 | return acc,same,m2.size(0)#,same0,same1,diff 170 | 171 | def acc_test(masks, labels, masks_con): 172 | 173 | masks1 = masks.flatten() 174 | lab1 = labels.flatten() 175 | 176 | masks1 = masks1.cpu().numpy() 177 | loc = np.argwhere(masks1==0) 178 | masks2 = masks_con.flatten()[loc] 179 | # masks3 = masks_rad.flatten()[loc] 180 | 181 | # print(masks2.max(),masks2.min(),masks2.sum()) 182 | # # print(masks3.max(),masks3.min(), masks3.sum()) 183 | lab2 = lab1[loc] 184 | # print(type(masks2), type(lab2)) 185 | m1 = masks2 186 | m2 = lab2 187 | 188 | same = (m1 == m2).sum().float() 189 | intersection = m1 * m2 190 | same1 = intersection.sum()#/len(m2)#same 191 | same0 = (same - intersection.sum())#/len(m2)#same 192 | 193 | acc = same#/len(m2) 194 | dice = 2*intersection.sum().float()/((m1.sum() + m2.sum()+1.0)) 195 | 196 | mis0 = ((m1 != m2) & (m2 == 1)).sum().float()#/len(m2) 197 | mis1 = ((m1 != m2) & (m2 == 0)).sum().float()#/len(m2) 198 | 199 | # #2 200 | # same = 0 201 | # same0 = 0 202 | # same1 = 0 203 | # pred1 = 0 204 | # pred0 = 0 205 | # lab1 = 0 206 | # lab0 = 0 207 | # for i in len(loc): 208 | # if masks_con[i]==lab[i]: 209 | # same +=1 210 | # if masks_con[i] == 1: 211 | # same1 +=1 212 | # else: 213 | # same0 +=1 214 | # elif masks_con[i]==1: 215 | # pred1 +=1 216 | # else: 217 | # lab1 +=1 218 | 219 | # acc = same/len(loc).sum() 220 | 221 | # dice = 2*same1/(pred1+lab1) 222 | return acc,dice,same0,same1,mis0,mis1,len(m1)#,diff 223 | 224 | 225 | def acc_m(masks, labels, masks_con): 226 | masks1 = masks.flatten() 227 | lab1 = labels.flatten() 228 | 229 | masks1 = masks1.cpu().numpy() 230 | loc = np.argwhere(masks1 == 0) 231 | masks2 = masks_con.flatten()[loc] 232 | # masks3 = masks_rad.flatten()[loc] 233 | 234 | # print(masks2.max(),masks2.min(),masks2.sum()) 235 | # # print(masks3.max(),masks3.min(), masks3.sum()) 236 | lab2 = lab1.flatten()[loc] 237 | 238 | m1 = masks2 239 | m2 = lab2 240 | 241 | same = (m1 == m2).sum().float() 242 | intersection = m1 * m2 243 | same1 = intersection.sum()/same 244 | same0 = (same - intersection.sum())/same 245 | 246 | acc = same#/len(m2) 247 | dice = 2 * intersection.sum().float() / ((m1.sum() + m2.sum() + 1.0)) 248 | 249 | mis0 = ((m1 != m2) & (m2 == 1)).sum().float() # /len(m2) 250 | mis1 = ((m1 != m2) & (m2 == 0)).sum().float() # /len(m2) 251 | 252 | # #2 253 | # same = 0 254 | # same0 = 0 255 | # same1 = 0 256 | # pred1 = 0 257 | # pred0 = 0 258 | # lab1 = 0 259 | # lab0 = 0 260 | # for i in len(loc): 261 | # if masks_con[i]==lab[i]: 262 | # same +=1 263 | # if masks_con[i] == 1: 264 | # same1 +=1 265 | # else: 266 | # same0 +=1 267 | # elif masks_con[i]==1: 268 | # pred1 +=1 269 | # else: 270 | # lab1 +=1 271 | 272 | # acc = same/len(loc).sum() 273 | 274 | # dice = 2*same1/(pred1+lab1) 275 | return acc, dice, same0, same1#, mis0, mis1, len(m1) # ,diff 276 | 277 | 278 | def pre_rec(masks, labels): 279 | """ 280 | dice ratio 281 | :param masks: 282 | :param labels: 283 | :return: 284 | """ 285 | m1 = masks.flatten() 286 | m2 = labels.flatten().float() 287 | 288 | intersection = m1 * m2 289 | 290 | pre = intersection.sum() / (m1.sum()+1e-6) 291 | rec = intersection.sum() / (m2.sum()+1e-6) 292 | 293 | return pre, rec -------------------------------------------------------------------------------- /dataset/LeftAtrium.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | from torchvision.transforms import Compose 10 | import json 11 | import cv2 12 | import SimpleITK as sitk 13 | 14 | 15 | def get_dataset_path(dataset='LA'): 16 | files = ['train_lab.txt', 'train_unlab.txt', 'test.txt'] 17 | return ['/'.join(['data_lists_cora', dataset, f]) for f in files] 18 | 19 | class LAHeart(Dataset): 20 | """ LA Dataset """ 21 | 22 | def __init__(self, base_dir, dataset_name, split, no_crop=False, require_mask=False): 23 | self._base_dir = base_dir 24 | 25 | self._base_dir = base_dir 26 | self.split = split 27 | self.require_mask = require_mask 28 | 29 | tr_transform = Compose([ 30 | # RandomRotFlip(), 31 | RandomCrop((112, 112, 80)), 32 | # RandomNoise(), 33 | ToTensor() 34 | ]) 35 | if no_crop: 36 | test_transform = Compose([ 37 | # CenterCrop((160, 160, 128)), 38 | CenterCrop((112, 112, 80)), 39 | ToTensor() 40 | ]) 41 | else: 42 | test_transform = Compose([ 43 | CenterCrop((112, 112, 80)), 44 | ToTensor() 45 | ]) 46 | 47 | data_list_paths = get_dataset_path(dataset_name) 48 | 49 | if split == 'train_lab': 50 | data_path = data_list_paths[0] 51 | self.transform = tr_transform 52 | elif split == 'train_unlab': 53 | data_path = data_list_paths[1] 54 | self.transform = test_transform #tr_transform 55 | else: 56 | data_path = data_list_paths[2] 57 | self.transform = test_transform 58 | 59 | with open(data_path, 'r') as f: 60 | self.image_list = f.readlines() 61 | 62 | self.image_list = [item.replace('\n', '') for item in self.image_list] 63 | self.image_list = [os.path.join(self._base_dir, item, "mri_norm2.h5") for item in self.image_list] 64 | 65 | print("total {} samples".format(len(self.image_list))) 66 | 67 | def __len__(self): 68 | if self.split == 'train_lab': 69 | return len(self.image_list) * 5 70 | else: 71 | return len(self.image_list) 72 | 73 | def __getitem__(self, idx): 74 | image_path = self.image_list[idx % len(self.image_list)] 75 | h5f = h5py.File(image_path, 'r') 76 | image, label = h5f['image'][:], h5f['label'][:].astype(np.float32) 77 | 78 | if self.require_mask: 79 | mask = (label > 0).astype(np.uint8) 80 | samples = image, label, mask 81 | if self.transform: 82 | tr_samples = self.transform(samples) 83 | image_, label_, mask_ = tr_samples 84 | return image_.float(), label_.long(), mask_.long() 85 | else: 86 | samples = image, label 87 | if self.transform: 88 | tr_samples = self.transform(samples) 89 | image_, label_ = tr_samples 90 | return image_.float(), label_.long() 91 | 92 | 93 | 94 | class MaxCenterCrop(object): 95 | def __init__(self, scale=16): 96 | self.output_scale = scale 97 | 98 | def _get_transform(self, label): 99 | max_v = max(label.shape) 100 | n = (max_v // self.output_scale) 101 | output_size = n * self.output_scale 102 | 103 | if label.shape[0] <= output_size[0] or label.shape[1] <= output_size[1] or label.shape[2] <= output_size[2]: 104 | pw = max((output_size[0] - label.shape[0]) // 2 + 1, 0) 105 | ph = max((output_size[1] - label.shape[1]) // 2 + 1, 0) 106 | pd = max((output_size[2] - label.shape[2]) // 2 + 1, 0) 107 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 108 | else: 109 | pw, ph, pd = 0, 0, 0 110 | 111 | (w, h, d) = label.shape 112 | w1 = int(round((w - output_size[0]) / 2.)) 113 | h1 = int(round((h - output_size[1]) / 2.)) 114 | d1 = int(round((d - output_size[2]) / 2.)) 115 | 116 | def do_transform(x): 117 | if x.shape[0] <= output_size[0] or x.shape[1] <= output_size[1] or x.shape[2] <= output_size[2]: 118 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 119 | x = x[w1:w1 + output_size[0], h1:h1 + output_size[1], d1:d1 + output_size[2]] 120 | return x 121 | return do_transform 122 | 123 | def __call__(self, samples): 124 | transform = self._get_transform(samples[0]) 125 | return [transform(s) for s in samples] 126 | 127 | 128 | class CenterCrop(object): 129 | def __init__(self, output_size): 130 | self.output_size = output_size 131 | 132 | def _get_transform(self, label): 133 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 134 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 1, 0) 135 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 1, 0) 136 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 1, 0) 137 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 138 | else: 139 | pw, ph, pd = 0, 0, 0 140 | 141 | (w, h, d) = label.shape 142 | w1 = int(round((w - self.output_size[0]) / 2.)) 143 | h1 = int(round((h - self.output_size[1]) / 2.)) 144 | d1 = int(round((d - self.output_size[2]) / 2.)) 145 | 146 | def do_transform(x): 147 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 148 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 149 | x = x[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 150 | return x 151 | return do_transform 152 | 153 | def __call__(self, samples): 154 | transform = self._get_transform(samples[0]) 155 | return [transform(s) for s in samples] 156 | 157 | 158 | class RandomCrop(object): 159 | """ 160 | Crop randomly the image in a sample 161 | Args: 162 | output_size (int): Desired output size 163 | """ 164 | 165 | def __init__(self, output_size, with_sdf=False): 166 | self.output_size = output_size 167 | self.with_sdf = with_sdf 168 | 169 | def _get_transform(self, x): 170 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 171 | pw = max((self.output_size[0] - x.shape[0]) // 2 + 1, 0) 172 | ph = max((self.output_size[1] - x.shape[1]) // 2 + 1, 0) 173 | pd = max((self.output_size[2] - x.shape[2]) // 2 + 1, 0) 174 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 175 | else: 176 | pw, ph, pd = 0, 0, 0 177 | 178 | (w, h, d) = x.shape 179 | w1 = np.random.randint(0, w - self.output_size[0]) 180 | h1 = np.random.randint(0, h - self.output_size[1]) 181 | d1 = np.random.randint(0, d - self.output_size[2]) 182 | 183 | def do_transform(image): 184 | if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= self.output_size[2]: 185 | try: 186 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 187 | except Exception as e: 188 | print(e) 189 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 190 | return image 191 | return do_transform 192 | 193 | def __call__(self, samples): 194 | transform = self._get_transform(samples[0]) 195 | return [transform(s) for s in samples] 196 | 197 | 198 | class RandomRotFlip(object): 199 | """ 200 | Crop randomly flip the dataset in a sample 201 | Args: 202 | output_size (int): Desired output size 203 | """ 204 | 205 | def _get_transform(self, x): 206 | k = np.random.randint(0, 4) 207 | axis = np.random.randint(0, 2) 208 | def do_transform(image): 209 | image = np.rot90(image, k) 210 | image = np.flip(image, axis=axis).copy() 211 | return image 212 | return do_transform 213 | 214 | def __call__(self, samples): 215 | transform = self._get_transform(samples[0]) 216 | return [transform(s) for s in samples] 217 | 218 | 219 | class RandomNoise(object): 220 | def __init__(self, mu=0, sigma=0.1): 221 | self.mu = mu 222 | self.sigma = sigma 223 | 224 | def _get_transform(self, x): 225 | noise = np.clip(self.sigma * np.random.randn(x.shape[0], x.shape[1], x.shape[2]), -2 * self.sigma, 2 * self.sigma) 226 | noise = noise + self.mu 227 | def do_transform(image): 228 | image = image + noise 229 | return image 230 | return do_transform 231 | 232 | def __call__(self, samples): 233 | transform = self._get_transform(samples[0]) 234 | return [transform(s) if i == 0 else s for i, s in enumerate(samples)] 235 | 236 | 237 | class ToTensor(object): 238 | """Convert ndarrays in sample to Tensors.""" 239 | 240 | def __call__(self, sample): 241 | image = sample[0] 242 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 243 | sample = [image] + [*sample[1:]] 244 | return [torch.from_numpy(s.astype(np.float32)) for s in sample] 245 | 246 | 247 | if __name__ == '__main__': 248 | pass 249 | -------------------------------------------------------------------------------- /dataset/pancreas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | import numpy as np 6 | from glob import glob 7 | from torch.utils.data import Dataset 8 | import h5py 9 | import itertools 10 | from torch.utils.data.sampler import Sampler 11 | from torchvision.transforms import Compose 12 | 13 | 14 | 15 | def get_dataset_path(dataset='pancreas'): 16 | files = ['train_lab.txt', 'train_unlab.txt', 'test.txt', 'train_whole.txt'] 17 | return ['/'.join(['data_lists_cora', dataset, f]) for f in files] 18 | 19 | 20 | class Pancreas(Dataset): 21 | """ Pancreas Dataset """ 22 | 23 | def __init__(self, base_dir, name, split, no_crop=False, TTA=False, require_mask=False): 24 | self._base_dir = base_dir 25 | self.split = split 26 | self.require_mask = require_mask 27 | 28 | tr_transform = Compose([ 29 | # RandomRotFlip(), 30 | RandomCrop((96, 96, 96)), 31 | # RandomNoise(), 32 | ToTensor() 33 | ]) 34 | if no_crop: 35 | test_transform = Compose([ 36 | # CenterCrop((160, 160, 128)), 37 | CenterCrop((96, 96, 96)), 38 | ToTensor() 39 | ]) 40 | else: 41 | test_transform = Compose([ 42 | CenterCrop((96, 96, 96)), 43 | ToTensor() 44 | ]) 45 | 46 | data_list_paths = get_dataset_path(name) 47 | 48 | if split == 'train_lab': 49 | data_path = data_list_paths[0] 50 | self.transform = tr_transform 51 | elif split == 'train_unlab': 52 | data_path = data_list_paths[1] 53 | self.transform = test_transform#tr_transform# 54 | elif split == 'train_whole': 55 | data_path = data_list_paths[3] 56 | self.transform = tr_transform 57 | else: 58 | data_path = data_list_paths[2] 59 | self.transform = test_transform 60 | 61 | with open(data_path, 'r') as f: 62 | self.image_list = f.readlines() 63 | 64 | self.image_list = [self._base_dir + "/{}".format(item.strip()) + '.h5' for item in self.image_list] 65 | print("Split : {}, total {} samples".format(split, len(self.image_list))) 66 | 67 | def __len__(self): 68 | if self.split == 'train_lab': 69 | return len(self.image_list) * 5 70 | else: 71 | return len(self.image_list) 72 | 73 | def __getitem__(self, idx): 74 | image_path = self.image_list[idx % len(self.image_list)] 75 | h5f = h5py.File(image_path, 'r') 76 | image, label = h5f['image'][:], h5f['label'][:].astype(np.float32) 77 | 78 | if self.require_mask: 79 | mask = (label > 0).astype(np.uint8) 80 | samples = image, label, mask 81 | if self.transform: 82 | tr_samples = self.transform(samples) 83 | image_, label_, mask_ = tr_samples 84 | return image_.float(), label_.long(), mask_.long() 85 | else: 86 | samples = image, label 87 | if self.transform: 88 | tr_samples = self.transform(samples) 89 | image_, label_ = tr_samples 90 | return image_.float(), label_.long() 91 | 92 | 93 | class MaxCenterCrop(object): 94 | def __init__(self, scale=16): 95 | self.output_scale = scale 96 | 97 | def _get_transform(self, label): 98 | max_v = max(label.shape) 99 | n = (max_v // self.output_scale) 100 | output_size = n * self.output_scale 101 | 102 | if label.shape[0] <= output_size[0] or label.shape[1] <= output_size[1] or label.shape[2] <= output_size[2]: 103 | pw = max((output_size[0] - label.shape[0]) // 2 + 1, 0) 104 | ph = max((output_size[1] - label.shape[1]) // 2 + 1, 0) 105 | pd = max((output_size[2] - label.shape[2]) // 2 + 1, 0) 106 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 107 | else: 108 | pw, ph, pd = 0, 0, 0 109 | 110 | (w, h, d) = label.shape 111 | w1 = int(round((w - output_size[0]) / 2.)) 112 | h1 = int(round((h - output_size[1]) / 2.)) 113 | d1 = int(round((d - output_size[2]) / 2.)) 114 | 115 | def do_transform(x): 116 | if x.shape[0] <= output_size[0] or x.shape[1] <= output_size[1] or x.shape[2] <= output_size[2]: 117 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 118 | x = x[w1:w1 + output_size[0], h1:h1 + output_size[1], d1:d1 + output_size[2]] 119 | return x 120 | return do_transform 121 | 122 | def __call__(self, samples): 123 | transform = self._get_transform(samples[0]) 124 | return [transform(s) for s in samples] 125 | 126 | 127 | class CenterCrop(object): 128 | def __init__(self, output_size): 129 | self.output_size = output_size 130 | 131 | def _get_transform(self, label): 132 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 133 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 1, 0) 134 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 1, 0) 135 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 1, 0) 136 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 137 | else: 138 | pw, ph, pd = 0, 0, 0 139 | 140 | (w, h, d) = label.shape 141 | w1 = int(round((w - self.output_size[0]) / 2.)) 142 | h1 = int(round((h - self.output_size[1]) / 2.)) 143 | d1 = int(round((d - self.output_size[2]) / 2.)) 144 | 145 | def do_transform(x): 146 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 147 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 148 | x = x[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 149 | return x 150 | return do_transform 151 | 152 | def __call__(self, samples): 153 | transform = self._get_transform(samples[0]) 154 | return [transform(s) for s in samples] 155 | 156 | 157 | class RandomCrop(object): 158 | """ 159 | Crop randomly the image in a sample 160 | Args: 161 | output_size (int): Desired output size 162 | """ 163 | 164 | def __init__(self, output_size, with_sdf=False): 165 | self.output_size = output_size 166 | self.with_sdf = with_sdf 167 | 168 | def _get_transform(self, x): 169 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 170 | pw = max((self.output_size[0] - x.shape[0]) // 2 + 1, 0) 171 | ph = max((self.output_size[1] - x.shape[1]) // 2 + 1, 0) 172 | pd = max((self.output_size[2] - x.shape[2]) // 2 + 1, 0) 173 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 174 | else: 175 | pw, ph, pd = 0, 0, 0 176 | 177 | (w, h, d) = x.shape 178 | w1 = np.random.randint(0, w - self.output_size[0]) 179 | h1 = np.random.randint(0, h - self.output_size[1]) 180 | d1 = np.random.randint(0, d - self.output_size[2]) 181 | 182 | def do_transform(image): 183 | if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= self.output_size[2]: 184 | try: 185 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 186 | except Exception as e: 187 | print(e) 188 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 189 | return image 190 | return do_transform 191 | 192 | def __call__(self, samples): 193 | transform = self._get_transform(samples[0]) 194 | return [transform(s) for s in samples] 195 | 196 | 197 | class RandomRotFlip(object): 198 | """ 199 | Crop randomly flip the dataset in a sample 200 | Args: 201 | output_size (int): Desired output size 202 | """ 203 | 204 | def _get_transform(self, x): 205 | k = np.random.randint(0, 4) 206 | axis = np.random.randint(0, 2) 207 | def do_transform(image): 208 | image = np.rot90(image, k) 209 | image = np.flip(image, axis=axis).copy() 210 | return image 211 | return do_transform 212 | 213 | def __call__(self, samples): 214 | transform = self._get_transform(samples[0]) 215 | return [transform(s) for s in samples] 216 | 217 | 218 | class RandomNoise(object): 219 | def __init__(self, mu=0, sigma=0.1): 220 | self.mu = mu 221 | self.sigma = sigma 222 | 223 | def _get_transform(self, x): 224 | noise = np.clip(self.sigma * np.random.randn(x.shape[0], x.shape[1], x.shape[2]), -2 * self.sigma, 2 * self.sigma) 225 | noise = noise + self.mu 226 | def do_transform(image): 227 | image = image + noise 228 | return image 229 | return do_transform 230 | 231 | def __call__(self, samples): 232 | transform = self._get_transform(samples[0]) 233 | return [transform(s) if i == 0 else s for i, s in enumerate(samples)] 234 | 235 | 236 | class ToTensor(object): 237 | """Convert ndarrays in sample to Tensors.""" 238 | 239 | def __call__(self, sample): 240 | image = sample[0] 241 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 242 | sample = [image] + [*sample[1:]] 243 | return [torch.from_numpy(s.astype(np.float32)) for s in sample] 244 | 245 | 246 | if __name__ == '__main__': 247 | pass 248 | -------------------------------------------------------------------------------- /vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'batchnorm': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'groupnorm': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'instancenorm': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | 111 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 112 | 113 | ops.append(nn.ReLU(inplace=True)) 114 | 115 | self.conv = nn.Sequential(*ops) 116 | 117 | def forward(self, x): 118 | x = self.conv(x) 119 | return x 120 | 121 | 122 | class Upsampling(nn.Module): 123 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 124 | super(Upsampling, self).__init__() 125 | 126 | ops = [] 127 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 128 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 129 | if normalization == 'batchnorm': 130 | ops.append(nn.BatchNorm3d(n_filters_out)) 131 | elif normalization == 'groupnorm': 132 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 133 | elif normalization == 'instancenorm': 134 | ops.append(nn.InstanceNorm3d(n_filters_out)) 135 | elif normalization != 'none': 136 | assert False 137 | ops.append(nn.ReLU(inplace=True)) 138 | 139 | self.conv = nn.Sequential(*ops) 140 | 141 | def forward(self, x): 142 | x = self.conv(x) 143 | return x 144 | 145 | 146 | class VNet(nn.Module): 147 | def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False, n_branches = 3): 148 | super(VNet, self).__init__() 149 | self.has_dropout = has_dropout 150 | 151 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 152 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 153 | 154 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 155 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 156 | 157 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 158 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 159 | 160 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 161 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 162 | 163 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 164 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 165 | 166 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 167 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 168 | 169 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 170 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 171 | 172 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 173 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 174 | if has_dropout: 175 | self.dropout = nn.Dropout3d(p=0.5) 176 | self.branchs = nn.ModuleList() 177 | self.al_input = 0 178 | 179 | for i in range(n_branches): 180 | if has_dropout: 181 | seq = nn.Sequential( 182 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 183 | nn.Dropout3d(p=0.5), 184 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 185 | ) 186 | else: 187 | seq = nn.Sequential( 188 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 189 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 190 | ) 191 | self.branchs.append(seq) 192 | # self.block_nine = 193 | # self.out_conv = 194 | 195 | # self.__init_weight() 196 | 197 | def encoder(self, input): 198 | x1 = self.block_one(input) 199 | x1_dw = self.block_one_dw(x1) 200 | 201 | x2 = self.block_two(x1_dw) 202 | x2_dw = self.block_two_dw(x2) 203 | 204 | x3 = self.block_three(x2_dw) 205 | x3_dw = self.block_three_dw(x3) 206 | 207 | x4 = self.block_four(x3_dw) 208 | x4_dw = self.block_four_dw(x4) 209 | 210 | x5 = self.block_five(x4_dw) 211 | 212 | if self.has_dropout: 213 | x5 = self.dropout(x5) 214 | 215 | res = [x1, x2, x3, x4, x5] 216 | 217 | return res 218 | 219 | def decoder(self, features): 220 | x1 = features[0] 221 | x2 = features[1] 222 | x3 = features[2] 223 | x4 = features[3] 224 | x5 = features[4] 225 | 226 | x5_up = self.block_five_up(x5) 227 | x5_up = x5_up + x4 228 | 229 | x6 = self.block_six(x5_up) 230 | x6_up = self.block_six_up(x6) 231 | x6_up = x6_up + x3 232 | 233 | x7 = self.block_seven(x6_up) 234 | x7_up = self.block_seven_up(x7) 235 | x7_up = x7_up + x2 236 | 237 | x8 = self.block_eight(x7_up) 238 | x8_up = self.block_eight_up(x8) 239 | x8_up = x8_up + x1 240 | out = [] 241 | self.al_input = x8_up 242 | for branch in self.branchs: 243 | o = branch(x8_up) 244 | out.append(o) 245 | # x9 = self.block_nine(x8_up) 246 | # # x9 = F.dropout3d(x9, p=0.5, training=True) 247 | # if self.has_dropout: 248 | # x9 = self.dropout(x9) 249 | # out = self.out_conv(x9) 250 | return out 251 | 252 | def forward(self, input, turnoff_drop=False): 253 | if turnoff_drop: 254 | has_dropout = self.has_dropout 255 | self.has_dropout = False 256 | features = self.encoder(input) 257 | out = self.decoder(features) 258 | if turnoff_drop: 259 | self.has_dropout = has_dropout 260 | return out 261 | 262 | 263 | if __name__ == '__main__': 264 | pass 265 | -------------------------------------------------------------------------------- /train_panc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | import logging 15 | import utils1.loss 16 | from dataset.make_dataset import make_data_3d 17 | from dataset.pancreas import Pancreas 18 | from test_util import test_calculate_metric 19 | from utils1 import statistic, ramps 20 | from utils1.loss import DiceLoss, SoftIoULoss 21 | from utils1.losses import FocalLoss 22 | from utils1.ResampleLoss import ResampleLossMCIntegral 23 | from vnet import VNet 24 | from aleatoric import StochasticDeepMedic 25 | import logging 26 | import sys 27 | import argparse 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--al_weight', type=float, default=0.1, help='the weight of aleatoric uncertainty loss') 31 | parser.add_argument('--gpu', type=str, default='1', help='GPU to use') 32 | 33 | args = parser.parse_args() 34 | 35 | al_weight = args.al_weight 36 | 37 | res_dir = 'result/pancreas_VNet_{}_seed/'.format(al_weight) 38 | 39 | if not os.path.exists(res_dir): 40 | os.makedirs(res_dir) 41 | 42 | logging.basicConfig(filename=res_dir + "log.txt", level=logging.INFO, 43 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 44 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 45 | logging.info('New Exp :') 46 | 47 | # 2,1 48 | # 因为加入了后面 aleatoric loss 的部分 gpu设置为多个点话会有问题 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | # Parameters 51 | num_class = 2 52 | base_dim = 8 53 | 54 | batch_size = 2 55 | lr = 1e-3 56 | beta1, beta2 = 0.5, 0.999 57 | 58 | 59 | # log settings & test 60 | pretraining_epochs = 60 61 | self_training_epochs = 201 62 | thres = 0.5 63 | pretrain_save_step = 10 64 | st_save_step = 10 65 | pred_step = 10 66 | 67 | r18 = False 68 | dataset_name = 'pancreas' 69 | data_root = '../pancreas/Pancreas-processed' 70 | cost_num = 3 71 | 72 | alpha = 0.99 73 | consistency = 1 74 | consistency_rampup = 40 75 | 76 | 77 | class AverageMeter(object): 78 | """Computes and stores the average and current value""" 79 | 80 | def __init__(self): 81 | self.reset() 82 | 83 | def reset(self): 84 | self.val = 0 85 | self.avg = 0 86 | self.sum = 0 87 | self.count = 0 88 | return self 89 | 90 | def update(self, val, n=1): 91 | self.val = val 92 | self.sum += val 93 | self.count += n 94 | self.avg = self.sum / self.count 95 | return self 96 | 97 | 98 | def set_random_seed(seed): 99 | random.seed(seed) 100 | np.random.seed(seed) 101 | torch.manual_seed(seed) 102 | torch.cuda.manual_seed(seed) 103 | 104 | 105 | def get_current_consistency_weight(epoch): 106 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 107 | return ramps.sigmoid_rampup(epoch, consistency_rampup) 108 | 109 | 110 | def update_ema_variables(model, ema_model, alpha, global_step): 111 | # Use the true average until the exponential average is more correct 112 | alpha = min(1 - 1 / (global_step + 1), alpha) 113 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 114 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) 115 | 116 | 117 | def create_model(ema=False): 118 | net = nn.DataParallel(VNet(n_branches=4)) 119 | model = net.cuda() 120 | if ema: 121 | for param in model.parameters(): 122 | param.detach_() 123 | return model 124 | 125 | 126 | def get_model_and_dataloader(): 127 | """Net & optimizer""" 128 | net = create_model() 129 | ema_net = create_model(ema=True).cuda() 130 | optimizer = optim.Adam(net.parameters(), lr=lr, betas=(beta1, beta2)) 131 | # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-4) 132 | 133 | """Loading Dataset""" 134 | logging.info("loading dataset") 135 | 136 | trainset_lab = Pancreas(data_root, dataset_name, split='train_lab', require_mask=True) 137 | lab_loader = DataLoader(trainset_lab, batch_size=batch_size, shuffle=False, num_workers=0) 138 | 139 | trainset_unlab = Pancreas(data_root, dataset_name, split='train_unlab', no_crop=True) 140 | unlab_loader = DataLoader(trainset_unlab, batch_size=1, shuffle=False, num_workers=0) 141 | 142 | testset = Pancreas(data_root, dataset_name, split='test') 143 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 144 | return net, ema_net, optimizer, lab_loader, unlab_loader, test_loader 145 | 146 | 147 | def save_net_opt(net, optimizer, path, epoch): 148 | state = { 149 | 'net': net.state_dict(), 150 | 'opt': optimizer.state_dict(), 151 | 'epoch': epoch, 152 | } 153 | torch.save(state, str(path)) 154 | 155 | 156 | def load_net_opt(net, optimizer, path): 157 | state = torch.load(str(path)) 158 | net.load_state_dict(state['net']) 159 | optimizer.load_state_dict(state['opt']) 160 | logging.info('Loaded from {}'.format(path)) 161 | 162 | 163 | def transform_label(label): 164 | s = label.shape 165 | res = torch.zeros(s[0], 2, s[1], s[2], s[3]).cuda() 166 | 167 | mask = (label == 0).long().unsqueeze(1).cuda() 168 | res[:, 0, :, :, :][mask] = 1 169 | 170 | mask = (label == 1).long().unsqueeze(1).cuda() 171 | res[:, 1, :, :, :][mask] = 1 172 | 173 | return res 174 | 175 | 176 | def pretrain(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader, start_epoch=1): 177 | save_path = Path(res_dir) / 'pretrain' 178 | save_path.mkdir(exist_ok=True) 179 | logging.info("Save path : {}".format(save_path)) 180 | 181 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) 182 | 183 | DICE = DiceLoss(nclass=2) 184 | Focal = FocalLoss() 185 | Iou = SoftIoULoss(nclass=2) 186 | 187 | maxdice1 = 0 188 | 189 | iter_num = 0 190 | 191 | for epoch in tqdm(range(start_epoch, pretraining_epochs + 1), ncols=70): 192 | logging.info('\n') 193 | """Testing""" 194 | if epoch % pretrain_save_step == 0: 195 | # maxdice, _ = test(net, unlab_loader, maxdice, max_flag) 196 | val_dice, maxdice1, max_flag = test(net, test_loader, maxdice1) 197 | 198 | writer.add_scalar('pretrain/test_dice', val_dice, epoch) 199 | 200 | save_net_opt(net, optimizer, save_path / ('%d.pth' % epoch), epoch) 201 | logging.info('Save model : {}'.format(epoch)) 202 | if max_flag: 203 | save_net_opt(net, optimizer, save_path / 'best.pth', epoch) 204 | save_net_opt(ema_net, optimizer, save_path / 'best_ema.pth', epoch) 205 | 206 | train_loss, train_dice= \ 207 | AverageMeter(), AverageMeter() 208 | net.train() 209 | for step, (img, lab) in enumerate(lab_loader): 210 | img, lab = img.cuda(), lab.cuda() 211 | out = net(img) 212 | 213 | 214 | ce_loss = F.cross_entropy(out[0], lab) 215 | dice_loss = DICE(out[1], lab) 216 | focal_loss = Focal(out[2], lab) 217 | 218 | # backup plan 直接label做unsqueeze(1 219 | iou_loss = Iou(out[3], lab) 220 | loss = (ce_loss + dice_loss + focal_loss + iou_loss) / 4 221 | 222 | optimizer.zero_grad() 223 | loss.backward() 224 | optimizer.step() 225 | 226 | masks = get_mask(out[0]) 227 | train_dice.update(statistic.dice_ratio(masks, lab), 1) 228 | train_loss.update(loss.item(), 1) 229 | 230 | logging.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % (epoch, step, train_loss.avg, train_dice.avg)) 231 | 232 | writer.add_scalar('pretrain/loss_all', train_loss.avg, epoch * len(lab_loader) + step) 233 | writer.add_scalar('pretrain/train_dice', train_dice.avg, epoch * len(lab_loader) + step) 234 | update_ema_variables(net, ema_net, alpha, step) 235 | writer.flush() 236 | 237 | 238 | def count_param(model): 239 | param_count = 0 240 | for param in model.parameters(): 241 | param_count += param.view(-1).size()[0] 242 | return param_count 243 | 244 | 245 | def get_mask(out): 246 | probs = F.softmax(out, 1) 247 | masks = (probs >= thres).float() 248 | masks = masks[:, 1, :, :].contiguous() 249 | return masks 250 | 251 | 252 | def train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader): 253 | save_path = Path(res_dir) / 'panc_al_weight_{}'.format(al_weight) 254 | save_path.mkdir(exist_ok=True) 255 | logging.info("Save path : ", save_path) 256 | 257 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) 258 | pretrained_path = Path(res_dir) / 'pretrain' 259 | 260 | pretrained_path = Path('/home/xiangjinyi/semi_supervised/MLNet/result/pancreas_VNet_8_cora/pretrain_con_5.0') 261 | load_net_opt(net, optimizer, pretrained_path / 'best.pth') 262 | load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 263 | 264 | AL_module = nn.DataParallel(StochasticDeepMedic(num_classes=2)) 265 | AL_module = AL_module.cuda() 266 | 267 | # load_net_opt(net, optimizer, save_path / 'best.pth') 268 | # load_net_opt(ema_net, optimizer, save_path / 'best.pth') 269 | 270 | consistency_criterion = utils1.loss.softmax_mse_loss 271 | 272 | DICE = DiceLoss(nclass=2) 273 | CE = nn.CrossEntropyLoss() 274 | 275 | 276 | Focal = FocalLoss() 277 | Iou = SoftIoULoss(nclass = 2) 278 | SSLoss = ResampleLossMCIntegral(20) # 原来论文用的20 279 | 280 | maxdice = 0 281 | maxdice1 = 0 282 | 283 | iter_num = 0 284 | new_loader, plab_dice = pred_unlabel(net, unlab_loader) 285 | writer.add_scalar('acc/plab_dice', plab_dice, 0) 286 | 287 | for epoch in tqdm(range(0, self_training_epochs)): 288 | logging.info('') 289 | writer.flush() 290 | if epoch % pred_step == 0: 291 | new_loader, plab_dice = pred_unlabel(net, unlab_loader) 292 | 293 | if epoch % st_save_step == 0: 294 | """Testing""" 295 | # val_dice, maxdice, _ = test(net, unlab_loader, maxdice) 296 | val_dice, maxdice1, max_flag = test(net, test_loader, maxdice1) 297 | writer.add_scalar('acc/plab_dice', plab_dice, epoch) 298 | 299 | """Save model""" 300 | if epoch > 0: 301 | save_net_opt(net, optimizer, str(save_path / ('{}.pth'.format(epoch))), epoch) 302 | logging.info('Save model : {}'.format(epoch)) 303 | 304 | if max_flag: 305 | save_net_opt(net, optimizer, str(save_path / 'best.pth'), epoch) 306 | 307 | 308 | 309 | net.train() 310 | ema_net.train() 311 | for step, (data1, data2) in enumerate(zip(lab_loader, new_loader)): 312 | img1, lab1, lab_mask = data1 313 | img1, lab1, lab_mask = img1.cuda(), lab1.long().cuda(), lab_mask.long().cuda() 314 | img2, plab1, mask1, lab2 = data2 315 | img2, plab1, mask1 = img2.cuda(), plab1.long().cuda(), mask1.float().cuda() 316 | # plab2 = lab2.cuda() 317 | 318 | '''Supervised Loss''' 319 | out1 = net(img1) 320 | 321 | loss_ce1 = CE(out1[0], lab1) 322 | dice_loss1 = DICE(out1[1], lab1) 323 | focal_loss1 = Focal(out1[2], lab1) 324 | iou_loss1 = Iou(out1[3], lab1) 325 | 326 | logits, state = AL_module(net.module.al_input, lab_mask) 327 | state.update({'target': lab1}) 328 | al_loss = SSLoss(logits, **state) 329 | 330 | # al_loss is computed in logit space. But the essence is cross entropy loss 331 | # So it's better to include in this averaging process 332 | supervised_loss = (loss_ce1 + focal_loss1 + iou_loss1 + dice_loss1 + al_loss * al_weight) / (4 + al_weight) 333 | 334 | 335 | # mask = torch.zeros_like(mask).cuda(mask.device).float() 336 | 337 | '''Certain Areas''' 338 | out2 = net(img2) 339 | loss_ce2 = (CE(out2[0], plab1) * mask1).sum() / (mask1.sum() + 1e-16) 340 | focal_loss2 = (Focal(out2[2], plab1) * mask1).sum() / (mask1.sum() + 1e-16) # 341 | 342 | dice_loss2 = DICE(out2[1], plab1, mask1) 343 | iou_loss2 = Iou(out2[3], plab1, mask1) 344 | 345 | certain_loss = (loss_ce2 + dice_loss2 + focal_loss2 + iou_loss2) / 4 346 | 347 | 348 | '''Uncertain Areas---Mean Teacher''' 349 | mask1 = (1 - mask1).unsqueeze(1) 350 | with torch.no_grad(): 351 | out_ema = ema_net(img2) 352 | consistency_weight = consistency * get_current_consistency_weight(epoch) 353 | consistency_dist1 = consistency_criterion(out2[0], out_ema[0]) 354 | const_loss1 = consistency_weight * ((consistency_dist1 * mask1).sum() / (mask1.sum() + 1e-16)) 355 | consistency_dist2 = consistency_criterion(out2[1], out_ema[1]) 356 | const_loss2 = consistency_weight * ((consistency_dist2 * mask1).sum() / (mask1.sum() + 1e-16)) 357 | consistency_dist3 = consistency_criterion(out2[2], out_ema[2]) 358 | const_loss3 = consistency_weight * ((consistency_dist3 * mask1).sum() / (mask1.sum() + 1e-16)) 359 | consistency_dist4 = consistency_criterion(out2[3], out_ema[3]) 360 | const_loss4 = consistency_weight * ((consistency_dist4 * mask1).sum() / (mask1.sum() + 1e-16)) 361 | uncertain_loss = (const_loss1 + const_loss2 + const_loss3 + const_loss4) / 4 362 | # logging.info(uncertain_loss) 363 | 364 | 365 | loss = supervised_loss + certain_loss + uncertain_loss # uncertain_loss * 0.3 #+ certain_loss*0.5 366 | 367 | optimizer.zero_grad() 368 | loss.backward() 369 | optimizer.step() 370 | 371 | with torch.no_grad(): 372 | update_ema_variables(net, ema_net, alpha, iter_num + len(lab_loader) * pretraining_epochs) 373 | iter_num = iter_num + 1 374 | 375 | if epoch % st_save_step == 0: 376 | writer.add_scalar('val_dice', val_dice, epoch) 377 | 378 | 379 | 380 | 381 | @torch.no_grad() 382 | def pred_unlabel(net, pred_loader): 383 | logging.info('Starting predict unlab') 384 | unimg, unlab, unmask, labs = [], [], [], [] 385 | plab_dice = 0 386 | for (step, data) in enumerate(pred_loader): 387 | img, lab = data 388 | img, lab = img.cuda(), lab.cuda() 389 | out = net(img) 390 | plab0 = get_mask(out[0]) # cross entropy prediction 391 | plab1 = get_mask(out[1]) # dice loss prediction 392 | plab2 = get_mask(out[2]) # focal loss prediction 393 | plab3 = get_mask(out[3]) # Iou loss prediction 394 | 395 | mask = ((plab0 == plab2) & (plab1 == plab3)).long() 396 | 397 | unimg.append(img) 398 | unlab.append(plab2) # suppose results derived from focal loss are the best 399 | unmask.append(mask) 400 | 401 | labs.append(lab) 402 | 403 | plab_dice += statistic.dice_ratio(plab2, lab) 404 | plab_dice /= len(pred_loader) 405 | logging.info('Pseudo label dice : {}'.format(plab_dice)) 406 | new_loader1 = DataLoader(make_data_3d(unimg, unlab, unmask, labs), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) 407 | # new_loader2 = DataLoader(make_data(unimg2, unlab2), batch_size=batch_size, shuffle=True, num_workers=0) 408 | return new_loader1, plab_dice 409 | 410 | 411 | @torch.no_grad() 412 | def test(net, val_loader, maxdice=0): 413 | metrics = test_calculate_metric(net, val_loader.dataset) 414 | val_dice = metrics[0] 415 | 416 | if val_dice > maxdice: 417 | maxdice = val_dice 418 | max_flag = True 419 | else: 420 | max_flag = False 421 | logging.info('Evaluation : val_dice: %.4f, val_maxdice: %.4f' % (val_dice, maxdice)) 422 | return val_dice, maxdice, max_flag 423 | 424 | 425 | if __name__ == '__main__': 426 | set_random_seed(1337) 427 | net, ema_net, optimizer, lab_loader, unlab_loader, test_loader = get_model_and_dataloader() 428 | # load model 429 | # net.load_state_dict(torch.load(res_dir + '/model/best.pth')) 430 | # pretrained_path = Path(res_dir) / 'pretrain_con_{}_consistency_{}'.format(w_con[1].item(), consistency) 431 | 432 | # load_net_opt(net, optimizer, pretrained_path / 'best.pth') 433 | # load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 434 | # pretrain(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader, start_epoch=1) 435 | 436 | train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader) 437 | 438 | logging.info(count_param(net)) 439 | -------------------------------------------------------------------------------- /train_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | import logging 15 | import utils1.loss 16 | from dataset.make_dataset import make_data_3d 17 | from dataset.LeftAtrium import LAHeart 18 | from test_util import test_calculate_metric_LA 19 | from utils1 import statistic, ramps 20 | from utils1.loss import DiceLoss, SoftIoULoss 21 | from utils1.losses import FocalLoss 22 | from utils1.ResampleLoss import ResampleLossMCIntegral 23 | from vnet import VNet 24 | from aleatoric import StochasticDeepMedic 25 | import logging 26 | import sys 27 | import argparse 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--al_weight', type=float, default=0.8, help='the weight of aleatoric uncertainty loss') 31 | parser.add_argument('--gpu', type=str, default='1', help='GPU to use') 32 | 33 | args = parser.parse_args() 34 | 35 | al_weight = args.al_weight 36 | 37 | res_dir = 'LA_result/LA_{}_al/'.format(al_weight) 38 | 39 | if not os.path.exists(res_dir): 40 | os.makedirs(res_dir) 41 | 42 | logging.basicConfig(filename=res_dir + "log.txt", level=logging.INFO, 43 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 44 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 45 | logging.info('New Exp :') 46 | 47 | # 2,1 48 | # 因为加入了后面 aleatoric loss 的部分 gpu设置为多个点话会有问题 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | # Parameters 51 | num_class = 2 52 | base_dim = 8 53 | 54 | batch_size = 2 55 | lr = 1e-3 56 | beta1, beta2 = 0.5, 0.999 57 | 58 | 59 | # log settings & test 60 | pretraining_epochs = 40 61 | self_training_epochs = 301 62 | thres = 0.5 63 | pretrain_save_step = 5 64 | st_save_step = 10 65 | pred_step = 10 66 | 67 | r18 = False 68 | split_name = 'LA_dataset' 69 | data_root = '../LA_dataset' 70 | cost_num = 3 71 | 72 | alpha = 0.99 73 | consistency = 1 74 | consistency_rampup = 40 75 | 76 | 77 | class AverageMeter(object): 78 | """Computes and stores the average and current value""" 79 | 80 | def __init__(self): 81 | self.reset() 82 | 83 | def reset(self): 84 | self.val = 0 85 | self.avg = 0 86 | self.sum = 0 87 | self.count = 0 88 | return self 89 | 90 | def update(self, val, n=1): 91 | self.val = val 92 | self.sum += val 93 | self.count += n 94 | self.avg = self.sum / self.count 95 | return self 96 | 97 | 98 | def set_random_seed(seed): 99 | random.seed(seed) 100 | np.random.seed(seed) 101 | torch.manual_seed(seed) 102 | torch.cuda.manual_seed(seed) 103 | 104 | 105 | def get_current_consistency_weight(epoch): 106 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 107 | return ramps.sigmoid_rampup(epoch, consistency_rampup) 108 | 109 | 110 | def update_ema_variables(model, ema_model, alpha, global_step): 111 | # Use the true average until the exponential average is more correct 112 | alpha = min(1 - 1 / (global_step + 1), alpha) 113 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 114 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) 115 | 116 | 117 | def create_model(ema=False): 118 | net = nn.DataParallel(VNet(n_branches=4)) 119 | model = net.cuda() 120 | if ema: 121 | for param in model.parameters(): 122 | param.detach_() 123 | return model 124 | 125 | 126 | def get_model_and_dataloader(): 127 | """Net & optimizer""" 128 | net = create_model() 129 | ema_net = create_model(ema=True).cuda() 130 | optimizer = optim.Adam(net.parameters(), lr=lr, betas=(beta1, beta2)) 131 | # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-4) 132 | 133 | """Loading Dataset""" 134 | logging.info("loading dataset") 135 | 136 | trainset_lab = LAHeart(data_root, split_name, split='train_lab', require_mask=True) 137 | lab_loader = DataLoader(trainset_lab, batch_size=batch_size, shuffle=False, num_workers=0) 138 | 139 | trainset_unlab = LAHeart(data_root, split_name, split='train_unlab', no_crop=True) 140 | unlab_loader = DataLoader(trainset_unlab, batch_size=1, shuffle=False, num_workers=0) 141 | 142 | testset = LAHeart(data_root, split_name, split='test') 143 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 144 | return net, ema_net, optimizer, lab_loader, unlab_loader, test_loader 145 | 146 | 147 | def save_net_opt(net, optimizer, path, epoch): 148 | state = { 149 | 'net': net.state_dict(), 150 | 'opt': optimizer.state_dict(), 151 | 'epoch': epoch, 152 | } 153 | torch.save(state, str(path)) 154 | 155 | 156 | def load_net_opt(net, optimizer, path): 157 | state = torch.load(str(path)) 158 | net.load_state_dict(state['net']) 159 | optimizer.load_state_dict(state['opt']) 160 | logging.info('Loaded from {}'.format(path)) 161 | 162 | 163 | def transform_label(label): 164 | s = label.shape 165 | res = torch.zeros(s[0], 2, s[1], s[2], s[3]).cuda() 166 | 167 | mask = (label == 0).long().unsqueeze(1).cuda() 168 | res[:, 0, :, :, :][mask] = 1 169 | 170 | mask = (label == 1).long().unsqueeze(1).cuda() 171 | res[:, 1, :, :, :][mask] = 1 172 | 173 | return res 174 | 175 | 176 | def pretrain(net, ema_net, optimizer, start_epoch=1): 177 | 178 | trainset_lab = LAHeart(data_root, split_name, split='train_lab', require_mask=False) 179 | lab_loader = DataLoader(trainset_lab, batch_size=batch_size * 2, shuffle=False, num_workers=0) 180 | 181 | testset = LAHeart(data_root, split_name, split='test', require_mask=False) 182 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 183 | 184 | save_path = Path(res_dir) / 'pretrain' 185 | save_path.mkdir(exist_ok=True) 186 | logging.info("Save path : {}".format(save_path)) 187 | 188 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) 189 | 190 | DICE = DiceLoss(nclass=2) 191 | #CE_con = nn.CrossEntropyLoss(weight=w_con.cuda()) 192 | #CE_rad = nn.CrossEntropyLoss(weight=w_rad.cuda()) 193 | Focal = FocalLoss() 194 | Iou = SoftIoULoss(nclass=2) 195 | 196 | maxdice1 = 0 197 | 198 | iter_num = 0 199 | 200 | for epoch in tqdm(range(start_epoch, pretraining_epochs + 1), ncols=70): 201 | logging.info('\n') 202 | """Testing""" 203 | if epoch % pretrain_save_step == 0: 204 | # maxdice, _ = test(net, unlab_loader, maxdice, max_flag) 205 | val_dice, maxdice1, max_flag = test(net, test_loader, maxdice1) 206 | 207 | writer.add_scalar('pretrain/test_dice', val_dice, epoch) 208 | 209 | save_net_opt(net, optimizer, save_path / ('%d.pth' % epoch), epoch) 210 | logging.info('Save model : {}'.format(epoch)) 211 | if max_flag: 212 | save_net_opt(net, optimizer, save_path / 'best.pth', epoch) 213 | save_net_opt(ema_net, optimizer, save_path / 'best_ema.pth', epoch) 214 | 215 | train_loss, train_dice= \ 216 | AverageMeter(), AverageMeter() 217 | net.train() 218 | for step, (img, lab) in enumerate(lab_loader): 219 | img, lab = img.cuda(), lab.cuda() 220 | out = net(img) 221 | 222 | 223 | ce_loss = F.cross_entropy(out[0], lab) 224 | dice_loss = DICE(out[1], lab) 225 | focal_loss = Focal(out[2], lab) 226 | 227 | # backup plan 直接label做unsqueeze(1 228 | iou_loss = Iou(out[3], lab) 229 | loss = (ce_loss + dice_loss + focal_loss + iou_loss) / 4 230 | 231 | optimizer.zero_grad() 232 | loss.backward() 233 | optimizer.step() 234 | 235 | masks = get_mask(out[0]) 236 | train_dice.update(statistic.dice_ratio(masks, lab), 1) 237 | train_loss.update(loss.item(), 1) 238 | 239 | logging.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % (epoch, step, train_loss.avg, train_dice.avg)) 240 | 241 | 242 | writer.add_scalar('pretrain/loss_all', train_loss.avg, epoch * len(lab_loader) + step) 243 | writer.add_scalar('pretrain/train_dice', train_dice.avg, epoch * len(lab_loader) + step) 244 | update_ema_variables(net, ema_net, alpha, step) 245 | writer.flush() 246 | 247 | 248 | def count_param(model): 249 | param_count = 0 250 | for param in model.parameters(): 251 | param_count += param.view(-1).size()[0] 252 | return param_count 253 | 254 | 255 | def get_mask(out): 256 | probs = F.softmax(out, 1) 257 | masks = (probs >= thres).float() 258 | masks = masks[:, 1, :, :].contiguous() 259 | return masks 260 | 261 | 262 | def train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader): 263 | save_path = Path(res_dir) / 'LA_al_weight_{}'.format(al_weight) 264 | save_path.mkdir(exist_ok=True) 265 | logging.info("Save path : ", save_path) 266 | 267 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) 268 | pretrained_path = Path(res_dir) / 'pretrain' 269 | 270 | # load already pretrained models 271 | pretrained_path = Path('/home/xiangjinyi/semi_supervised/alnet/LA_result/pancreas_VNet_0.5_al_cora/pretrain_con_5.0') 272 | load_net_opt(net, optimizer, pretrained_path / 'best.pth') 273 | load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 274 | 275 | AL_module = nn.DataParallel(StochasticDeepMedic(num_classes=2)) 276 | AL_module = AL_module.cuda() 277 | 278 | # load_net_opt(net, optimizer, save_path / 'best.pth') 279 | # load_net_opt(ema_net, optimizer, save_path / 'best.pth') 280 | 281 | consistency_criterion = utils1.loss.softmax_mse_loss 282 | 283 | DICE = DiceLoss(nclass=2) 284 | CE = nn.CrossEntropyLoss() 285 | 286 | Focal = FocalLoss() 287 | Iou = SoftIoULoss(nclass = 2) 288 | SSLoss = ResampleLossMCIntegral(20) # 原来论文用的20 289 | 290 | maxdice = 0 291 | maxdice1 = 0 292 | 293 | iter_num = 0 294 | new_loader, plab_dice = pred_unlabel(net, unlab_loader) 295 | writer.add_scalar('acc/plab_dice', plab_dice, 0) 296 | 297 | for epoch in tqdm(range(0, self_training_epochs)): 298 | logging.info('') 299 | writer.flush() 300 | if epoch % pred_step == 0: 301 | new_loader, plab_dice = pred_unlabel(net, unlab_loader) 302 | 303 | if epoch % st_save_step == 0: 304 | """Testing""" 305 | # val_dice, maxdice, _ = test(net, unlab_loader, maxdice) 306 | val_dice, maxdice1, max_flag = test(net, test_loader, maxdice1) 307 | writer.add_scalar('acc/plab_dice', plab_dice, epoch) 308 | 309 | """Save model""" 310 | if epoch > 0: 311 | save_net_opt(net, optimizer, str(save_path / ('{}.pth'.format(epoch))), epoch) 312 | logging.info('Save model : {}'.format(epoch)) 313 | 314 | if max_flag: 315 | save_net_opt(net, optimizer, str(save_path / 'best.pth'), epoch) 316 | 317 | 318 | net.train() 319 | ema_net.train() 320 | for step, (data1, data2) in enumerate(zip(lab_loader, new_loader)): 321 | img1, lab1, lab_mask = data1 322 | img1, lab1, lab_mask = img1.cuda(), lab1.long().cuda(), lab_mask.long().cuda() 323 | img2, plab1, mask1, lab2 = data2 324 | img2, plab1, mask1 = img2.cuda(), plab1.long().cuda(), mask1.float().cuda() 325 | # plab2 = lab2.cuda() 326 | 327 | '''Supervised Loss''' 328 | out1 = net(img1) 329 | 330 | loss_ce1 = CE(out1[0], lab1) 331 | dice_loss1 = DICE(out1[1], lab1) 332 | focal_loss1 = Focal(out1[2], lab1) 333 | iou_loss1 = Iou(out1[3], lab1) 334 | 335 | logits, state = AL_module(net.module.al_input, lab_mask) 336 | state.update({'target': lab1}) 337 | al_loss = SSLoss(logits, **state) 338 | 339 | # al_loss is computed in logit space. But the essence is cross entropy loss 340 | # So it's better to include in this averaging process 341 | supervised_loss = (loss_ce1 + focal_loss1 + iou_loss1 + dice_loss1 + al_loss * al_weight) / (4 + al_weight) 342 | 343 | 344 | # mask = torch.zeros_like(mask).cuda(mask.device).float() 345 | 346 | '''Certain Areas''' 347 | out2 = net(img2) 348 | loss_ce2 = (CE(out2[0], plab1) * mask1).sum() / (mask1.sum() + 1e-16) 349 | focal_loss2 = (Focal(out2[2], plab1) * mask1).sum() / (mask1.sum() + 1e-16) # 350 | 351 | dice_loss2 = DICE(out2[1], plab1, mask1) 352 | iou_loss2 = Iou(out2[3], plab1, mask1) 353 | 354 | certain_loss = (loss_ce2 + dice_loss2 + focal_loss2 + iou_loss2) / 4 355 | 356 | 357 | '''Uncertain Areas---Mean Teacher''' 358 | mask1 = (1 - mask1).unsqueeze(1) 359 | with torch.no_grad(): 360 | out_ema = ema_net(img2) 361 | consistency_weight = consistency * get_current_consistency_weight(epoch) 362 | consistency_dist1 = consistency_criterion(out2[0], out_ema[0]) 363 | const_loss1 = consistency_weight * ((consistency_dist1 * mask1).sum() / (mask1.sum() + 1e-16)) 364 | consistency_dist2 = consistency_criterion(out2[1], out_ema[1]) 365 | const_loss2 = consistency_weight * ((consistency_dist2 * mask1).sum() / (mask1.sum() + 1e-16)) 366 | consistency_dist3 = consistency_criterion(out2[2], out_ema[2]) 367 | const_loss3 = consistency_weight * ((consistency_dist3 * mask1).sum() / (mask1.sum() + 1e-16)) 368 | consistency_dist4 = consistency_criterion(out2[3], out_ema[3]) 369 | const_loss4 = consistency_weight * ((consistency_dist4 * mask1).sum() / (mask1.sum() + 1e-16)) 370 | uncertain_loss = (const_loss1 + const_loss2 + const_loss3 + const_loss4) / 4 371 | # logging.info(uncertain_loss) 372 | 373 | 374 | loss = supervised_loss + certain_loss + uncertain_loss # uncertain_loss * 0.3 #+ certain_loss*0.5 375 | 376 | optimizer.zero_grad() 377 | loss.backward() 378 | optimizer.step() 379 | 380 | with torch.no_grad(): 381 | update_ema_variables(net, ema_net, alpha, iter_num + len(lab_loader) * pretraining_epochs) 382 | iter_num = iter_num + 1 383 | 384 | 385 | if epoch % st_save_step == 0: 386 | writer.add_scalar('val_dice', val_dice, epoch) 387 | 388 | 389 | @torch.no_grad() 390 | def pred_unlabel(net, pred_loader): 391 | logging.info('Starting predict unlab') 392 | unimg, unlab, unmask, labs = [], [], [], [] 393 | plab_dice = 0 394 | for (step, data) in enumerate(pred_loader): 395 | img, lab = data 396 | img, lab = img.cuda(), lab.cuda() 397 | out = net(img) 398 | plab0 = get_mask(out[0]) # cross entropy prediction 399 | plab1 = get_mask(out[1]) # dice loss prediction 400 | plab2 = get_mask(out[2]) # focal loss prediction 401 | plab3 = get_mask(out[3]) # Iou loss prediction 402 | 403 | mask = ((plab0 == plab2) & (plab1 == plab3)).long() 404 | 405 | unimg.append(img) 406 | unlab.append(plab2) # suppose results derived from focal loss are the best 407 | unmask.append(mask) 408 | 409 | labs.append(lab) 410 | 411 | plab_dice += statistic.dice_ratio(plab2, lab) 412 | plab_dice /= len(pred_loader) 413 | logging.info('Pseudo label dice : {}'.format(plab_dice)) 414 | new_loader1 = DataLoader(make_data_3d(unimg, unlab, unmask, labs), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) 415 | # new_loader2 = DataLoader(make_data(unimg2, unlab2), batch_size=batch_size, shuffle=True, num_workers=0) 416 | return new_loader1, plab_dice 417 | 418 | 419 | @torch.no_grad() 420 | def test(net, val_loader, maxdice=0): 421 | metrics = test_calculate_metric_LA(net, val_loader.dataset) 422 | val_dice = metrics[0] 423 | 424 | if val_dice > maxdice: 425 | maxdice = val_dice 426 | max_flag = True 427 | else: 428 | max_flag = False 429 | logging.info('Evaluation : val_dice: %.4f, val_maxdice: %.4f' % (val_dice, maxdice)) 430 | return val_dice, maxdice, max_flag 431 | 432 | 433 | if __name__ == '__main__': 434 | # set_random_seed(1337) 435 | net, ema_net, optimizer, lab_loader, unlab_loader, test_loader = get_model_and_dataloader() 436 | # load model 437 | # net.load_state_dict(torch.load(res_dir + '/model/best.pth')) 438 | # pretrained_path = Path(res_dir) / 'pretrain_con_{}_consistency_{}'.format(w_con[1].item(), consistency) 439 | 440 | # load_net_opt(net, optimizer, pretrained_path / 'best.pth') 441 | # load_net_opt(ema_net, optimizer, pretrained_path / 'best.pth') 442 | #pretrain(net, ema_net, optimizer, start_epoch=1) 443 | 444 | train(net, ema_net, optimizer, lab_loader, unlab_loader, test_loader) 445 | 446 | logging.info(count_param(net)) 447 | --------------------------------------------------------------------------------