├── .gitignore
├── .idea
├── .gitignore
├── Pytorch-Iternet.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .ipynb_checkpoints
├── LICENSE-checkpoint
└── main-checkpoint.py
├── LICENSE
├── README.md
├── augmentation
├── __pycache__
│ └── transforms.cpython-36.pyc
└── transforms.py
├── dataset
├── __pycache__
│ └── dataset_retinal.cpython-36.pyc
└── dataset_retinal.py
├── main.py
├── model
├── __pycache__
│ └── iternet.cpython-36.pyc
└── iternet
│ ├── .ipynb_checkpoints
│ ├── iternet_model-checkpoint.py
│ └── unet_parts-checkpoint.py
│ ├── __pycache__
│ ├── iternet_model.cpython-36.pyc
│ └── unet_parts.cpython-36.pyc
│ ├── iternet_model.py
│ └── unet_parts.py
├── requirements.txt
├── trainer
├── __pycache__
│ └── trainer.cpython-36.pyc
└── trainer.py
└── utils
└── eval.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | .ipynb_checkpoints/*
3 | */.ipynb_checkpoints/*
4 | exp
5 | .idea
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/Pytorch-Iternet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
24 |
25 |
26 |
27 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/LICENSE-checkpoint:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 amri369
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/main-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch import nn
4 | from torch import optim
5 | from dataset.dataset_retinal import DatasetRetinal
6 | from augmentation.transforms import TransformImg, TransformImgMask
7 | from model.iternet.iternet_model import Iternet
8 | from trainer.trainer import Trainer
9 |
10 | import argparse
11 |
12 |
13 | def main(args):
14 | # set the transform
15 | transform_img_mask = TransformImgMask(
16 | size=(args.size, args.size),
17 | size_crop=(args.crop_size, args.crop_size),
18 | to_tensor=True
19 | )
20 |
21 | # set datasets
22 | csv_dir = {
23 | 'train': args.train_csv,
24 | 'val': args.val_csv
25 | }
26 | datasets = {
27 | x: DatasetRetinal(csv_dir[x],
28 | args.image_dir,
29 | args.mask_dir,
30 | batch_size=args.batch_size,
31 | transform_img_mask=transform_img_mask,
32 | transform_img=TransformImg()) for x in ['train', 'val']
33 | }
34 |
35 | # set dataloaders
36 | dataloaders = {
37 | x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val']
38 | }
39 |
40 | # initialize the model
41 | model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3)
42 |
43 | # set loss function and optimizer
44 | criteria = nn.BCEWithLogitsLoss()
45 | optimizer = optim.RMSprop(
46 | model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
47 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
48 | optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
49 |
50 | # train the model
51 | trainer = Trainer(model, criteria, optimizer,
52 | scheduler, args.gpus, args.seed)
53 | trainer(dataloaders, args.epochs, args.model_dir)
54 | torch.cuda.empty_cache()
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser(description='Model Training')
59 | parser.add_argument('--gpus', default='4,5,6',
60 | type=str, help='CUDA_VISIBLE_DEVICES')
61 | parser.add_argument('--size', default='592', type=int,
62 | help='CUDA_VISIBLE_DEVICES')
63 | parser.add_argument('--crop_size', default='128',
64 | type=int, help='CUDA_VISIBLE_DEVICES')
65 | parser.add_argument('--image_dir', default='data/stare/stare-images/',
66 | type=str, help='Images folder path')
67 | parser.add_argument('--mask_dir', default='data/stare/labels-ah/',
68 | type=str, help='Masks folder path')
69 | parser.add_argument('--train_csv', default='data/stare/train.csv',
70 | type=str, help='list of training set')
71 | parser.add_argument('--val_csv', default='data/stare/val.csv',
72 | type=str, help='list of validation set')
73 | parser.add_argument('--lr', default='0.0001',
74 | type=float, help='learning rate')
75 | parser.add_argument('--epochs', default='2',
76 | type=int, help='Number of epochs')
77 | parser.add_argument('--batch_size', default='32',
78 | type=int, help='Batch Size')
79 | parser.add_argument('--model_dir', default='exp/',
80 | type=str, help='Images folder path')
81 | parser.add_argument('--seed', default='2020123',
82 | type=int, help='Random status')
83 | args = parser.parse_args()
84 |
85 | main(args)
86 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 amri369
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-Iternet
2 |
3 | PyTorch implementation of IterNet, based on paper IterNet: Retinal Image Segmentation Utilizing Structural Redundancy
4 | in Vessel Networks [(Li et al., 2019)](https://arxiv.org/abs/1912.05763) and accompanying [code](https://github.com/conscienceli/IterNet).
5 |
6 | # Training parameters
7 |
8 | ```bash
9 | python main.py --param param_value
10 | ```
11 |
12 | The following hyperparameters can also be provided. Smallest model from paper is
13 | shown for comparison.
14 |
15 | | Argument | Default |
16 | |----------------|---------|
17 | | `--gpus` | 4,5,6 |
18 | | `--size.` | 592 |
19 | | `--crop_size` | 128 |
20 | | `--image_dir` | '' |
21 | | `--mask_dir` | '' |
22 | | `--train_csv` | '' |
23 | | `--val_csv` | '' |
24 | | `--lr` | 0.0001 |
25 | | `--epochs` | 2 |
26 | | `--batch_size` | 32 |
27 | | `--model_dir` | '' |
28 | | `--seed` | 2020123 |
29 |
--------------------------------------------------------------------------------
/augmentation/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/augmentation/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/augmentation/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | from torchvision import transforms
4 | import numbers
5 | import random
6 | import numpy as np
7 |
8 | class RandomFlip(torch.nn.Module):
9 | """Horizontally/vertically flip the given image randomly with a given probability.
10 | The image can be a PIL Image or a torch Tensor, in which case it is expected
11 | to have [..., H, W] shape, where ... means an arbitrary number of leading
12 | dimensions
13 |
14 | Args:
15 | p (float): probability of the image being flipped horizontally or vertically. Default value is 0.5
16 | """
17 |
18 | def __init__(self, p=0.5):
19 | super().__init__()
20 | self.p = p
21 |
22 | def forward(self, img, mask):
23 | """
24 | Args:
25 | img (PIL Image or Tensor): Image to be flipped.
26 |
27 | Returns:
28 | PIL Image or Tensor: Randomly flipped image.
29 | """
30 | if torch.rand(1) < self.p:
31 | return F.hflip(img), F.hflip(mask)
32 | else:
33 | return F.vflip(img), F.vflip(mask)
34 |
35 | def __repr__(self):
36 | return self.__class__.__name__ + '(p={})'.format(self.p)
37 |
38 | class RandomRotation(object):
39 | """Rotate the image by angle.
40 |
41 | Args:
42 | degrees (sequence or float or int): Range of degrees to select from.
43 | If degrees is a number instead of sequence like (min, max), the range of degrees
44 | will be (-degrees, +degrees).
45 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
46 | An optional resampling filter. See `filters`_ for more information.
47 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
48 | expand (bool, optional): Optional expansion flag.
49 | If true, expands the output to make it large enough to hold the entire rotated image.
50 | If false or omitted, make the output image the same size as the input image.
51 | Note that the expand flag assumes rotation around the center and no translation.
52 | center (2-tuple, optional): Optional center of rotation.
53 | Origin is the upper left corner.
54 | Default is the center of the image.
55 | fill (n-tuple or int or float): Pixel fill value for area outside the rotated
56 | image. If int or float, the value is used for all bands respectively.
57 | Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
58 |
59 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
60 |
61 | """
62 |
63 | def __init__(self, degrees, p=0.5):
64 | if isinstance(degrees, numbers.Number):
65 | if degrees < 0:
66 | raise ValueError("If degrees is a single number, it must be positive.")
67 | self.degrees = (-degrees, degrees)
68 | else:
69 | if len(degrees) != 2:
70 | raise ValueError("If degrees is a sequence, it must be of len 2.")
71 | self.degrees = degrees
72 |
73 | self.p = p
74 |
75 | @staticmethod
76 | def get_params(degrees):
77 | """Get parameters for ``rotate`` for a random rotation.
78 |
79 | Returns:
80 | sequence: params to be passed to ``rotate`` for random rotation.
81 | """
82 | angle = random.uniform(degrees[0], degrees[1])
83 |
84 | return angle
85 |
86 | def __call__(self, img, mask):
87 | """
88 | Args:
89 | img (PIL Image): Image to be rotated.
90 |
91 | Returns:
92 | PIL Image: Rotated image.
93 | """
94 | if torch.rand(1) < self.p:
95 | angle = self.get_params(self.degrees)
96 | img = F.rotate(img, angle)
97 | mask = F.rotate(mask, angle)
98 |
99 | return img, mask
100 |
101 | def __repr__(self):
102 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
103 | format_string += '(p={})'.format(self.p)
104 | format_string += ')'
105 | return format_string
106 |
107 | class RandomAffine(object):
108 | """Random affine transformation of the image keeping center invariant
109 |
110 | Args:
111 | degrees (sequence or float or int): Range of degrees to select from.
112 | If degrees is a number instead of sequence like (min, max), the range of degrees
113 | will be (-degrees, +degrees). Set to 0 to deactivate rotations.
114 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal
115 | and vertical translations. For example translate=(a, b), then horizontal shift
116 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
117 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
118 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
119 | randomly sampled from the range a <= scale <= b. Will keep original scale by default.
120 | shear (sequence or float or int, optional): Range of degrees to select from.
121 | If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
122 | will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
123 | range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
124 | a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
125 | Will not apply shear by default
126 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
127 | An optional resampling filter. See `filters`_ for more information.
128 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
129 | fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
130 | outside the transform in the output image.(Pillow>=5.0.0)
131 |
132 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
133 |
134 | """
135 |
136 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
137 | if isinstance(degrees, numbers.Number):
138 | if degrees < 0:
139 | raise ValueError("If degrees is a single number, it must be positive.")
140 | self.degrees = (-degrees, degrees)
141 | else:
142 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
143 | "degrees should be a list or tuple and it must be of length 2."
144 | self.degrees = degrees
145 |
146 | if translate is not None:
147 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
148 | "translate should be a list or tuple and it must be of length 2."
149 | for t in translate:
150 | if not (0.0 <= t <= 1.0):
151 | raise ValueError("translation values should be between 0 and 1")
152 | self.translate = translate
153 |
154 | if scale is not None:
155 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
156 | "scale should be a list or tuple and it must be of length 2."
157 | for s in scale:
158 | if s <= 0:
159 | raise ValueError("scale values should be positive")
160 | self.scale = scale
161 |
162 | if shear is not None:
163 | if isinstance(shear, numbers.Number):
164 | if shear < 0:
165 | raise ValueError("If shear is a single number, it must be positive.")
166 | self.shear = (-shear, shear)
167 | else:
168 | assert isinstance(shear, (tuple, list)) and \
169 | (len(shear) == 2 or len(shear) == 4), \
170 | "shear should be a list or tuple and it must be of length 2 or 4."
171 | # X-Axis shear with [min, max]
172 | if len(shear) == 2:
173 | self.shear = [shear[0], shear[1], 0., 0.]
174 | elif len(shear) == 4:
175 | self.shear = [s for s in shear]
176 | else:
177 | self.shear = shear
178 |
179 | self.resample = resample
180 | self.fillcolor = fillcolor
181 |
182 | @staticmethod
183 | def get_params(degrees, translate, scale_ranges, shears, img_size):
184 | """Get parameters for affine transformation
185 |
186 | Returns:
187 | sequence: params to be passed to the affine transformation
188 | """
189 | angle = random.uniform(degrees[0], degrees[1])
190 | if translate is not None:
191 | max_dx = translate[0] * img_size[0]
192 | max_dy = translate[1] * img_size[1]
193 | translations = (np.round(random.uniform(-max_dx, max_dx)),
194 | np.round(random.uniform(-max_dy, max_dy)))
195 | else:
196 | translations = (0, 0)
197 |
198 | if scale_ranges is not None:
199 | scale = random.uniform(scale_ranges[0], scale_ranges[1])
200 | else:
201 | scale = 1.0
202 |
203 | if shears is not None:
204 | if len(shears) == 2:
205 | shear = [random.uniform(shears[0], shears[1]), 0.]
206 | elif len(shears) == 4:
207 | shear = [random.uniform(shears[0], shears[1]),
208 | random.uniform(shears[2], shears[3])]
209 | else:
210 | shear = 0.0
211 |
212 | return angle, translations, scale, shear
213 |
214 | def __call__(self, img, mask):
215 | """
216 | img (PIL Image): Image to be transformed.
217 |
218 | Returns:
219 | PIL Image: Affine transformed image.
220 | """
221 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
222 | img_ = F.affine(img, *ret)
223 | mask_ = F.affine(mask, *ret)
224 | return img_, mask_
225 |
226 | def __repr__(self):
227 | s = '{name}(degrees={degrees}'
228 | if self.translate is not None:
229 | s += ', translate={translate}'
230 | if self.scale is not None:
231 | s += ', scale={scale}'
232 | if self.shear is not None:
233 | s += ', shear={shear}'
234 | if self.resample > 0:
235 | s += ', resample={resample}'
236 | if self.fillcolor != 0:
237 | s += ', fillcolor={fillcolor}'
238 | s += ')'
239 | d = dict(self.__dict__)
240 | d['resample'] = _pil_interpolation_to_str[d['resample']]
241 | return s.format(name=self.__class__.__name__, **d)
242 |
243 | def _get_image_size(img):
244 | if F._is_pil_image(img):
245 | return img.size
246 | elif isinstance(img, torch.Tensor) and img.dim() > 2:
247 | return img.shape[-2:][::-1]
248 | else:
249 | raise TypeError("Unexpected type {}".format(type(img)))
250 |
251 | class RandomCrop(object):
252 | """Crop the given PIL Image at a random location.
253 |
254 | Args:
255 | size (sequence or int): Desired output size of the crop. If size is an
256 | int instead of sequence like (h, w), a square crop (size, size) is
257 | made.
258 | padding (int or sequence, optional): Optional padding on each border
259 | of the image. Default is None, i.e no padding. If a sequence of length
260 | 4 is provided, it is used to pad left, top, right, bottom borders
261 | respectively. If a sequence of length 2 is provided, it is used to
262 | pad left/right, top/bottom borders, respectively.
263 | pad_if_needed (boolean): It will pad the image if smaller than the
264 | desired size to avoid raising an exception. Since cropping is done
265 | after padding, the padding seems to be done at a random offset.
266 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of
267 | length 3, it is used to fill R, G, B channels respectively.
268 | This value is only used when the padding_mode is constant
269 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
270 |
271 | - constant: pads with a constant value, this value is specified with fill
272 |
273 | - edge: pads with the last value on the edge of the image
274 |
275 | - reflect: pads with reflection of image (without repeating the last value on the edge)
276 |
277 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
278 | will result in [3, 2, 1, 2, 3, 4, 3, 2]
279 |
280 | - symmetric: pads with reflection of image (repeating the last value on the edge)
281 |
282 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
283 | will result in [2, 1, 1, 2, 3, 4, 4, 3]
284 |
285 | """
286 |
287 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
288 | if isinstance(size, numbers.Number):
289 | self.size = (int(size), int(size))
290 | else:
291 | self.size = size
292 | self.padding = padding
293 | self.pad_if_needed = pad_if_needed
294 | self.fill = fill
295 | self.padding_mode = padding_mode
296 |
297 | @staticmethod
298 | def get_params(img, output_size):
299 | """Get parameters for ``crop`` for a random crop.
300 |
301 | Args:
302 | img (PIL Image): Image to be cropped.
303 | output_size (tuple): Expected output size of the crop.
304 |
305 | Returns:
306 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
307 | """
308 | w, h = _get_image_size(img)
309 | th, tw = output_size
310 | if w == tw and h == th:
311 | return 0, 0, h, w
312 |
313 | i = random.randint(0, h - th)
314 | j = random.randint(0, w - tw)
315 | return i, j, th, tw
316 |
317 | def __call__(self, img, mask):
318 | """
319 | Args:
320 | img (PIL Image): Image to be cropped.
321 |
322 | Returns:
323 | PIL Image: Cropped image.
324 | """
325 | if self.padding is not None:
326 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
327 |
328 | # pad the width if needed
329 | if self.pad_if_needed and img.size[0] < self.size[1]:
330 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
331 | mask = F.pad(mask, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
332 | # pad the height if needed
333 | if self.pad_if_needed and img.size[1] < self.size[0]:
334 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
335 | mask = F.pad(mask, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
336 | i, j, h, w = self.get_params(img, self.size)
337 |
338 | img = F.crop(img, i, j, h, w)
339 | mask = F.crop(mask, i, j, h, w)
340 |
341 | return img, mask
342 |
343 | def __repr__(self):
344 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
345 |
346 | class TransformImgMask(object):
347 | def __init__(self,
348 | size=None,
349 | size_crop=128,
350 | rotate_limit=(-180, 180),
351 | intensity_range=(-5, 5),
352 | translate=(0.1, 0.2),
353 | zoom_range=(0.8, 1.2),
354 | to_tensor=False):
355 | self.size = size
356 | self.random_crop = RandomCrop(size_crop)
357 | if self.size:
358 | self.resize = transforms.Resize(size=size)
359 | self.random_rotate = RandomRotation(degrees=rotate_limit)
360 | self.random_flip = RandomFlip()
361 | self.random_affine = RandomAffine(degrees=0, shear=intensity_range, translate=translate, scale=zoom_range)
362 | self.to_tensor = to_tensor
363 |
364 | def __call__(self, img, mask):
365 | img, mask = self.random_crop(img, mask)
366 | if self.size:
367 | img, mask = self.resize(img), self.resize(mask)
368 | img, mask = self.random_rotate(img, mask)
369 | img, mask = self.random_flip(img, mask)
370 | img, mask = self.random_affine(img, mask)
371 |
372 | if self.to_tensor:
373 | img, mask = transforms.ToTensor()(img), transforms.ToTensor()(mask)
374 | mask = (mask>0).float()[0]
375 | mask = mask.unsqueeze(0)
376 | return img, mask
377 |
378 | class TransformImg(object):
379 | def __init__(self,
380 | brightness=(0.2, 1),
381 | contrast=(0.2, 1),
382 | saturation=(0.2, 1)):
383 | self.transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation)
384 |
385 | def __call__(self, img):
386 | img = self.transform(img)
387 | return img
388 |
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_retinal.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/dataset/__pycache__/dataset_retinal.cpython-36.pyc
--------------------------------------------------------------------------------
/dataset/dataset_retinal.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 |
7 | class DatasetRetinal(Dataset):
8 |
9 | def __init__(self, csv_file, image_dir, mask_dir,
10 | col_filename='filename', transform_img=None, transform_img_mask=None, batch_size=32):
11 | """
12 | Args:
13 | csv_file (Pandas dataframe): Path to the csv file with list of images in the dataset.
14 | image_dir (string): Directory with all the images.
15 | mask_dir (string): Directory with all the masks.
16 | col_filename (string): column name containing images names.
17 | transform_img (callable, optional): Optional transform to be applied on images only.
18 | transform_img_mask (callable, optional): Optional transform to be applied on images and masks simultaneously.
19 | """
20 | self.filenames = pd.read_csv(csv_file)[col_filename]
21 | self.image_dir = image_dir
22 | self.mask_dir = mask_dir
23 | self.transform_img_mask = transform_img_mask
24 | self.transform_img = transform_img
25 | self.batch_size = batch_size
26 | self.data_len = len(self.filenames)
27 |
28 | def __len__(self):
29 | return max(len(self.filenames), self.batch_size)
30 |
31 | def __getitem__(self, idx):
32 | if torch.is_tensor(idx):
33 | idx = idx.tolist()
34 | idx = idx % self.data_len
35 | # get image data
36 | filename = self.filenames.iloc[idx]
37 | img_name = os.path.join(self.image_dir, filename)
38 | img = Image.open(img_name).convert('RGB')
39 |
40 | # get mask_data
41 | mask_name = os.path.join(self.mask_dir, filename)
42 | mask = Image.open(mask_name).convert('RGB')
43 |
44 | if self.transform_img is not None:
45 | img = self.transform_img(img)
46 |
47 | if self.transform_img_mask is not None:
48 | img, mask = self.transform_img_mask(img, mask)
49 |
50 | return img, mask
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch import nn
4 | from torch import optim
5 | from dataset.dataset_retinal import DatasetRetinal
6 | from augmentation.transforms import TransformImg, TransformImgMask
7 | from model.iternet.iternet_model import Iternet
8 | from trainer.trainer import Trainer
9 |
10 | import argparse
11 |
12 |
13 | def main(args):
14 | # set the transform
15 | transform_img_mask = TransformImgMask(
16 | size=(args.size, args.size),
17 | size_crop=(args.crop_size, args.crop_size),
18 | to_tensor=True
19 | )
20 |
21 | # set datasets
22 | csv_dir = {
23 | 'train': args.train_csv,
24 | 'val': args.val_csv
25 | }
26 | datasets = {
27 | x: DatasetRetinal(csv_dir[x],
28 | args.image_dir,
29 | args.mask_dir,
30 | batch_size=args.batch_size,
31 | transform_img_mask=transform_img_mask,
32 | transform_img=TransformImg()) for x in ['train', 'val']
33 | }
34 |
35 | # set dataloaders
36 | dataloaders = {
37 | x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val']
38 | }
39 |
40 | # initialize the model
41 | model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3)
42 |
43 | # set loss function and optimizer
44 | criteria = nn.BCEWithLogitsLoss()
45 | optimizer = optim.RMSprop(
46 | model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
47 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
48 | optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
49 |
50 | # train the model
51 | trainer = Trainer(model, criteria, optimizer,
52 | scheduler, args.gpus, args.seed)
53 | trainer(dataloaders, args.epochs, args.model_dir)
54 | torch.cuda.empty_cache()
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser(description='Model Training')
59 | parser.add_argument('--gpus', default='4,5,6',
60 | type=str, help='CUDA_VISIBLE_DEVICES')
61 | parser.add_argument('--size', default='592', type=int,
62 | help='CUDA_VISIBLE_DEVICES')
63 | parser.add_argument('--crop_size', default='128',
64 | type=int, help='CUDA_VISIBLE_DEVICES')
65 | parser.add_argument('--image_dir', default='data/stare/stare-images/',
66 | type=str, help='Images folder path')
67 | parser.add_argument('--mask_dir', default='data/stare/labels-ah/',
68 | type=str, help='Masks folder path')
69 | parser.add_argument('--train_csv', default='data/stare/train.csv',
70 | type=str, help='list of training set')
71 | parser.add_argument('--val_csv', default='data/stare/val.csv',
72 | type=str, help='list of validation set')
73 | parser.add_argument('--lr', default='0.0001',
74 | type=float, help='learning rate')
75 | parser.add_argument('--epochs', default='2',
76 | type=int, help='Number of epochs')
77 | parser.add_argument('--batch_size', default='32',
78 | type=int, help='Batch Size')
79 | parser.add_argument('--model_dir', default='exp/',
80 | type=str, help='Images folder path')
81 | parser.add_argument('--seed', default='2020123',
82 | type=int, help='Random status')
83 | args = parser.parse_args()
84 |
85 | main(args)
86 |
--------------------------------------------------------------------------------
/model/__pycache__/iternet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/__pycache__/iternet.cpython-36.pyc
--------------------------------------------------------------------------------
/model/iternet/.ipynb_checkpoints/iternet_model-checkpoint.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | import torch.nn.functional as F
4 | from torch.nn import ModuleList
5 | import torch
6 |
7 | from .unet_parts import *
8 |
9 |
10 | class UNet(nn.Module):
11 | def __init__(self, n_channels, n_classes, out_channels=32):
12 | super(UNet, self).__init__()
13 | self.n_channels = n_channels
14 | self.n_classes = n_classes
15 | bilinear = False
16 |
17 | self.inc = DoubleConv(n_channels, out_channels)
18 | self.down1 = Down(out_channels, out_channels * 2)
19 | self.down2 = Down(out_channels * 2, out_channels * 4)
20 | self.down3 = Down(out_channels * 4, out_channels * 8)
21 | factor = 2 if bilinear else 1
22 | self.down4 = Down(out_channels * 8, out_channels * 16 // factor)
23 | self.up1 = Up(out_channels * 16, out_channels * 8 // factor, bilinear)
24 | self.up2 = Up(out_channels * 8, out_channels * 4 // factor, bilinear)
25 | self.up3 = Up(out_channels * 4, out_channels * 2 // factor, bilinear)
26 | self.up4 = Up(out_channels * 2, out_channels, bilinear)
27 | self.outc = OutConv(out_channels, n_classes)
28 |
29 | def forward(self, x):
30 | x1 = self.inc(x)
31 | x2 = self.down1(x1)
32 | x3 = self.down2(x2)
33 | x4 = self.down3(x3)
34 | x5 = self.down4(x4)
35 | x = self.up1(x5, x4)
36 | x = self.up2(x, x3)
37 | x = self.up3(x, x2)
38 | x = self.up4(x, x1)
39 | logits = self.outc(x)
40 | return x1, x, logits
41 |
42 | class MiniUNet(nn.Module):
43 | def __init__(self, n_channels, n_classes, out_channels=32):
44 | super(MiniUNet, self).__init__()
45 | self.n_channels = n_channels
46 | self.n_classes = n_classes
47 | bilinear = False
48 |
49 | self.inc = DoubleConv(n_channels, out_channels)
50 | self.down1 = Down(out_channels, out_channels*2)
51 | self.down2 = Down(out_channels*2, out_channels*4)
52 | self.down3 = Down(out_channels*4, out_channels*8)
53 | self.up1 = Up(out_channels*8, out_channels*4, bilinear)
54 | self.up2 = Up(out_channels*4, out_channels*2, bilinear)
55 | self.up3 = Up(out_channels*2, out_channels, bilinear)
56 | self.outc = OutConv(out_channels, n_classes)
57 |
58 | def forward(self, x):
59 | x1 = self.inc(x)
60 | x2 = self.down1(x1)
61 | x3 = self.down2(x2)
62 | x4 = self.down3(x3)
63 | x = self.up1(x4, x3)
64 | x = self.up2(x, x2)
65 | x = self.up3(x, x1)
66 | logits = self.outc(x)
67 | return x1, x, logits
68 |
69 | class Iternet(nn.Module):
70 | def __init__(self, n_channels, n_classes, out_channels=32, iterations=3):
71 | super(Iternet, self).__init__()
72 | self.n_channels = n_channels
73 | self.n_classes = n_classes
74 | self.iterations = iterations
75 |
76 | # define the network UNet layer
77 | self.model_unet = UNet(n_channels=n_channels, n_classes=n_classes, out_channels=out_channels)
78 |
79 | # define the network MiniUNet layers
80 | self.model_miniunet = ModuleList(MiniUNet(n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations))
81 |
82 | def forward(self, x):
83 | x1, x2, logits = self.model_unet(x)
84 | for i in range(self.iterations):
85 | x = torch.cat([x1, x2], dim=1)
86 | _, x2, logits = self.model_miniunet[i](x)
87 |
88 | return logits
--------------------------------------------------------------------------------
/model/iternet/.ipynb_checkpoints/unet_parts-checkpoint.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 |
57 | def forward(self, x1, x2):
58 | x1 = self.up(x1)
59 | # input is CHW
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 | # if you have padding issues, see
66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
68 | x = torch.cat([x2, x1], dim=1)
69 | return self.conv(x)
70 |
71 |
72 | class OutConv(nn.Module):
73 | def __init__(self, in_channels, out_channels):
74 | super(OutConv, self).__init__()
75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
--------------------------------------------------------------------------------
/model/iternet/__pycache__/iternet_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/iternet/__pycache__/iternet_model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/iternet/__pycache__/unet_parts.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/model/iternet/__pycache__/unet_parts.cpython-36.pyc
--------------------------------------------------------------------------------
/model/iternet/iternet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | import torch.nn.functional as F
4 | from torch.nn import ModuleList
5 | import torch
6 |
7 | from .unet_parts import *
8 |
9 |
10 | class UNet(nn.Module):
11 | def __init__(self, n_channels, n_classes, out_channels=32):
12 | super(UNet, self).__init__()
13 | self.n_channels = n_channels
14 | self.n_classes = n_classes
15 | bilinear = False
16 |
17 | self.inc = DoubleConv(n_channels, out_channels)
18 | self.down1 = Down(out_channels, out_channels * 2)
19 | self.down2 = Down(out_channels * 2, out_channels * 4)
20 | self.down3 = Down(out_channels * 4, out_channels * 8)
21 | factor = 2 if bilinear else 1
22 | self.down4 = Down(out_channels * 8, out_channels * 16 // factor)
23 | self.up1 = Up(out_channels * 16, out_channels * 8 // factor, bilinear)
24 | self.up2 = Up(out_channels * 8, out_channels * 4 // factor, bilinear)
25 | self.up3 = Up(out_channels * 4, out_channels * 2 // factor, bilinear)
26 | self.up4 = Up(out_channels * 2, out_channels, bilinear)
27 | self.outc = OutConv(out_channels, n_classes)
28 |
29 | def forward(self, x):
30 | x1 = self.inc(x)
31 | x2 = self.down1(x1)
32 | x3 = self.down2(x2)
33 | x4 = self.down3(x3)
34 | x5 = self.down4(x4)
35 | x = self.up1(x5, x4)
36 | x = self.up2(x, x3)
37 | x = self.up3(x, x2)
38 | x = self.up4(x, x1)
39 | logits = self.outc(x)
40 | return x1, x, logits
41 |
42 |
43 | class MiniUNet(nn.Module):
44 | def __init__(self, n_channels, n_classes, out_channels=32):
45 | super(MiniUNet, self).__init__()
46 | self.n_channels = n_channels
47 | self.n_classes = n_classes
48 | bilinear = False
49 |
50 | self.inc = DoubleConv(n_channels, out_channels)
51 | self.down1 = Down(out_channels, out_channels*2)
52 | self.down2 = Down(out_channels*2, out_channels*4)
53 | self.down3 = Down(out_channels*4, out_channels*8)
54 | self.up1 = Up(out_channels*8, out_channels*4, bilinear)
55 | self.up2 = Up(out_channels*4, out_channels*2, bilinear)
56 | self.up3 = Up(out_channels*2, out_channels, bilinear)
57 | self.outc = OutConv(out_channels, n_classes)
58 |
59 | def forward(self, x):
60 | x1 = self.inc(x)
61 | x2 = self.down1(x1)
62 | x3 = self.down2(x2)
63 | x4 = self.down3(x3)
64 | x = self.up1(x4, x3)
65 | x = self.up2(x, x2)
66 | x = self.up3(x, x1)
67 | logits = self.outc(x)
68 | return x1, x, logits
69 |
70 |
71 | class Iternet(nn.Module):
72 | def __init__(self, n_channels, n_classes, out_channels=32, iterations=3):
73 | super(Iternet, self).__init__()
74 | self.n_channels = n_channels
75 | self.n_classes = n_classes
76 | self.iterations = iterations
77 |
78 | # define the network UNet layer
79 | self.model_unet = UNet(n_channels=n_channels,
80 | n_classes=n_classes, out_channels=out_channels)
81 |
82 | # define the network MiniUNet layers
83 | self.model_miniunet = ModuleList(MiniUNet(
84 | n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations))
85 |
86 | def forward(self, x):
87 | logits = []
88 | x1, x2, logit = self.model_unet(x)
89 | logits.append(logit)
90 | for i in range(self.iterations):
91 | x = torch.cat([x1, x2], dim=1)
92 | _, x2, logit = self.model_miniunet[i](x)
93 | logits.append(logit)
94 |
95 | return logits
96 |
--------------------------------------------------------------------------------
/model/iternet/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 |
57 | def forward(self, x1, x2):
58 | x1 = self.up(x1)
59 | # input is CHW
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 | # if you have padding issues, see
66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
68 | x = torch.cat([x2, x1], dim=1)
69 | return self.conv(x)
70 |
71 |
72 | class OutConv(nn.Module):
73 | def __init__(self, in_channels, out_channels):
74 | super(OutConv, self).__init__()
75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.2.1
2 | numpy==2.1.0
3 | opencv-python==4.1.0.25
4 | pandas==2.2.2
5 | tqdm==4.42.1
6 | torch==2.4.0
7 | torchvision==0.19.0
8 |
--------------------------------------------------------------------------------
/trainer/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amri369/Pytorch-Iternet/d657c5d7c11dee8fe3e2301f4b8cff872273a77e/trainer/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import random
5 | import numpy as np
6 |
7 |
8 | class Trainer(object):
9 |
10 | def __init__(self, model, criteria, optimizer, scheduler, gpus, seed):
11 | self.model = model
12 | self.criteria = criteria
13 | self.optimizer = optimizer
14 | self.scheduler = scheduler
15 | self.gpus = gpus
16 | self.is_gpu_available = torch.cuda.is_available()
17 | Trainer.set_seed(seed)
18 |
19 | def set_devices(self):
20 | if self.is_gpu_available:
21 | os.environ["CUDA_VISIBLE_DEVICES"] = self.gpus
22 | self.model = self.model.cuda()
23 | self.model = torch.nn.DataParallel(self.model)
24 | self.criteria = self.criteria.cuda()
25 | else:
26 | self.model = self.model.cpu()
27 | self.criteria = self.criteria.cpu()
28 |
29 | def set_seed(seed):
30 | torch.manual_seed(seed)
31 |
32 | if torch.cuda.is_available():
33 | torch.cuda.manual_seed_all(seed)
34 |
35 | random.seed(seed)
36 |
37 | np.random.seed(seed)
38 |
39 | torch.backends.cudnn.deterministic = True
40 |
41 | def training_step(self, dataloader):
42 | # initialize the loss
43 | epoch_loss = 0.0
44 |
45 | # loop over training set
46 | self.model.train()
47 | for x, y in dataloader:
48 | if self.is_gpu_available:
49 | x, y = x.cuda(), y.cuda()
50 |
51 | with torch.set_grad_enabled(True):
52 | z = self.model(x)
53 | loss = self.criteria(z[0], y)
54 | _n = len(z)
55 | for b in range(1, _n):
56 | loss += self.criteria(z[b], y)
57 |
58 | # back propagation
59 | self.optimizer.zero_grad()
60 | loss.backward()
61 | nn.utils.clip_grad_value_(self.model.parameters(), 0.1)
62 | self.optimizer.step()
63 |
64 | epoch_loss += loss.item() * len(x)
65 |
66 | epoch_loss = epoch_loss / len(dataloader)
67 | return epoch_loss
68 |
69 | def validation_step(self, dataloader):
70 | # initialize the loss
71 | epoch_loss = 0.0
72 |
73 | # loop over validation set
74 | self.model.eval()
75 | for x, y in dataloader:
76 | if self.is_gpu_available:
77 | x, y = x.cuda(), y.cuda()
78 | with torch.set_grad_enabled(False):
79 | z = self.model(x)
80 | loss = self.criteria(z[0], y)
81 | _n = len(z)
82 | for b in range(1, _n):
83 | loss += self.criteria(z[b], y)
84 |
85 | epoch_loss += loss.item() * len(x)
86 |
87 | epoch_loss = epoch_loss / len(dataloader)
88 | return epoch_loss
89 |
90 | def save_checkpoint(self, epoch, model_dir):
91 | # create the state dictionary
92 | state = {
93 | 'epoch': epoch,
94 | 'state_dict': self.model.state_dict(),
95 | 'optimizer': self.optimizer.state_dict()
96 | }
97 |
98 | if not os.path.exists(model_dir):
99 | os.makedirs(model_dir)
100 | model_out_path = model_dir+"_epoch_{}.pth".format(epoch)
101 | torch.save(state, model_out_path)
102 |
103 | def __call__(self, dataloaders, epochs, model_dir):
104 | self.set_devices()
105 | for epoch in range(epochs):
106 | train_loss = self.training_step(dataloaders['train'])
107 | val_loss = self.validation_step(dataloaders['val'])
108 | self.save_checkpoint(epoch, model_dir)
109 | print('------', epoch+1, '/', epochs, train_loss, val_loss)
110 |
111 | if self.is_gpu_available:
112 | torch.cuda.empty_cache()
113 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from tqdm import tqdm
4 |
5 | from dice_loss import dice_coeff
6 |
7 |
8 | def eval_net(net, loader, device):
9 | """Evaluation without the densecrf with the dice coefficient"""
10 | net.eval()
11 | mask_type = torch.float32 if net.n_classes == 1 else torch.long
12 | n_val = len(loader) # the number of batch
13 | tot = 0
14 |
15 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
16 | for batch in loader:
17 | imgs, true_masks = batch['image'], batch['mask']
18 | imgs = imgs.to(device=device, dtype=torch.float32)
19 | true_masks = true_masks.to(device=device, dtype=mask_type)
20 |
21 | with torch.no_grad():
22 | mask_pred = net(imgs)
23 |
24 | if net.n_classes > 1:
25 | tot += F.cross_entropy(mask_pred, true_masks).item()
26 | else:
27 | pred = torch.sigmoid(mask_pred)
28 | pred = (pred > 0.5).float()
29 | tot += dice_coeff(pred, true_masks).item()
30 | pbar.update()
31 |
32 | net.train()
33 | return tot / n_val
--------------------------------------------------------------------------------