├── .gitignore ├── .idea ├── .gitignore ├── Pytorch-Iternet.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── .ipynb_checkpoints ├── LICENSE-checkpoint └── main-checkpoint.py ├── LICENSE ├── README.md ├── augmentation ├── __pycache__ │ └── transforms.cpython-36.pyc └── transforms.py ├── dataset ├── __pycache__ │ └── dataset_retinal.cpython-36.pyc └── dataset_retinal.py ├── main.py ├── model ├── __pycache__ │ └── iternet.cpython-36.pyc └── iternet │ ├── .ipynb_checkpoints │ ├── iternet_model-checkpoint.py │ └── unet_parts-checkpoint.py │ ├── __pycache__ │ ├── iternet_model.cpython-36.pyc │ └── unet_parts.cpython-36.pyc │ ├── iternet_model.py │ └── unet_parts.py ├── requirements.txt ├── trainer ├── __pycache__ │ └── trainer.cpython-36.pyc └── trainer.py └── utils └── eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | .ipynb_checkpoints/* 3 | */.ipynb_checkpoints/* 4 | exp 5 | .idea 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/Pytorch-Iternet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 45 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/LICENSE-checkpoint: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 amri369 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/main-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from torch import optim 5 | from dataset.dataset_retinal import DatasetRetinal 6 | from augmentation.transforms import TransformImg, TransformImgMask 7 | from model.iternet.iternet_model import Iternet 8 | from trainer.trainer import Trainer 9 | 10 | import argparse 11 | 12 | 13 | def main(args): 14 | # set the transform 15 | transform_img_mask = TransformImgMask( 16 | size=(args.size, args.size), 17 | size_crop=(args.crop_size, args.crop_size), 18 | to_tensor=True 19 | ) 20 | 21 | # set datasets 22 | csv_dir = { 23 | 'train': args.train_csv, 24 | 'val': args.val_csv 25 | } 26 | datasets = { 27 | x: DatasetRetinal(csv_dir[x], 28 | args.image_dir, 29 | args.mask_dir, 30 | batch_size=args.batch_size, 31 | transform_img_mask=transform_img_mask, 32 | transform_img=TransformImg()) for x in ['train', 'val'] 33 | } 34 | 35 | # set dataloaders 36 | dataloaders = { 37 | x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val'] 38 | } 39 | 40 | # initialize the model 41 | model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3) 42 | 43 | # set loss function and optimizer 44 | criteria = nn.BCEWithLogitsLoss() 45 | optimizer = optim.RMSprop( 46 | model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) 47 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 48 | optimizer, 'min' if model.n_classes > 1 else 'max', patience=2) 49 | 50 | # train the model 51 | trainer = Trainer(model, criteria, optimizer, 52 | scheduler, args.gpus, args.seed) 53 | trainer(dataloaders, args.epochs, args.model_dir) 54 | torch.cuda.empty_cache() 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='Model Training') 59 | parser.add_argument('--gpus', default='4,5,6', 60 | type=str, help='CUDA_VISIBLE_DEVICES') 61 | parser.add_argument('--size', default='592', type=int, 62 | help='CUDA_VISIBLE_DEVICES') 63 | parser.add_argument('--crop_size', default='128', 64 | type=int, help='CUDA_VISIBLE_DEVICES') 65 | parser.add_argument('--image_dir', default='data/stare/stare-images/', 66 | type=str, help='Images folder path') 67 | parser.add_argument('--mask_dir', default='data/stare/labels-ah/', 68 | type=str, help='Masks folder path') 69 | parser.add_argument('--train_csv', default='data/stare/train.csv', 70 | type=str, help='list of training set') 71 | parser.add_argument('--val_csv', default='data/stare/val.csv', 72 | type=str, help='list of validation set') 73 | parser.add_argument('--lr', default='0.0001', 74 | type=float, help='learning rate') 75 | parser.add_argument('--epochs', default='2', 76 | type=int, help='Number of epochs') 77 | parser.add_argument('--batch_size', default='32', 78 | type=int, help='Batch Size') 79 | parser.add_argument('--model_dir', default='exp/', 80 | type=str, help='Images folder path') 81 | parser.add_argument('--seed', default='2020123', 82 | type=int, help='Random status') 83 | args = parser.parse_args() 84 | 85 | main(args) 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 amri369 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Iternet 2 | 3 | PyTorch implementation of IterNet, based on paper IterNet: Retinal Image Segmentation Utilizing Structural Redundancy 4 | in Vessel Networks [(Li et al., 2019)](https://arxiv.org/abs/1912.05763) and accompanying [code](https://github.com/conscienceli/IterNet). 5 | 6 | # Training parameters 7 | 8 | ```bash 9 | python main.py --param param_value 10 | ``` 11 | 12 | The following hyperparameters can also be provided. Smallest model from paper is 13 | shown for comparison. 14 | 15 | | Argument | Default | 16 | |----------------|---------| 17 | | `--gpus` | 4,5,6 | 18 | | `--size.` | 592 | 19 | | `--crop_size` | 128 | 20 | | `--image_dir` | '' | 21 | | `--mask_dir` | '' | 22 | | `--train_csv` | '' | 23 | | `--val_csv` | '' | 24 | | `--lr` | 0.0001 | 25 | | `--epochs` | 2 | 26 | | `--batch_size` | 32 | 27 | | `--model_dir` | '' | 28 | | `--seed` | 2020123 | 29 | -------------------------------------------------------------------------------- /augmentation/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/augmentation/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /augmentation/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | from torchvision import transforms 4 | import numbers 5 | import random 6 | import numpy as np 7 | 8 | class RandomFlip(torch.nn.Module): 9 | """Horizontally/vertically flip the given image randomly with a given probability. 10 | The image can be a PIL Image or a torch Tensor, in which case it is expected 11 | to have [..., H, W] shape, where ... means an arbitrary number of leading 12 | dimensions 13 | 14 | Args: 15 | p (float): probability of the image being flipped horizontally or vertically. Default value is 0.5 16 | """ 17 | 18 | def __init__(self, p=0.5): 19 | super().__init__() 20 | self.p = p 21 | 22 | def forward(self, img, mask): 23 | """ 24 | Args: 25 | img (PIL Image or Tensor): Image to be flipped. 26 | 27 | Returns: 28 | PIL Image or Tensor: Randomly flipped image. 29 | """ 30 | if torch.rand(1) < self.p: 31 | return F.hflip(img), F.hflip(mask) 32 | else: 33 | return F.vflip(img), F.vflip(mask) 34 | 35 | def __repr__(self): 36 | return self.__class__.__name__ + '(p={})'.format(self.p) 37 | 38 | class RandomRotation(object): 39 | """Rotate the image by angle. 40 | 41 | Args: 42 | degrees (sequence or float or int): Range of degrees to select from. 43 | If degrees is a number instead of sequence like (min, max), the range of degrees 44 | will be (-degrees, +degrees). 45 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 46 | An optional resampling filter. See `filters`_ for more information. 47 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 48 | expand (bool, optional): Optional expansion flag. 49 | If true, expands the output to make it large enough to hold the entire rotated image. 50 | If false or omitted, make the output image the same size as the input image. 51 | Note that the expand flag assumes rotation around the center and no translation. 52 | center (2-tuple, optional): Optional center of rotation. 53 | Origin is the upper left corner. 54 | Default is the center of the image. 55 | fill (n-tuple or int or float): Pixel fill value for area outside the rotated 56 | image. If int or float, the value is used for all bands respectively. 57 | Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. 58 | 59 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 60 | 61 | """ 62 | 63 | def __init__(self, degrees, p=0.5): 64 | if isinstance(degrees, numbers.Number): 65 | if degrees < 0: 66 | raise ValueError("If degrees is a single number, it must be positive.") 67 | self.degrees = (-degrees, degrees) 68 | else: 69 | if len(degrees) != 2: 70 | raise ValueError("If degrees is a sequence, it must be of len 2.") 71 | self.degrees = degrees 72 | 73 | self.p = p 74 | 75 | @staticmethod 76 | def get_params(degrees): 77 | """Get parameters for ``rotate`` for a random rotation. 78 | 79 | Returns: 80 | sequence: params to be passed to ``rotate`` for random rotation. 81 | """ 82 | angle = random.uniform(degrees[0], degrees[1]) 83 | 84 | return angle 85 | 86 | def __call__(self, img, mask): 87 | """ 88 | Args: 89 | img (PIL Image): Image to be rotated. 90 | 91 | Returns: 92 | PIL Image: Rotated image. 93 | """ 94 | if torch.rand(1) < self.p: 95 | angle = self.get_params(self.degrees) 96 | img = F.rotate(img, angle) 97 | mask = F.rotate(mask, angle) 98 | 99 | return img, mask 100 | 101 | def __repr__(self): 102 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 103 | format_string += '(p={})'.format(self.p) 104 | format_string += ')' 105 | return format_string 106 | 107 | class RandomAffine(object): 108 | """Random affine transformation of the image keeping center invariant 109 | 110 | Args: 111 | degrees (sequence or float or int): Range of degrees to select from. 112 | If degrees is a number instead of sequence like (min, max), the range of degrees 113 | will be (-degrees, +degrees). Set to 0 to deactivate rotations. 114 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 115 | and vertical translations. For example translate=(a, b), then horizontal shift 116 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 117 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 118 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 119 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 120 | shear (sequence or float or int, optional): Range of degrees to select from. 121 | If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) 122 | will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the 123 | range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, 124 | a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. 125 | Will not apply shear by default 126 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 127 | An optional resampling filter. See `filters`_ for more information. 128 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 129 | fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area 130 | outside the transform in the output image.(Pillow>=5.0.0) 131 | 132 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 133 | 134 | """ 135 | 136 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 137 | if isinstance(degrees, numbers.Number): 138 | if degrees < 0: 139 | raise ValueError("If degrees is a single number, it must be positive.") 140 | self.degrees = (-degrees, degrees) 141 | else: 142 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 143 | "degrees should be a list or tuple and it must be of length 2." 144 | self.degrees = degrees 145 | 146 | if translate is not None: 147 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 148 | "translate should be a list or tuple and it must be of length 2." 149 | for t in translate: 150 | if not (0.0 <= t <= 1.0): 151 | raise ValueError("translation values should be between 0 and 1") 152 | self.translate = translate 153 | 154 | if scale is not None: 155 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 156 | "scale should be a list or tuple and it must be of length 2." 157 | for s in scale: 158 | if s <= 0: 159 | raise ValueError("scale values should be positive") 160 | self.scale = scale 161 | 162 | if shear is not None: 163 | if isinstance(shear, numbers.Number): 164 | if shear < 0: 165 | raise ValueError("If shear is a single number, it must be positive.") 166 | self.shear = (-shear, shear) 167 | else: 168 | assert isinstance(shear, (tuple, list)) and \ 169 | (len(shear) == 2 or len(shear) == 4), \ 170 | "shear should be a list or tuple and it must be of length 2 or 4." 171 | # X-Axis shear with [min, max] 172 | if len(shear) == 2: 173 | self.shear = [shear[0], shear[1], 0., 0.] 174 | elif len(shear) == 4: 175 | self.shear = [s for s in shear] 176 | else: 177 | self.shear = shear 178 | 179 | self.resample = resample 180 | self.fillcolor = fillcolor 181 | 182 | @staticmethod 183 | def get_params(degrees, translate, scale_ranges, shears, img_size): 184 | """Get parameters for affine transformation 185 | 186 | Returns: 187 | sequence: params to be passed to the affine transformation 188 | """ 189 | angle = random.uniform(degrees[0], degrees[1]) 190 | if translate is not None: 191 | max_dx = translate[0] * img_size[0] 192 | max_dy = translate[1] * img_size[1] 193 | translations = (np.round(random.uniform(-max_dx, max_dx)), 194 | np.round(random.uniform(-max_dy, max_dy))) 195 | else: 196 | translations = (0, 0) 197 | 198 | if scale_ranges is not None: 199 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 200 | else: 201 | scale = 1.0 202 | 203 | if shears is not None: 204 | if len(shears) == 2: 205 | shear = [random.uniform(shears[0], shears[1]), 0.] 206 | elif len(shears) == 4: 207 | shear = [random.uniform(shears[0], shears[1]), 208 | random.uniform(shears[2], shears[3])] 209 | else: 210 | shear = 0.0 211 | 212 | return angle, translations, scale, shear 213 | 214 | def __call__(self, img, mask): 215 | """ 216 | img (PIL Image): Image to be transformed. 217 | 218 | Returns: 219 | PIL Image: Affine transformed image. 220 | """ 221 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 222 | img_ = F.affine(img, *ret) 223 | mask_ = F.affine(mask, *ret) 224 | return img_, mask_ 225 | 226 | def __repr__(self): 227 | s = '{name}(degrees={degrees}' 228 | if self.translate is not None: 229 | s += ', translate={translate}' 230 | if self.scale is not None: 231 | s += ', scale={scale}' 232 | if self.shear is not None: 233 | s += ', shear={shear}' 234 | if self.resample > 0: 235 | s += ', resample={resample}' 236 | if self.fillcolor != 0: 237 | s += ', fillcolor={fillcolor}' 238 | s += ')' 239 | d = dict(self.__dict__) 240 | d['resample'] = _pil_interpolation_to_str[d['resample']] 241 | return s.format(name=self.__class__.__name__, **d) 242 | 243 | def _get_image_size(img): 244 | if F._is_pil_image(img): 245 | return img.size 246 | elif isinstance(img, torch.Tensor) and img.dim() > 2: 247 | return img.shape[-2:][::-1] 248 | else: 249 | raise TypeError("Unexpected type {}".format(type(img))) 250 | 251 | class RandomCrop(object): 252 | """Crop the given PIL Image at a random location. 253 | 254 | Args: 255 | size (sequence or int): Desired output size of the crop. If size is an 256 | int instead of sequence like (h, w), a square crop (size, size) is 257 | made. 258 | padding (int or sequence, optional): Optional padding on each border 259 | of the image. Default is None, i.e no padding. If a sequence of length 260 | 4 is provided, it is used to pad left, top, right, bottom borders 261 | respectively. If a sequence of length 2 is provided, it is used to 262 | pad left/right, top/bottom borders, respectively. 263 | pad_if_needed (boolean): It will pad the image if smaller than the 264 | desired size to avoid raising an exception. Since cropping is done 265 | after padding, the padding seems to be done at a random offset. 266 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 267 | length 3, it is used to fill R, G, B channels respectively. 268 | This value is only used when the padding_mode is constant 269 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 270 | 271 | - constant: pads with a constant value, this value is specified with fill 272 | 273 | - edge: pads with the last value on the edge of the image 274 | 275 | - reflect: pads with reflection of image (without repeating the last value on the edge) 276 | 277 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 278 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 279 | 280 | - symmetric: pads with reflection of image (repeating the last value on the edge) 281 | 282 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 283 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 284 | 285 | """ 286 | 287 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 288 | if isinstance(size, numbers.Number): 289 | self.size = (int(size), int(size)) 290 | else: 291 | self.size = size 292 | self.padding = padding 293 | self.pad_if_needed = pad_if_needed 294 | self.fill = fill 295 | self.padding_mode = padding_mode 296 | 297 | @staticmethod 298 | def get_params(img, output_size): 299 | """Get parameters for ``crop`` for a random crop. 300 | 301 | Args: 302 | img (PIL Image): Image to be cropped. 303 | output_size (tuple): Expected output size of the crop. 304 | 305 | Returns: 306 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 307 | """ 308 | w, h = _get_image_size(img) 309 | th, tw = output_size 310 | if w == tw and h == th: 311 | return 0, 0, h, w 312 | 313 | i = random.randint(0, h - th) 314 | j = random.randint(0, w - tw) 315 | return i, j, th, tw 316 | 317 | def __call__(self, img, mask): 318 | """ 319 | Args: 320 | img (PIL Image): Image to be cropped. 321 | 322 | Returns: 323 | PIL Image: Cropped image. 324 | """ 325 | if self.padding is not None: 326 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 327 | 328 | # pad the width if needed 329 | if self.pad_if_needed and img.size[0] < self.size[1]: 330 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 331 | mask = F.pad(mask, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 332 | # pad the height if needed 333 | if self.pad_if_needed and img.size[1] < self.size[0]: 334 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 335 | mask = F.pad(mask, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 336 | i, j, h, w = self.get_params(img, self.size) 337 | 338 | img = F.crop(img, i, j, h, w) 339 | mask = F.crop(mask, i, j, h, w) 340 | 341 | return img, mask 342 | 343 | def __repr__(self): 344 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 345 | 346 | class TransformImgMask(object): 347 | def __init__(self, 348 | size=None, 349 | size_crop=128, 350 | rotate_limit=(-180, 180), 351 | intensity_range=(-5, 5), 352 | translate=(0.1, 0.2), 353 | zoom_range=(0.8, 1.2), 354 | to_tensor=False): 355 | self.size = size 356 | self.random_crop = RandomCrop(size_crop) 357 | if self.size: 358 | self.resize = transforms.Resize(size=size) 359 | self.random_rotate = RandomRotation(degrees=rotate_limit) 360 | self.random_flip = RandomFlip() 361 | self.random_affine = RandomAffine(degrees=0, shear=intensity_range, translate=translate, scale=zoom_range) 362 | self.to_tensor = to_tensor 363 | 364 | def __call__(self, img, mask): 365 | img, mask = self.random_crop(img, mask) 366 | if self.size: 367 | img, mask = self.resize(img), self.resize(mask) 368 | img, mask = self.random_rotate(img, mask) 369 | img, mask = self.random_flip(img, mask) 370 | img, mask = self.random_affine(img, mask) 371 | 372 | if self.to_tensor: 373 | img, mask = transforms.ToTensor()(img), transforms.ToTensor()(mask) 374 | mask = (mask>0).float()[0] 375 | mask = mask.unsqueeze(0) 376 | return img, mask 377 | 378 | class TransformImg(object): 379 | def __init__(self, 380 | brightness=(0.2, 1), 381 | contrast=(0.2, 1), 382 | saturation=(0.2, 1)): 383 | self.transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation) 384 | 385 | def __call__(self, img): 386 | img = self.transform(img) 387 | return img 388 | -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_retinal.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/dataset/__pycache__/dataset_retinal.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/dataset_retinal.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | 7 | class DatasetRetinal(Dataset): 8 | 9 | def __init__(self, csv_file, image_dir, mask_dir, 10 | col_filename='filename', transform_img=None, transform_img_mask=None, batch_size=32): 11 | """ 12 | Args: 13 | csv_file (Pandas dataframe): Path to the csv file with list of images in the dataset. 14 | image_dir (string): Directory with all the images. 15 | mask_dir (string): Directory with all the masks. 16 | col_filename (string): column name containing images names. 17 | transform_img (callable, optional): Optional transform to be applied on images only. 18 | transform_img_mask (callable, optional): Optional transform to be applied on images and masks simultaneously. 19 | """ 20 | self.filenames = pd.read_csv(csv_file)[col_filename] 21 | self.image_dir = image_dir 22 | self.mask_dir = mask_dir 23 | self.transform_img_mask = transform_img_mask 24 | self.transform_img = transform_img 25 | self.batch_size = batch_size 26 | self.data_len = len(self.filenames) 27 | 28 | def __len__(self): 29 | return max(len(self.filenames), self.batch_size) 30 | 31 | def __getitem__(self, idx): 32 | if torch.is_tensor(idx): 33 | idx = idx.tolist() 34 | idx = idx % self.data_len 35 | # get image data 36 | filename = self.filenames.iloc[idx] 37 | img_name = os.path.join(self.image_dir, filename) 38 | img = Image.open(img_name).convert('RGB') 39 | 40 | # get mask_data 41 | mask_name = os.path.join(self.mask_dir, filename) 42 | mask = Image.open(mask_name).convert('RGB') 43 | 44 | if self.transform_img is not None: 45 | img = self.transform_img(img) 46 | 47 | if self.transform_img_mask is not None: 48 | img, mask = self.transform_img_mask(img, mask) 49 | 50 | return img, mask -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from torch import optim 5 | from dataset.dataset_retinal import DatasetRetinal 6 | from augmentation.transforms import TransformImg, TransformImgMask 7 | from model.iternet.iternet_model import Iternet 8 | from trainer.trainer import Trainer 9 | 10 | import argparse 11 | 12 | 13 | def main(args): 14 | # set the transform 15 | transform_img_mask = TransformImgMask( 16 | size=(args.size, args.size), 17 | size_crop=(args.crop_size, args.crop_size), 18 | to_tensor=True 19 | ) 20 | 21 | # set datasets 22 | csv_dir = { 23 | 'train': args.train_csv, 24 | 'val': args.val_csv 25 | } 26 | datasets = { 27 | x: DatasetRetinal(csv_dir[x], 28 | args.image_dir, 29 | args.mask_dir, 30 | batch_size=args.batch_size, 31 | transform_img_mask=transform_img_mask, 32 | transform_img=TransformImg()) for x in ['train', 'val'] 33 | } 34 | 35 | # set dataloaders 36 | dataloaders = { 37 | x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val'] 38 | } 39 | 40 | # initialize the model 41 | model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3) 42 | 43 | # set loss function and optimizer 44 | criteria = nn.BCEWithLogitsLoss() 45 | optimizer = optim.RMSprop( 46 | model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) 47 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 48 | optimizer, 'min' if model.n_classes > 1 else 'max', patience=2) 49 | 50 | # train the model 51 | trainer = Trainer(model, criteria, optimizer, 52 | scheduler, args.gpus, args.seed) 53 | trainer(dataloaders, args.epochs, args.model_dir) 54 | torch.cuda.empty_cache() 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='Model Training') 59 | parser.add_argument('--gpus', default='4,5,6', 60 | type=str, help='CUDA_VISIBLE_DEVICES') 61 | parser.add_argument('--size', default='592', type=int, 62 | help='CUDA_VISIBLE_DEVICES') 63 | parser.add_argument('--crop_size', default='128', 64 | type=int, help='CUDA_VISIBLE_DEVICES') 65 | parser.add_argument('--image_dir', default='data/stare/stare-images/', 66 | type=str, help='Images folder path') 67 | parser.add_argument('--mask_dir', default='data/stare/labels-ah/', 68 | type=str, help='Masks folder path') 69 | parser.add_argument('--train_csv', default='data/stare/train.csv', 70 | type=str, help='list of training set') 71 | parser.add_argument('--val_csv', default='data/stare/val.csv', 72 | type=str, help='list of validation set') 73 | parser.add_argument('--lr', default='0.0001', 74 | type=float, help='learning rate') 75 | parser.add_argument('--epochs', default='2', 76 | type=int, help='Number of epochs') 77 | parser.add_argument('--batch_size', default='32', 78 | type=int, help='Batch Size') 79 | parser.add_argument('--model_dir', default='exp/', 80 | type=str, help='Images folder path') 81 | parser.add_argument('--seed', default='2020123', 82 | type=int, help='Random status') 83 | args = parser.parse_args() 84 | 85 | main(args) 86 | -------------------------------------------------------------------------------- /model/__pycache__/iternet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/__pycache__/iternet.cpython-36.pyc -------------------------------------------------------------------------------- /model/iternet/.ipynb_checkpoints/iternet_model-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | from torch.nn import ModuleList 5 | import torch 6 | 7 | from .unet_parts import * 8 | 9 | 10 | class UNet(nn.Module): 11 | def __init__(self, n_channels, n_classes, out_channels=32): 12 | super(UNet, self).__init__() 13 | self.n_channels = n_channels 14 | self.n_classes = n_classes 15 | bilinear = False 16 | 17 | self.inc = DoubleConv(n_channels, out_channels) 18 | self.down1 = Down(out_channels, out_channels * 2) 19 | self.down2 = Down(out_channels * 2, out_channels * 4) 20 | self.down3 = Down(out_channels * 4, out_channels * 8) 21 | factor = 2 if bilinear else 1 22 | self.down4 = Down(out_channels * 8, out_channels * 16 // factor) 23 | self.up1 = Up(out_channels * 16, out_channels * 8 // factor, bilinear) 24 | self.up2 = Up(out_channels * 8, out_channels * 4 // factor, bilinear) 25 | self.up3 = Up(out_channels * 4, out_channels * 2 // factor, bilinear) 26 | self.up4 = Up(out_channels * 2, out_channels, bilinear) 27 | self.outc = OutConv(out_channels, n_classes) 28 | 29 | def forward(self, x): 30 | x1 = self.inc(x) 31 | x2 = self.down1(x1) 32 | x3 = self.down2(x2) 33 | x4 = self.down3(x3) 34 | x5 = self.down4(x4) 35 | x = self.up1(x5, x4) 36 | x = self.up2(x, x3) 37 | x = self.up3(x, x2) 38 | x = self.up4(x, x1) 39 | logits = self.outc(x) 40 | return x1, x, logits 41 | 42 | class MiniUNet(nn.Module): 43 | def __init__(self, n_channels, n_classes, out_channels=32): 44 | super(MiniUNet, self).__init__() 45 | self.n_channels = n_channels 46 | self.n_classes = n_classes 47 | bilinear = False 48 | 49 | self.inc = DoubleConv(n_channels, out_channels) 50 | self.down1 = Down(out_channels, out_channels*2) 51 | self.down2 = Down(out_channels*2, out_channels*4) 52 | self.down3 = Down(out_channels*4, out_channels*8) 53 | self.up1 = Up(out_channels*8, out_channels*4, bilinear) 54 | self.up2 = Up(out_channels*4, out_channels*2, bilinear) 55 | self.up3 = Up(out_channels*2, out_channels, bilinear) 56 | self.outc = OutConv(out_channels, n_classes) 57 | 58 | def forward(self, x): 59 | x1 = self.inc(x) 60 | x2 = self.down1(x1) 61 | x3 = self.down2(x2) 62 | x4 = self.down3(x3) 63 | x = self.up1(x4, x3) 64 | x = self.up2(x, x2) 65 | x = self.up3(x, x1) 66 | logits = self.outc(x) 67 | return x1, x, logits 68 | 69 | class Iternet(nn.Module): 70 | def __init__(self, n_channels, n_classes, out_channels=32, iterations=3): 71 | super(Iternet, self).__init__() 72 | self.n_channels = n_channels 73 | self.n_classes = n_classes 74 | self.iterations = iterations 75 | 76 | # define the network UNet layer 77 | self.model_unet = UNet(n_channels=n_channels, n_classes=n_classes, out_channels=out_channels) 78 | 79 | # define the network MiniUNet layers 80 | self.model_miniunet = ModuleList(MiniUNet(n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations)) 81 | 82 | def forward(self, x): 83 | x1, x2, logits = self.model_unet(x) 84 | for i in range(self.iterations): 85 | x = torch.cat([x1, x2], dim=1) 86 | _, x2, logits = self.model_miniunet[i](x) 87 | 88 | return logits -------------------------------------------------------------------------------- /model/iternet/.ipynb_checkpoints/unet_parts-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) -------------------------------------------------------------------------------- /model/iternet/__pycache__/iternet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/iternet/__pycache__/iternet_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/iternet/__pycache__/unet_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/iternet/__pycache__/unet_parts.cpython-36.pyc -------------------------------------------------------------------------------- /model/iternet/iternet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | from torch.nn import ModuleList 5 | import torch 6 | 7 | from .unet_parts import * 8 | 9 | 10 | class UNet(nn.Module): 11 | def __init__(self, n_channels, n_classes, out_channels=32): 12 | super(UNet, self).__init__() 13 | self.n_channels = n_channels 14 | self.n_classes = n_classes 15 | bilinear = False 16 | 17 | self.inc = DoubleConv(n_channels, out_channels) 18 | self.down1 = Down(out_channels, out_channels * 2) 19 | self.down2 = Down(out_channels * 2, out_channels * 4) 20 | self.down3 = Down(out_channels * 4, out_channels * 8) 21 | factor = 2 if bilinear else 1 22 | self.down4 = Down(out_channels * 8, out_channels * 16 // factor) 23 | self.up1 = Up(out_channels * 16, out_channels * 8 // factor, bilinear) 24 | self.up2 = Up(out_channels * 8, out_channels * 4 // factor, bilinear) 25 | self.up3 = Up(out_channels * 4, out_channels * 2 // factor, bilinear) 26 | self.up4 = Up(out_channels * 2, out_channels, bilinear) 27 | self.outc = OutConv(out_channels, n_classes) 28 | 29 | def forward(self, x): 30 | x1 = self.inc(x) 31 | x2 = self.down1(x1) 32 | x3 = self.down2(x2) 33 | x4 = self.down3(x3) 34 | x5 = self.down4(x4) 35 | x = self.up1(x5, x4) 36 | x = self.up2(x, x3) 37 | x = self.up3(x, x2) 38 | x = self.up4(x, x1) 39 | logits = self.outc(x) 40 | return x1, x, logits 41 | 42 | 43 | class MiniUNet(nn.Module): 44 | def __init__(self, n_channels, n_classes, out_channels=32): 45 | super(MiniUNet, self).__init__() 46 | self.n_channels = n_channels 47 | self.n_classes = n_classes 48 | bilinear = False 49 | 50 | self.inc = DoubleConv(n_channels, out_channels) 51 | self.down1 = Down(out_channels, out_channels*2) 52 | self.down2 = Down(out_channels*2, out_channels*4) 53 | self.down3 = Down(out_channels*4, out_channels*8) 54 | self.up1 = Up(out_channels*8, out_channels*4, bilinear) 55 | self.up2 = Up(out_channels*4, out_channels*2, bilinear) 56 | self.up3 = Up(out_channels*2, out_channels, bilinear) 57 | self.outc = OutConv(out_channels, n_classes) 58 | 59 | def forward(self, x): 60 | x1 = self.inc(x) 61 | x2 = self.down1(x1) 62 | x3 = self.down2(x2) 63 | x4 = self.down3(x3) 64 | x = self.up1(x4, x3) 65 | x = self.up2(x, x2) 66 | x = self.up3(x, x1) 67 | logits = self.outc(x) 68 | return x1, x, logits 69 | 70 | 71 | class Iternet(nn.Module): 72 | def __init__(self, n_channels, n_classes, out_channels=32, iterations=3): 73 | super(Iternet, self).__init__() 74 | self.n_channels = n_channels 75 | self.n_classes = n_classes 76 | self.iterations = iterations 77 | 78 | # define the network UNet layer 79 | self.model_unet = UNet(n_channels=n_channels, 80 | n_classes=n_classes, out_channels=out_channels) 81 | 82 | # define the network MiniUNet layers 83 | self.model_miniunet = ModuleList(MiniUNet( 84 | n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations)) 85 | 86 | def forward(self, x): 87 | logits = [] 88 | x1, x2, logit = self.model_unet(x) 89 | logits.append(logit) 90 | for i in range(self.iterations): 91 | x = torch.cat([x1, x2], dim=1) 92 | _, x2, logit = self.model_miniunet[i](x) 93 | logits.append(logit) 94 | 95 | return logits 96 | -------------------------------------------------------------------------------- /model/iternet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.1 2 | numpy==2.1.0 3 | opencv-python==4.1.0.25 4 | pandas==2.2.2 5 | tqdm==4.42.1 6 | torch==2.4.0 7 | torchvision==0.19.0 8 | -------------------------------------------------------------------------------- /trainer/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/trainer/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import random 5 | import numpy as np 6 | 7 | 8 | class Trainer(object): 9 | 10 | def __init__(self, model, criteria, optimizer, scheduler, gpus, seed): 11 | self.model = model 12 | self.criteria = criteria 13 | self.optimizer = optimizer 14 | self.scheduler = scheduler 15 | self.gpus = gpus 16 | self.is_gpu_available = torch.cuda.is_available() 17 | Trainer.set_seed(seed) 18 | 19 | def set_devices(self): 20 | if self.is_gpu_available: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = self.gpus 22 | self.model = self.model.cuda() 23 | self.model = torch.nn.DataParallel(self.model) 24 | self.criteria = self.criteria.cuda() 25 | else: 26 | self.model = self.model.cpu() 27 | self.criteria = self.criteria.cpu() 28 | 29 | def set_seed(seed): 30 | torch.manual_seed(seed) 31 | 32 | if torch.cuda.is_available(): 33 | torch.cuda.manual_seed_all(seed) 34 | 35 | random.seed(seed) 36 | 37 | np.random.seed(seed) 38 | 39 | torch.backends.cudnn.deterministic = True 40 | 41 | def training_step(self, dataloader): 42 | # initialize the loss 43 | epoch_loss = 0.0 44 | 45 | # loop over training set 46 | self.model.train() 47 | for x, y in dataloader: 48 | if self.is_gpu_available: 49 | x, y = x.cuda(), y.cuda() 50 | 51 | with torch.set_grad_enabled(True): 52 | z = self.model(x) 53 | loss = self.criteria(z[0], y) 54 | _n = len(z) 55 | for b in range(1, _n): 56 | loss += self.criteria(z[b], y) 57 | 58 | # back propagation 59 | self.optimizer.zero_grad() 60 | loss.backward() 61 | nn.utils.clip_grad_value_(self.model.parameters(), 0.1) 62 | self.optimizer.step() 63 | 64 | epoch_loss += loss.item() * len(x) 65 | 66 | epoch_loss = epoch_loss / len(dataloader) 67 | return epoch_loss 68 | 69 | def validation_step(self, dataloader): 70 | # initialize the loss 71 | epoch_loss = 0.0 72 | 73 | # loop over validation set 74 | self.model.eval() 75 | for x, y in dataloader: 76 | if self.is_gpu_available: 77 | x, y = x.cuda(), y.cuda() 78 | with torch.set_grad_enabled(False): 79 | z = self.model(x) 80 | loss = self.criteria(z[0], y) 81 | _n = len(z) 82 | for b in range(1, _n): 83 | loss += self.criteria(z[b], y) 84 | 85 | epoch_loss += loss.item() * len(x) 86 | 87 | epoch_loss = epoch_loss / len(dataloader) 88 | return epoch_loss 89 | 90 | def save_checkpoint(self, epoch, model_dir): 91 | # create the state dictionary 92 | state = { 93 | 'epoch': epoch, 94 | 'state_dict': self.model.state_dict(), 95 | 'optimizer': self.optimizer.state_dict() 96 | } 97 | 98 | if not os.path.exists(model_dir): 99 | os.makedirs(model_dir) 100 | model_out_path = model_dir+"_epoch_{}.pth".format(epoch) 101 | torch.save(state, model_out_path) 102 | 103 | def __call__(self, dataloaders, epochs, model_dir): 104 | self.set_devices() 105 | for epoch in range(epochs): 106 | train_loss = self.training_step(dataloaders['train']) 107 | val_loss = self.validation_step(dataloaders['val']) 108 | self.save_checkpoint(epoch, model_dir) 109 | print('------', epoch+1, '/', epochs, train_loss, val_loss) 110 | 111 | if self.is_gpu_available: 112 | torch.cuda.empty_cache() 113 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | 5 | from dice_loss import dice_coeff 6 | 7 | 8 | def eval_net(net, loader, device): 9 | """Evaluation without the densecrf with the dice coefficient""" 10 | net.eval() 11 | mask_type = torch.float32 if net.n_classes == 1 else torch.long 12 | n_val = len(loader) # the number of batch 13 | tot = 0 14 | 15 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: 16 | for batch in loader: 17 | imgs, true_masks = batch['image'], batch['mask'] 18 | imgs = imgs.to(device=device, dtype=torch.float32) 19 | true_masks = true_masks.to(device=device, dtype=mask_type) 20 | 21 | with torch.no_grad(): 22 | mask_pred = net(imgs) 23 | 24 | if net.n_classes > 1: 25 | tot += F.cross_entropy(mask_pred, true_masks).item() 26 | else: 27 | pred = torch.sigmoid(mask_pred) 28 | pred = (pred > 0.5).float() 29 | tot += dice_coeff(pred, true_masks).item() 30 | pbar.update() 31 | 32 | net.train() 33 | return tot / n_val --------------------------------------------------------------------------------